Skip to content
Snippets Groups Projects
Select Git revision
  • c180c88d6f2f60b5e5d991c4e11b3a62d083de63
  • master default protected
  • dev
  • sybilNLO
  • gprBug
  • maximumtotalflux
  • easyConstraint
  • switchbug
  • thuong
  • momafix
  • rmReactBug
11 results

promptSysBiolAlg.R

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    evaluate.py 9.90 KiB
    # -*- coding: utf-8 -*-
    import numpy as np
    import torch
    import random
    from torch import multiprocessing as mp
    from convlab2.dialog_agent.agent import PipelineAgent
    from convlab2.dialog_agent.session import BiSession
    from convlab2.dialog_agent.env import Environment
    from convlab2.dst.rule.multiwoz import RuleDST
    from convlab2.policy.rule.multiwoz import RulePolicy
    from convlab2.policy.rlmodule import Memory, Transition
    from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
    from pprint import pprint
    import json
    import matplotlib.pyplot as plt
    import sys
    import logging
    import os
    import datetime
    import argparse
    
    def init_logging(log_dir_path, path_suffix=None):
        if not os.path.exists(log_dir_path):
            os.makedirs(log_dir_path)
        current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
        if path_suffix:
            log_file_path = os.path.join(log_dir_path, f"{current_time}_{path_suffix}.log")
        else:
            log_file_path = os.path.join(log_dir_path, "{}.log".format(current_time))
    
        stderr_handler = logging.StreamHandler()
        file_handler = logging.FileHandler(log_file_path)
        format_str = "%(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s"
        logging.basicConfig(level=logging.DEBUG, handlers=[stderr_handler, file_handler], format=format_str)
    
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    
    def sampler(pid, queue, evt, env, policy, batchsz):
        """
        This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
        processes.
        :param pid: process id
        :param queue: multiprocessing.Queue, to collect sampled data
        :param evt: multiprocessing.Event, to keep the process alive
        :param env: environment instance
        :param policy: policy network, to generate action from current policy
        :param batchsz: total sampled items
        :return:
        """
        buff = Memory()
    
        # we need to sample batchsz of (state, action, next_state, reward, mask)
        # each trajectory contains `trajectory_len` num of items, so we only need to sample
        # `batchsz//trajectory_len` num of trajectory totally
        # the final sampled number may be larger than batchsz.
    
        sampled_num = 0
        sampled_traj_num = 0
        traj_len = 50
        real_traj_len = 0
    
        while sampled_num < batchsz:
            # for each trajectory, we reset the env and get initial state
            s = env.reset()
    
            for t in range(traj_len):
    
                # [s_dim] => [a_dim]
                s_vec = torch.Tensor(policy.vector.state_vectorize(s))
                a = policy.predict(s)
    
                # interact with env
                next_s, r, done = env.step(a)
    
                # a flag indicates ending or not
                mask = 0 if done else 1
    
                # get reward compared to demostrations
                next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
    
                # save to queue
                buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
    
                # update per step
                s = next_s
                real_traj_len = t
    
                if done:
                    break
    
            # this is end of one trajectory
            sampled_num += real_traj_len
            sampled_traj_num += 1
            # t indicates the valid trajectory length
    
        # this is end of sampling all batchsz of items.
        # when sampling is over, push all buff data into queue
        queue.put([pid, buff])
        evt.wait()
    
    
    def sample(env, policy, batchsz, process_num):
        """
        Given batchsz number of task, the batchsz will be splited equally to each processes
        and when processes return, it merge all data and return
    	:param env:
    	:param policy:
        :param batchsz:
    	:param process_num:
        :return: batch
        """
    
        # batchsz will be splitted into each process,
        # final batchsz maybe larger than batchsz parameters
        process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
        # buffer to save all data
        queue = mp.Queue()
    
        # start processes for pid in range(1, processnum)
        # if processnum = 1, this part will be ignored.
        # when save tensor in Queue, the process should keep alive till Queue.get(),
        # please refer to : https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847
        # however still some problem on CUDA tensors on multiprocessing queue,
        # please refer to : https://discuss.pytorch.org/t/cuda-tensors-on-multiprocessing-queue/28626
        # so just transform tensors into numpy, then put them into queue.
        evt = mp.Event()
        processes = []
        for i in range(process_num):
            process_args = (i, queue, evt, env, policy, process_batchsz)
            processes.append(mp.Process(target=sampler, args=process_args))
        for p in processes:
            # set the process as daemon, and it will be killed once the main process is stoped.
            p.daemon = True
            p.start()
    
        # we need to get the first Memory object and then merge others Memory use its append function.
        pid0, buff0 = queue.get()
        for _ in range(1, process_num):
            pid, buff_ = queue.get()
            buff0.append(buff_)  # merge current Memory into buff0
        evt.set()
    
        # now buff saves all the sampled data
        buff = buff0
    
        return buff.get_batch()
    
    def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
        seed = 20190827
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
    
        if dataset_name == 'MultiWOZ':
            dst_sys = RuleDST()
            
            if model_name == "PPO":
                from convlab2.policy.ppo import PPO
                if load_path:
                    policy_sys = PPO(False)
                    policy_sys.load(load_path)
                else:
                    policy_sys = PPO.from_pretrained()
            elif model_name == "PG":
                from convlab2.policy.pg import PG
                if load_path:
                    policy_sys = PG(False)
                    policy_sys.load(load_path)
                else:
                    policy_sys = PG.from_pretrained()
            elif model_name == "MLE":
                from convlab2.policy.mle.multiwoz import MLE
                if load_path:
                    policy_sys = MLE()
                    policy_sys.load(load_path)
                else:
                    policy_sys = MLE.from_pretrained()
            elif model_name == "GDPL":
                from convlab2.policy.gdpl import GDPL
                if load_path:
                    policy_sys = GDPL(False)
                    policy_sys.load(load_path)
                else:
                    policy_sys = GDPL.from_pretrained()
                
            dst_usr = None
    
            policy_usr = RulePolicy(character='usr')
            simulator = PipelineAgent(None, None, policy_usr, None, 'user')
    
            env = Environment(None, simulator, None, dst_sys)
    
            agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')
    
            evaluator = MultiWozEvaluator()
            sess = BiSession(agent_sys, simulator, None, evaluator)
    
            task_success = {'All': []}
            for seed in range(100):
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)
                sess.init_session()
                sys_response = []
                logging.info('-'*50)
                logging.info(f'seed {seed}')
                for i in range(40):
                    sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
                    if session_over is True:
                        task_succ = sess.evaluator.task_success()
                        logging.info(f'task success: {task_succ}')
                        logging.info(f'book rate: {sess.evaluator.book_rate()}')
                        logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
                        logging.info(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}")
                        logging.info('-'*50)
                        break
                else: 
                    task_succ = 0
        
                for key in sess.evaluator.goal: 
                    if key not in task_success: 
                        task_success[key] = []
                    else: 
                        task_success[key].append(task_succ)
                task_success['All'].append(task_succ)
            
            for key in task_success: 
                logging.info(f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}')
    
            if calculate_reward:
                reward_tot = []
                for seed in range(100):
                    s = env.reset()
                    reward = []
                    value = []
                    mask = []
                    for t in range(40):
                        s_vec = torch.Tensor(policy_sys.vector.state_vectorize(s))
                        a = policy_sys.predict(s)
    
                        # interact with env
                        next_s, r, done = env.step(a)
                        logging.info(r)
                        reward.append(r)
                        if done: # one due to counting from 0, the one for the last turn
                            break
                    logging.info(f'{seed} reward: {np.mean(reward)}')
                    reward_tot.append(np.mean(reward))
                logging.info(f'total avg reward: {np.mean(reward_tot)}')
        else:
            raise Exception("currently supported dataset: MultiWOZ")
        
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--dataset_name", type=str, default="MultiWOZ", help="name of dataset")
        parser.add_argument("--model_name", type=str, default="PPO", help="name of model")
        parser.add_argument("--load_path", type=str, default='', help="path of model")
        parser.add_argument("--log_path_suffix", type=str, default="", help="suffix of path of log file")
        parser.add_argument("--log_dir_path", type=str, default="log", help="path of log directory")
        args = parser.parse_args()
    
        init_logging(log_dir_path=args.log_dir_path, path_suffix=args.log_path_suffix)
        evaluate(
            dataset_name=args.dataset_name,
            model_name=args.model_name,
            load_path=args.load_path,
            calculate_reward=True
        )