Skip to content
Snippets Groups Projects
Unverified Commit d0ec9473 authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

revise xldst evaluation (#124)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example
parent 3733179c
No related branches found
No related tags found
No related merge requests found
**/*.json
from .eval_file import evaluate as eval_file
from .eval_model import evaluate as eval_model
......@@ -2,28 +2,46 @@
evaluate output file
"""
from convlab2.dst.dstc9.utils import prepare_data, eval_states
if __name__ == '__main__':
import os
import json
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
args = parser.parse_args()
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in prepare_data(args.subtask).items()
}
# json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4)
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
def evaluate(model_dir, subtask, gt):
subdir = get_subdir(subtask)
results = {}
for i in range(1, 6):
filename = f'submission{i}.json'
if not os.path.exists(filename):
filepath = os.path.join(model_dir, subdir, f'submission{i}.json')
if not os.path.exists(filepath):
continue
pred = json.load(open(filename))
results[filename] = eval_states(gt, pred)
pred = json.load(open(filepath))
results[i] = eval_states(gt, pred)
json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False)
def dump_example(subtask, split):
test_data = prepare_data(subtask, split)
gt = extract_gt(test_data)
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in gt.items():
for state in states:
for domain in state.values():
for slot in domain:
domain[slot] = ""
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
args = parser.parse_args()
subtask = args.subtask
split = args.split
dump_example(subtask, split)
test_data = prepare_data(subtask, split)
gt = extract_gt(test_data)
evaluate('example', subtask, gt)
......@@ -7,36 +7,36 @@ import json
import importlib
from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states
from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir
def evaluate(model_name, subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
module = importlib.import_module(f'{model_name}.{subdir}')
def evaluate(model_dir, subtask, test_data, gt):
subdir = get_subdir(subtask)
module = importlib.import_module(f'{model_dir}.{subdir}')
assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
model_cls = module.__getattribute__('Model')
assert issubclass(model_cls, DST), 'the model must implement DST interface'
# load weights, set eval() on default
model = model_cls()
gt = {}
pred = {}
for dialog_id, turns in prepare_data(subtask).items():
gt_dialog = []
pred_dialog = []
for dialog_id, turns in test_data.items():
model.init_session()
for sys_utt, user_utt, gt_turn in turns:
gt_dialog.append(gt_turn)
pred_dialog.append(model.update_turn(sys_utt, user_utt))
gt[dialog_id] = gt_dialog
pred[dialog_id] = pred_dialog
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
result = eval_states(gt, pred)
print(result)
json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False)
json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
args = parser.parse_args()
evaluate('example', args.subtask)
subtask = args.subtask
test_data = prepare_data(subtask, args.split)
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in test_data.items()
}
evaluate('example', subtask, test_data, gt)
*/result.json
......@@ -51,8 +51,8 @@ class ExampleModel(DST):
"出发时间": "",
"目的地": "",
"日期": "",
"到达时间": "未提及",
"出发地": "未提及",
"到达时间": "",
"出发地": "",
},
}
......
......@@ -2,21 +2,13 @@ import os
import json
import zipfile
def load_test_data(subtask):
from convlab2 import DATA_ROOT
data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
# test public data currently
# to check if this script works properly with your code when label information is
# not available, you may need to fill the missing fields yourself (with any value)
zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json'))
assert len(test_data) == 250
return test_data
def prepare_data(subtask):
test_data = load_test_data(subtask)
def prepare_data(subtask, split, data_root=DATA_ROOT):
data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
zip_filename = os.path.join(data_dir, f'{split}.json.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
data = {}
if subtask == 'multiwoz':
for dialog_id, dialog in test_data.items():
......@@ -57,6 +49,14 @@ def prepare_data(subtask):
return data
def extract_gt(test_data):
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in test_data.items()
}
return gt
def eval_states(gt, pred):
def exception(description, **kargs):
ret = {
......@@ -116,3 +116,8 @@ def eval_states(gt, pred):
# 'f1': f1,
# }
}
def get_subdir(subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
return subdir
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment