diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index d592b27c8de844338b1634bd26abcb6f079f37df..50784fe125bf0a72c954448539b12608a5ce07dc 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -5,7 +5,7 @@ import json import os -from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir +from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir, dump_result def evaluate(model_dir, subtask, gt): @@ -18,7 +18,7 @@ def evaluate(model_dir, subtask, gt): results[i] = eval_states(gt, pred, subtask) print(json.dumps(results, indent=4, ensure_ascii=False)) - json.dump(results, open(os.path.join(model_dir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) + dump_result(model_dir, 'file-results.json', results) # generate submission examples diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index eaba8e6ac46ff711558a57ba8d92916ff1fb79d9..bad4ee79a4ff284efe5a5613b00d59938658a8ce 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -7,7 +7,7 @@ import json import importlib from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states +from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result def evaluate(model_dir, subtask, test_data, gt): @@ -23,7 +23,7 @@ def evaluate(model_dir, subtask, test_data, gt): pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] result = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) - json.dump(result, open(os.path.join(model_dir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) + dump_result(model_dir, 'model-result.json', result) if __name__ == '__main__': diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 0987cbe06bf1abfa0edc03d728b450c799369c91..ad10cf5a5e4450dddbc0dacdeecaed0be7b24767 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -144,3 +144,9 @@ def eval_states(gt, pred, subtask): 'f1': f1, } } + + +def dump_result(model_dir, filename, result): + output_dir = os.path.join('../results', model_dir) + os.makedirs(output_dir, exist_ok=True) + json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False) diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f96d680e04cd73cdec1fcc01a6a379d590ef90de --- /dev/null +++ b/data/.gitignore @@ -0,0 +1 @@ +*/dstc9*.json.zip