From 700b6017ad71ef5f9076a756cf1bfdb6d8fdf6df Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 25 Jul 2022 11:49:46 +0800 Subject: [PATCH] write finetune result to csv --- convlab/base_models/t5/key2gen/evaluate.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/convlab/base_models/t5/key2gen/evaluate.py b/convlab/base_models/t5/key2gen/evaluate.py index 03a8b562..7acb1118 100644 --- a/convlab/base_models/t5/key2gen/evaluate.py +++ b/convlab/base_models/t5/key2gen/evaluate.py @@ -4,6 +4,7 @@ import json from tqdm import tqdm from datasets import load_metric import numpy as np +import csv def evaluate(filename, metric): """ @@ -50,7 +51,8 @@ if __name__ == '__main__': args = parser.parse_args() print(args) - tables = [] + table = [] + fieldnames = [] for task_name in tqdm(args.tasks, desc='tasks'): metric = load_metric("metric.py", task_name) dataset_name = task_name if task_name != "nlg" else "multiwoz21" @@ -78,8 +80,17 @@ if __name__ == '__main__': "model": f"{model_name}", **avg_result(results) } - tables.append(res) - # print(res) - res = tabulate(tables, headers='keys', tablefmt='github') + table.append(res) + for k in res: + if k not in fieldnames: + fieldnames.append(k) + + res = tabulate(table, headers='keys', tablefmt='github') with open(f'eval_results.txt', 'w', encoding='utf-8') as f: print(res, file=f) + with open('eval_results.csv', 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for res in table: + writer.writerow(res) -- GitLab