Skip to content
Snippets Groups Projects
Commit 269b24cb authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent f92cc630
Branches
No related tags found
No related merge requests found
......@@ -25,6 +25,8 @@ class UserActionPolicy(GenTUSUserActionPolicy):
for emotion, index in data_emotion.items():
self.emotion_list[index] = emotion
self.use_sentiment = kwargs.get("use_sentiment", False)
self.init_session()
def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None):
......@@ -46,7 +48,14 @@ class UserActionPolicy(GenTUSUserActionPolicy):
else:
history = self.usr_acts[-1*self.max_history:]
# TODO add user info? impolite?
# TODO add user info? impolite? -> check self.use_sentiment
if self.use_sentiment:
# TODO how to get event and user politeness?
inputs = json.dumps({"system": sys_act,
"goal": self.goal.get_goal_list(),
"history": history,
"turn": str(int(self.time_step/2))})
else:
inputs = json.dumps({"system": sys_act,
"goal": self.goal.get_goal_list(),
"history": history,
......@@ -101,6 +110,9 @@ class UserActionPolicy(GenTUSUserActionPolicy):
in_str = in_str.replace('<s>', '').replace(
'<\\s>', '').replace('o"clock', "o'clock")
action = {"emotion": "Neutral", "action": [], "text": ""}
if self.use_sentiment:
action["sentiment"] = "Neutral"
try:
action = json.loads(in_str)
except:
......@@ -115,6 +127,11 @@ class UserActionPolicy(GenTUSUserActionPolicy):
self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
if self.use_sentiment:
sentiment = self._get_sentiment(
model_input, self.seq[:1, :pos], mode)
pos = self._update_seq(sentiment["token_id"], pos)
else:
emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, emotion_mode)
pos = self._update_seq(emotion["token_id"], pos)
......@@ -136,6 +153,13 @@ class UserActionPolicy(GenTUSUserActionPolicy):
if self.only_action:
return self.vector.decode(self.seq[0, :pos])
if self.use_sentiment:
pos = self._update_seq(self.token_map.get_id('start_emotion'), pos)
emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, emotion_mode)
pos = self._update_seq(emotion["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.token_map.get_id("start_text"), pos)
text = self._get_text(model_input, pos)
......@@ -216,6 +240,11 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_output = self._get_text(model_input, pos)
return self._parse_output(raw_output)["text"]
def _get_sentiment(self, model_input, generated_so_far, mode="max"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
return self.kg.get_sentiment(next_token_logits, mode)
def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
......
......@@ -42,117 +42,111 @@ def arg_parser():
class Evaluator:
def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False):
def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False, use_sentiment=False):
self.dataset = dataset
self.model_checkpoint = model_checkpoint
self.model_weight = model_weight
self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
# if model_weight:
# self.usr_policy = UserPolicy(
# self.model_checkpoint, only_action=only_action)
# self.usr_policy.load(model_weight)
# self.usr = self.usr_policy.usr
# else:
self.use_sentiment = use_sentiment
self.usr = UserActionPolicy(
model_checkpoint, only_action=only_action, dataset=self.dataset)
model_checkpoint, only_action=only_action, dataset=self.dataset, use_sentiment=use_sentiment)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
def generate_results(self, f_eval, golden=False, no_neutral=False):
emotion_mode = "max"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval))
r = {
"input": [],
self.r = {"input": [],
"golden_acts": [],
"golden_utts": [],
"golden_emotion": [],
"gen_acts": [],
"gen_utts": [],
"gen_emotion": []
}
"gen_emotion": []}
if use_sentiment:
self.r["golden_sentiment"] = []
self.r["gen_sentiment"] = []
def _append_result(self, temp):
for x in self.r:
self.r[x].append(temp[x])
def generate_results(self, f_eval, golden=False, no_neutral=False):
emotion_mode = "max"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval))
for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
if no_neutral:
if labels["emotion"].lower() == "neutral":
print("skip")
if no_neutral and labels["emotion"].lower() == "neutral":
continue
print("do", labels["emotion"])
if golden:
usr_act = labels["action"]
usr_utt = self.usr.generate_text_from_give_semantic(
inputs, labels["action"], labels["emotion"])
else:
output = self.usr._parse_output(
self.usr._generate_action(inputs, emotion_mode=emotion_mode))
usr_emo = output["emotion"]
usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"]
r["input"].append(inputs)
r["golden_acts"].append(labels["action"])
r["golden_utts"].append(labels["text"])
r["golden_emotion"].append(labels["emotion"])
r["gen_acts"].append(usr_act)
r["gen_utts"].append(usr_utt)
r["gen_emotion"].append(usr_emo)
temp = {}
temp["input"] = inputs
temp["golden_acts"] = labels["action"]
temp["golden_utts"] = labels["text"]
temp["golden_emotion"] = labels["emotion"]
return r
temp["gen_acts"] = usr_act
temp["gen_utts"] = usr_utt
temp["gen_emotion"] = usr_emo
if self.use_sentiment:
temp["golden_sentiment"] = labels["sentiment"]
temp["gen_sentiment"] = output["sentiment"]
self._append_result(temp)
def read_generated_result(self, f_eval):
in_file = json.load(open(f_eval))
r = {
"input": [],
"golden_acts": [],
"golden_utts": [],
"golden_emotion": [],
"gen_acts": [],
"gen_utts": [],
"gen_emotion": []
}
for dialog in tqdm(in_file['dialog']):
for x in dialog:
r[x].append(dialog[x])
return r
self.r[x].append(dialog[x])
def _transform_result(self):
index = [x for x in self.r]
result = []
for i in range(len(self.r[index[0]])):
temp = {}
for x in index:
temp[x] = self.r[x][i]
result.append(temp)
return result
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False):
if input_file:
print("Force generation")
gen_r = self.generate_results(input_file, golden, no_neutral)
self.generate_results(input_file, golden, no_neutral)
elif generated_file:
gen_r = self.read_generated_result(generated_file)
self.read_generated_result(generated_file)
else:
print("You must specify the input_file or the generated_file")
nlg_eval = {
"golden": golden,
"metrics": {},
"dialog": []
"dialog": self._transform_result()
}
for input, golden_act, golden_utt, golden_emo, gen_act, gen_utt, gen_emo in zip(
gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["golden_emotion"],
gen_r["gen_acts"], gen_r["gen_utts"], gen_r["gen_emotion"]):
nlg_eval["dialog"].append({
"input": input,
"golden_acts": golden_act,
"golden_utts": golden_utt,
"golden_emotion": golden_emo,
"gen_acts": gen_act,
"gen_utts": gen_utt,
"gen_emotion": gen_emo
})
if golden:
print("Calculate BLEU")
bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in gen_r["golden_utts"]]
labels = [[utt] for utt in self.r["golden_utts"]]
bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
bleu_score = bleu_metric.compute(predictions=self.r["gen_utts"],
references=labels,
force=True)
print("bleu_metric", bleu_score)
......@@ -161,7 +155,7 @@ class Evaluator:
else:
print("Calculate SER")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
gen_r["gen_acts"], gen_r["gen_utts"])
self.r["gen_acts"], self.r["gen_utts"])
print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
"genTUSNLG", missing, total, hallucinate, missing/total))
......@@ -171,7 +165,8 @@ class Evaluator:
dir_name = self.model_checkpoint
json.dump(nlg_eval,
open(os.path.join(dir_name, f"{self.time}-nlg_eval.json"), 'w'),
open(os.path.join(
dir_name, f"{self.time}-nlg_eval.json"), 'w'),
indent=2)
return os.path.join(dir_name, f"{self.time}-nlg_eval.json")
......@@ -232,8 +227,19 @@ class Evaluator:
result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}")
# TODO no neutral
emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint,
time=self.time, no_neutral=True)
emo_score = emotion_score(
golden_emotions,
gen_emotions,
self.model_checkpoint,
time=self.time,
no_neutral=False)
if self.use_sentiment:
sent_score = sentiment_score(
gen_file['dialog']["golden_sentiment"],
gen_file['dialog']["gen_sentiment"],
self.model_checkpoint,
time=self.time)
# for metric in emo_score:
# result[metric] = emo_score[metric]
# print(f"{metric}: {result[metric]}")
......@@ -254,7 +260,8 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra
macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro")
sep_f1 = metrics.f1_score(
golden_emotions, gen_emotions, average=None, labels=labels)
cm = metrics.confusion_matrix(golden_emotions, gen_emotions, normalize="true", labels=labels)
cm = metrics.confusion_matrix(
golden_emotions, gen_emotions, normalize="true", labels=labels)
disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, display_labels=labels)
disp.plot()
......@@ -265,6 +272,26 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra
return r
def sentiment_score(golden_sentiment, gen_sentiment, dirname=".", time=""):
labels = ["Neutral", "Negative", "Positive"]
print(labels)
macro_f1 = metrics.f1_score(
golden_sentiment, gen_sentiment, average="macro")
sep_f1 = metrics.f1_score(
golden_sentiment, gen_sentiment, average=None, labels=labels)
cm = metrics.confusion_matrix(
golden_sentiment, gen_sentiment, normalize="true", labels=labels)
disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.savefig(os.path.join(dirname, f"{time}-sentiment.png"))
r = {"macro_f1": float(macro_f1), "sep_f1": list(
sep_f1), "cm": [list(c) for c in list(cm)]}
print(r)
return r
def f1_measure(preds, labels):
tp = 0
score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
......
......@@ -2,7 +2,7 @@ import json
class tokenMap:
def __init__(self, tokenizer):
def __init__(self, tokenizer, use_sentiment=False):
self.tokenizer = tokenizer
self.token_name = {}
self.hash_map = {}
......@@ -10,7 +10,6 @@ class tokenMap:
self.default()
def default(self, only_action=False):
# TODO
self.format_tokens = {
'start_json': '{"emotion": "', # 49643, 10845, 7862, 646
'start_act': 'action": [["', # 49329
......@@ -21,6 +20,10 @@ class tokenMap:
'end_json': '}', # 24303
'end_json_2': '"}' # 48805
}
if self.use_sentiment:
self.format_tokens['start_json'] = '{"sentiment": "'
self.format_tokens['start_emotion'] = 'emotion": "'
if only_action:
self.format_tokens['end_act'] = '"]]}'
for token_name in self.format_tokens:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment