Skip to content
Snippets Groups Projects
Commit 700b6017 authored by zqwerty's avatar zqwerty
Browse files

write finetune result to csv

parent b0d0851e
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ import json
from tqdm import tqdm
from datasets import load_metric
import numpy as np
import csv
def evaluate(filename, metric):
"""
......@@ -50,7 +51,8 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
tables = []
table = []
fieldnames = []
for task_name in tqdm(args.tasks, desc='tasks'):
metric = load_metric("metric.py", task_name)
dataset_name = task_name if task_name != "nlg" else "multiwoz21"
......@@ -78,8 +80,17 @@ if __name__ == '__main__':
"model": f"{model_name}",
**avg_result(results)
}
tables.append(res)
# print(res)
res = tabulate(tables, headers='keys', tablefmt='github')
table.append(res)
for k in res:
if k not in fieldnames:
fieldnames.append(k)
res = tabulate(table, headers='keys', tablefmt='github')
with open(f'eval_results.txt', 'w', encoding='utf-8') as f:
print(res, file=f)
with open('eval_results.csv', 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for res in table:
writer.writerow(res)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment