Skip to content
Snippets Groups Projects
Commit ac553284 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 8749360e
No related branches found
No related tags found
No related merge requests found
......@@ -106,7 +106,7 @@ class UserActionPolicy(GenTUSUserActionPolicy):
print("-"*20)
return action
def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True):
def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="max"):
self.kg.parse_input(raw_inputs)
model_input = self.vector.encode(raw_inputs, self.max_in_len)
# start token
......@@ -114,7 +114,7 @@ class UserActionPolicy(GenTUSUserActionPolicy):
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, allow_general_intent)
model_input, self.seq[:1, :pos], mode, emotion_mode)
pos = self._update_seq(emotion["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
......@@ -214,10 +214,10 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_output = self._get_text(model_input, pos)
return self._parse_output(raw_output)["text"]
def _get_emotion(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
return self.kg.get_emotion(next_token_logits, mode, allow_general_intent)
return self.kg.get_emotion(next_token_logits, mode, emotion_mode)
def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
next_token_logits = self.model.get_next_token_logits(
......
......@@ -35,6 +35,8 @@ def arg_parser():
help="do nlg generation")
parser.add_argument("--do-golden-nlg", action="store_true",
help="do golden nlg generation")
parser.add_argument("--no-neutral", action="store_true",
help="skip neutral emotion")
return parser.parse_args()
......@@ -53,7 +55,10 @@ class Evaluator:
model_checkpoint, only_action=only_action, dataset=self.dataset)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
def generate_results(self, f_eval, golden=False):
def generate_results(self, f_eval, golden=False, no_neutral=False):
emotion_mode = "max"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval))
r = {
"input": [],
......@@ -67,14 +72,20 @@ class Evaluator:
for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
if no_neutral:
if labels["emotion"].lower() == "neutral":
print("skip")
continue
print("do", labels["emotion"])
if golden:
usr_act = labels["action"]
usr_utt = self.usr.generate_text_from_give_semantic(
inputs, labels["action"], labels["emotion"])
else:
output = self.usr._parse_output(
self.usr._generate_action(inputs))
self.usr._generate_action(inputs, emotion_mode=emotion_mode))
usr_emo = output["emotion"]
usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"]
......@@ -106,10 +117,10 @@ class Evaluator:
return r
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False):
if input_file:
print("Force generation")
gen_r = self.generate_results(input_file, golden)
gen_r = self.generate_results(input_file, golden, no_neutral)
elif generated_file:
gen_r = self.read_generated_result(generated_file)
......@@ -284,7 +295,8 @@ def main():
else:
nlg_result = eval.nlg_evaluation(input_file=args.input_file,
generated_file=args.generated_file,
golden=args.do_golden_nlg)
golden=args.do_golden_nlg,
no_neutral=args.no_neutral)
generated_file = nlg_result
eval.evaluation(args.input_file,
......
......@@ -32,10 +32,16 @@ class KnowledgeGraph(GenTUSKnowledgeGraph):
for emotion in self.emotion:
self.kg_map["emotion"].add_token(emotion, emotion)
def get_emotion(self, outputs, mode="max", allow_general_intent=True):
canidate_list = self.emotion
def get_emotion(self, outputs, mode="max", emotion_mode="normal"):
if emotion_mode == "normal":
score = self._get_max_score(
outputs, self.emotion, "emotion", weight=self.prior)
elif emotion_mode == "no_neutral":
score = self._get_max_score(
outputs, canidate_list, "emotion", weight=self.prior)
outputs, self.emotion[1:], "emotion", weight=self.prior)
else:
print(f"unknown emotion mode: {emotion_mode}")
s = self._select(score, mode)
return score[s]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment