diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 3a3e0814266947728981d12be38b49ae7b85940b..c2013c566d5cb0de7248b5488d375cec2fa381fa 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -9,7 +9,7 @@ import importlib from tqdm import tqdm from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result +from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result, extract_gt def evaluate(model_dir, subtask, test_data, gt): @@ -34,16 +34,28 @@ def evaluate(model_dir, subtask, test_data, gt): dump_result(model_dir, 'model-result.json', result, errors, pred) +def eval_team(team): + for subtask in ['multiwoz', 'crosswoz']: + test_data = prepare_data(subtask, 'dstc9') + gt = extract_gt(test_data) + for i in range(1, 6): + model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}') + if not os.path.exists(model_dir): + continue + print(model_dir) + evaluate(model_dir, subtask, test_data, gt) + + if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + parser.add_argument('--teams', type=str, nargs='*') args = parser.parse_args() - subtask = args.subtask - test_data = prepare_data(subtask, args.split) - gt = { - dialog_id: [state for _, _, state in turns] - for dialog_id, turns in test_data.items() - } - evaluate('example', subtask, test_data, gt) + if not args.teams: + for team in os.listdir('.'): + if not os.path.isdir(team): + continue + eval_team(team) + else: + for team in args.teams: + eval_team(team) \ No newline at end of file