diff --git a/convlab/policy/emoTUS/dialogue_collector.py b/convlab/policy/emoTUS/dialogue_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..07707377fb7c2765f53fe8b50cbbe00153e9afc2 --- /dev/null +++ b/convlab/policy/emoTUS/dialogue_collector.py @@ -0,0 +1,98 @@ +from argparse import ArgumentParser + +from tqdm import tqdm + +from convlab.policy.rule.multiwoz import RulePolicy +from convlab.task.multiwoz.goal_generator import GoalGenerator +from convlab.util.custom_util import (create_goals, data_goals, env_config, + get_config, set_seed) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--config", type=str, help="the model path") + parser.add_argument("-N", "--num", type=int, + default=500, help="# of evaluation dialogue") + parser.add_argument("--model", type=str, + default="rule", help="# of evaluation dialogue") + return parser.parse_args() + + +def interact(model_name, config, seed=0, num_goals=500): + conversation = [] + set_seed(seed) + conf = get_config(config, []) + + if model_name == "rule": + policy_sys = RulePolicy() + elif model_name == "PPO": + from convlab.policy.ppo import PPO + policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated']) + + model_path = conf['model']['load_path'] + if model_path: + policy_sys.load(model_path) + + env, sess = env_config(conf, policy_sys) + goal_generator = GoalGenerator() + + goals = create_goals(goal_generator, num_goals=num_goals, + single_domains=False, allowed_domains=None) + + for seed in tqdm(range(1000, 1000 + num_goals)): + dialogue = {"seed": seed, "log": []} + set_seed(seed) + sess.init_session(goal=goals[seed-1000]) + sys_response = [] + actions = 0.0 + total_return = 0.0 + turns = 0 + task_succ = 0 + task_succ_strict = 0 + complete = 0 + dialogue["goal"] = env.usr.policy.policy.goal.domain_goals + dialogue["user info"] = env.usr.policy.policy.user_info + + for i in range(40): + sys_response, user_response, session_over, reward = sess.next_turn( + sys_response) + dialogue["log"].append( + {"role": "usr", + "utt": user_response, + "emotion": env.usr.policy.policy.emotion, + "act": env.usr.policy.policy.semantic_action}) + dialogue["log"].append({"role": "sys", "utt": sys_response}) + + # logging.info(f"Actions in turn: {len(sys_response)}") + turns += 1 + total_return += sess.evaluator.get_reward(session_over) + + if session_over: + task_succ = sess.evaluator.task_success() + task_succ = sess.evaluator.success + task_succ_strict = sess.evaluator.success_strict + complete = sess.evaluator.complete + break + + dialogue['Complete'] = complete + dialogue['Success'] = task_succ + dialogue['Success strict'] = task_succ_strict + dialogue['total_return'] = total_return + dialogue['turns'] = turns + + conversation.append(dialogue) + return conversation + + +if __name__ == "__main__": + import json + from datetime import datetime + import os + time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + args = arg_parser() + conversation = interact(model_name=args.model, + config=args.config, num_goals=args.num) + json.dump(conversation, + open(os.path.join("convlab/policy/emoTUS", + f"conversation-{time}.json"), 'w'), + indent=2)