Select Git revision
editEnvir.Rd
-
Claus Jonathan Fritzemeier authoredClaus Jonathan Fritzemeier authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
evaluate.py 9.90 KiB
# -*- coding: utf-8 -*-
import numpy as np
import torch
import random
from torch import multiprocessing as mp
from convlab2.dialog_agent.agent import PipelineAgent
from convlab2.dialog_agent.session import BiSession
from convlab2.dialog_agent.env import Environment
from convlab2.dst.rule.multiwoz import RuleDST
from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.policy.rlmodule import Memory, Transition
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from pprint import pprint
import json
import matplotlib.pyplot as plt
import sys
import logging
import os
import datetime
import argparse
def init_logging(log_dir_path, path_suffix=None):
if not os.path.exists(log_dir_path):
os.makedirs(log_dir_path)
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
if path_suffix:
log_file_path = os.path.join(log_dir_path, f"{current_time}_{path_suffix}.log")
else:
log_file_path = os.path.join(log_dir_path, "{}.log".format(current_time))
stderr_handler = logging.StreamHandler()
file_handler = logging.FileHandler(log_file_path)
format_str = "%(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s"
logging.basicConfig(level=logging.DEBUG, handlers=[stderr_handler, file_handler], format=format_str)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def sampler(pid, queue, evt, env, policy, batchsz):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.
sampled_num = 0
sampled_traj_num = 0
traj_len = 50
real_traj_len = 0
while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state
s = env.reset()
for t in range(traj_len):
# [s_dim] => [a_dim]
s_vec = torch.Tensor(policy.vector.state_vectorize(s))
a = policy.predict(s)
# interact with env
next_s, r, done = env.step(a)
# a flag indicates ending or not
mask = 0 if done else 1
# get reward compared to demostrations
next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
# save to queue
buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
# update per step
s = next_s
real_traj_len = t
if done:
break
# this is end of one trajectory
sampled_num += real_traj_len
sampled_traj_num += 1
# t indicates the valid trajectory length
# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue.put([pid, buff])
evt.wait()
def sample(env, policy, batchsz, process_num):
"""
Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return
:param env:
:param policy:
:param batchsz:
:param process_num:
:return: batch
"""
# batchsz will be splitted into each process,
# final batchsz maybe larger than batchsz parameters
process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
# buffer to save all data
queue = mp.Queue()
# start processes for pid in range(1, processnum)
# if processnum = 1, this part will be ignored.
# when save tensor in Queue, the process should keep alive till Queue.get(),
# please refer to : https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847
# however still some problem on CUDA tensors on multiprocessing queue,
# please refer to : https://discuss.pytorch.org/t/cuda-tensors-on-multiprocessing-queue/28626
# so just transform tensors into numpy, then put them into queue.
evt = mp.Event()
processes = []
for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz)
processes.append(mp.Process(target=sampler, args=process_args))
for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped.
p.daemon = True
p.start()
# we need to get the first Memory object and then merge others Memory use its append function.
pid0, buff0 = queue.get()
for _ in range(1, process_num):
pid, buff_ = queue.get()
buff0.append(buff_) # merge current Memory into buff0
evt.set()
# now buff saves all the sampled data
buff = buff0
return buff.get_batch()
def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
seed = 20190827
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if dataset_name == 'MultiWOZ':
dst_sys = RuleDST()
if model_name == "PPO":
from convlab2.policy.ppo import PPO
if load_path:
policy_sys = PPO(False)
policy_sys.load(load_path)
else:
policy_sys = PPO.from_pretrained()
elif model_name == "PG":
from convlab2.policy.pg import PG
if load_path:
policy_sys = PG(False)
policy_sys.load(load_path)
else:
policy_sys = PG.from_pretrained()
elif model_name == "MLE":
from convlab2.policy.mle.multiwoz import MLE
if load_path:
policy_sys = MLE()
policy_sys.load(load_path)
else:
policy_sys = MLE.from_pretrained()
elif model_name == "GDPL":
from convlab2.policy.gdpl import GDPL
if load_path:
policy_sys = GDPL(False)
policy_sys.load(load_path)
else:
policy_sys = GDPL.from_pretrained()
dst_usr = None
policy_usr = RulePolicy(character='usr')
simulator = PipelineAgent(None, None, policy_usr, None, 'user')
env = Environment(None, simulator, None, dst_sys)
agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')
evaluator = MultiWozEvaluator()
sess = BiSession(agent_sys, simulator, None, evaluator)
task_success = {'All': []}
for seed in range(100):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
sess.init_session()
sys_response = []
logging.info('-'*50)
logging.info(f'seed {seed}')
for i in range(40):
sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
if session_over is True:
task_succ = sess.evaluator.task_success()
logging.info(f'task success: {task_succ}')
logging.info(f'book rate: {sess.evaluator.book_rate()}')
logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
logging.info(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}")
logging.info('-'*50)
break
else:
task_succ = 0
for key in sess.evaluator.goal:
if key not in task_success:
task_success[key] = []
else:
task_success[key].append(task_succ)
task_success['All'].append(task_succ)
for key in task_success:
logging.info(f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}')
if calculate_reward:
reward_tot = []
for seed in range(100):
s = env.reset()
reward = []
value = []
mask = []
for t in range(40):
s_vec = torch.Tensor(policy_sys.vector.state_vectorize(s))
a = policy_sys.predict(s)
# interact with env
next_s, r, done = env.step(a)
logging.info(r)
reward.append(r)
if done: # one due to counting from 0, the one for the last turn
break
logging.info(f'{seed} reward: {np.mean(reward)}')
reward_tot.append(np.mean(reward))
logging.info(f'total avg reward: {np.mean(reward_tot)}')
else:
raise Exception("currently supported dataset: MultiWOZ")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="MultiWOZ", help="name of dataset")
parser.add_argument("--model_name", type=str, default="PPO", help="name of model")
parser.add_argument("--load_path", type=str, default='', help="path of model")
parser.add_argument("--log_path_suffix", type=str, default="", help="suffix of path of log file")
parser.add_argument("--log_dir_path", type=str, default="log", help="path of log directory")
args = parser.parse_args()
init_logging(log_dir_path=args.log_dir_path, path_suffix=args.log_path_suffix)
evaluate(
dataset_name=args.dataset_name,
model_name=args.model_name,
load_path=args.load_path,
calculate_reward=True
)