Skip to content
Snippets Groups Projects
Commit f11381b9 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 32c3731a
Branches
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ def arg_parser():
parser.add_argument("--use-sentiment", action="store_true")
parser.add_argument("--emotion-mid", action="store_true")
parser.add_argument("--weight", type=float, default=None)
parser.add_argument("--sample", action="store_true")
return parser.parse_args()
......@@ -47,6 +48,7 @@ class Evaluator:
self.add_persona = kwargs.get("add_persona", False)
self.emotion_mid = kwargs.get("emotion_mid", False)
weight = kwargs.get("weight", None)
self.sample = kwargs.get("sample", False)
self.usr = UserActionPolicy(
model_checkpoint,
......@@ -95,8 +97,11 @@ class Evaluator:
inputs, labels["action"], labels["emotion"])
else:
mode = "max"
if self.sample:
mode = "sample"
output = self.usr._parse_output(
self.usr._generate_action(inputs, emotion_mode=emotion_mode))
self.usr._generate_action(inputs, mode=mode, emotion_mode=emotion_mode))
usr_emo = output["emotion"]
usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"]
......@@ -143,9 +148,13 @@ class Evaluator:
self.read_generated_result(generated_file)
else:
print("You must specify the input_file or the generated_file")
mode = "max"
if self.sample:
mode = "sample"
nlg_eval = {
"golden": golden,
"mode": mode,
"metrics": {},
"dialog": self._transform_result()
}
......@@ -336,7 +345,8 @@ def main():
args.model_weight,
use_sentiment=args.use_sentiment,
emotion_mid=args.emotion_mid,
weight=args.weight)
weight=args.weight,
sample=args.sample)
print("model checkpoint", args.model_checkpoint)
print("generated_file", args.generated_file)
print("input_file", args.input_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment