Skip to content
Snippets Groups Projects
Commit 2d70d95b authored by linh's avatar linh
Browse files

Merge branch 'genTUS_v2' of gitlab.cs.uni-duesseldorf.de:dsml/convlab/ConvLab3 into genTUS_v2

parents 7237d5cb 14f05d86
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser
from tqdm import tqdm
from convlab.policy.rule.multiwoz import RulePolicy
from convlab.task.multiwoz.goal_generator import GoalGenerator
from convlab.util.custom_util import (create_goals, data_goals, env_config,
get_config, set_seed)
def arg_parser():
parser = ArgumentParser()
parser.add_argument("--config", type=str, help="the model path")
parser.add_argument("-N", "--num", type=int,
default=500, help="# of evaluation dialogue")
parser.add_argument("--model", type=str,
default="rule", help="# of evaluation dialogue")
return parser.parse_args()
def interact(model_name, config, seed=0, num_goals=500):
conversation = []
set_seed(seed)
conf = get_config(config, [])
if model_name == "rule":
policy_sys = RulePolicy()
elif model_name == "PPO":
from convlab.policy.ppo import PPO
policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated'])
model_path = conf['model']['load_path']
if model_path:
policy_sys.load(model_path)
env, sess = env_config(conf, policy_sys)
goal_generator = GoalGenerator()
goals = create_goals(goal_generator, num_goals=num_goals,
single_domains=False, allowed_domains=None)
for seed in tqdm(range(1000, 1000 + num_goals)):
dialogue = {"seed": seed, "log": []}
set_seed(seed)
sess.init_session(goal=goals[seed-1000])
sys_response = []
actions = 0.0
total_return = 0.0
turns = 0
task_succ = 0
task_succ_strict = 0
complete = 0
dialogue["goal"] = env.usr.policy.policy.goal.domain_goals
dialogue["user info"] = env.usr.policy.policy.user_info
for i in range(40):
sys_response, user_response, session_over, reward = sess.next_turn(
sys_response)
dialogue["log"].append(
{"role": "usr",
"utt": user_response,
"emotion": env.usr.policy.policy.emotion,
"act": env.usr.policy.policy.semantic_action})
dialogue["log"].append({"role": "sys", "utt": sys_response})
# logging.info(f"Actions in turn: {len(sys_response)}")
turns += 1
total_return += sess.evaluator.get_reward(session_over)
if session_over:
task_succ = sess.evaluator.task_success()
task_succ = sess.evaluator.success
task_succ_strict = sess.evaluator.success_strict
complete = sess.evaluator.complete
break
dialogue['Complete'] = complete
dialogue['Success'] = task_succ
dialogue['Success strict'] = task_succ_strict
dialogue['total_return'] = total_return
dialogue['turns'] = turns
conversation.append(dialogue)
return conversation
if __name__ == "__main__":
import json
from datetime import datetime
import os
time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
args = arg_parser()
conversation = interact(model_name=args.model,
config=args.config, num_goals=args.num)
json.dump(conversation,
open(os.path.join("convlab/policy/emoTUS",
f"conversation-{time}.json"), 'w'),
indent=2)
{
"model": {
"load_path": "convlab/policy/ppo/pretrained_models/mle",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 200,
"seed": 0,
"epoch": 100,
"eval_frequency": 5,
"process_num": 1,
"num_eval_dialogues": 20,
"sys_semantic_to_usr": false
},
"vectorizer_sys": {
"uncertainty_vector_mul": {
"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
"ini_params": {
"use_masking": true,
"manually_add_entity_names": true,
"seed": 0
}
}
},
"nlu_sys": {
"BertNLU": {
"class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU",
"ini_params": {
"mode": "all",
"config_file": "multiwoz21_all.json",
"model_file": "https://huggingface.co/ConvLab/bert-base-nlu/resolve/main/bertnlu_unified_multiwoz21_all_context0.zip"
}
}
},
"dst_sys": {
"RuleDST": {
"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
"ini_params": {}
}
},
"sys_nlg": {},
"nlu_usr": {},
"dst_usr": {},
"policy_usr": {
"emoTUS": {
"class_path": "convlab.policy.emoTUS.emoTUS.UserPolicy",
"ini_params": {
"model_checkpoint": "convlab/policy/emoTUS/unify/experiments/EmoUS_emowoz+dialmage_0_1",
"use_sentiment": false,
"add_persona": true,
"sample": false
}
}
},
"usr_nlg": {}
}
\ No newline at end of file
...@@ -91,7 +91,6 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -91,7 +91,6 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_output = self._generate_action( raw_output = self._generate_action(
raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
output = self._parse_output(raw_output) output = self._parse_output(raw_output)
print(output)
self.semantic_action = self._remove_illegal_action(output["action"]) self.semantic_action = self._remove_illegal_action(output["action"])
self.utterance = output["text"] self.utterance = output["text"]
self.emotion = output["emotion"] self.emotion = output["emotion"]
......
...@@ -184,23 +184,34 @@ class Evaluator: ...@@ -184,23 +184,34 @@ class Evaluator:
scores = {} scores = {}
for emotion in self.emotion_list: for emotion in self.emotion_list:
# if emotion == "Neutral":
# continue
scores[emotion] = {"precision": [], scores[emotion] = {"precision": [],
"recall": [], "f1": [], "turn_acc": []} "recall": [], "f1": [], "turn_acc": []}
for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["golden_acts"]): for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["Neutral_acts"]):
s = f1_measure(preds=gen_act, labels=golden_act) s = f1_measure(preds=gen_act, labels=golden_act)
for metric in scores[emotion]: for metric in scores[emotion]:
scores[emotion][metric].append(s[metric]) scores[emotion][metric].append(s[metric])
result = {} result = {}
for emotion in self.emotion_list: for emotion in self.emotion_list:
# if emotion == "Neutral":
# continue
result[emotion] = {} result[emotion] = {}
result[emotion]["bleu"] = bleu(golden_utts=r["golden_utts"],
gen_utts=r[f"{emotion}_utts"])
result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"],
gen_acts=r[f"{emotion}_acts"])
for metric in scores[emotion]: for metric in scores[emotion]:
result[emotion][metric] = sum( result[emotion][metric] = sum(
scores[emotion][metric])/len(scores[emotion][metric]) scores[emotion][metric])/len(scores[emotion][metric])
result[emotion]["bleu"] = bleu(golden_utts=r["Neutral_utts"],
gen_utts=r[f"{emotion}_utts"])
result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"],
gen_acts=r[f"{emotion}_acts"])
result[emotion]["len"] = avg_len(gen_utts=r[f"{emotion}_utts"])
rouge_score = rouge(golden_utts=r["Neutral_utts"],
gen_utts=r[f"{emotion}_utts"])
for metric, score in rouge_score.items():
result[emotion][metric] = score.mid.fmeasure
print("emotion:", emotion) print("emotion:", emotion)
for metric in result[emotion]: for metric in result[emotion]:
...@@ -217,6 +228,11 @@ class Evaluator: ...@@ -217,6 +228,11 @@ class Evaluator:
self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'), indent=2) self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'), indent=2)
def avg_len(gen_utts):
n = [len(s.split()) for s in gen_utts]
return sum(n)/len(n)
def bleu(golden_utts, gen_utts): def bleu(golden_utts, gen_utts):
bleu_metric = load_metric("sacrebleu") bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in golden_utts] labels = [[utt] for utt in golden_utts]
...@@ -227,6 +243,13 @@ def bleu(golden_utts, gen_utts): ...@@ -227,6 +243,13 @@ def bleu(golden_utts, gen_utts):
return bleu_score["score"] return bleu_score["score"]
def rouge(golden_utts, gen_utts):
rouge_metric = load_metric("rouge")
rouge_score = rouge_metric.compute(predictions=gen_utts,
references=golden_utts)
return rouge_score
def SER(gen_utts, gen_acts): def SER(gen_utts, gen_acts):
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
gen_acts, gen_utts) gen_acts, gen_utts)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment