Skip to content
Snippets Groups Projects
Commit f11381b9 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 32c3731a
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,7 @@ def arg_parser(): ...@@ -34,6 +34,7 @@ def arg_parser():
parser.add_argument("--use-sentiment", action="store_true") parser.add_argument("--use-sentiment", action="store_true")
parser.add_argument("--emotion-mid", action="store_true") parser.add_argument("--emotion-mid", action="store_true")
parser.add_argument("--weight", type=float, default=None) parser.add_argument("--weight", type=float, default=None)
parser.add_argument("--sample", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -47,6 +48,7 @@ class Evaluator: ...@@ -47,6 +48,7 @@ class Evaluator:
self.add_persona = kwargs.get("add_persona", False) self.add_persona = kwargs.get("add_persona", False)
self.emotion_mid = kwargs.get("emotion_mid", False) self.emotion_mid = kwargs.get("emotion_mid", False)
weight = kwargs.get("weight", None) weight = kwargs.get("weight", None)
self.sample = kwargs.get("sample", False)
self.usr = UserActionPolicy( self.usr = UserActionPolicy(
model_checkpoint, model_checkpoint,
...@@ -95,8 +97,11 @@ class Evaluator: ...@@ -95,8 +97,11 @@ class Evaluator:
inputs, labels["action"], labels["emotion"]) inputs, labels["action"], labels["emotion"])
else: else:
mode = "max"
if self.sample:
mode = "sample"
output = self.usr._parse_output( output = self.usr._parse_output(
self.usr._generate_action(inputs, emotion_mode=emotion_mode)) self.usr._generate_action(inputs, mode=mode, emotion_mode=emotion_mode))
usr_emo = output["emotion"] usr_emo = output["emotion"]
usr_act = self.usr._remove_illegal_action(output["action"]) usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"] usr_utt = output["text"]
...@@ -143,9 +148,13 @@ class Evaluator: ...@@ -143,9 +148,13 @@ class Evaluator:
self.read_generated_result(generated_file) self.read_generated_result(generated_file)
else: else:
print("You must specify the input_file or the generated_file") print("You must specify the input_file or the generated_file")
mode = "max"
if self.sample:
mode = "sample"
nlg_eval = { nlg_eval = {
"golden": golden, "golden": golden,
"mode": mode,
"metrics": {}, "metrics": {},
"dialog": self._transform_result() "dialog": self._transform_result()
} }
...@@ -336,7 +345,8 @@ def main(): ...@@ -336,7 +345,8 @@ def main():
args.model_weight, args.model_weight,
use_sentiment=args.use_sentiment, use_sentiment=args.use_sentiment,
emotion_mid=args.emotion_mid, emotion_mid=args.emotion_mid,
weight=args.weight) weight=args.weight,
sample=args.sample)
print("model checkpoint", args.model_checkpoint) print("model checkpoint", args.model_checkpoint)
print("generated_file", args.generated_file) print("generated_file", args.generated_file)
print("input_file", args.input_file) print("input_file", args.input_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment