diff --git a/README.md b/README.md index 5433d94db5996b48368df4354bcd795ee9461c24..252a121c39c87e00598ae7fbff313c22af7b6f6e 100755 --- a/README.md +++ b/README.md @@ -208,9 +208,9 @@ evaluation of our pre-trained models are: (joint acc.) | type | CrossWOZ-en | MultiWOZ-zh | | ----- | ----------- | ----------- | -| val | 12.2% | 44.8% | -| test | 12.4% | 42.3% | -| human_val | 10.9% | 48.2% | +| val | 12.4% | 45.1% | +| test | 12.4% | 43.5% | +| human_val | 10.6% | 49.4% | `human_val` option will make the model evaluate on the validation set translated by human. diff --git a/convlab2/__init__.py b/convlab2/__init__.py index e28e3a1879d84c5ed920912e58b6915dcb13d0e8..87a7442310d0d5bad9dbeae9b1b29041d4490067 100755 --- a/convlab2/__init__.py +++ b/convlab2/__init__.py @@ -1,3 +1,5 @@ +import os + from convlab2.nlu import NLU from convlab2.dst import DST from convlab2.policy import Policy @@ -11,6 +13,5 @@ from os.path import abspath, dirname def get_root_path(): return dirname(dirname(abspath(__file__))) -import os -DATA_ROOT = os.path.join(get_root_path(), 'data') \ No newline at end of file +DATA_ROOT = os.path.join(get_root_path(), 'data') diff --git a/convlab2/dst/evaluate.py b/convlab2/dst/evaluate.py index 053140f0108205c4148327bdb2fdba9ee9144b36..842dde702c2402470eee14ad2bbe8d672b10127f 100755 --- a/convlab2/dst/evaluate.py +++ b/convlab2/dst/evaluate.py @@ -11,17 +11,40 @@ from tqdm import tqdm import copy import jieba -multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat'] -crosswoz_slot_list = ["景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分", "酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话", "餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地", "出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站", "地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务", "酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网", "酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域", "酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施", "酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车", "酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位", "酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话", "酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"] +multiwoz_slot_list = [ + 'attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', + 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', + 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', + 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', + 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', + 'train-day', 'train-departure', 'train-destination', 'train-leaveat' +] +crosswoz_slot_list = [ + "景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分", + "酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话", + "餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地", + "出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站", + "地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务", + "酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网", + "酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域", + "酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施", + "酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车", + "酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位", + "酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话", + "酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池" +] + from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list + def format_history(context): history = [] for i in range(len(context)): - history.append(['sys' if i%2==1 else 'usr', context[i]]) + history.append(['sys' if i % 2 == 1 else 'usr', context[i]]) return history + def sentseg(sent): sent = sent.replace('\t', ' ') sent = ' '.join(sent.split()) @@ -215,6 +238,8 @@ if __name__ == '__main__': if model_name == 'sumbt': from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker model = SUMBTTracker() + else: + raise Exception("Available models: sumbt") else: if model_name == 'TRADE': from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE @@ -228,7 +253,7 @@ if __name__ == '__main__': else: raise Exception("Available models: TRADE") - ## load data + # load data from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader diff --git a/convlab2/dst/sumbt/BeliefTrackerSlotQueryMultiSlot.py b/convlab2/dst/sumbt/BeliefTrackerSlotQueryMultiSlot.py index bd3dc57052afb186627ccacc73c6fa679dc99b2d..c33ee6891b31fdf02947e5f298b4fce015401560 100755 --- a/convlab2/dst/sumbt/BeliefTrackerSlotQueryMultiSlot.py +++ b/convlab2/dst/sumbt/BeliefTrackerSlotQueryMultiSlot.py @@ -150,6 +150,9 @@ class BeliefTracker(nn.Module): ### Etc. self.dropout = nn.Dropout(self.hidden_dropout_prob) + # default evaluation mode + self.eval() + def initialize_slot_value_lookup(self, label_ids, slot_ids): self.sv_encoder.eval() diff --git a/data/crosswoz/extract_all_ontology.py b/data/crosswoz/extract_all_ontology.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca4deaee5ae01b850793667b4e6bee8d17064e4 --- /dev/null +++ b/data/crosswoz/extract_all_ontology.py @@ -0,0 +1,203 @@ +""" +extract all value appear in dialog act and state, for translation. +""" +import json +import zipfile +import re +from pprint import pprint + +def read_zipped_json(filepath, filename): + archive = zipfile.ZipFile(filepath, 'r') + return json.load(archive.open(filename)) + + +zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') +multi_name = set() + +def extract_ontology(data): + intent_set = set() + domain_set = set() + slot_set = set() + value_set = {} + for _, sess in data.items(): + for i, turn in enumerate(sess['messages']): + for intent, domain, slot, value in turn['dialog_act']: + intent_set.add(intent) + domain_set.add(domain) + slot_set.add(slot) + if not domain in value_set: + value_set[domain] = {} + elif slot not in value_set[domain]: + value_set[domain][slot] = set() + elif slot in ['推荐菜', '名称', '酒店设施']: + if slot == '名称' and len(value.split()) > 1: + multi_name.add((domain, slot, value)) + elif slot == '推荐菜' and '-' in value: + print((domain, slot, value)) + for dish in value.split(): + value_set[domain][slot].add(dish) + elif slot == 'selectedResults' and domain in ['地铁', '出租']: + if domain == '地铁': + value = value[5:] + value_set[domain][slot].add(value.strip()) + if domain == '出租': + value = value[4:-1] + for v in value.split(' - '): + value_set[domain][slot].add(v.strip()) + else: + value_set[domain][slot].add(value) + if turn['role'] == 'usr': + for _, domain, slot, value, __ in turn['user_state']: + domain_set.add(domain) + slot_set.add(slot) + if isinstance(value, list): + for v in value: + if not domain in value_set: + value_set[domain] = {} + elif slot not in value_set[domain]: + value_set[domain][slot] = set() + elif slot in ['推荐菜', '名称', '酒店设施']: + if slot == '名称' and len(value.split()) > 1: + multi_name.add((domain, slot, value)) + elif slot == '推荐菜' and '-' in value: + print((domain, slot, value)) + for dish in v.split(): + value_set[domain][slot].add(dish) + elif slot == 'selectedResults' and domain in ['地铁', '出租']: + if domain == '地铁': + v = v[5:] + value_set[domain][slot].add(v.strip()) + if domain == '出租': + v = v[4:-1] + for item in v.split(' - '): + value_set[domain][slot].add(item.strip()) + else: + value_set[domain][slot].add(v) + else: + assert isinstance(value, str) + if not domain in value_set: + value_set[domain] = {} + elif slot not in value_set[domain]: + value_set[domain][slot] = set() + elif slot in ['推荐菜', '名称', '酒店设施']: + if slot == '名称' and len(value.split()) > 1: + multi_name.add((domain, slot, value)) + elif slot == '推荐菜' and '-' in value: + print((domain, slot, value)) + for dish in value.split(): + value_set[domain][slot].add(dish) + elif slot == 'selectedResults' and domain in ['地铁', '出租']: + if domain == '地铁': + value = value[5:] + value_set[domain][slot].add(value.strip()) + if domain == '出租': + value = value[4:-1] + for v in value.split(' - '): + value_set[domain][slot].add(v.strip()) + else: + value_set[domain][slot].add(value) + else: + for state_key in ['sys_state', 'sys_state_init']: + for domain, svd in turn[state_key].items(): + domain_set.add(domain) + for slot, value in svd.items(): + slot_set.add(slot) + if isinstance(value, list): + for v in value: + if not domain in value_set: + value_set[domain] = {} + elif slot not in value_set[domain]: + value_set[domain][slot] = set() + elif slot in ['推荐菜', '名称', '酒店设施']: + if slot == '名称' and len(value.split()) > 1: + multi_name.add((domain, slot, value)) + elif slot == '推荐菜' and '-' in value: + print((domain, slot, value)) + for dish in v.split(): + value_set[domain][slot].add(dish) + elif slot == 'selectedResults' and domain in ['地铁', '出租']: + if domain == '地铁': + v = v[5:] + value_set[domain][slot].add(v.strip()) + if domain == '出租': + v = v[4:-1] + for item in v.split(' - '): + value_set[domain][slot].add(item.strip()) + else: + value_set[domain][slot].add(v) + else: + assert isinstance(value, str) + if not domain in value_set: + value_set[domain] = {} + elif slot not in value_set[domain]: + value_set[domain][slot] = set() + elif slot in ['推荐菜', '名称', '酒店设施']: + if slot == '名称' and len(value.split()) > 1: + multi_name.add((domain, slot, value)) + elif slot == '推荐菜' and '-' in value: + print((domain, slot, value)) + for dish in value.split(): + value_set[domain][slot].add(dish) + elif slot == 'selectedResults' and domain in ['地铁', '出租']: + if domain == '地铁': + value = value[5:] + value_set[domain][slot].add(value.strip()) + if domain == '出租': + value = value[4:-1] + for v in value.split(' - '): + value_set[domain][slot].add(v.strip()) + else: + value_set[domain][slot].add(value) + return intent_set, domain_set, slot_set, value_set + + +if __name__ == '__main__': + intent_set = set() + domain_set = set() + slot_set = set() + value_set = {} + + for s in ['train', 'val', 'test', 'dstc9_data']: + print(f'Proceeding {s} set...') + data = read_zipped_json(s+'.json.zip', s+'.json') + output = extract_ontology(data) + intent_set |= output[0] + domain_set |= output[1] + slot_set |= output[2] + for domain in output[3]: + if domain in value_set: + for slot in output[3][domain]: + if slot in value_set[domain]: + value_set[domain][slot] |= output[3][domain][slot] + else: + value_set[domain][slot] = output[3][domain][slot] + else: + value_set[domain] = output[3][domain] + + print(len(domain_set)) + print(len(intent_set)) + print(len(slot_set)) + print(len(value_set)) + + intent_set = list(set([s.lower() for s in intent_set])) + domain_set = list(set([s.lower() for s in domain_set])) + slot_set = list(set([s.lower() for s in slot_set])) + for domain in value_set: + for slot in value_set[domain]: + value_set[domain][slot] = list(set([s.lower() for s in value_set[domain][slot]])) + # json.dump({ + # 'intent_set': intent_set, + # 'domain_set': domain_set, + # 'slot_set': slot_set, + # 'value_set': value_set, + # }, open('all_value_train_val_test.json', 'w'), indent=2, ensure_ascii=False) + print(len(domain_set)) + print(len(intent_set)) + print(len(slot_set)) + print(len(value_set)) + print(len(multi_name)) + pprint(multi_name) + + # 统计一下每个domain-slot对应的value数 + # value_count = {x: len(value_set[x]) for x in value_set.keys()} + # pprint(value_count) diff --git a/data/crosswoz/extract_all_value.py b/data/crosswoz/extract_all_value.py index 9d9db1b8d1b1676122e08e6f91566a9407499bb6..6492aa682a2675ac437c4e81c42438e95977ef6d 100755 --- a/data/crosswoz/extract_all_value.py +++ b/data/crosswoz/extract_all_value.py @@ -33,7 +33,7 @@ def extract_ontology(data): assert isinstance(value, str) value_set.add(value) else: - for domain, svd in turn['sys_state'].items(): + for domain, svd in turn['sys_state_init'].items(): domain_set.add(domain) for slot, value in svd.items(): slot_set.add(slot) @@ -51,7 +51,7 @@ if __name__ == '__main__': domain_set = set() slot_set = set() value_set = set() - for s in ['train', 'val', 'test']: + for s in ['train', 'val', 'test', 'dstc9_data']: data = read_zipped_json(s+'.json.zip', s+'.json') output = extract_ontology(data) intent_set |= output[0] diff --git a/data/crosswoz/gen_ontology.py b/data/crosswoz/gen_ontology.py new file mode 100644 index 0000000000000000000000000000000000000000..f0aa7ad002951c8d12d61500b8c3b545ae2a2013 --- /dev/null +++ b/data/crosswoz/gen_ontology.py @@ -0,0 +1,86 @@ +import json +from zipfile import ZipFile + +import re + +ontology = { + "景点": { + "名称": set(), + "门票": set(), + "游玩时间": set(), + "评分": set(), + "周边景点": set(), + "周边餐馆": set(), + "周边酒店": set(), + }, + "餐馆": { + "名称": set(), + "推荐菜": set(), + "人均消费": set(), + "评分": set(), + "周边景点": set(), + "周边餐馆": set(), + "周边酒店": set(), + }, + "酒店": { + "名称": set(), + "酒店类型": set(), + "酒店设施": set(), + "价格": set(), + "评分": set(), + "周边景点": set(), + "周边餐馆": set(), + "周边酒店": set(), + }, + "地铁": { + "出发地": set(), + "目的地": set(), + }, + "出租": { + "出发地": set(), + "目的地": set(), + } +} + +if __name__ == '__main__': + pattern = re.compile('. .+') + for split in ['train', 'val', 'test', 'dstc9_data']: + print(split) + with ZipFile(f'{split}.json.zip', 'r') as zipfile: + with zipfile.open(f'{split}.json', 'r') as f: + data = json.load(f) + + for dialog in data.values(): + for turn in dialog['messages']: + if turn['role'] == 'sys': + state = turn['sys_state_init'] + for domain_name, domain in state.items(): + for slot_name, value in domain.items(): + + if slot_name == 'selectedResults': + continue + else: + value = value.replace('\t', ' ').strip() + if not value: + continue + values = ontology[domain_name][slot_name] + if slot_name in ['酒店设施', '推荐菜']: + # deal with values contain bothering space like "早 餐 服 务 无 烟 房" + if pattern.match(value): + print(value) + value = value.replace(' ', ';').replace(' ', '').replace(';', ' ') + print(value) + for v in value.split(' '): + if v: + values.add(v) + elif value and value not in values: + # if ',' in value or ',' in value or ' ' in value: + # print(value, slot_name) + values.add(value) + + for domain in ontology.values(): + for slot_name, values in domain.items(): + domain[slot_name] = list(values) + + with open('ontology.json', 'w') as f: + json.dump(ontology, f, indent=4, ensure_ascii=False) diff --git a/data/crosswoz/train.json.zip b/data/crosswoz/train.json.zip index 09d303cdf5525be1ee2c20b98046857622e62972..25b9a5957021db8217b6ea6e7b67f468f04a1b16 100755 Binary files a/data/crosswoz/train.json.zip and b/data/crosswoz/train.json.zip differ diff --git a/data/crosswoz/val.json.zip b/data/crosswoz/val.json.zip index 6c38c8d75e41d632d36e4e239e0dd9dbbadfdad8..3f5d5ba5713d16755dd96c3f9e5779f973d58eef 100755 Binary files a/data/crosswoz/val.json.zip and b/data/crosswoz/val.json.zip differ diff --git a/data/crosswoz_en/.gitignore b/data/crosswoz_en/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c249259bc672859adde39d83d57df40beddab6f9 --- /dev/null +++ b/data/crosswoz_en/.gitignore @@ -0,0 +1 @@ +CrossWOZ_translate \ No newline at end of file