diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index dadfc923e48fdcefbb18c210f250d81f52e7022a..fca0f3a9dc3e44283ebcf959d74bf7eeaeeec61b 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -245,14 +245,15 @@ class UserActionPolicy(GenTUSUserActionPolicy): emotion_list = [emotion] else: emotion_list = self.emotion_list - print(emotion_list) + for emotion in emotion_list: # start token - print("emotion", emotion) self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() pos = self._update_seq([0], 0) pos = self._update_seq(self.token_map.get_id('start_json'), pos) + pos = self._update_seq( + self.token_map.get_id('start_emotion'), pos) pos = self._update_seq(self.kg._get_token_id(emotion), pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos) diff --git a/convlab/policy/emoTUS/emotion_eval.py b/convlab/policy/emoTUS/emotion_eval.py index c46f7eca18f28aa6698c88454f41bfcadb38f743..8a206ab1d69eef91ba214781a9b794598cc82ae5 100644 --- a/convlab/policy/emoTUS/emotion_eval.py +++ b/convlab/policy/emoTUS/emotion_eval.py @@ -92,7 +92,7 @@ class Evaluator: emotion_mode = "normal" in_file = json.load(open(f_eval)) - for dialog in tqdm(in_file['dialog'][:2]): + for dialog in tqdm(in_file['dialog']): temp = {} inputs = dialog["in"] labels = self.usr._parse_output(dialog["out"]) @@ -230,7 +230,9 @@ def bleu(golden_utts, gen_utts): def SER(gen_utts, gen_acts): missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( gen_acts, gen_utts) - + if total <= 0: + print("ERROR, total = 0") + return 1 return missing/total