Skip to content
Snippets Groups Projects
Commit 8fc6764b authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent ce5a0f9a
Branches
No related tags found
No related merge requests found
......@@ -245,14 +245,15 @@ class UserActionPolicy(GenTUSUserActionPolicy):
emotion_list = [emotion]
else:
emotion_list = self.emotion_list
print(emotion_list)
for emotion in emotion_list:
# start token
print("emotion", emotion)
self.seq = torch.zeros(1, self.max_out_len,
device=self.device).long()
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
pos = self._update_seq(
self.token_map.get_id('start_emotion'), pos)
pos = self._update_seq(self.kg._get_token_id(emotion), pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
......
......@@ -92,7 +92,7 @@ class Evaluator:
emotion_mode = "normal"
in_file = json.load(open(f_eval))
for dialog in tqdm(in_file['dialog'][:2]):
for dialog in tqdm(in_file['dialog']):
temp = {}
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
......@@ -230,7 +230,9 @@ def bleu(golden_utts, gen_utts):
def SER(gen_utts, gen_acts):
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
gen_acts, gen_utts)
if total <= 0:
print("ERROR, total = 0")
return 1
return missing/total
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment