diff --git a/convlab/base_models/t5/key2gen/evaluate.py b/convlab/base_models/t5/key2gen/evaluate.py index 03a8b56254501595fcacf73d1c76bb050f93115e..7acb1118cc857d4cd3e1b401b1d8ecddab2288e9 100644 --- a/convlab/base_models/t5/key2gen/evaluate.py +++ b/convlab/base_models/t5/key2gen/evaluate.py @@ -4,6 +4,7 @@ import json from tqdm import tqdm from datasets import load_metric import numpy as np +import csv def evaluate(filename, metric): """ @@ -50,7 +51,8 @@ if __name__ == '__main__': args = parser.parse_args() print(args) - tables = [] + table = [] + fieldnames = [] for task_name in tqdm(args.tasks, desc='tasks'): metric = load_metric("metric.py", task_name) dataset_name = task_name if task_name != "nlg" else "multiwoz21" @@ -78,8 +80,17 @@ if __name__ == '__main__': "model": f"{model_name}", **avg_result(results) } - tables.append(res) - # print(res) - res = tabulate(tables, headers='keys', tablefmt='github') + table.append(res) + for k in res: + if k not in fieldnames: + fieldnames.append(k) + + res = tabulate(table, headers='keys', tablefmt='github') with open(f'eval_results.txt', 'w', encoding='utf-8') as f: print(res, file=f) + with open('eval_results.csv', 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for res in table: + writer.writerow(res)