Skip to content
Snippets Groups Projects
Commit b91364f8 authored by function2's avatar function2
Browse files

dstc9 xldst evaluation

parent dff824fc
Branches
No related tags found
No related merge requests found
"""Dialog State Tracker Interface"""
from convlab2.util.module import Module
import copy
from abc import abstractmethod
from convlab2.util.module import Module
class DST(Module):
......@@ -18,6 +20,21 @@ class DST(Module):
"""
pass
@abstractmethod
def update_turn(self, sys_utt, user_utt):
""" Update the internal dialog state variable with .
Args:
sys_utt (str):
system utterance of current turn, set to `None` for the first turn
user_utt (str):
user utterance of current turn
Returns:
new_state (dict):
Updated dialog state, with the same form of previous state.
"""
pass
def to_cache(self, *args, **kwargs):
return copy.deepcopy(self.state)
......
"""
evaluate output file
"""
from convlab2.dst.dstc9.utils import prepare_data, eval_states
if __name__ == '__main__':
import os
import json
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
args = parser.parse_args()
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in prepare_data(args.subtask).items()
}
# json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4)
results = {}
for i in range(1, 6):
filename = f'submission{i}.json'
if not os.path.exists(filename):
continue
pred = json.load(open(filename))
results[filename] = eval_states(gt, pred)
json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False)
"""
evaluate DST model
"""
import os
import json
import importlib
from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states
def evaluate(model_name, subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
module = importlib.import_module(f'{model_name}.{subdir}')
assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
model_cls = module.__getattribute__('Model')
assert issubclass(model_cls, DST), 'the model must implement DST interface'
# load weights, set eval() on default
model = model_cls()
gt = {}
pred = {}
for dialog_id, turns in prepare_data(subtask).items():
gt_dialog = []
pred_dialog = []
model.init_session()
for sys_utt, user_utt, gt_turn in turns:
gt_dialog.append(gt_turn)
pred_dialog.append(model.update_turn(sys_utt, user_utt))
gt[dialog_id] = gt_dialog
pred[dialog_id] = pred_dialog
result = eval_states(gt, pred)
print(result)
json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
args = parser.parse_args()
evaluate('example', args.subtask)
from .model import ExampleModel as Model
from convlab2.dst import DST
class ExampleModel(DST):
def update_turn(self, sys_utt, user_utt):
return {
"Attraction": {
"name": "",
"fee": "",
"duration": "",
"rating": "",
"nearby attract.": "",
"nearby rest.": "",
"nearby hotels": ""
},
"Restaurant": {
"name": "",
"dishes": "",
"cost": "",
"rating": "",
"nearby attract.": "",
"nearby rest.": "",
"nearby hotels": ""
},
"Hotel": {
"name": "",
"type": "",
"Hotel Facilities": "",
"price": "",
"rating": "",
"nearby attract.": "",
"nearby rest.": "",
"nearby hotels": ""
},
"Metro": {
"from": "",
"to": ""
},
"Taxi": {
"from": "",
"to": ""
}
}
from .model import ExampleModel as Model
from convlab2.dst import DST
class ExampleModel(DST):
def update_turn(self, sys_utt, user_utt):
return {
"出租车": {
"出发时间": "",
"目的地": "",
"出发地": "",
"到达时间": "",
},
"餐厅": {
"时间": "",
"日期": "",
"人数": "",
"食物": "",
"价格范围": "",
"名称": "",
"区域": "",
},
"公共汽车": {
"人数": "",
"出发时间": "",
"目的地": "",
"日期": "",
"到达时间": "",
"出发地": "",
},
"旅馆": {
"停留天数": "",
"日期": "",
"人数": "",
"名称": "",
"区域": "",
"停车处": "",
"价格范围": "",
"星级": "",
"互联网": "",
"类型": "",
},
"景点": {
"类型": "",
"名称": "",
"区域": "",
},
"列车": {
"票价": "",
"人数": "",
"出发时间": "",
"目的地": "",
"日期": "",
"到达时间": "未提及",
"出发地": "未提及",
},
}
import os
import json
import zipfile
def load_test_data(subtask):
from convlab2 import DATA_ROOT
data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
# the filename will change to test file during testing phase
zip_filename = os.path.join(data_dir, 'human_val.json.zip')
return json.load(zipfile.ZipFile(zip_filename).open('human_val.json'))
def prepare_data(subtask):
test_data = load_test_data(subtask)
data = {}
if subtask == 'multiwoz':
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['log']
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['text'] if i else None
user_utt = turns[i]['text']
state = {}
for domain_name, domain in turns[i + 1]['metadata'].items():
if domain_name in ['警察机关', '医院']:
continue
domain_state = {}
for slots in domain.values():
for slot_name, value in slots.items():
domain_state[slot_name] = value
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data
else:
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['messages']
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content']
state = {}
for domain_name, domain in turns[i + 1]['sys_state_init'].items():
domain_state = {}
for slot_name, value in domain.items():
if slot_name == 'selectedResults':
continue
domain_state[slot_name] = value
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data
return data
def eval_states(gt, pred):
def exception(description, **kargs):
ret = {
'status': 'exception',
'description': description,
}
for k, v in kargs.items():
ret[k] = v
return ret
joint_acc, joint_tot = 0, 0
slot_acc, slot_tot = 0, 0
tp, fp, fn = 0, 0, 0
for dialog_id, gt_states in gt.items():
if dialog_id not in pred:
return exception('some dialog not found', dialog_id=dialog_id)
pred_states = pred[dialog_id]
if len(gt_states) != len(pred_states):
return exception(f'turns number incorrect, {len(gt_states)} expected, {len(pred_states)} found', dialog_id=dialog_id)
for turn_id, (gt_state, pred_state) in enumerate(zip(gt_states, pred_states)):
joint_tot += 1
turn_result = True
for domain_name, gt_domain in gt_state.items():
if domain_name not in pred_state:
return exception('domain missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name)
pred_domain = pred_state[domain_name]
for slot_name, gt_value in gt_domain.items():
if slot_name not in pred_domain:
return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
pred_value = pred_domain[slot_name]
slot_tot += 1
if gt_value == pred_value:
slot_acc += 1
tp += 1
else:
turn_result = False
# for class of gt_value
fn += 1
# for class of pred_value
fp += 1
joint_acc += turn_result
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * tp / (2 * tp + fp + fn)
return {
'status': 'ok',
'joint accuracy': joint_acc / joint_tot,
'slot': {
'accuracy': slot_acc / slot_tot,
'precision': precision,
'recall': recall,
'f1': f1,
}
}
import random
from argparse import ArgumentParser
import numpy as np
import torch
parser = ArgumentParser()
parser.add_argument('--seed', type=int, default=23333)
parser.add_argument('--subtask', type=str, required=True, choices=['multiwoz', 'crosswoz'])
args = parser.parse_args()
# make your model's behavior deterministic
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f'fix random seed: {seed}')
subtask = args.subtask
if __name__ == '__main__':
pass
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment