From f11381b938707013ae064521a6bdca5f17067202 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Tue, 24 Jan 2023 02:16:01 +0100 Subject: [PATCH] wip --- convlab/policy/emoTUS/evaluate.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 2204c26d..358a4828 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -34,6 +34,7 @@ def arg_parser(): parser.add_argument("--use-sentiment", action="store_true") parser.add_argument("--emotion-mid", action="store_true") parser.add_argument("--weight", type=float, default=None) + parser.add_argument("--sample", action="store_true") return parser.parse_args() @@ -47,6 +48,7 @@ class Evaluator: self.add_persona = kwargs.get("add_persona", False) self.emotion_mid = kwargs.get("emotion_mid", False) weight = kwargs.get("weight", None) + self.sample = kwargs.get("sample", False) self.usr = UserActionPolicy( model_checkpoint, @@ -95,8 +97,11 @@ class Evaluator: inputs, labels["action"], labels["emotion"]) else: + mode = "max" + if self.sample: + mode = "sample" output = self.usr._parse_output( - self.usr._generate_action(inputs, emotion_mode=emotion_mode)) + self.usr._generate_action(inputs, mode=mode, emotion_mode=emotion_mode)) usr_emo = output["emotion"] usr_act = self.usr._remove_illegal_action(output["action"]) usr_utt = output["text"] @@ -143,9 +148,13 @@ class Evaluator: self.read_generated_result(generated_file) else: print("You must specify the input_file or the generated_file") + mode = "max" + if self.sample: + mode = "sample" nlg_eval = { "golden": golden, + "mode": mode, "metrics": {}, "dialog": self._transform_result() } @@ -336,7 +345,8 @@ def main(): args.model_weight, use_sentiment=args.use_sentiment, emotion_mid=args.emotion_mid, - weight=args.weight) + weight=args.weight, + sample=args.sample) print("model checkpoint", args.model_checkpoint) print("generated_file", args.generated_file) print("input_file", args.input_file) -- GitLab