Skip to content
Snippets Groups Projects
Commit 7237d5cb authored by linh's avatar linh
Browse files

Merge branch 'genTUS_v2' of gitlab.cs.uni-duesseldorf.de:dsml/convlab/ConvLab3 into genTUS_v2

parents 3e55b7b6 8fc6764b
No related branches found
No related tags found
No related merge requests found
...@@ -245,14 +245,15 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -245,14 +245,15 @@ class UserActionPolicy(GenTUSUserActionPolicy):
emotion_list = [emotion] emotion_list = [emotion]
else: else:
emotion_list = self.emotion_list emotion_list = self.emotion_list
print(emotion_list)
for emotion in emotion_list: for emotion in emotion_list:
# start token # start token
print("emotion", emotion)
self.seq = torch.zeros(1, self.max_out_len, self.seq = torch.zeros(1, self.max_out_len,
device=self.device).long() device=self.device).long()
pos = self._update_seq([0], 0) pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos) pos = self._update_seq(self.token_map.get_id('start_json'), pos)
pos = self._update_seq(
self.token_map.get_id('start_emotion'), pos)
pos = self._update_seq(self.kg._get_token_id(emotion), pos) pos = self._update_seq(self.kg._get_token_id(emotion), pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
......
...@@ -92,7 +92,7 @@ class Evaluator: ...@@ -92,7 +92,7 @@ class Evaluator:
emotion_mode = "normal" emotion_mode = "normal"
in_file = json.load(open(f_eval)) in_file = json.load(open(f_eval))
for dialog in tqdm(in_file['dialog'][:2]): for dialog in tqdm(in_file['dialog']):
temp = {} temp = {}
inputs = dialog["in"] inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"]) labels = self.usr._parse_output(dialog["out"])
...@@ -230,7 +230,9 @@ def bleu(golden_utts, gen_utts): ...@@ -230,7 +230,9 @@ def bleu(golden_utts, gen_utts):
def SER(gen_utts, gen_acts): def SER(gen_utts, gen_acts):
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
gen_acts, gen_utts) gen_acts, gen_utts)
if total <= 0:
print("ERROR, total = 0")
return 1
return missing/total return missing/total
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment