Skip to content
Snippets Groups Projects
Unverified Commit 3733179c authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

dstc9 xldst evaluation (#122)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data
parent a48cff97
No related branches found
No related tags found
No related merge requests found
"""Dialog State Tracker Interface"""
from convlab2.util.module import Module
import copy
from abc import abstractmethod
from convlab2.util.module import Module
class DST(Module):
......@@ -18,6 +20,21 @@ class DST(Module):
"""
pass
@abstractmethod
def update_turn(self, sys_utt, user_utt):
""" Update the internal dialog state variable with .
Args:
sys_utt (str):
system utterance of current turn, set to `None` for the first turn
user_utt (str):
user utterance of current turn
Returns:
new_state (dict):
Updated dialog state, with the same form of previous state.
"""
pass
def to_cache(self, *args, **kwargs):
return copy.deepcopy(self.state)
......
"""
evaluate output file
"""
from convlab2.dst.dstc9.utils import prepare_data, eval_states
if __name__ == '__main__':
import os
import json
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
args = parser.parse_args()
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in prepare_data(args.subtask).items()
}
# json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4)
results = {}
for i in range(1, 6):
filename = f'submission{i}.json'
if not os.path.exists(filename):
continue
pred = json.load(open(filename))
results[filename] = eval_states(gt, pred)
json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False)
"""
evaluate DST model
"""
import os
import json
import importlib
from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states
def evaluate(model_name, subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
module = importlib.import_module(f'{model_name}.{subdir}')
assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
model_cls = module.__getattribute__('Model')
assert issubclass(model_cls, DST), 'the model must implement DST interface'
# load weights, set eval() on default
model = model_cls()
gt = {}
pred = {}
for dialog_id, turns in prepare_data(subtask).items():
gt_dialog = []
pred_dialog = []
model.init_session()
for sys_utt, user_utt, gt_turn in turns:
gt_dialog.append(gt_turn)
pred_dialog.append(model.update_turn(sys_utt, user_utt))
gt[dialog_id] = gt_dialog
pred[dialog_id] = pred_dialog
result = eval_states(gt, pred)
print(result)
json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
args = parser.parse_args()
evaluate('example', args.subtask)
*/result.json
from .model import ExampleModel as Model
from convlab2.dst import DST
class ExampleModel(DST):
def init_session(self):
self.history = []
self.state = {
"Attraction": {
"name": "",
"fee": "",
"duration": "",
"rating": "",
"nearby attract.": "",
"nearby rest.": "",
"nearby hotels": ""
},
"Restaurant": {
"name": "",
"dishes": "",
"cost": "",
"rating": "",
"nearby attract.": "",
"nearby rest.": "",
"nearby hotels": ""
},
"Hotel": {
"name": "",
"type": "",
"Hotel Facilities": "",
"price": "",
"rating": "",
"nearby attract.": "",
"nearby rest.": "",
"nearby hotels": ""
},
"Metro": {
"from": "",
"to": ""
},
"Taxi": {
"from": "",
"to": ""
}
}
def update_turn(self, sys_utt, user_utt):
if sys_utt is not None:
self.history.append(sys_utt)
self.history.append(user_utt)
# model can do some modification to state here
return self.state
from .model import ExampleModel as Model
from convlab2.dst import DST
class ExampleModel(DST):
def init_session(self):
self.history = []
self.state = {
"出租车": {
"出发时间": "",
"目的地": "",
"出发地": "",
"到达时间": "",
},
"餐厅": {
"时间": "",
"日期": "",
"人数": "",
"食物": "",
"价格范围": "",
"名称": "",
"区域": "",
},
"公共汽车": {
"人数": "",
"出发时间": "",
"目的地": "",
"日期": "",
"到达时间": "",
"出发地": "",
},
"旅馆": {
"停留天数": "",
"日期": "",
"人数": "",
"名称": "",
"区域": "",
"停车处": "",
"价格范围": "",
"星级": "",
"互联网": "",
"类型": "",
},
"景点": {
"类型": "",
"名称": "",
"区域": "",
},
"列车": {
"票价": "",
"人数": "",
"出发时间": "",
"目的地": "",
"日期": "",
"到达时间": "未提及",
"出发地": "未提及",
},
}
def update_turn(self, sys_utt, user_utt):
if sys_utt is not None:
self.history.append(sys_utt)
self.history.append(user_utt)
# model can do some modification to state here
return self.state
import os
import json
import zipfile
def load_test_data(subtask):
from convlab2 import DATA_ROOT
data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
# test public data currently
# to check if this script works properly with your code when label information is
# not available, you may need to fill the missing fields yourself (with any value)
zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json'))
assert len(test_data) == 250
return test_data
def prepare_data(subtask):
test_data = load_test_data(subtask)
data = {}
if subtask == 'multiwoz':
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['log']
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['text'] if i else None
user_utt = turns[i]['text']
state = {}
for domain_name, domain in turns[i + 1]['metadata'].items():
if domain_name in ['警察机关', '医院']:
continue
domain_state = {}
for slots in domain.values():
for slot_name, value in slots.items():
domain_state[slot_name] = value
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data
else:
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['messages']
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content']
state = {}
for domain_name, domain in turns[i + 1]['sys_state_init'].items():
domain_state = {}
for slot_name, value in domain.items():
if slot_name == 'selectedResults':
continue
domain_state[slot_name] = value
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data
return data
def eval_states(gt, pred):
def exception(description, **kargs):
ret = {
'status': 'exception',
'description': description,
}
for k, v in kargs.items():
ret[k] = v
return ret
joint_acc, joint_tot = 0, 0
slot_acc, slot_tot = 0, 0
tp, fp, fn = 0, 0, 0
for dialog_id, gt_states in gt.items():
if dialog_id not in pred:
return exception('some dialog not found', dialog_id=dialog_id)
pred_states = pred[dialog_id]
if len(gt_states) != len(pred_states):
return exception(f'turns number incorrect, {len(gt_states)} expected, {len(pred_states)} found', dialog_id=dialog_id)
for turn_id, (gt_state, pred_state) in enumerate(zip(gt_states, pred_states)):
joint_tot += 1
turn_result = True
for domain_name, gt_domain in gt_state.items():
if domain_name not in pred_state:
return exception('domain missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name)
pred_domain = pred_state[domain_name]
for slot_name, gt_value in gt_domain.items():
if slot_name not in pred_domain:
return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
pred_value = pred_domain[slot_name]
slot_tot += 1
if gt_value == pred_value:
slot_acc += 1
tp += 1
else:
turn_result = False
# for class of gt_value
fn += 1
# for class of pred_value
fp += 1
joint_acc += turn_result
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * tp / (2 * tp + fp + fn)
return {
'status': 'ok',
'joint accuracy': joint_acc / joint_tot,
'slot accuracy': slot_acc / slot_tot,
# 'slot': {
# 'accuracy': slot_acc / slot_tot,
# 'precision': precision,
# 'recall': recall,
# 'f1': f1,
# }
}
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment