diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index 4b3e1b26106c318bd139a401d2fbeeba13bbe481..fe33b0244895a66f464dfe73822d6560c984bbd7 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -5,28 +5,28 @@ import json import os -from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir +from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir, dump_result def evaluate(model_dir, subtask, gt): - subdir = get_subdir(subtask) - results = {} for i in range(1, 6): - filepath = os.path.join(model_dir, subdir, f'submission{i}.json') + filepath = os.path.join(model_dir, f'submission{i}.json') if not os.path.exists(filepath): continue pred = json.load(open(filepath)) - results[i] = eval_states(gt, pred, subtask) - - print(json.dumps(results, indent=4)) - json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) + result, errors = eval_states(gt, pred, subtask) + print(json.dumps(result, indent=4, ensure_ascii=False)) + dump_result(model_dir, 'file-result.json', result) + return + raise ValueError('submission file not found') # generate submission examples def dump_example(subtask, split): test_data = prepare_data(subtask, split) pred = extract_gt(test_data) - json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) + subdir = get_subdir(subtask) + json.dump(pred, open(os.path.join('example', subdir, 'submission1.json'), 'w'), ensure_ascii=False, indent=4) import random for dialog_id, states in pred.items(): for state in states: @@ -38,13 +38,13 @@ def dump_example(subtask, split): else: if random.randint(0, 4) == 0: domain[slot] = "2333" - json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) + json.dump(pred, open(os.path.join('example', subdir, 'submission2.json'), 'w'), ensure_ascii=False, indent=4) for dialog_id, states in pred.items(): for state in states: for domain in state.values(): for slot in domain: domain[slot] = "" - json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4) + json.dump(pred, open(os.path.join('example', subdir, 'submission3.json'), 'w'), ensure_ascii=False, indent=4) if __name__ == '__main__': @@ -58,4 +58,3 @@ if __name__ == '__main__': dump_example(subtask, split) test_data = prepare_data(subtask, split) gt = extract_gt(test_data) - evaluate('example', subtask, gt) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 377ecdd390c37111974ad6043545e3f19ad5e6bd..3a3e0814266947728981d12be38b49ae7b85940b 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -6,25 +6,32 @@ import os import json import importlib +from tqdm import tqdm + from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir +from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result def evaluate(model_dir, subtask, test_data, gt): - subdir = get_subdir(subtask) - module = importlib.import_module(f'{model_dir}.{subdir}') + module = importlib.import_module(model_dir.replace('/', '.')) assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root' - model_cls = module.__getattribute__('Model') + model_cls = getattr(module, 'Model') assert issubclass(model_cls, DST), 'the model must implement DST interface' # load weights, set eval() on default model = model_cls() pred = {} + bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating') for dialog_id, turns in test_data.items(): model.init_session() - pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] - result = eval_states(gt, pred, subtask) + pred[dialog_id] = [] + for sys_utt, user_utt, gt_turn in turns: + pred[dialog_id].append(model.update_turn(sys_utt, user_utt)) + bar.update() + bar.close() + + result, errors = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) - json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) + dump_result(model_dir, 'model-result.json', result, errors, pred) if __name__ == '__main__': diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index b68562317e9b85f532e496ddb7fb087dafcea4ee..8a37f4bf9b1f50c173bcdf48a64997a3d8207064 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -25,7 +25,7 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): user_utt = turns[i]['text'] state = {} for domain_name, domain in turns[i + 1]['metadata'].items(): - if domain_name in ['警察机关', '医院']: + if domain_name in ['警察机关', '医院', '公共汽车']: continue domain_state = {} for slots in domain.values(): @@ -79,7 +79,7 @@ def unify_value(value, subtask): 'crosswoz': { 'None': '', } - }[subtask].get(value, value) + }[subtask].get(value, value).lower() return ' '.join(value.strip().split()) @@ -92,7 +92,8 @@ def eval_states(gt, pred, subtask): } for k, v in kargs.items(): ret[k] = v - return ret + return ret, None + errors = [] joint_acc, joint_tot = 0, 0 slot_acc, slot_tot = 0, 0 @@ -119,11 +120,13 @@ def eval_states(gt, pred, subtask): gt_value = unify_value(gt_value, subtask) pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1 else: + errors.append([gt_value, pred_value]) turn_result = False if gt_value: fn += 1 @@ -143,4 +146,17 @@ def eval_states(gt, pred, subtask): 'recall': recall, 'f1': f1, } - } + }, errors + + +def dump_result(model_dir, filename, result, errors=None, pred=None): + output_dir = os.path.join('../results', model_dir) + os.makedirs(output_dir, exist_ok=True) + json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False) + if errors: + import csv + with open(os.path.join(output_dir, 'errors.csv'), 'w') as f: + writer = csv.writer(f) + writer.writerows(errors) + if pred: + json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False) diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f96d680e04cd73cdec1fcc01a6a379d590ef90de --- /dev/null +++ b/data/.gitignore @@ -0,0 +1 @@ +*/dstc9*.json.zip