Skip to content
Snippets Groups Projects
Commit ac553284 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 8749360e
No related branches found
No related tags found
No related merge requests found
...@@ -106,7 +106,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -106,7 +106,7 @@ class UserActionPolicy(GenTUSUserActionPolicy):
print("-"*20) print("-"*20)
return action return action
def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True): def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="max"):
self.kg.parse_input(raw_inputs) self.kg.parse_input(raw_inputs)
model_input = self.vector.encode(raw_inputs, self.max_in_len) model_input = self.vector.encode(raw_inputs, self.max_in_len)
# start token # start token
...@@ -114,7 +114,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -114,7 +114,7 @@ class UserActionPolicy(GenTUSUserActionPolicy):
pos = self._update_seq([0], 0) pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos) pos = self._update_seq(self.token_map.get_id('start_json'), pos)
emotion = self._get_emotion( emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, allow_general_intent) model_input, self.seq[:1, :pos], mode, emotion_mode)
pos = self._update_seq(emotion["token_id"], pos) pos = self._update_seq(emotion["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.token_map.get_id('start_act'), pos) pos = self._update_seq(self.token_map.get_id('start_act'), pos)
...@@ -214,10 +214,10 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -214,10 +214,10 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_output = self._get_text(model_input, pos) raw_output = self._get_text(model_input, pos)
return self._parse_output(raw_output)["text"] return self._parse_output(raw_output)["text"]
def _get_emotion(self, model_input, generated_so_far, mode="max", allow_general_intent=True): def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"):
next_token_logits = self.model.get_next_token_logits( next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far) model_input, generated_so_far)
return self.kg.get_emotion(next_token_logits, mode, allow_general_intent) return self.kg.get_emotion(next_token_logits, mode, emotion_mode)
def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True): def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
next_token_logits = self.model.get_next_token_logits( next_token_logits = self.model.get_next_token_logits(
......
...@@ -35,6 +35,8 @@ def arg_parser(): ...@@ -35,6 +35,8 @@ def arg_parser():
help="do nlg generation") help="do nlg generation")
parser.add_argument("--do-golden-nlg", action="store_true", parser.add_argument("--do-golden-nlg", action="store_true",
help="do golden nlg generation") help="do golden nlg generation")
parser.add_argument("--no-neutral", action="store_true",
help="skip neutral emotion")
return parser.parse_args() return parser.parse_args()
...@@ -53,7 +55,10 @@ class Evaluator: ...@@ -53,7 +55,10 @@ class Evaluator:
model_checkpoint, only_action=only_action, dataset=self.dataset) model_checkpoint, only_action=only_action, dataset=self.dataset)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
def generate_results(self, f_eval, golden=False): def generate_results(self, f_eval, golden=False, no_neutral=False):
emotion_mode = "max"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval)) in_file = json.load(open(f_eval))
r = { r = {
"input": [], "input": [],
...@@ -67,14 +72,20 @@ class Evaluator: ...@@ -67,14 +72,20 @@ class Evaluator:
for dialog in tqdm(in_file['dialog']): for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"] inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"]) labels = self.usr._parse_output(dialog["out"])
if no_neutral:
if labels["emotion"].lower() == "neutral":
print("skip")
continue
print("do", labels["emotion"])
if golden: if golden:
usr_act = labels["action"] usr_act = labels["action"]
usr_utt = self.usr.generate_text_from_give_semantic( usr_utt = self.usr.generate_text_from_give_semantic(
inputs, labels["action"], labels["emotion"]) inputs, labels["action"], labels["emotion"])
else: else:
output = self.usr._parse_output( output = self.usr._parse_output(
self.usr._generate_action(inputs)) self.usr._generate_action(inputs, emotion_mode=emotion_mode))
usr_emo = output["emotion"] usr_emo = output["emotion"]
usr_act = self.usr._remove_illegal_action(output["action"]) usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"] usr_utt = output["text"]
...@@ -106,10 +117,10 @@ class Evaluator: ...@@ -106,10 +117,10 @@ class Evaluator:
return r return r
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False): def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False):
if input_file: if input_file:
print("Force generation") print("Force generation")
gen_r = self.generate_results(input_file, golden) gen_r = self.generate_results(input_file, golden, no_neutral)
elif generated_file: elif generated_file:
gen_r = self.read_generated_result(generated_file) gen_r = self.read_generated_result(generated_file)
...@@ -284,7 +295,8 @@ def main(): ...@@ -284,7 +295,8 @@ def main():
else: else:
nlg_result = eval.nlg_evaluation(input_file=args.input_file, nlg_result = eval.nlg_evaluation(input_file=args.input_file,
generated_file=args.generated_file, generated_file=args.generated_file,
golden=args.do_golden_nlg) golden=args.do_golden_nlg,
no_neutral=args.no_neutral)
generated_file = nlg_result generated_file = nlg_result
eval.evaluation(args.input_file, eval.evaluation(args.input_file,
......
...@@ -32,10 +32,16 @@ class KnowledgeGraph(GenTUSKnowledgeGraph): ...@@ -32,10 +32,16 @@ class KnowledgeGraph(GenTUSKnowledgeGraph):
for emotion in self.emotion: for emotion in self.emotion:
self.kg_map["emotion"].add_token(emotion, emotion) self.kg_map["emotion"].add_token(emotion, emotion)
def get_emotion(self, outputs, mode="max", allow_general_intent=True): def get_emotion(self, outputs, mode="max", emotion_mode="normal"):
canidate_list = self.emotion
if emotion_mode == "normal":
score = self._get_max_score(
outputs, self.emotion, "emotion", weight=self.prior)
elif emotion_mode == "no_neutral":
score = self._get_max_score( score = self._get_max_score(
outputs, canidate_list, "emotion", weight=self.prior) outputs, self.emotion[1:], "emotion", weight=self.prior)
else:
print(f"unknown emotion mode: {emotion_mode}")
s = self._select(score, mode) s = self._select(score, mode)
return score[s] return score[s]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment