diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 69878e9e36ee4a76221e3c0a2f3a205895215173..b7456b36b5276815056c39670fc93d50ef4473d7 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -106,7 +106,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): print("-"*20) return action - def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True): + def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="max"): self.kg.parse_input(raw_inputs) model_input = self.vector.encode(raw_inputs, self.max_in_len) # start token @@ -114,7 +114,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): pos = self._update_seq([0], 0) pos = self._update_seq(self.token_map.get_id('start_json'), pos) emotion = self._get_emotion( - model_input, self.seq[:1, :pos], mode, allow_general_intent) + model_input, self.seq[:1, :pos], mode, emotion_mode) pos = self._update_seq(emotion["token_id"], pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos) pos = self._update_seq(self.token_map.get_id('start_act'), pos) @@ -214,10 +214,10 @@ class UserActionPolicy(GenTUSUserActionPolicy): raw_output = self._get_text(model_input, pos) return self._parse_output(raw_output)["text"] - def _get_emotion(self, model_input, generated_so_far, mode="max", allow_general_intent=True): + def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"): next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) - return self.kg.get_emotion(next_token_logits, mode, allow_general_intent) + return self.kg.get_emotion(next_token_logits, mode, emotion_mode) def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True): next_token_logits = self.model.get_next_token_logits( diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index b7fd63a4ea4fa7dd7f03e5aa1f0e262a8699cca7..2a9032069f87a533a5dd53033c6cc2e6e744a6a4 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -35,6 +35,8 @@ def arg_parser(): help="do nlg generation") parser.add_argument("--do-golden-nlg", action="store_true", help="do golden nlg generation") + parser.add_argument("--no-neutral", action="store_true", + help="skip neutral emotion") return parser.parse_args() @@ -53,7 +55,10 @@ class Evaluator: model_checkpoint, only_action=only_action, dataset=self.dataset) self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) - def generate_results(self, f_eval, golden=False): + def generate_results(self, f_eval, golden=False, no_neutral=False): + emotion_mode = "max" + if no_neutral: + emotion_mode = "no_neutral" in_file = json.load(open(f_eval)) r = { "input": [], @@ -67,14 +72,20 @@ class Evaluator: for dialog in tqdm(in_file['dialog']): inputs = dialog["in"] labels = self.usr._parse_output(dialog["out"]) + if no_neutral: + if labels["emotion"].lower() == "neutral": + print("skip") + continue + print("do", labels["emotion"]) if golden: usr_act = labels["action"] usr_utt = self.usr.generate_text_from_give_semantic( inputs, labels["action"], labels["emotion"]) else: + output = self.usr._parse_output( - self.usr._generate_action(inputs)) + self.usr._generate_action(inputs, emotion_mode=emotion_mode)) usr_emo = output["emotion"] usr_act = self.usr._remove_illegal_action(output["action"]) usr_utt = output["text"] @@ -106,10 +117,10 @@ class Evaluator: return r - def nlg_evaluation(self, input_file=None, generated_file=None, golden=False): + def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False): if input_file: print("Force generation") - gen_r = self.generate_results(input_file, golden) + gen_r = self.generate_results(input_file, golden, no_neutral) elif generated_file: gen_r = self.read_generated_result(generated_file) @@ -284,7 +295,8 @@ def main(): else: nlg_result = eval.nlg_evaluation(input_file=args.input_file, generated_file=args.generated_file, - golden=args.do_golden_nlg) + golden=args.do_golden_nlg, + no_neutral=args.no_neutral) generated_file = nlg_result eval.evaluation(args.input_file, diff --git a/convlab/policy/emoTUS/unify/knowledge_graph.py b/convlab/policy/emoTUS/unify/knowledge_graph.py index 3f99c84cbc24509f810c7b5b605a95b2512dcb84..b6b53c4354a50b87e48b5ea656a6cb578a2e2f13 100644 --- a/convlab/policy/emoTUS/unify/knowledge_graph.py +++ b/convlab/policy/emoTUS/unify/knowledge_graph.py @@ -32,10 +32,16 @@ class KnowledgeGraph(GenTUSKnowledgeGraph): for emotion in self.emotion: self.kg_map["emotion"].add_token(emotion, emotion) - def get_emotion(self, outputs, mode="max", allow_general_intent=True): - canidate_list = self.emotion - score = self._get_max_score( - outputs, canidate_list, "emotion", weight=self.prior) + def get_emotion(self, outputs, mode="max", emotion_mode="normal"): + + if emotion_mode == "normal": + score = self._get_max_score( + outputs, self.emotion, "emotion", weight=self.prior) + elif emotion_mode == "no_neutral": + score = self._get_max_score( + outputs, self.emotion[1:], "emotion", weight=self.prior) + else: + print(f"unknown emotion mode: {emotion_mode}") s = self._select(score, mode) return score[s]