Skip to content
Snippets Groups Projects
Commit e90344db authored by zqwerty's avatar zqwerty
Browse files

update evaluate.py

parent 3775f099
No related branches found
No related tags found
No related merge requests found
...@@ -54,8 +54,8 @@ if __name__ == '__main__': ...@@ -54,8 +54,8 @@ if __name__ == '__main__':
for task_name in tqdm(args.tasks, desc='tasks'): for task_name in tqdm(args.tasks, desc='tasks'):
metric = load_metric("metric.py", task_name) metric = load_metric("metric.py", task_name)
dataset_name = task_name if task_name != "nlg" else "multiwoz21" dataset_name = task_name if task_name != "nlg" else "multiwoz21"
for shot in tqdm(args.shots, desc='shots'): for shot in tqdm(args.shots, desc='shots', leave=False):
for output_dir in tqdm(args.output_dirs, desc='models'): for output_dir in tqdm(args.output_dirs, desc='models', leave=False):
model_name = output_dir.split('/')[-1] model_name = output_dir.split('/')[-1]
if task_name == "wow": if task_name == "wow":
test_splits = ["_seen", "_unseen"] test_splits = ["_seen", "_unseen"]
...@@ -63,9 +63,16 @@ if __name__ == '__main__': ...@@ -63,9 +63,16 @@ if __name__ == '__main__':
test_splits = [""] test_splits = [""]
for test_split in test_splits: for test_split in test_splits:
results = [] results = []
for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders'): for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False):
result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}")
result_file = os.path.join(result_dir, "result.json")
if not os.path.exists(result_file):
filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}/generated_predictions.json") filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}/generated_predictions.json")
results.append(evaluate(filename, metric)) result = evaluate(filename, metric)
json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
else:
result = json.load(open(result_file))
results.append(result)
res = { res = {
"dataset": f"{task_name}{test_split}-{shot}shot", "dataset": f"{task_name}{test_split}-{shot}shot",
"model": f"{model_name}", "model": f"{model_name}",
...@@ -74,5 +81,5 @@ if __name__ == '__main__': ...@@ -74,5 +81,5 @@ if __name__ == '__main__':
tables.append(res) tables.append(res)
# print(res) # print(res)
res = tabulate(tables, headers='keys', tablefmt='github') res = tabulate(tables, headers='keys', tablefmt='github')
with open(f'eval_results.txt', 'a+', encoding='utf-8') as f: with open(f'eval_results.txt', 'w', encoding='utf-8') as f:
print(res, file=f) print(res, file=f)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment