From 2c8dc0d9a869cb7c38edab4c5d25906a6211b362 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Sat, 17 Oct 2020 16:08:57 +0800 Subject: [PATCH] udpate dstc9 eval --- convlab2/dst/dstc9/eval_file.py | 10 +++++----- convlab2/dst/dstc9/eval_model.py | 5 +++-- convlab2/dst/dstc9/utils.py | 18 +++++++++++++----- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index 50784fe..fe33b02 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -9,16 +9,16 @@ from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_ def evaluate(model_dir, subtask, gt): - results = {} for i in range(1, 6): filepath = os.path.join(model_dir, f'submission{i}.json') if not os.path.exists(filepath): continue pred = json.load(open(filepath)) - results[i] = eval_states(gt, pred, subtask) - - print(json.dumps(results, indent=4, ensure_ascii=False)) - dump_result(model_dir, 'file-results.json', results) + result, errors = eval_states(gt, pred, subtask) + print(json.dumps(result, indent=4, ensure_ascii=False)) + dump_result(model_dir, 'file-result.json', result) + return + raise ValueError('submission file not found') # generate submission examples diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 207eac7..3a3e081 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -28,9 +28,10 @@ def evaluate(model_dir, subtask, test_data, gt): pred[dialog_id].append(model.update_turn(sys_utt, user_utt)) bar.update() bar.close() - result = eval_states(gt, pred, subtask) + + result, errors = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) - dump_result(model_dir, 'model-result.json', result) + dump_result(model_dir, 'model-result.json', result, errors, pred) if __name__ == '__main__': diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 4e7d967..8a37f4b 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -83,8 +83,6 @@ def unify_value(value, subtask): return ' '.join(value.strip().split()) - return ' '.join(value.strip().split()) - def eval_states(gt, pred, subtask): def exception(description, **kargs): @@ -94,7 +92,8 @@ def eval_states(gt, pred, subtask): } for k, v in kargs.items(): ret[k] = v - return ret + return ret, None + errors = [] joint_acc, joint_tot = 0, 0 slot_acc, slot_tot = 0, 0 @@ -121,11 +120,13 @@ def eval_states(gt, pred, subtask): gt_value = unify_value(gt_value, subtask) pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1 else: + errors.append([gt_value, pred_value]) turn_result = False if gt_value: fn += 1 @@ -145,10 +146,17 @@ def eval_states(gt, pred, subtask): 'recall': recall, 'f1': f1, } - } + }, errors -def dump_result(model_dir, filename, result): +def dump_result(model_dir, filename, result, errors=None, pred=None): output_dir = os.path.join('../results', model_dir) os.makedirs(output_dir, exist_ok=True) json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False) + if errors: + import csv + with open(os.path.join(output_dir, 'errors.csv'), 'w') as f: + writer = csv.writer(f) + writer.writerows(errors) + if pred: + json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False) -- GitLab