Skip to content
Snippets Groups Projects
Commit 1e1d1f9c authored by function2's avatar function2
Browse files

dump dst eval results

parent c24d5ecb
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import json import json
import os import os
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir, dump_result
def evaluate(model_dir, subtask, gt): def evaluate(model_dir, subtask, gt):
...@@ -18,7 +18,7 @@ def evaluate(model_dir, subtask, gt): ...@@ -18,7 +18,7 @@ def evaluate(model_dir, subtask, gt):
results[i] = eval_states(gt, pred, subtask) results[i] = eval_states(gt, pred, subtask)
print(json.dumps(results, indent=4, ensure_ascii=False)) print(json.dumps(results, indent=4, ensure_ascii=False))
json.dump(results, open(os.path.join(model_dir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) dump_result(model_dir, 'file-results.json', results)
# generate submission examples # generate submission examples
......
...@@ -7,7 +7,7 @@ import json ...@@ -7,7 +7,7 @@ import json
import importlib import importlib
from convlab2.dst import DST from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result
def evaluate(model_dir, subtask, test_data, gt): def evaluate(model_dir, subtask, test_data, gt):
...@@ -23,7 +23,7 @@ def evaluate(model_dir, subtask, test_data, gt): ...@@ -23,7 +23,7 @@ def evaluate(model_dir, subtask, test_data, gt):
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
result = eval_states(gt, pred, subtask) result = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4)) print(json.dumps(result, indent=4))
json.dump(result, open(os.path.join(model_dir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) dump_result(model_dir, 'model-result.json', result)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -144,3 +144,9 @@ def eval_states(gt, pred, subtask): ...@@ -144,3 +144,9 @@ def eval_states(gt, pred, subtask):
'f1': f1, 'f1': f1,
} }
} }
def dump_result(model_dir, filename, result):
output_dir = os.path.join('../results', model_dir)
os.makedirs(output_dir, exist_ok=True)
json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False)
*/dstc9*.json.zip
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment