Skip to content
Snippets Groups Projects
Commit 269b24cb authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent f92cc630
Branches
No related tags found
No related merge requests found
...@@ -25,6 +25,8 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -25,6 +25,8 @@ class UserActionPolicy(GenTUSUserActionPolicy):
for emotion, index in data_emotion.items(): for emotion, index in data_emotion.items():
self.emotion_list[index] = emotion self.emotion_list[index] = emotion
self.use_sentiment = kwargs.get("use_sentiment", False)
self.init_session() self.init_session()
def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None): def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None):
...@@ -46,7 +48,14 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -46,7 +48,14 @@ class UserActionPolicy(GenTUSUserActionPolicy):
else: else:
history = self.usr_acts[-1*self.max_history:] history = self.usr_acts[-1*self.max_history:]
# TODO add user info? impolite? # TODO add user info? impolite? -> check self.use_sentiment
if self.use_sentiment:
# TODO how to get event and user politeness?
inputs = json.dumps({"system": sys_act,
"goal": self.goal.get_goal_list(),
"history": history,
"turn": str(int(self.time_step/2))})
else:
inputs = json.dumps({"system": sys_act, inputs = json.dumps({"system": sys_act,
"goal": self.goal.get_goal_list(), "goal": self.goal.get_goal_list(),
"history": history, "history": history,
...@@ -101,6 +110,9 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -101,6 +110,9 @@ class UserActionPolicy(GenTUSUserActionPolicy):
in_str = in_str.replace('<s>', '').replace( in_str = in_str.replace('<s>', '').replace(
'<\\s>', '').replace('o"clock', "o'clock") '<\\s>', '').replace('o"clock', "o'clock")
action = {"emotion": "Neutral", "action": [], "text": ""} action = {"emotion": "Neutral", "action": [], "text": ""}
if self.use_sentiment:
action["sentiment"] = "Neutral"
try: try:
action = json.loads(in_str) action = json.loads(in_str)
except: except:
...@@ -115,6 +127,11 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -115,6 +127,11 @@ class UserActionPolicy(GenTUSUserActionPolicy):
self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
pos = self._update_seq([0], 0) pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos) pos = self._update_seq(self.token_map.get_id('start_json'), pos)
if self.use_sentiment:
sentiment = self._get_sentiment(
model_input, self.seq[:1, :pos], mode)
pos = self._update_seq(sentiment["token_id"], pos)
else:
emotion = self._get_emotion( emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, emotion_mode) model_input, self.seq[:1, :pos], mode, emotion_mode)
pos = self._update_seq(emotion["token_id"], pos) pos = self._update_seq(emotion["token_id"], pos)
...@@ -136,6 +153,13 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -136,6 +153,13 @@ class UserActionPolicy(GenTUSUserActionPolicy):
if self.only_action: if self.only_action:
return self.vector.decode(self.seq[0, :pos]) return self.vector.decode(self.seq[0, :pos])
if self.use_sentiment:
pos = self._update_seq(self.token_map.get_id('start_emotion'), pos)
emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, emotion_mode)
pos = self._update_seq(emotion["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.token_map.get_id("start_text"), pos) pos = self._update_seq(self.token_map.get_id("start_text"), pos)
text = self._get_text(model_input, pos) text = self._get_text(model_input, pos)
...@@ -216,6 +240,11 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -216,6 +240,11 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_output = self._get_text(model_input, pos) raw_output = self._get_text(model_input, pos)
return self._parse_output(raw_output)["text"] return self._parse_output(raw_output)["text"]
def _get_sentiment(self, model_input, generated_so_far, mode="max"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
return self.kg.get_sentiment(next_token_logits, mode)
def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"): def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"):
next_token_logits = self.model.get_next_token_logits( next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far) model_input, generated_so_far)
......
...@@ -42,117 +42,111 @@ def arg_parser(): ...@@ -42,117 +42,111 @@ def arg_parser():
class Evaluator: class Evaluator:
def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False): def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False, use_sentiment=False):
self.dataset = dataset self.dataset = dataset
self.model_checkpoint = model_checkpoint self.model_checkpoint = model_checkpoint
self.model_weight = model_weight self.model_weight = model_weight
self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
# if model_weight: self.use_sentiment = use_sentiment
# self.usr_policy = UserPolicy(
# self.model_checkpoint, only_action=only_action)
# self.usr_policy.load(model_weight)
# self.usr = self.usr_policy.usr
# else:
self.usr = UserActionPolicy( self.usr = UserActionPolicy(
model_checkpoint, only_action=only_action, dataset=self.dataset) model_checkpoint, only_action=only_action, dataset=self.dataset, use_sentiment=use_sentiment)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
def generate_results(self, f_eval, golden=False, no_neutral=False): self.r = {"input": [],
emotion_mode = "max"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval))
r = {
"input": [],
"golden_acts": [], "golden_acts": [],
"golden_utts": [], "golden_utts": [],
"golden_emotion": [], "golden_emotion": [],
"gen_acts": [], "gen_acts": [],
"gen_utts": [], "gen_utts": [],
"gen_emotion": [] "gen_emotion": []}
} if use_sentiment:
self.r["golden_sentiment"] = []
self.r["gen_sentiment"] = []
def _append_result(self, temp):
for x in self.r:
self.r[x].append(temp[x])
def generate_results(self, f_eval, golden=False, no_neutral=False):
emotion_mode = "max"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval))
for dialog in tqdm(in_file['dialog']): for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"] inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"]) labels = self.usr._parse_output(dialog["out"])
if no_neutral: if no_neutral and labels["emotion"].lower() == "neutral":
if labels["emotion"].lower() == "neutral":
print("skip")
continue continue
print("do", labels["emotion"])
if golden: if golden:
usr_act = labels["action"] usr_act = labels["action"]
usr_utt = self.usr.generate_text_from_give_semantic( usr_utt = self.usr.generate_text_from_give_semantic(
inputs, labels["action"], labels["emotion"]) inputs, labels["action"], labels["emotion"])
else: else:
output = self.usr._parse_output( output = self.usr._parse_output(
self.usr._generate_action(inputs, emotion_mode=emotion_mode)) self.usr._generate_action(inputs, emotion_mode=emotion_mode))
usr_emo = output["emotion"] usr_emo = output["emotion"]
usr_act = self.usr._remove_illegal_action(output["action"]) usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"] usr_utt = output["text"]
r["input"].append(inputs)
r["golden_acts"].append(labels["action"])
r["golden_utts"].append(labels["text"])
r["golden_emotion"].append(labels["emotion"])
r["gen_acts"].append(usr_act) temp = {}
r["gen_utts"].append(usr_utt) temp["input"] = inputs
r["gen_emotion"].append(usr_emo) temp["golden_acts"] = labels["action"]
temp["golden_utts"] = labels["text"]
temp["golden_emotion"] = labels["emotion"]
return r temp["gen_acts"] = usr_act
temp["gen_utts"] = usr_utt
temp["gen_emotion"] = usr_emo
if self.use_sentiment:
temp["golden_sentiment"] = labels["sentiment"]
temp["gen_sentiment"] = output["sentiment"]
self._append_result(temp)
def read_generated_result(self, f_eval): def read_generated_result(self, f_eval):
in_file = json.load(open(f_eval)) in_file = json.load(open(f_eval))
r = {
"input": [],
"golden_acts": [],
"golden_utts": [],
"golden_emotion": [],
"gen_acts": [],
"gen_utts": [],
"gen_emotion": []
}
for dialog in tqdm(in_file['dialog']): for dialog in tqdm(in_file['dialog']):
for x in dialog: for x in dialog:
r[x].append(dialog[x]) self.r[x].append(dialog[x])
return r def _transform_result(self):
index = [x for x in self.r]
result = []
for i in range(len(self.r[index[0]])):
temp = {}
for x in index:
temp[x] = self.r[x][i]
result.append(temp)
return result
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False): def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False):
if input_file: if input_file:
print("Force generation") print("Force generation")
gen_r = self.generate_results(input_file, golden, no_neutral) self.generate_results(input_file, golden, no_neutral)
elif generated_file: elif generated_file:
gen_r = self.read_generated_result(generated_file) self.read_generated_result(generated_file)
else: else:
print("You must specify the input_file or the generated_file") print("You must specify the input_file or the generated_file")
nlg_eval = { nlg_eval = {
"golden": golden, "golden": golden,
"metrics": {}, "metrics": {},
"dialog": [] "dialog": self._transform_result()
} }
for input, golden_act, golden_utt, golden_emo, gen_act, gen_utt, gen_emo in zip(
gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["golden_emotion"],
gen_r["gen_acts"], gen_r["gen_utts"], gen_r["gen_emotion"]):
nlg_eval["dialog"].append({
"input": input,
"golden_acts": golden_act,
"golden_utts": golden_utt,
"golden_emotion": golden_emo,
"gen_acts": gen_act,
"gen_utts": gen_utt,
"gen_emotion": gen_emo
})
if golden: if golden:
print("Calculate BLEU") print("Calculate BLEU")
bleu_metric = load_metric("sacrebleu") bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in gen_r["golden_utts"]] labels = [[utt] for utt in self.r["golden_utts"]]
bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"], bleu_score = bleu_metric.compute(predictions=self.r["gen_utts"],
references=labels, references=labels,
force=True) force=True)
print("bleu_metric", bleu_score) print("bleu_metric", bleu_score)
...@@ -161,7 +155,7 @@ class Evaluator: ...@@ -161,7 +155,7 @@ class Evaluator:
else: else:
print("Calculate SER") print("Calculate SER")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
gen_r["gen_acts"], gen_r["gen_utts"]) self.r["gen_acts"], self.r["gen_utts"])
print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format( print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
"genTUSNLG", missing, total, hallucinate, missing/total)) "genTUSNLG", missing, total, hallucinate, missing/total))
...@@ -171,7 +165,8 @@ class Evaluator: ...@@ -171,7 +165,8 @@ class Evaluator:
dir_name = self.model_checkpoint dir_name = self.model_checkpoint
json.dump(nlg_eval, json.dump(nlg_eval,
open(os.path.join(dir_name, f"{self.time}-nlg_eval.json"), 'w'), open(os.path.join(
dir_name, f"{self.time}-nlg_eval.json"), 'w'),
indent=2) indent=2)
return os.path.join(dir_name, f"{self.time}-nlg_eval.json") return os.path.join(dir_name, f"{self.time}-nlg_eval.json")
...@@ -232,8 +227,19 @@ class Evaluator: ...@@ -232,8 +227,19 @@ class Evaluator:
result[metric] = sum(scores[metric])/len(scores[metric]) result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}") print(f"{metric}: {result[metric]}")
# TODO no neutral # TODO no neutral
emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint, emo_score = emotion_score(
time=self.time, no_neutral=True) golden_emotions,
gen_emotions,
self.model_checkpoint,
time=self.time,
no_neutral=False)
if self.use_sentiment:
sent_score = sentiment_score(
gen_file['dialog']["golden_sentiment"],
gen_file['dialog']["gen_sentiment"],
self.model_checkpoint,
time=self.time)
# for metric in emo_score: # for metric in emo_score:
# result[metric] = emo_score[metric] # result[metric] = emo_score[metric]
# print(f"{metric}: {result[metric]}") # print(f"{metric}: {result[metric]}")
...@@ -254,7 +260,8 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra ...@@ -254,7 +260,8 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra
macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro") macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro")
sep_f1 = metrics.f1_score( sep_f1 = metrics.f1_score(
golden_emotions, gen_emotions, average=None, labels=labels) golden_emotions, gen_emotions, average=None, labels=labels)
cm = metrics.confusion_matrix(golden_emotions, gen_emotions, normalize="true", labels=labels) cm = metrics.confusion_matrix(
golden_emotions, gen_emotions, normalize="true", labels=labels)
disp = metrics.ConfusionMatrixDisplay( disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, display_labels=labels) confusion_matrix=cm, display_labels=labels)
disp.plot() disp.plot()
...@@ -265,6 +272,26 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra ...@@ -265,6 +272,26 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra
return r return r
def sentiment_score(golden_sentiment, gen_sentiment, dirname=".", time=""):
labels = ["Neutral", "Negative", "Positive"]
print(labels)
macro_f1 = metrics.f1_score(
golden_sentiment, gen_sentiment, average="macro")
sep_f1 = metrics.f1_score(
golden_sentiment, gen_sentiment, average=None, labels=labels)
cm = metrics.confusion_matrix(
golden_sentiment, gen_sentiment, normalize="true", labels=labels)
disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.savefig(os.path.join(dirname, f"{time}-sentiment.png"))
r = {"macro_f1": float(macro_f1), "sep_f1": list(
sep_f1), "cm": [list(c) for c in list(cm)]}
print(r)
return r
def f1_measure(preds, labels): def f1_measure(preds, labels):
tp = 0 tp = 0
score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0} score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
......
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
class tokenMap: class tokenMap:
def __init__(self, tokenizer): def __init__(self, tokenizer, use_sentiment=False):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.token_name = {} self.token_name = {}
self.hash_map = {} self.hash_map = {}
...@@ -10,7 +10,6 @@ class tokenMap: ...@@ -10,7 +10,6 @@ class tokenMap:
self.default() self.default()
def default(self, only_action=False): def default(self, only_action=False):
# TODO
self.format_tokens = { self.format_tokens = {
'start_json': '{"emotion": "', # 49643, 10845, 7862, 646 'start_json': '{"emotion": "', # 49643, 10845, 7862, 646
'start_act': 'action": [["', # 49329 'start_act': 'action": [["', # 49329
...@@ -21,6 +20,10 @@ class tokenMap: ...@@ -21,6 +20,10 @@ class tokenMap:
'end_json': '}', # 24303 'end_json': '}', # 24303
'end_json_2': '"}' # 48805 'end_json_2': '"}' # 48805
} }
if self.use_sentiment:
self.format_tokens['start_json'] = '{"sentiment": "'
self.format_tokens['start_emotion'] = 'emotion": "'
if only_action: if only_action:
self.format_tokens['end_act'] = '"]]}' self.format_tokens['end_act'] = '"]]}'
for token_name in self.format_tokens: for token_name in self.format_tokens:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment