diff --git a/convlab/base_models/t5/key2gen/evaluate.py b/convlab/base_models/t5/key2gen/evaluate.py index 7144223111222a55d8253fd63b6ef8f092e923c1..03a8b56254501595fcacf73d1c76bb050f93115e 100644 --- a/convlab/base_models/t5/key2gen/evaluate.py +++ b/convlab/base_models/t5/key2gen/evaluate.py @@ -54,8 +54,8 @@ if __name__ == '__main__': for task_name in tqdm(args.tasks, desc='tasks'): metric = load_metric("metric.py", task_name) dataset_name = task_name if task_name != "nlg" else "multiwoz21" - for shot in tqdm(args.shots, desc='shots'): - for output_dir in tqdm(args.output_dirs, desc='models'): + for shot in tqdm(args.shots, desc='shots', leave=False): + for output_dir in tqdm(args.output_dirs, desc='models', leave=False): model_name = output_dir.split('/')[-1] if task_name == "wow": test_splits = ["_seen", "_unseen"] @@ -63,9 +63,16 @@ if __name__ == '__main__': test_splits = [""] for test_split in test_splits: results = [] - for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders'): - filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}/generated_predictions.json") - results.append(evaluate(filename, metric)) + for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False): + result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}") + result_file = os.path.join(result_dir, "result.json") + if not os.path.exists(result_file): + filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}/generated_predictions.json") + result = evaluate(filename, metric) + json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + else: + result = json.load(open(result_file)) + results.append(result) res = { "dataset": f"{task_name}{test_split}-{shot}shot", "model": f"{model_name}", @@ -74,5 +81,5 @@ if __name__ == '__main__': tables.append(res) # print(res) res = tabulate(tables, headers='keys', tablefmt='github') - with open(f'eval_results.txt', 'a+', encoding='utf-8') as f: + with open(f'eval_results.txt', 'w', encoding='utf-8') as f: print(res, file=f)