Skip to content
Snippets Groups Projects
Commit 7feb0968 authored by function2-llx's avatar function2-llx Committed by zhuqi
Browse files

update sumbt translation train result with evaluation mode set

parent 9d85c597
Branches
No related tags found
No related merge requests found
......@@ -208,9 +208,9 @@ evaluation of our pre-trained models are: (joint acc.)
| type | CrossWOZ-en | MultiWOZ-zh |
| ----- | ----------- | ----------- |
| val | 12.2% | 44.8% |
| test | 12.4% | 42.3% |
| human_val | 10.9% | 48.2% |
| val | 12.4% | 45.1% |
| test | 12.4% | 43.5% |
| human_val | 10.6% | 49.4% |
`human_val` option will make the model evaluate on the validation set translated by human.
......
import os
from convlab2.nlu import NLU
from convlab2.dst import DST
from convlab2.policy import Policy
......@@ -11,6 +13,5 @@ from os.path import abspath, dirname
def get_root_path():
return dirname(dirname(abspath(__file__)))
import os
DATA_ROOT = os.path.join(get_root_path(), 'data')
......@@ -11,17 +11,40 @@ from tqdm import tqdm
import copy
import jieba
multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat']
crosswoz_slot_list = ["景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分", "酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话", "餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地", "出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站", "地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务", "酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网", "酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域", "酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施", "酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车", "酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位", "酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话", "酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"]
multiwoz_slot_list = [
'attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people',
'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange',
'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time',
'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby',
'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby',
'train-day', 'train-departure', 'train-destination', 'train-leaveat'
]
crosswoz_slot_list = [
"景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分",
"酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话",
"餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地",
"出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站",
"地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务",
"酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网",
"酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域",
"酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施",
"酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车",
"酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位",
"酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话",
"酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"
]
from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list
from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list
def format_history(context):
history = []
for i in range(len(context)):
history.append(['sys' if i % 2 == 1 else 'usr', context[i]])
return history
def sentseg(sent):
sent = sent.replace('\t', ' ')
sent = ' '.join(sent.split())
......@@ -215,6 +238,8 @@ if __name__ == '__main__':
if model_name == 'sumbt':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
raise Exception("Available models: sumbt")
else:
if model_name == 'TRADE':
from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE
......@@ -228,7 +253,7 @@ if __name__ == '__main__':
else:
raise Exception("Available models: TRADE")
## load data
# load data
from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader
......
......@@ -150,6 +150,9 @@ class BeliefTracker(nn.Module):
### Etc.
self.dropout = nn.Dropout(self.hidden_dropout_prob)
# default evaluation mode
self.eval()
def initialize_slot_value_lookup(self, label_ids, slot_ids):
self.sv_encoder.eval()
......
"""
extract all value appear in dialog act and state, for translation.
"""
import json
import zipfile
import re
from pprint import pprint
def read_zipped_json(filepath, filename):
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
multi_name = set()
def extract_ontology(data):
intent_set = set()
domain_set = set()
slot_set = set()
value_set = {}
for _, sess in data.items():
for i, turn in enumerate(sess['messages']):
for intent, domain, slot, value in turn['dialog_act']:
intent_set.add(intent)
domain_set.add(domain)
slot_set.add(slot)
if not domain in value_set:
value_set[domain] = {}
elif slot not in value_set[domain]:
value_set[domain][slot] = set()
elif slot in ['推荐菜', '名称', '酒店设施']:
if slot == '名称' and len(value.split()) > 1:
multi_name.add((domain, slot, value))
elif slot == '推荐菜' and '-' in value:
print((domain, slot, value))
for dish in value.split():
value_set[domain][slot].add(dish)
elif slot == 'selectedResults' and domain in ['地铁', '出租']:
if domain == '地铁':
value = value[5:]
value_set[domain][slot].add(value.strip())
if domain == '出租':
value = value[4:-1]
for v in value.split(' - '):
value_set[domain][slot].add(v.strip())
else:
value_set[domain][slot].add(value)
if turn['role'] == 'usr':
for _, domain, slot, value, __ in turn['user_state']:
domain_set.add(domain)
slot_set.add(slot)
if isinstance(value, list):
for v in value:
if not domain in value_set:
value_set[domain] = {}
elif slot not in value_set[domain]:
value_set[domain][slot] = set()
elif slot in ['推荐菜', '名称', '酒店设施']:
if slot == '名称' and len(value.split()) > 1:
multi_name.add((domain, slot, value))
elif slot == '推荐菜' and '-' in value:
print((domain, slot, value))
for dish in v.split():
value_set[domain][slot].add(dish)
elif slot == 'selectedResults' and domain in ['地铁', '出租']:
if domain == '地铁':
v = v[5:]
value_set[domain][slot].add(v.strip())
if domain == '出租':
v = v[4:-1]
for item in v.split(' - '):
value_set[domain][slot].add(item.strip())
else:
value_set[domain][slot].add(v)
else:
assert isinstance(value, str)
if not domain in value_set:
value_set[domain] = {}
elif slot not in value_set[domain]:
value_set[domain][slot] = set()
elif slot in ['推荐菜', '名称', '酒店设施']:
if slot == '名称' and len(value.split()) > 1:
multi_name.add((domain, slot, value))
elif slot == '推荐菜' and '-' in value:
print((domain, slot, value))
for dish in value.split():
value_set[domain][slot].add(dish)
elif slot == 'selectedResults' and domain in ['地铁', '出租']:
if domain == '地铁':
value = value[5:]
value_set[domain][slot].add(value.strip())
if domain == '出租':
value = value[4:-1]
for v in value.split(' - '):
value_set[domain][slot].add(v.strip())
else:
value_set[domain][slot].add(value)
else:
for state_key in ['sys_state', 'sys_state_init']:
for domain, svd in turn[state_key].items():
domain_set.add(domain)
for slot, value in svd.items():
slot_set.add(slot)
if isinstance(value, list):
for v in value:
if not domain in value_set:
value_set[domain] = {}
elif slot not in value_set[domain]:
value_set[domain][slot] = set()
elif slot in ['推荐菜', '名称', '酒店设施']:
if slot == '名称' and len(value.split()) > 1:
multi_name.add((domain, slot, value))
elif slot == '推荐菜' and '-' in value:
print((domain, slot, value))
for dish in v.split():
value_set[domain][slot].add(dish)
elif slot == 'selectedResults' and domain in ['地铁', '出租']:
if domain == '地铁':
v = v[5:]
value_set[domain][slot].add(v.strip())
if domain == '出租':
v = v[4:-1]
for item in v.split(' - '):
value_set[domain][slot].add(item.strip())
else:
value_set[domain][slot].add(v)
else:
assert isinstance(value, str)
if not domain in value_set:
value_set[domain] = {}
elif slot not in value_set[domain]:
value_set[domain][slot] = set()
elif slot in ['推荐菜', '名称', '酒店设施']:
if slot == '名称' and len(value.split()) > 1:
multi_name.add((domain, slot, value))
elif slot == '推荐菜' and '-' in value:
print((domain, slot, value))
for dish in value.split():
value_set[domain][slot].add(dish)
elif slot == 'selectedResults' and domain in ['地铁', '出租']:
if domain == '地铁':
value = value[5:]
value_set[domain][slot].add(value.strip())
if domain == '出租':
value = value[4:-1]
for v in value.split(' - '):
value_set[domain][slot].add(v.strip())
else:
value_set[domain][slot].add(value)
return intent_set, domain_set, slot_set, value_set
if __name__ == '__main__':
intent_set = set()
domain_set = set()
slot_set = set()
value_set = {}
for s in ['train', 'val', 'test', 'dstc9_data']:
print(f'Proceeding {s} set...')
data = read_zipped_json(s+'.json.zip', s+'.json')
output = extract_ontology(data)
intent_set |= output[0]
domain_set |= output[1]
slot_set |= output[2]
for domain in output[3]:
if domain in value_set:
for slot in output[3][domain]:
if slot in value_set[domain]:
value_set[domain][slot] |= output[3][domain][slot]
else:
value_set[domain][slot] = output[3][domain][slot]
else:
value_set[domain] = output[3][domain]
print(len(domain_set))
print(len(intent_set))
print(len(slot_set))
print(len(value_set))
intent_set = list(set([s.lower() for s in intent_set]))
domain_set = list(set([s.lower() for s in domain_set]))
slot_set = list(set([s.lower() for s in slot_set]))
for domain in value_set:
for slot in value_set[domain]:
value_set[domain][slot] = list(set([s.lower() for s in value_set[domain][slot]]))
# json.dump({
# 'intent_set': intent_set,
# 'domain_set': domain_set,
# 'slot_set': slot_set,
# 'value_set': value_set,
# }, open('all_value_train_val_test.json', 'w'), indent=2, ensure_ascii=False)
print(len(domain_set))
print(len(intent_set))
print(len(slot_set))
print(len(value_set))
print(len(multi_name))
pprint(multi_name)
# 统计一下每个domain-slot对应的value数
# value_count = {x: len(value_set[x]) for x in value_set.keys()}
# pprint(value_count)
......@@ -33,7 +33,7 @@ def extract_ontology(data):
assert isinstance(value, str)
value_set.add(value)
else:
for domain, svd in turn['sys_state'].items():
for domain, svd in turn['sys_state_init'].items():
domain_set.add(domain)
for slot, value in svd.items():
slot_set.add(slot)
......@@ -51,7 +51,7 @@ if __name__ == '__main__':
domain_set = set()
slot_set = set()
value_set = set()
for s in ['train', 'val', 'test']:
for s in ['train', 'val', 'test', 'dstc9_data']:
data = read_zipped_json(s+'.json.zip', s+'.json')
output = extract_ontology(data)
intent_set |= output[0]
......
import json
from zipfile import ZipFile
import re
ontology = {
"景点": {
"名称": set(),
"门票": set(),
"游玩时间": set(),
"评分": set(),
"周边景点": set(),
"周边餐馆": set(),
"周边酒店": set(),
},
"餐馆": {
"名称": set(),
"推荐菜": set(),
"人均消费": set(),
"评分": set(),
"周边景点": set(),
"周边餐馆": set(),
"周边酒店": set(),
},
"酒店": {
"名称": set(),
"酒店类型": set(),
"酒店设施": set(),
"价格": set(),
"评分": set(),
"周边景点": set(),
"周边餐馆": set(),
"周边酒店": set(),
},
"地铁": {
"出发地": set(),
"目的地": set(),
},
"出租": {
"出发地": set(),
"目的地": set(),
}
}
if __name__ == '__main__':
pattern = re.compile('. .+')
for split in ['train', 'val', 'test', 'dstc9_data']:
print(split)
with ZipFile(f'{split}.json.zip', 'r') as zipfile:
with zipfile.open(f'{split}.json', 'r') as f:
data = json.load(f)
for dialog in data.values():
for turn in dialog['messages']:
if turn['role'] == 'sys':
state = turn['sys_state_init']
for domain_name, domain in state.items():
for slot_name, value in domain.items():
if slot_name == 'selectedResults':
continue
else:
value = value.replace('\t', ' ').strip()
if not value:
continue
values = ontology[domain_name][slot_name]
if slot_name in ['酒店设施', '推荐菜']:
# deal with values contain bothering space like "早 餐 服 务 无 烟 房"
if pattern.match(value):
print(value)
value = value.replace(' ', ';').replace(' ', '').replace(';', ' ')
print(value)
for v in value.split(' '):
if v:
values.add(v)
elif value and value not in values:
# if ',' in value or ',' in value or ' ' in value:
# print(value, slot_name)
values.add(value)
for domain in ontology.values():
for slot_name, values in domain.items():
domain[slot_name] = list(values)
with open('ontology.json', 'w') as f:
json.dump(ontology, f, indent=4, ensure_ascii=False)
No preview for this file type
No preview for this file type
CrossWOZ_translate
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment