Skip to content
Snippets Groups Projects
Commit 2c8dc0d9 authored by function2's avatar function2
Browse files

udpate dstc9 eval

parent 979c7e2f
No related branches found
No related tags found
No related merge requests found
......@@ -9,16 +9,16 @@ from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_
def evaluate(model_dir, subtask, gt):
results = {}
for i in range(1, 6):
filepath = os.path.join(model_dir, f'submission{i}.json')
if not os.path.exists(filepath):
continue
pred = json.load(open(filepath))
results[i] = eval_states(gt, pred, subtask)
print(json.dumps(results, indent=4, ensure_ascii=False))
dump_result(model_dir, 'file-results.json', results)
result, errors = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4, ensure_ascii=False))
dump_result(model_dir, 'file-result.json', result)
return
raise ValueError('submission file not found')
# generate submission examples
......
......@@ -28,9 +28,10 @@ def evaluate(model_dir, subtask, test_data, gt):
pred[dialog_id].append(model.update_turn(sys_utt, user_utt))
bar.update()
bar.close()
result = eval_states(gt, pred, subtask)
result, errors = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4))
dump_result(model_dir, 'model-result.json', result)
dump_result(model_dir, 'model-result.json', result, errors, pred)
if __name__ == '__main__':
......
......@@ -83,8 +83,6 @@ def unify_value(value, subtask):
return ' '.join(value.strip().split())
return ' '.join(value.strip().split())
def eval_states(gt, pred, subtask):
def exception(description, **kargs):
......@@ -94,7 +92,8 @@ def eval_states(gt, pred, subtask):
}
for k, v in kargs.items():
ret[k] = v
return ret
return ret, None
errors = []
joint_acc, joint_tot = 0, 0
slot_acc, slot_tot = 0, 0
......@@ -121,11 +120,13 @@ def eval_states(gt, pred, subtask):
gt_value = unify_value(gt_value, subtask)
pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1
if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
slot_acc += 1
if gt_value:
tp += 1
else:
errors.append([gt_value, pred_value])
turn_result = False
if gt_value:
fn += 1
......@@ -145,10 +146,17 @@ def eval_states(gt, pred, subtask):
'recall': recall,
'f1': f1,
}
}
}, errors
def dump_result(model_dir, filename, result):
def dump_result(model_dir, filename, result, errors=None, pred=None):
output_dir = os.path.join('../results', model_dir)
os.makedirs(output_dir, exist_ok=True)
json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False)
if errors:
import csv
with open(os.path.join(output_dir, 'errors.csv'), 'w') as f:
writer = csv.writer(f)
writer.writerows(errors)
if pred:
json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment