From c24d5ecb2182b96240f12bd6189a47838d59a8c2 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Tue, 13 Oct 2020 10:44:37 +0800 Subject: [PATCH] update eval --- convlab2/dst/dstc9/eval_file.py | 15 +++++++-------- convlab2/dst/dstc9/eval_model.py | 9 ++++----- convlab2/dst/dstc9/utils.py | 2 +- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index 4b3e1b2..d592b27 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -9,24 +9,24 @@ from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_ def evaluate(model_dir, subtask, gt): - subdir = get_subdir(subtask) results = {} for i in range(1, 6): - filepath = os.path.join(model_dir, subdir, f'submission{i}.json') + filepath = os.path.join(model_dir, f'submission{i}.json') if not os.path.exists(filepath): continue pred = json.load(open(filepath)) results[i] = eval_states(gt, pred, subtask) - print(json.dumps(results, indent=4)) - json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) + print(json.dumps(results, indent=4, ensure_ascii=False)) + json.dump(results, open(os.path.join(model_dir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) # generate submission examples def dump_example(subtask, split): test_data = prepare_data(subtask, split) pred = extract_gt(test_data) - json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) + subdir = get_subdir(subtask) + json.dump(pred, open(os.path.join('example', subdir, 'submission1.json'), 'w'), ensure_ascii=False, indent=4) import random for dialog_id, states in pred.items(): for state in states: @@ -38,13 +38,13 @@ def dump_example(subtask, split): else: if random.randint(0, 4) == 0: domain[slot] = "2333" - json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) + json.dump(pred, open(os.path.join('example', subdir, 'submission2.json'), 'w'), ensure_ascii=False, indent=4) for dialog_id, states in pred.items(): for state in states: for domain in state.values(): for slot in domain: domain[slot] = "" - json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4) + json.dump(pred, open(os.path.join('example', subdir, 'submission3.json'), 'w'), ensure_ascii=False, indent=4) if __name__ == '__main__': @@ -58,4 +58,3 @@ if __name__ == '__main__': dump_example(subtask, split) test_data = prepare_data(subtask, split) gt = extract_gt(test_data) - evaluate('example', subtask, gt) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 377ecdd..eaba8e6 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -7,14 +7,13 @@ import json import importlib from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir +from convlab2.dst.dstc9.utils import prepare_data, eval_states def evaluate(model_dir, subtask, test_data, gt): - subdir = get_subdir(subtask) - module = importlib.import_module(f'{model_dir}.{subdir}') + module = importlib.import_module(model_dir.replace('/', '.')) assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root' - model_cls = module.__getattribute__('Model') + model_cls = getattr(module, 'Model') assert issubclass(model_cls, DST), 'the model must implement DST interface' # load weights, set eval() on default model = model_cls() @@ -24,7 +23,7 @@ def evaluate(model_dir, subtask, test_data, gt): pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] result = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) - json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) + json.dump(result, open(os.path.join(model_dir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) if __name__ == '__main__': diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index b685623..0987cbe 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -25,7 +25,7 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): user_utt = turns[i]['text'] state = {} for domain_name, domain in turns[i + 1]['metadata'].items(): - if domain_name in ['警察机关', '医院']: + if domain_name in ['警察机关', '医院', '公共汽车']: continue domain_state = {} for slots in domain.values(): -- GitLab