From d8f059498b7f1196c8ee79277fc7961a97df6e26 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Mon, 17 Oct 2022 14:46:05 +0200 Subject: [PATCH] wip --- convlab/policy/genTUS/evaluate.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py index ac9f49d7..82779afe 100644 --- a/convlab/policy/genTUS/evaluate.py +++ b/convlab/policy/genTUS/evaluate.py @@ -145,6 +145,7 @@ class Evaluator: json.dump(nlg_eval, open(os.path.join(dir_name, "nlg_eval.json"), 'w'), indent=2) + return os.path.join(dir_name, "nlg_eval.json") def evaluation(self, input_file=None, generated_file=None): force_prediction = True @@ -241,10 +242,15 @@ def main(): if args.do_semantic: eval.evaluation(args.input_file) if args.do_nlg: - eval.nlg_evaluation(input_file=args.input_file, - generated_file=args.generated_file, - golden=args.do_golden_nlg) - eval.evaluation(args.input_file) + nlg_result = eval.nlg_evaluation(input_file=args.input_file, + generated_file=args.generated_file, + golden=args.do_golden_nlg) + if args.generated_file: + generated_file = args.generated_file + else: + generated_file = nlg_result + eval.evaluation(args.input_file, + generated_file) if __name__ == '__main__': -- GitLab