From 22cc51f4515dde684f2255dd6dcf46463577180c Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Sun, 18 Oct 2020 11:08:01 +0800 Subject: [PATCH] update model evaluation --- convlab2/dst/dstc9/eval_model.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 3a3e081..c2013c5 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -9,7 +9,7 @@ import importlib from tqdm import tqdm from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result +from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result, extract_gt def evaluate(model_dir, subtask, test_data, gt): @@ -34,16 +34,28 @@ def evaluate(model_dir, subtask, test_data, gt): dump_result(model_dir, 'model-result.json', result, errors, pred) +def eval_team(team): + for subtask in ['multiwoz', 'crosswoz']: + test_data = prepare_data(subtask, 'dstc9') + gt = extract_gt(test_data) + for i in range(1, 6): + model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}') + if not os.path.exists(model_dir): + continue + print(model_dir) + evaluate(model_dir, subtask, test_data, gt) + + if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + parser.add_argument('--teams', type=str, nargs='*') args = parser.parse_args() - subtask = args.subtask - test_data = prepare_data(subtask, args.split) - gt = { - dialog_id: [state for _, _, state in turns] - for dialog_id, turns in test_data.items() - } - evaluate('example', subtask, test_data, gt) + if not args.teams: + for team in os.listdir('.'): + if not os.path.isdir(team): + continue + eval_team(team) + else: + for team in args.teams: + eval_team(team) \ No newline at end of file -- GitLab