Skip to content
Snippets Groups Projects
Commit 0e0323a0 authored by function2's avatar function2
Browse files

revise evaluation

parent 9102053a
No related branches found
No related tags found
No related merge requests found
**/*.json
from .eval_file import evaluate as eval_file
from .eval_model import evaluate as eval_model
...@@ -2,7 +2,32 @@ ...@@ -2,7 +2,32 @@
evaluate output file evaluate output file
""" """
from convlab2.dst.dstc9.utils import prepare_data, eval_states from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
def evaluate(model_dir, subtask, gt):
subdir = get_subdir(subtask)
results = {}
for i in range(1, 6):
filepath = os.path.join(model_dir, subdir, f'submission{i}.json')
if not os.path.exists(filepath):
continue
pred = json.load(open(filepath))
results[i] = eval_states(gt, pred)
json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
def dump_example(test_data):
gt = extract_gt(test_data)
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in gt.items():
for state in states:
for domain in state.values():
for slot in domain:
domain[slot] = ""
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
if __name__ == '__main__': if __name__ == '__main__':
import os import os
...@@ -10,20 +35,10 @@ if __name__ == '__main__': ...@@ -10,20 +35,10 @@ if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
args = parser.parse_args() args = parser.parse_args()
subtask = args.subtask
gt = { test_data = prepare_data(subtask, args.split)
dialog_id: [state for _, _, state in turns] dump_example(test_data)
for dialog_id, turns in prepare_data(args.subtask).items() gt = extract_gt(test_data)
} evaluate('example', subtask, gt)
# json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4)
results = {}
for i in range(1, 6):
filename = f'submission{i}.json'
if not os.path.exists(filename):
continue
pred = json.load(open(filename))
results[filename] = eval_states(gt, pred)
json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False)
...@@ -7,36 +7,37 @@ import json ...@@ -7,36 +7,37 @@ import json
import importlib import importlib
from convlab2.dst import DST from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir
def evaluate(model_name, subtask): def evaluate(model_dir, subtask, test_data, gt):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' subdir = get_subdir(subtask)
module = importlib.import_module(f'{model_name}.{subdir}') module = importlib.import_module(f'{model_dir}.{subdir}')
assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root' assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
model_cls = module.__getattribute__('Model') model_cls = module.__getattribute__('Model')
assert issubclass(model_cls, DST), 'the model must implement DST interface' assert issubclass(model_cls, DST), 'the model must implement DST interface'
# load weights, set eval() on default # load weights, set eval() on default
model = model_cls() model = model_cls()
gt = {}
pred = {} pred = {}
for dialog_id, turns in prepare_data(subtask).items(): for dialog_id, turns in test_data.items():
gt_dialog = []
pred_dialog = []
model.init_session() model.init_session()
for sys_utt, user_utt, gt_turn in turns: pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
gt_dialog.append(gt_turn)
pred_dialog.append(model.update_turn(sys_utt, user_utt))
gt[dialog_id] = gt_dialog
pred[dialog_id] = pred_dialog
result = eval_states(gt, pred) result = eval_states(gt, pred)
print(result) print(result)
json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False) json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
if __name__ == '__main__': if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
args = parser.parse_args() args = parser.parse_args()
evaluate('example', args.subtask) subtask = args.subtask
test_data = prepare_data(subtask, args.split)
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in test_data.items()
}
evaluate('example', subtask, test_data, gt)
*/result.json
...@@ -51,8 +51,8 @@ class ExampleModel(DST): ...@@ -51,8 +51,8 @@ class ExampleModel(DST):
"出发时间": "", "出发时间": "",
"目的地": "", "目的地": "",
"日期": "", "日期": "",
"到达时间": "未提及", "到达时间": "",
"出发地": "未提及", "出发地": "",
}, },
} }
......
...@@ -2,21 +2,13 @@ import os ...@@ -2,21 +2,13 @@ import os
import json import json
import zipfile import zipfile
def load_test_data(subtask):
from convlab2 import DATA_ROOT from convlab2 import DATA_ROOT
data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
# test public data currently
# to check if this script works properly with your code when label information is
# not available, you may need to fill the missing fields yourself (with any value)
zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json'))
assert len(test_data) == 250
return test_data
def prepare_data(subtask): def prepare_data(subtask, split, data_root=DATA_ROOT):
test_data = load_test_data(subtask) data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
zip_filename = os.path.join(data_dir, f'{split}.json.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
data = {} data = {}
if subtask == 'multiwoz': if subtask == 'multiwoz':
for dialog_id, dialog in test_data.items(): for dialog_id, dialog in test_data.items():
...@@ -57,6 +49,14 @@ def prepare_data(subtask): ...@@ -57,6 +49,14 @@ def prepare_data(subtask):
return data return data
def extract_gt(test_data):
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in test_data.items()
}
return gt
def eval_states(gt, pred): def eval_states(gt, pred):
def exception(description, **kargs): def exception(description, **kargs):
ret = { ret = {
...@@ -116,3 +116,8 @@ def eval_states(gt, pred): ...@@ -116,3 +116,8 @@ def eval_states(gt, pred):
# 'f1': f1, # 'f1': f1,
# } # }
} }
def get_subdir(subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
return subdir
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment