diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index 50784fe125bf0a72c954448539b12608a5ce07dc..fe33b0244895a66f464dfe73822d6560c984bbd7 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -9,16 +9,16 @@ from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_ def evaluate(model_dir, subtask, gt): - results = {} for i in range(1, 6): filepath = os.path.join(model_dir, f'submission{i}.json') if not os.path.exists(filepath): continue pred = json.load(open(filepath)) - results[i] = eval_states(gt, pred, subtask) - - print(json.dumps(results, indent=4, ensure_ascii=False)) - dump_result(model_dir, 'file-results.json', results) + result, errors = eval_states(gt, pred, subtask) + print(json.dumps(result, indent=4, ensure_ascii=False)) + dump_result(model_dir, 'file-result.json', result) + return + raise ValueError('submission file not found') # generate submission examples diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 207eac70e6f28e4b9bcab0e48cf284afba9a8e4a..3a3e0814266947728981d12be38b49ae7b85940b 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -28,9 +28,10 @@ def evaluate(model_dir, subtask, test_data, gt): pred[dialog_id].append(model.update_turn(sys_utt, user_utt)) bar.update() bar.close() - result = eval_states(gt, pred, subtask) + + result, errors = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) - dump_result(model_dir, 'model-result.json', result) + dump_result(model_dir, 'model-result.json', result, errors, pred) if __name__ == '__main__': diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 4e7d967d9d745052570f20e10a642fe3824e59a7..8a37f4bf9b1f50c173bcdf48a64997a3d8207064 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -83,8 +83,6 @@ def unify_value(value, subtask): return ' '.join(value.strip().split()) - return ' '.join(value.strip().split()) - def eval_states(gt, pred, subtask): def exception(description, **kargs): @@ -94,7 +92,8 @@ def eval_states(gt, pred, subtask): } for k, v in kargs.items(): ret[k] = v - return ret + return ret, None + errors = [] joint_acc, joint_tot = 0, 0 slot_acc, slot_tot = 0, 0 @@ -121,11 +120,13 @@ def eval_states(gt, pred, subtask): gt_value = unify_value(gt_value, subtask) pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1 else: + errors.append([gt_value, pred_value]) turn_result = False if gt_value: fn += 1 @@ -145,10 +146,17 @@ def eval_states(gt, pred, subtask): 'recall': recall, 'f1': f1, } - } + }, errors -def dump_result(model_dir, filename, result): +def dump_result(model_dir, filename, result, errors=None, pred=None): output_dir = os.path.join('../results', model_dir) os.makedirs(output_dir, exist_ok=True) json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False) + if errors: + import csv + with open(os.path.join(output_dir, 'errors.csv'), 'w') as f: + writer = csv.writer(f) + writer.writerows(errors) + if pred: + json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False)