diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 2204c26dcfbdcf391e4c403719c75807a76afd29..358a4828c7738c4544bdffbfa7108da72a6b08d1 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -34,6 +34,7 @@ def arg_parser(): parser.add_argument("--use-sentiment", action="store_true") parser.add_argument("--emotion-mid", action="store_true") parser.add_argument("--weight", type=float, default=None) + parser.add_argument("--sample", action="store_true") return parser.parse_args() @@ -47,6 +48,7 @@ class Evaluator: self.add_persona = kwargs.get("add_persona", False) self.emotion_mid = kwargs.get("emotion_mid", False) weight = kwargs.get("weight", None) + self.sample = kwargs.get("sample", False) self.usr = UserActionPolicy( model_checkpoint, @@ -95,8 +97,11 @@ class Evaluator: inputs, labels["action"], labels["emotion"]) else: + mode = "max" + if self.sample: + mode = "sample" output = self.usr._parse_output( - self.usr._generate_action(inputs, emotion_mode=emotion_mode)) + self.usr._generate_action(inputs, mode=mode, emotion_mode=emotion_mode)) usr_emo = output["emotion"] usr_act = self.usr._remove_illegal_action(output["action"]) usr_utt = output["text"] @@ -143,9 +148,13 @@ class Evaluator: self.read_generated_result(generated_file) else: print("You must specify the input_file or the generated_file") + mode = "max" + if self.sample: + mode = "sample" nlg_eval = { "golden": golden, + "mode": mode, "metrics": {}, "dialog": self._transform_result() } @@ -336,7 +345,8 @@ def main(): args.model_weight, use_sentiment=args.use_sentiment, emotion_mid=args.emotion_mid, - weight=args.weight) + weight=args.weight, + sample=args.sample) print("model checkpoint", args.model_checkpoint) print("generated_file", args.generated_file) print("input_file", args.input_file)