diff --git a/convlab2/dst/dst.py b/convlab2/dst/dst.py index 946f711bdb7a9f5c283749f6d5c3921be78ee09f..b2f45fdd74c0c4d6752ef9605c3066877e3c131e 100755 --- a/convlab2/dst/dst.py +++ b/convlab2/dst/dst.py @@ -1,6 +1,8 @@ """Dialog State Tracker Interface""" -from convlab2.util.module import Module import copy +from abc import abstractmethod + +from convlab2.util.module import Module class DST(Module): @@ -18,6 +20,21 @@ class DST(Module): """ pass + @abstractmethod + def update_turn(self, sys_utt, user_utt): + """ Update the internal dialog state variable with . + + Args: + sys_utt (str): + system utterance of current turn, set to `None` for the first turn + user_utt (str): + user utterance of current turn + Returns: + new_state (dict): + Updated dialog state, with the same form of previous state. + """ + pass + def to_cache(self, *args, **kwargs): return copy.deepcopy(self.state) diff --git a/convlab2/dst/dstc9/__init__.py b/convlab2/dst/dstc9/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3187d556bc9997cb1bf130f94afdafbe852ff4 --- /dev/null +++ b/convlab2/dst/dstc9/eval_file.py @@ -0,0 +1,29 @@ +""" + evaluate output file +""" + +from convlab2.dst.dstc9.utils import prepare_data, eval_states + +if __name__ == '__main__': + import os + import json + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) + args = parser.parse_args() + + gt = { + dialog_id: [state for _, _, state in turns] + for dialog_id, turns in prepare_data(args.subtask).items() + } + # json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4) + + results = {} + for i in range(1, 6): + filename = f'submission{i}.json' + if not os.path.exists(filename): + continue + pred = json.load(open(filename)) + results[filename] = eval_states(gt, pred) + + json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9b86d0577be01c4532688ec9782bc184cb0dd23b --- /dev/null +++ b/convlab2/dst/dstc9/eval_model.py @@ -0,0 +1,42 @@ +""" + evaluate DST model +""" + +import os +import json +import importlib + +from convlab2.dst import DST +from convlab2.dst.dstc9.utils import prepare_data, eval_states + + +def evaluate(model_name, subtask): + subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' + module = importlib.import_module(f'{model_name}.{subdir}') + assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root' + model_cls = module.__getattribute__('Model') + assert issubclass(model_cls, DST), 'the model must implement DST interface' + # load weights, set eval() on default + model = model_cls() + gt = {} + pred = {} + for dialog_id, turns in prepare_data(subtask).items(): + gt_dialog = [] + pred_dialog = [] + model.init_session() + for sys_utt, user_utt, gt_turn in turns: + gt_dialog.append(gt_turn) + pred_dialog.append(model.update_turn(sys_utt, user_utt)) + gt[dialog_id] = gt_dialog + pred[dialog_id] = pred_dialog + result = eval_states(gt, pred) + print(result) + json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) + args = parser.parse_args() + evaluate('example', args.subtask) diff --git a/convlab2/dst/dstc9/example/.gitignore b/convlab2/dst/dstc9/example/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4b544e7771d6b7257bf1ea3aedbf42519654abad --- /dev/null +++ b/convlab2/dst/dstc9/example/.gitignore @@ -0,0 +1 @@ +*/result.json diff --git a/convlab2/dst/dstc9/example/__init__.py b/convlab2/dst/dstc9/example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/convlab2/dst/dstc9/example/crosswoz_en/__init__.py b/convlab2/dst/dstc9/example/crosswoz_en/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30800a8ab77cda37677f4e850ba8cfd3db6c30cc --- /dev/null +++ b/convlab2/dst/dstc9/example/crosswoz_en/__init__.py @@ -0,0 +1 @@ +from .model import ExampleModel as Model diff --git a/convlab2/dst/dstc9/example/crosswoz_en/model.py b/convlab2/dst/dstc9/example/crosswoz_en/model.py new file mode 100644 index 0000000000000000000000000000000000000000..542fd4799ce168db23aaaa699fa1fa42745bff69 --- /dev/null +++ b/convlab2/dst/dstc9/example/crosswoz_en/model.py @@ -0,0 +1,51 @@ +from convlab2.dst import DST + + +class ExampleModel(DST): + def init_session(self): + self.history = [] + self.state = { + "Attraction": { + "name": "", + "fee": "", + "duration": "", + "rating": "", + "nearby attract.": "", + "nearby rest.": "", + "nearby hotels": "" + }, + "Restaurant": { + "name": "", + "dishes": "", + "cost": "", + "rating": "", + "nearby attract.": "", + "nearby rest.": "", + "nearby hotels": "" + }, + "Hotel": { + "name": "", + "type": "", + "Hotel Facilities": "", + "price": "", + "rating": "", + "nearby attract.": "", + "nearby rest.": "", + "nearby hotels": "" + }, + "Metro": { + "from": "", + "to": "" + }, + "Taxi": { + "from": "", + "to": "" + } + } + + def update_turn(self, sys_utt, user_utt): + if sys_utt is not None: + self.history.append(sys_utt) + self.history.append(user_utt) + # model can do some modification to state here + return self.state diff --git a/convlab2/dst/dstc9/example/multiwoz_zh/__init__.py b/convlab2/dst/dstc9/example/multiwoz_zh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30800a8ab77cda37677f4e850ba8cfd3db6c30cc --- /dev/null +++ b/convlab2/dst/dstc9/example/multiwoz_zh/__init__.py @@ -0,0 +1 @@ +from .model import ExampleModel as Model diff --git a/convlab2/dst/dstc9/example/multiwoz_zh/model.py b/convlab2/dst/dstc9/example/multiwoz_zh/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a81cdb34909d5ec4b29428c47281ec8c47db9c9 --- /dev/null +++ b/convlab2/dst/dstc9/example/multiwoz_zh/model.py @@ -0,0 +1,64 @@ +from convlab2.dst import DST + + +class ExampleModel(DST): + def init_session(self): + self.history = [] + self.state = { + "出租车": { + "出发时间": "", + "目的地": "", + "出发地": "", + "到达时间": "", + }, + "餐厅": { + "时间": "", + "日期": "", + "人数": "", + "食物": "", + "价格范围": "", + "名称": "", + "区域": "", + }, + "公共汽车": { + "人数": "", + "出发时间": "", + "目的地": "", + "日期": "", + "到达时间": "", + "出发地": "", + }, + "旅馆": { + "停留天数": "", + "日期": "", + "人数": "", + "名称": "", + "区域": "", + "停车处": "", + "价格范围": "", + "星级": "", + "互联网": "", + "类型": "", + }, + "景点": { + "类型": "", + "名称": "", + "区域": "", + }, + "列车": { + "票价": "", + "人数": "", + "出发时间": "", + "目的地": "", + "日期": "", + "到达时间": "未提及", + "出发地": "未提及", + }, + } + + def update_turn(self, sys_utt, user_utt): + if sys_utt is not None: + self.history.append(sys_utt) + self.history.append(user_utt) + # model can do some modification to state here + return self.state diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..75b60774d9f076067ba84e922e18875fecf4bc60 --- /dev/null +++ b/convlab2/dst/dstc9/utils.py @@ -0,0 +1,118 @@ +import os +import json +import zipfile + + +def load_test_data(subtask): + from convlab2 import DATA_ROOT + data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') + # test public data currently + # to check if this script works properly with your code when label information is + # not available, you may need to fill the missing fields yourself (with any value) + zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip') + test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json')) + assert len(test_data) == 250 + return test_data + + +def prepare_data(subtask): + test_data = load_test_data(subtask) + data = {} + if subtask == 'multiwoz': + for dialog_id, dialog in test_data.items(): + dialog_data = [] + turns = dialog['log'] + for i in range(0, len(turns), 2): + sys_utt = turns[i - 1]['text'] if i else None + user_utt = turns[i]['text'] + state = {} + for domain_name, domain in turns[i + 1]['metadata'].items(): + if domain_name in ['警察机关', '医院']: + continue + domain_state = {} + for slots in domain.values(): + for slot_name, value in slots.items(): + domain_state[slot_name] = value + state[domain_name] = domain_state + dialog_data.append((sys_utt, user_utt, state)) + data[dialog_id] = dialog_data + else: + for dialog_id, dialog in test_data.items(): + dialog_data = [] + turns = dialog['messages'] + for i in range(0, len(turns), 2): + sys_utt = turns[i - 1]['content'] if i else None + user_utt = turns[i]['content'] + state = {} + for domain_name, domain in turns[i + 1]['sys_state_init'].items(): + domain_state = {} + for slot_name, value in domain.items(): + if slot_name == 'selectedResults': + continue + domain_state[slot_name] = value + state[domain_name] = domain_state + dialog_data.append((sys_utt, user_utt, state)) + data[dialog_id] = dialog_data + + return data + + +def eval_states(gt, pred): + def exception(description, **kargs): + ret = { + 'status': 'exception', + 'description': description, + } + for k, v in kargs.items(): + ret[k] = v + return ret + + joint_acc, joint_tot = 0, 0 + slot_acc, slot_tot = 0, 0 + tp, fp, fn = 0, 0, 0 + for dialog_id, gt_states in gt.items(): + if dialog_id not in pred: + return exception('some dialog not found', dialog_id=dialog_id) + + pred_states = pred[dialog_id] + if len(gt_states) != len(pred_states): + return exception(f'turns number incorrect, {len(gt_states)} expected, {len(pred_states)} found', dialog_id=dialog_id) + + for turn_id, (gt_state, pred_state) in enumerate(zip(gt_states, pred_states)): + joint_tot += 1 + turn_result = True + for domain_name, gt_domain in gt_state.items(): + if domain_name not in pred_state: + return exception('domain missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name) + + pred_domain = pred_state[domain_name] + for slot_name, gt_value in gt_domain.items(): + if slot_name not in pred_domain: + return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) + pred_value = pred_domain[slot_name] + slot_tot += 1 + if gt_value == pred_value: + slot_acc += 1 + tp += 1 + else: + turn_result = False + # for class of gt_value + fn += 1 + # for class of pred_value + fp += 1 + joint_acc += turn_result + + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * tp / (2 * tp + fp + fn) + return { + 'status': 'ok', + 'joint accuracy': joint_acc / joint_tot, + 'slot accuracy': slot_acc / slot_tot, + # 'slot': { + # 'accuracy': slot_acc / slot_tot, + # 'precision': precision, + # 'recall': recall, + # 'f1': f1, + # } + } diff --git a/data/crosswoz_en/dstc9-test-250.zip b/data/crosswoz_en/dstc9-test-250.zip new file mode 100644 index 0000000000000000000000000000000000000000..e9237d0b119004bd28a0aa20b17e1f76faadd2e6 Binary files /dev/null and b/data/crosswoz_en/dstc9-test-250.zip differ diff --git a/data/multiwoz_zh/dstc9-test-250.zip b/data/multiwoz_zh/dstc9-test-250.zip new file mode 100644 index 0000000000000000000000000000000000000000..e70cc90ba0be060dcb97e78195f0d0a1cbd984ca Binary files /dev/null and b/data/multiwoz_zh/dstc9-test-250.zip differ