diff --git a/convlab/policy/emoTUS/dialogue_collector.py b/convlab/policy/emoTUS/dialogue_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..07707377fb7c2765f53fe8b50cbbe00153e9afc2 --- /dev/null +++ b/convlab/policy/emoTUS/dialogue_collector.py @@ -0,0 +1,98 @@ +from argparse import ArgumentParser + +from tqdm import tqdm + +from convlab.policy.rule.multiwoz import RulePolicy +from convlab.task.multiwoz.goal_generator import GoalGenerator +from convlab.util.custom_util import (create_goals, data_goals, env_config, + get_config, set_seed) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--config", type=str, help="the model path") + parser.add_argument("-N", "--num", type=int, + default=500, help="# of evaluation dialogue") + parser.add_argument("--model", type=str, + default="rule", help="# of evaluation dialogue") + return parser.parse_args() + + +def interact(model_name, config, seed=0, num_goals=500): + conversation = [] + set_seed(seed) + conf = get_config(config, []) + + if model_name == "rule": + policy_sys = RulePolicy() + elif model_name == "PPO": + from convlab.policy.ppo import PPO + policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated']) + + model_path = conf['model']['load_path'] + if model_path: + policy_sys.load(model_path) + + env, sess = env_config(conf, policy_sys) + goal_generator = GoalGenerator() + + goals = create_goals(goal_generator, num_goals=num_goals, + single_domains=False, allowed_domains=None) + + for seed in tqdm(range(1000, 1000 + num_goals)): + dialogue = {"seed": seed, "log": []} + set_seed(seed) + sess.init_session(goal=goals[seed-1000]) + sys_response = [] + actions = 0.0 + total_return = 0.0 + turns = 0 + task_succ = 0 + task_succ_strict = 0 + complete = 0 + dialogue["goal"] = env.usr.policy.policy.goal.domain_goals + dialogue["user info"] = env.usr.policy.policy.user_info + + for i in range(40): + sys_response, user_response, session_over, reward = sess.next_turn( + sys_response) + dialogue["log"].append( + {"role": "usr", + "utt": user_response, + "emotion": env.usr.policy.policy.emotion, + "act": env.usr.policy.policy.semantic_action}) + dialogue["log"].append({"role": "sys", "utt": sys_response}) + + # logging.info(f"Actions in turn: {len(sys_response)}") + turns += 1 + total_return += sess.evaluator.get_reward(session_over) + + if session_over: + task_succ = sess.evaluator.task_success() + task_succ = sess.evaluator.success + task_succ_strict = sess.evaluator.success_strict + complete = sess.evaluator.complete + break + + dialogue['Complete'] = complete + dialogue['Success'] = task_succ + dialogue['Success strict'] = task_succ_strict + dialogue['total_return'] = total_return + dialogue['turns'] = turns + + conversation.append(dialogue) + return conversation + + +if __name__ == "__main__": + import json + from datetime import datetime + import os + time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + args = arg_parser() + conversation = interact(model_name=args.model, + config=args.config, num_goals=args.num) + json.dump(conversation, + open(os.path.join("convlab/policy/emoTUS", + f"conversation-{time}.json"), 'w'), + indent=2) diff --git a/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json b/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json new file mode 100644 index 0000000000000000000000000000000000000000..458b0628f0d387600696511f011ef540f0388380 --- /dev/null +++ b/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json @@ -0,0 +1,55 @@ +{ + "model": { + "load_path": "convlab/policy/ppo/pretrained_models/mle", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 200, + "seed": 0, + "epoch": 100, + "eval_frequency": 5, + "process_num": 1, + "num_eval_dialogues": 20, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "uncertainty_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary.VectorBinary", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": { + "BertNLU": { + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", + "ini_params": { + "mode": "all", + "config_file": "multiwoz21_all.json", + "model_file": "https://huggingface.co/ConvLab/bert-base-nlu/resolve/main/bertnlu_unified_multiwoz21_all_context0.zip" + } + } + }, + "dst_sys": { + "RuleDST": { + "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "ini_params": {} + } + }, + "sys_nlg": {}, + "nlu_usr": {}, + "dst_usr": {}, + "policy_usr": { + "emoTUS": { + "class_path": "convlab.policy.emoTUS.emoTUS.UserPolicy", + "ini_params": { + "model_checkpoint": "convlab/policy/emoTUS/unify/experiments/EmoUS_emowoz+dialmage_0_1", + "use_sentiment": false, + "add_persona": true, + "sample": false + } + } + }, + "usr_nlg": {} +} \ No newline at end of file diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index fca0f3a9dc3e44283ebcf959d74bf7eeaeeec61b..2fb0790f19b07c50c97b3aa652db0615ddc8dd60 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -91,7 +91,6 @@ class UserActionPolicy(GenTUSUserActionPolicy): raw_output = self._generate_action( raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) output = self._parse_output(raw_output) - print(output) self.semantic_action = self._remove_illegal_action(output["action"]) self.utterance = output["text"] self.emotion = output["emotion"] diff --git a/convlab/policy/emoTUS/emotion_eval.py b/convlab/policy/emoTUS/emotion_eval.py index 8a206ab1d69eef91ba214781a9b794598cc82ae5..da1479f14dc955367a5915cfc586d02b7b1d4cea 100644 --- a/convlab/policy/emoTUS/emotion_eval.py +++ b/convlab/policy/emoTUS/emotion_eval.py @@ -184,23 +184,34 @@ class Evaluator: scores = {} for emotion in self.emotion_list: + # if emotion == "Neutral": + # continue scores[emotion] = {"precision": [], "recall": [], "f1": [], "turn_acc": []} - for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["golden_acts"]): + for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["Neutral_acts"]): s = f1_measure(preds=gen_act, labels=golden_act) for metric in scores[emotion]: scores[emotion][metric].append(s[metric]) result = {} for emotion in self.emotion_list: + # if emotion == "Neutral": + # continue result[emotion] = {} - result[emotion]["bleu"] = bleu(golden_utts=r["golden_utts"], - gen_utts=r[f"{emotion}_utts"]) - result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"], - gen_acts=r[f"{emotion}_acts"]) for metric in scores[emotion]: result[emotion][metric] = sum( scores[emotion][metric])/len(scores[emotion][metric]) + result[emotion]["bleu"] = bleu(golden_utts=r["Neutral_utts"], + gen_utts=r[f"{emotion}_utts"]) + result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"], + gen_acts=r[f"{emotion}_acts"]) + + result[emotion]["len"] = avg_len(gen_utts=r[f"{emotion}_utts"]) + + rouge_score = rouge(golden_utts=r["Neutral_utts"], + gen_utts=r[f"{emotion}_utts"]) + for metric, score in rouge_score.items(): + result[emotion][metric] = score.mid.fmeasure print("emotion:", emotion) for metric in result[emotion]: @@ -217,6 +228,11 @@ class Evaluator: self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'), indent=2) +def avg_len(gen_utts): + n = [len(s.split()) for s in gen_utts] + return sum(n)/len(n) + + def bleu(golden_utts, gen_utts): bleu_metric = load_metric("sacrebleu") labels = [[utt] for utt in golden_utts] @@ -227,6 +243,13 @@ def bleu(golden_utts, gen_utts): return bleu_score["score"] +def rouge(golden_utts, gen_utts): + rouge_metric = load_metric("rouge") + rouge_score = rouge_metric.compute(predictions=gen_utts, + references=golden_utts) + return rouge_score + + def SER(gen_utts, gen_acts): missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( gen_acts, gen_utts)