Skip to content
Snippets Groups Projects
Unverified Commit 4b133a8d authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

Translation train on MultiWOZ (Chinese) nad CrossWOZ (English) of SUMBT (#17)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en
parent 206af602
Branches
No related tags found
No related merge requests found
Showing
with 2815 additions and 51 deletions
...@@ -66,3 +66,6 @@ convlab2/dst/trade/multiwoz_config/ ...@@ -66,3 +66,6 @@ convlab2/dst/trade/multiwoz_config/
deploy/bert_multiwoz_all.zip deploy/bert_multiwoz_all.zip
deploy/templates/dialog_eg.html deploy/templates/dialog_eg.html
test.py test.py
*.egg-info
pre-trained-models/
\ No newline at end of file
...@@ -159,6 +159,64 @@ By running `convlab2/nlg/evaluate.py MultiWOZ $model sys` ...@@ -159,6 +159,64 @@ By running `convlab2/nlg/evaluate.py MultiWOZ $model sys`
| Template | 0.3309 | | Template | 0.3309 |
| SCLSTM | 0.4884 | | SCLSTM | 0.4884 |
## translation train with SUMBT
### train
With Convlab-2, you can train SUMBT on a translated dataset like this:
```python
# train.py
import os
from sys import argv
if __name__ == "__main__":
if len(argv) != 2:
print('usage: python3 train.py [dataset]')
exit(1)
assert argv[1] in ['multiwoz', 'crosswoz']
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBT_PATH
if argv[1] == 'multiwoz':
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker as SUMBT
elif argv[1] == 'crosswoz':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker as SUMBT
sumbt = SUMBT()
sumbt.train(True)
```
### evaluate
Execute `evaluate.py` (under `convlab2/dst/`) with following command:
```bash
python3 evaluate.py [CorssWOZ-en|MultiWOZ-zh] [test|human|val]
```
`human` option will make the model evaluate on the validation set translated by human.
evaluation of our pre-trained models are:
| type | CrossWOZ-en | MultiWOZ-zh |
| ----- | ----------- | ----------- |
| test | 12.4% | 42.3% |
| human | 10.9% | 48.2% |
| val | 12.2% | 44.8% |
Note: You may want to download pre-traiend BERT models and translation-train pre-trained DST models provided by us.
Without modifying any code, you could:
- download pre-trained BERT model from:
- [CorssWOZ-en](https://huggingface.co/bert-base-uncased)
- [MultiWOZ-zh](https://huggingface.co/hfl/chinese-bert-wwm-ext)
extract it to `./pre-trained-models`.
- for a pre-trained DST model, e.g. say the DST model is SUMBT, data set is CrossWOZ (English), (after extraction) just save the pre-trained model under `./convlab2/dst/sumbt/crosswoz_en/pre-trained` and name it with `pytorch_model.bin`.
## Issues ## Issues
You are welcome to create an issue if you want to request a feature, report a bug or ask a general question. You are welcome to create an issue if you want to request a feature, report a bug or ask a general question.
......
...@@ -10,3 +10,7 @@ from os.path import abspath, dirname ...@@ -10,3 +10,7 @@ from os.path import abspath, dirname
def get_root_path(): def get_root_path():
return dirname(dirname(abspath(__file__))) return dirname(dirname(abspath(__file__)))
import os
DATA_ROOT = os.path.join(get_root_path(), 'data')
\ No newline at end of file
# -*- coding: gbk -*-
""" """
Evaluate NLU models on specified dataset Evaluate NLU models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule] Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule]
...@@ -12,8 +11,9 @@ import copy ...@@ -12,8 +11,9 @@ import copy
import jieba import jieba
multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat'] multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat']
crosswoz_slot_list = ["쒼듐-쳔튿", "쒼듐-팀롸", "꽜반-츰냔", "아듦-송목", "아듦-팀롸", "쒼듐-츰냔", "쒼듐-뒈囹", "쒼듐-踏鯤珂쇌", "꽜반-檀撚珂쇌", "꽜반-팀롸", "아듦-츰냔", "아듦-鷺긋쒼듐", "아듦-아듦嘉-싻今륩蛟", "아듦-아듦잚謹", "꽜반-훙엇句롤", "꽜반-股수꽉", "아듦-아듦嘉", "아듦-든뺐", "쒼듐-든뺐", "꽜반-鷺긋꽜반", "꽜반-든뺐", "꽜반-none", "꽜반-뒈囹", "아듦-아듦嘉-轟緊렛", "아듦-뒈囹", "쒼듐-鷺긋쒼듐", "쒼듐-鷺긋아듦", "놔理-놔랙뒈", "놔理-커돨뒈", "뒈屆-놔랙뒈", "뒈屆-커돨뒈", "쒼듐-鷺긋꽜반", "아듦-鷺긋꽜반", "놔理-났謹", "꽜반-鷺긋쒼듐", "꽜반-鷺긋아듦", "뒈屆-놔랙뒈맒쐤뒈屆籃", "뒈屆-커돨뒈맒쐤뒈屆籃", "쒼듐-none", "아듦-아듦嘉-蛟櫓懃", "꽜반-都쥴堵", "아듦-아듦嘉-櫓駕꽜戒", "아듦-아듦嘉-쌈籃륩蛟", "아듦-아듦嘉-벌셥낀槁든뺐", "아듦-아듦嘉-뉘루샙", "아듦-아듦嘉-삔累杆", "아듦-都쥴堵", "아듦-none", "아듦-아듦嘉-욱던貢", "아듦-아듦嘉-였빱鬼벚륩蛟", "아듦-아듦嘉-아듦몹뇹瓊묩wifi", "아듦-아듦嘉-킁폭", "아듦-아듦嘉-spa", "놔理-났탬", "쒼듐-都쥴堵", "아듦-아듦嘉-契쟀셍닸", "아듦-아듦嘉-鮫駕꽜戒", "아듦-아듦嘉-아걸", "아듦-아듦嘉-豆꽜륩蛟", "아듦-아듦嘉-숯렛", "아듦-아듦嘉-꽥섣훙嘉", "아듦-아듦嘉-출롤懇코든뺐", "아듦-아듦嘉-쌈덤棍깟", "아듦-아듦嘉-꼬롸렛쇌瓊묩wifi", "아듦-아듦嘉-求擄륩蛟", "아듦-아듦嘉-理났", "아듦-아듦嘉-무묾혐堵뵨꼬롸렛쇌瓊묩wifi", "아듦-아듦嘉-24鬼珂훑彊", "아듦-아듦嘉-侊홋", "아듦-아듦嘉-컬", "아듦-아듦嘉-澗롤界났貫", "아듦-鷺긋아듦", "아듦-아듦嘉-쌈샙륩蛟", "아듦-아듦嘉-杰唐렛쇌瓊묩wifi", "아듦-아듦嘉-펙탬杆", "아듦-아듦嘉-출롤벌코낀槁든뺐", "아듦-아듦嘉-杆코踏曇넥", "아듦-아듦嘉-豆꽜륩蛟출롤", "아듦-아듦嘉-무묾혐堵瓊묩wifi", "아듦-아듦嘉-杆棍踏曇넥"] crosswoz_slot_list = ["景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分", "酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话", "餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地", "出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站", "地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务", "酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网", "酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域", "酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施", "酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车", "酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位", "酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话", "酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"]
from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list
from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list
def format_history(context): def format_history(context):
history = [] history = []
...@@ -37,7 +37,7 @@ def reformat_state(state): ...@@ -37,7 +37,7 @@ def reformat_state(state):
domain_data = domain_data['semi'] domain_data = domain_data['semi']
for slot in domain_data.keys(): for slot in domain_data.keys():
val = domain_data[slot] val = domain_data[slot]
if val is not None and val != '' and val != 'not mentioned': if val is not None and val not in ['', 'not mentioned', '未提及', '未提到', '没有提到']:
new_state.append(domain + '-' + slot + '-' + val) new_state.append(domain + '-' + slot + '-' + val)
# lower # lower
new_state = [item.lower() for item in new_state] new_state = [item.lower() for item in new_state]
...@@ -47,19 +47,27 @@ def reformat_state_crosswoz(state): ...@@ -47,19 +47,27 @@ def reformat_state_crosswoz(state):
if 'belief_state' in state: if 'belief_state' in state:
state = state['belief_state'] state = state['belief_state']
new_state = [] new_state = []
# print(state)
for domain in state.keys(): for domain in state.keys():
domain_data = state[domain] domain_data = state[domain]
for slot in domain_data.keys(): for slot in domain_data.keys():
if slot == 'selectedResults': continue if slot == 'selectedResults': continue
val = domain_data[slot] val = domain_data[slot]
if val is not None and val != '': if slot == 'Hotel Facilities' and val not in ['', 'none']:
for facility in val.split(','):
new_state.append(domain + '-' + f'Hotel Facilities - {facility}' + 'yes')
else:
if val is not None and val not in ['', 'none']:
# print(domain, slot, val)
new_state.append(domain + '-' + slot + '-' + val) new_state.append(domain + '-' + slot + '-' + val)
return new_state return new_state
def compute_acc(gold, pred, slot_temp): def compute_acc(gold, pred, slot_temp):
# TODO: not mentioned in gold # TODO: not mentioned in gold
miss_gold = 0 miss_gold = 0
miss_slot = [] miss_slot = []
# print(gold, pred)
for g in gold: for g in gold:
if g not in pred: if g not in pred:
miss_gold += 1 miss_gold += 1
...@@ -124,34 +132,44 @@ if __name__ == '__main__': ...@@ -124,34 +132,44 @@ if __name__ == '__main__':
numpy.random.seed(seed) numpy.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if len(sys.argv) != 3: if len(sys.argv) != 4:
print("usage:") print("usage:")
print("\t python evaluate.py dataset model") print("\t python evaluate.py dataset model")
print("\t dataset=MultiWOZ, CrossWOZ") print("\t dataset=MultiWOZ, MultiWOZ-zh, CrossWOZ, CrossWOZ-en")
print("\t model=TRADE, mdbt, sumbt") print("\t model=TRADE, mdbt, sumbt")
print("\t val=[val|test|human]")
sys.exit() sys.exit()
## init phase ## init phase
dataset_name = sys.argv[1] dataset_name = sys.argv[1]
model_name = sys.argv[2] model_name = sys.argv[2]
if dataset_name == 'MultiWOZ': data_key = sys.argv[3]
if model_name == 'TRADE':
if dataset_name.startswith('MultiWOZ'):
if dataset_name.endswith('zh'):
if model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
raise Exception("Available models: sumbt")
else:
if model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
model = SUMBTTracker()
elif model_name == 'TRADE':
from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE
model = MultiWOZTRADE() model = MultiWOZTRADE()
elif model_name == 'mdbt': elif model_name == 'mdbt':
from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT
model = MultiWozMDBT() model = MultiWozMDBT()
elif model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
model = SUMBTTracker()
else: else:
raise Exception("Available models: TRADE/mdbt/sumbt") raise Exception("Available models: TRADE/mdbt/sumbt")
## load data ## load data
from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader()) dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader(dataset_name.endswith('zh')))
data = dataloader.load_data(data_key='test')['test'] data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['belief_state'] context, golden_truth = data['context'], data['belief_state']
all_predictions = {} all_predictions = {}
test_set = [] test_set = []
...@@ -160,7 +178,6 @@ if __name__ == '__main__': ...@@ -160,7 +178,6 @@ if __name__ == '__main__':
turn_count = 0 turn_count = 0
is_start = True is_start = True
for i in tqdm(range(len(context))): for i in tqdm(range(len(context))):
# for i in tqdm(range(200)): # for test
if len(context[i]) == 0: if len(context[i]) == 0:
turn_count = 0 turn_count = 0
if is_start: if is_start:
...@@ -181,19 +198,23 @@ if __name__ == '__main__': ...@@ -181,19 +198,23 @@ if __name__ == '__main__':
'turn_belief': reformat_state(y), 'turn_belief': reformat_state(y),
'pred_bs_ptr': reformat_state(pred) 'pred_bs_ptr': reformat_state(pred)
} }
# print('golden: ', reformat_state(y))
# print('pred :', reformat_state(pred))
turn_count += 1 turn_count += 1
# add last session # add last session
if len(curr_sess) > 0: if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess) all_predictions[session_count] = copy.deepcopy(curr_sess)
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", multiwoz_slot_list) slot_list = multiwoz_zh_slot_list if dataset_name.endswith('zh') else multiwoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr, evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr} "Joint F1": F1_score_ptr}
print(evaluation_metrics) print(evaluation_metrics)
elif dataset_name.startswith('CrossWOZ'):
elif dataset_name == 'CrossWOZ': en = dataset_name.endswith('en')
if en:
if model_name == 'sumbt':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
if model_name == 'TRADE': if model_name == 'TRADE':
from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE
model = CrossWOZTRADE() model = CrossWOZTRADE()
...@@ -210,8 +231,8 @@ if __name__ == '__main__': ...@@ -210,8 +231,8 @@ if __name__ == '__main__':
from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader
dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader()) dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader(en))
data = dataloader.load_data(data_key='test')['test'] data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['sys_state_init'] context, golden_truth = data['context'], data['sys_state_init']
all_predictions = {} all_predictions = {}
test_set = [] test_set = []
...@@ -220,7 +241,6 @@ if __name__ == '__main__': ...@@ -220,7 +241,6 @@ if __name__ == '__main__':
turn_count = 0 turn_count = 0
is_start = True is_start = True
for i in tqdm(range(len(context))): for i in tqdm(range(len(context))):
# for i in tqdm(range(10)): # for test
if len(context[i]) == 0: if len(context[i]) == 0:
turn_count = 0 turn_count = 0
if is_start: if is_start:
...@@ -229,12 +249,17 @@ if __name__ == '__main__': ...@@ -229,12 +249,17 @@ if __name__ == '__main__':
all_predictions[session_count] = copy.deepcopy(curr_sess) all_predictions[session_count] = copy.deepcopy(curr_sess)
session_count += 1 session_count += 1
curr_sess = {} curr_sess = {}
# skip usr turn
if len(context[i]) % 2 == 0: if len(context[i]) % 2 == 0:
continue continue
# add turn # add turn
x = context[i] x = context[i]
y = golden_truth[i] y = golden_truth[i]
# process y # process y
if not en:
for domain in y.keys(): for domain in y.keys():
domain_data = y[domain] domain_data = y[domain]
for slot in domain_data.keys(): for slot in domain_data.keys():
...@@ -242,9 +267,9 @@ if __name__ == '__main__': ...@@ -242,9 +267,9 @@ if __name__ == '__main__':
val = domain_data[slot] val = domain_data[slot]
if val is not None and val != '': if val is not None and val != '':
val = sentseg(val) val = sentseg(val)
y[domain][slot] = val domain_data[slot] = val
model.init_session() model.init_session()
model.state['history'] = format_history([sentseg(item) for item in context[i]]) model.state['history'] = format_history([item if en else sentseg(item) for item in context[i]])
pred = model.update(x[-1] if len(x) > 0 else '') pred = model.update(x[-1] if len(x) > 0 else '')
curr_sess[turn_count] = { curr_sess[turn_count] = {
'turn_belief': reformat_state_crosswoz(y), 'turn_belief': reformat_state_crosswoz(y),
...@@ -255,8 +280,9 @@ if __name__ == '__main__': ...@@ -255,8 +280,9 @@ if __name__ == '__main__':
if len(curr_sess) > 0: if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess) all_predictions[session_count] = copy.deepcopy(curr_sess)
slot_list = crosswoz_en_slot_list if en else crosswoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr",
crosswoz_slot_list) slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr, evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr} "Joint F1": F1_score_ptr}
print(evaluation_metrics) print(evaluation_metrics)
*/model_output/
...@@ -272,6 +272,8 @@ class BeliefTracker(nn.Module): ...@@ -272,6 +272,8 @@ class BeliefTracker(nn.Module):
# calculate joint accuracy # calculate joint accuracy
pred_slot = torch.cat(pred_slot, 2) pred_slot = torch.cat(pred_slot, 2)
# print('pred slot:', pred_slot[0][0])
# print('labels:', labels[0][0])
accuracy = (pred_slot == labels).view(-1, slot_dim) accuracy = (pred_slot == labels).view(-1, slot_dim)
acc_slot = torch.sum(accuracy, 0).float() \ acc_slot = torch.sum(accuracy, 0).float() \
/ torch.sum(labels.view(-1, slot_dim) > -1, 0).float() / torch.sum(labels.view(-1, slot_dim) > -1, 0).float()
......
model_output/
pre-trained/
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker as SUMBT
import json
import zipfile
from convlab2.dst.sumbt.crosswoz_en.sumbt_config import *
null = 'none'
def trans_value(value):
trans = {
'': 'none',
}
value = value.strip()
value = trans.get(value, value)
value = value.replace('', "'")
value = value.replace('', "'")
return value
def convert_to_glue_format(data_dir, sumbt_dir):
if not os.path.isdir(os.path.join(sumbt_dir, args.tmp_data_dir)):
os.mkdir(os.path.join(sumbt_dir, args.tmp_data_dir))
### Read ontology file
with open(os.path.join(data_dir, "ontology.json"), "r") as fp_ont:
data_ont = json.load(fp_ont)
ontology = {}
facilities = []
for domain_slot in data_ont:
domain, slot = domain_slot.split('-', 1)
if domain not in ontology:
ontology[domain] = {}
if slot.startswith('Hotel Facilities'):
facilities.append(slot.split(' - ')[1])
ontology[domain][slot] = set(map(str.lower, data_ont[domain_slot]))
### Read woz logs and write to tsv files
tsv_filename = os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")
print('tsv file: ', os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"))
if os.path.exists(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")):
print('data has been processed!')
return 0
else:
print('processing data')
with open(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"), "w") as fp_train, \
open(os.path.join(sumbt_dir, args.tmp_data_dir, "dev.tsv"), "w") as fp_dev, \
open(os.path.join(sumbt_dir, args.tmp_data_dir, "test.tsv"), "w") as fp_test:
fp_train.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_dev.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_test.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
fp_train.write(f'{str(domain)}-{str(slot)}\t')
fp_dev.write(f'{str(domain)}-{str(slot)}\t')
fp_test.write(f'{str(domain)}-{str(slot)}\t')
fp_train.write('\n')
fp_dev.write('\n')
fp_test.write('\n')
# fp_data = open(os.path.join(SELF_DATA_DIR, "data.json"), "r")
# data = json.load(fp_data)
file_split = ['train', 'val', 'test']
fp = [fp_train, fp_dev, fp_test]
for split_type, split_fp in zip(file_split, fp):
zipfile_name = "{}.json.zip".format(split_type)
zip_fp = zipfile.ZipFile(os.path.join(data_dir, zipfile_name))
data = json.loads(str(zip_fp.read(zip_fp.namelist()[0]), 'utf-8'))
for file_id in data:
user_utterance = ''
system_response = ''
turn_idx = 0
messages = data[file_id]['messages']
for idx, turn in enumerate(messages):
if idx % 2 == 0: # user turn
user_utterance = turn['content']
else: # system turn
user_utterance = user_utterance.replace('\t', ' ')
user_utterance = user_utterance.replace('\n', ' ')
user_utterance = user_utterance.replace(' ', ' ')
system_response = system_response.replace('\t', ' ')
system_response = system_response.replace('\n', ' ')
system_response = system_response.replace(' ', ' ')
split_fp.write(str(file_id)) # 0: dialogue ID
split_fp.write('\t' + str(turn_idx)) # 1: turn index
split_fp.write('\t' + str(user_utterance)) # 2: user utterance
split_fp.write('\t' + str(system_response)) # 3: system response
# hardcode the value of facilities as 'yes' and 'no'
belief = {f'Hotel-Hotel Facilities - {str(facility)}': null for facility in facilities}
sys_state_init = turn['sys_state_init']
for domain, slots in sys_state_init.items():
for slot, value in slots.items():
# skip selected results
if isinstance(value, list):
continue
if domain not in ontology:
print("domain (%s) is not defined" % domain)
continue
if slot == 'Hotel Facilities':
for facility in value.split(','):
belief[f'{str(domain)}-Hotel Facilities - {str(facility)}'] = 'yes'
else:
if slot not in ontology[domain]:
print("slot (%s) in domain (%s) is not defined" % (slot, domain)) # bus-arriveBy not defined
continue
value = trans_value(value).lower()
if value not in ontology[domain][slot] and value != null:
print("%s: value (%s) in domain (%s) slot (%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = null
belief[f'{str(domain)}-{str(slot)}'] = value
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
key = str(domain) + '-' + str(slot)
if key in belief:
val = belief[key]
split_fp.write('\t' + val)
else:
split_fp.write(f'\t{null}')
split_fp.write('\n')
split_fp.flush()
system_response = turn['content']
turn_idx += 1
print('data has been processed!')
This diff is collapsed.
import os
import convlab2
class DotMap():
def __init__(self):
self.max_label_length = 35
self.num_rnn_layers = 1
self.zero_init_rnn = False
self.attn_head = 4
self.do_eval = True
self.do_train = False
self.train_batch_size = 3
self.dev_batch_size = 1
self.eval_batch_size = 16
self.learning_rate = 5e-5
self.warmup_proportion = 0.1
self.local_rank = -1
self.seed = 42
self.gradient_accumulation_steps = 1
self.fp16 = False
self.loss_scale = 0
self.do_not_use_tensorboard = False
self.fix_utterance_encoder = False
self.do_eval = True
self.num_train_epochs = 300
self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-base-uncased")
self.do_lower_case = True
self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn'
self.target_slot = 'all'
self.distance_metric = 'euclidean'
self.patience = 15
self.hidden_dim = 300
self.max_seq_length = 35
self.max_turn_length = 23
self.fp16_loss_scale = 0.0
self.data_dir = 'data/crosswoz_en/'
self.tf_dir = 'tensorboard'
self.tmp_data_dir = 'processed_data/'
self.output_dir = 'model_output/'
args = DotMap()
\ No newline at end of file
import csv
import os
import json
import collections
import logging
import re
import torch
from convlab2.dst.sumbt.crosswoz_en.convert_to_glue_format import null
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if len(line) > 0 and line[0][0] == '#': # ignore comments (starting with '#')
continue
lines.append(line)
return lines
class Processor(DataProcessor):
"""Processor for the belief tracking dataset (GLUE version)."""
def __init__(self, config):
super(Processor, self).__init__()
# crosswoz dataset
with open(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), config.data_dir, "ontology.json"), "r") as fp_ontology:
ontology = json.load(fp_ontology)
for slot in ontology.keys():
ontology[slot].append(null)
assert config.target_slot == 'all'
# if not config.target_slot == 'all':
# slot_idx = {'attraction': '0:1:2', 'bus': '3:4:5:6', 'hospital': '7',
# 'hotel': '8:9:10:11:12:13:14:15:16:17', \
# 'restaurant': '18:19:20:21:22:23:24', 'taxi': '25:26:27:28', 'train': '29:30:31:32:33:34'}
# target_slot = []
# for key, value in slot_idx.items():
# if key != config.target_slot:
# target_slot.append(value)
# config.target_slot = ':'.join(target_slot)
# sorting the ontology according to the alphabetic order of the slots
ontology = collections.OrderedDict(sorted(ontology.items()))
# select slots to train
nslots = len(ontology.keys())
target_slot = list(ontology.keys())
if config.target_slot == 'all':
self.target_slot_idx = [*range(0, nslots)]
else:
self.target_slot_idx = sorted([int(x) for x in config.target_slot.split(':')])
for idx in range(0, nslots):
if not idx in self.target_slot_idx:
del ontology[target_slot[idx]]
self.ontology = ontology
self.target_slot = list(self.ontology.keys())
# for i, slot in enumerate(self.target_slot):
# if slot == "pricerange":
# self.target_slot[i] = "price range"
logger.info('Processor: target_slot')
logger.info(self.target_slot)
def get_train_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", accumulation)
def get_dev_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", accumulation)
def get_test_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test", accumulation)
def get_labels(self):
"""See base class."""
return [list(map(str.lower, self.ontology[slot])) for slot in self.target_slot]
def _create_examples(self, lines, set_type, accumulation=False):
"""Creates examples for the training and dev sets."""
prev_dialogue_index = None
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s-%s" % (set_type, line[0], line[1]) # line[0]: dialogue index, line[1]: turn index
if accumulation:
if prev_dialogue_index is None or prev_dialogue_index != line[0]:
text_a = line[2]
text_b = line[3]
prev_dialogue_index = line[0]
else:
# The symbol '#' will be replaced with '[SEP]' after tokenization.
text_a = line[2] + " # " + text_a
text_b = line[3] + " # " + text_b
else:
text_a = line[2] # line[2]: user utterance
text_b = line[3] # line[3]: system response
label = [line[4 + idx] for idx in self.target_slot_idx]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def normalize_text(text):
global replacements
# lower case every word
text = text.lower()
# replace white spaces in front and end
text = re.sub(r'^\s*|\s*$', '', text)
# hotel domain pfb30
text = re.sub(r"b&b", "bed and breakfast", text)
text = re.sub(r"b and b", "bed and breakfast", text)
# replace st.
text = text.replace(';', ',')
text = re.sub('$\/', '', text)
text = text.replace('/', ' and ')
# replace other special characters
text = text.replace('-', ' ')
text = re.sub('[\"\<>@\(\)]', '', text) # remove
# insert white space before and after tokens:
for token in ['?', '.', ',', '!']:
text = insertSpace(token, text)
# insert white space for 's
text = insertSpace('\'s', text)
# replace it's, does't, you'd ... etc
text = re.sub('^\'', '', text)
text = re.sub('\'$', '', text)
text = re.sub('\'\s', ' ', text)
text = re.sub('\s\'', ' ', text)
for fromx, tox in replacements:
text = ' ' + text + ' '
text = text.replace(fromx, tox)[1:-1]
# remove multiple spaces
text = re.sub(' +', ' ', text)
# concatenate numbers
tmp = text
tokens = text.split()
i = 1
while i < len(tokens):
if re.match(u'^\d+$', tokens[i]) and \
re.match(u'\d+$', tokens[i - 1]):
tokens[i - 1] += tokens[i]
del tokens[i]
else:
i += 1
text = ' '.join(tokens)
return text
def insertSpace(token, text):
sidx = 0
while True:
sidx = text.find(token, sidx)
if sidx == -1:
break
if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
re.match('[0-9]', text[sidx + 1]):
sidx += 1
continue
if text[sidx - 1] != ' ':
text = text[:sidx] + ' ' + text[sidx:]
sidx += 1
if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
text = text[:sidx + 1] + ' ' + text[sidx + 1:]
sidx += 1
return text
# convert tokens in labels to the identifier in vocabulary
def get_label_embedding(labels, max_seq_length, tokenizer, device):
features = []
for label in labels:
label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
# just truncate, some names are unreasonable long
label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)[:max_seq_length]
label_len = len(label_token_ids)
label_padding = [0] * (max_seq_length - len(label_token_ids))
label_token_ids += label_padding
assert len(label_token_ids) == max_seq_length
features.append((label_token_ids, label_len))
all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
return all_label_token_ids, all_label_len
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x / warmup
return 1.0 - x
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_len, label_id):
self.input_ids = input_ids
self.input_len = input_len
self.label_id = label_id
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, max_turn_length):
"""Loads a data file into a list of `InputBatch`s."""
label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
slot_dim = len(label_list)
features = []
prev_dialogue_idx = None
all_padding = [0] * max_seq_length
all_padding_len = [0, 0]
max_turn = 0
for (ex_index, example) in enumerate(examples):
if max_turn < int(example.guid.split('-')[2]):
max_turn = int(example.guid.split('-')[2])
max_turn_length = min(max_turn + 1, max_turn_length)
logger.info("max_turn_length = %d" % max_turn)
for (ex_index, example) in enumerate(examples):
tokens_a = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_a)]
tokens_b = None
if example.text_b:
tokens_b = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_b)]
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
input_len = [len(tokens), 0]
if tokens_b:
tokens += tokens_b + ["[SEP]"]
input_len[1] = len(tokens_b) + 1
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == max_seq_length
FLAG_TEST = False
if example.label is not None:
label_id = []
label_info = 'label: '
for i, label in enumerate(example.label):
if label == 'dontcare':
label = 'do not care'
label_id.append(label_map[i][label])
label_info += '%s (id = %d) ' % (label, label_map[i][label])
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_len: %s" % " ".join([str(x) for x in input_len]))
logger.info("label: " + label_info)
else:
FLAG_TEST = True
label_id = None
curr_dialogue_idx = example.guid.split('-')[1]
curr_turn_idx = int(example.guid.split('-')[2])
if prev_dialogue_idx is not None and prev_dialogue_idx != curr_dialogue_idx:
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
# print(len(features), max_turn_length)
assert len(features) % max_turn_length == 0
if prev_dialogue_idx is None or prev_turn_idx < max_turn_length:
features.append(
InputFeatures(input_ids=input_ids,
input_len=input_len,
label_id=label_id))
prev_dialogue_idx = curr_dialogue_idx
prev_turn_idx = curr_turn_idx
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_len = torch.tensor([f.input_len for f in features], dtype=torch.long)
if not FLAG_TEST:
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
# reshape tensors to [#batch, #max_turn_length, #max_seq_length]
all_input_ids = all_input_ids.view(-1, max_turn_length, max_seq_length)
all_input_len = all_input_len.view(-1, max_turn_length, 2)
if not FLAG_TEST:
all_label_ids = all_label_ids.view(-1, max_turn_length, slot_dim)
else:
all_label_ids = None
return all_input_ids, all_input_len, all_label_ids
def eval_all_accs(pred_slot, labels, accuracies):
def _eval_acc(_pred_slot, _labels):
slot_dim = _labels.size(-1)
accuracy = (_pred_slot == _labels).view(-1, slot_dim)
num_turn = torch.sum(_labels[:, :, 0].view(-1) > -1, 0).float()
num_data = torch.sum(_labels > -1).float()
# joint accuracy
joint_acc = sum(torch.sum(accuracy, 1) / slot_dim).float()
# slot accuracy
slot_acc = torch.sum(accuracy).float()
return joint_acc, slot_acc, num_turn, num_data
# 7 domains
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot, labels)
accuracies['joint7'] += joint_acc
accuracies['slot7'] += slot_acc
accuracies['num_turn'] += num_turn
accuracies['num_slot7'] += num_data
# restaurant domain
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot[:,:,18:25], labels[:,:,18:25])
accuracies['joint_rest'] += joint_acc
accuracies['slot_rest'] += slot_acc
accuracies['num_slot_rest'] += num_data
pred_slot5 = torch.cat((pred_slot[:,:,0:3], pred_slot[:,:,8:]), 2)
label_slot5 = torch.cat((labels[:,:,0:3], labels[:,:,8:]), 2)
# 5 domains (excluding bus and hotel domain)
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot5, label_slot5)
accuracies['joint5'] += joint_acc
accuracies['slot5'] += slot_acc
accuracies['num_slot5'] += num_data
return accuracies
import os import os
import convlab2
class DotMap(): class DotMap():
def __init__(self): def __init__(self):
self.max_label_length = 32 self.max_label_length = 32
...@@ -28,7 +30,7 @@ class DotMap(): ...@@ -28,7 +30,7 @@ class DotMap():
self.num_train_epochs = 300 self.num_train_epochs = 300
self.bert_model = 'bert-base-uncased' self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-base-uncased")
self.do_lower_case = True self.do_lower_case = True
self.task_name = 'bert-gru-sumbt' self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn' self.nbt = 'rnn'
......
pre-trained/
model_output/
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker as SUMBT
import json
import zipfile
from convlab2.dst.sumbt.multiwoz_zh.sumbt_config import *
def trans_value(value):
trans = {
'': '未提及',
'没有提到': '未提及',
'没有': '未提及',
'未提到': '未提及',
'一个也没有': '未提及',
'': '未提及',
'是的': '',
'不是': '没有',
'不关心': '不在意',
'不在乎': '不在意',
}
return trans.get(value, value)
def convert_to_glue_format(data_dir, sumbt_dir):
if not os.path.isdir(os.path.join(sumbt_dir, args.tmp_data_dir)):
os.mkdir(os.path.join(sumbt_dir, args.tmp_data_dir))
### Read ontology file
with open(os.path.join(data_dir, "ontology.json"), "r") as fp_ont:
data_ont = json.load(fp_ont)
ontology = {}
for domain_slot in data_ont:
domain, slot = domain_slot.split('-')
if domain not in ontology:
ontology[domain] = {}
ontology[domain][slot] = {}
for value in data_ont[domain_slot]:
ontology[domain][slot][value] = 1
### Read woz logs and write to tsv files
if os.path.exists(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")):
print('data has been processed!')
return 0
print('begin processing data')
fp_train = open(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"), "w")
fp_dev = open(os.path.join(sumbt_dir, args.tmp_data_dir, "dev.tsv"), "w")
fp_test = open(os.path.join(sumbt_dir, args.tmp_data_dir, "test.tsv"), "w")
fp_train.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_dev.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_test.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
fp_train.write(str(domain) + '-' + str(slot) + '\t')
fp_dev.write(str(domain) + '-' + str(slot) + '\t')
fp_test.write(str(domain) + '-' + str(slot) + '\t')
fp_train.write('\n')
fp_dev.write('\n')
fp_test.write('\n')
# fp_data = open(os.path.join(SELF_DATA_DIR, "data.json"), "r")
# data = json.load(fp_data)
file_split = ['train', 'val', 'test']
fp = [fp_train, fp_dev, fp_test]
for split_type, split_fp in zip(file_split, fp):
zipfile_name = "{}.json.zip".format(split_type)
zip_fp = zipfile.ZipFile(os.path.join(data_dir, zipfile_name))
data = json.loads(str(zip_fp.read(zip_fp.namelist()[0]), 'utf-8'))
for file_id in data:
user_utterance = ''
system_response = ''
turn_idx = 0
for idx, turn in enumerate(data[file_id]['log']):
if idx % 2 == 0: # user turn
user_utterance = data[file_id]['log'][idx]['text']
else: # system turn
user_utterance = user_utterance.replace('\t', ' ')
user_utterance = user_utterance.replace('\n', ' ')
user_utterance = user_utterance.replace(' ', ' ')
system_response = system_response.replace('\t', ' ')
system_response = system_response.replace('\n', ' ')
system_response = system_response.replace(' ', ' ')
split_fp.write(str(file_id)) # 0: dialogue ID
split_fp.write('\t' + str(turn_idx)) # 1: turn index
split_fp.write('\t' + str(user_utterance)) # 2: user utterance
split_fp.write('\t' + str(system_response)) # 3: system response
belief = {}
for domain in data[file_id]['log'][idx]['metadata'].keys():
for slot in data[file_id]['log'][idx]['metadata'][domain]['semi'].keys():
value = data[file_id]['log'][idx]['metadata'][domain]['semi'][slot].strip()
# value = value_trans.get(value, value)
value = trans_value(value)
if domain not in ontology:
print("domain (%s) is not defined" % domain)
continue
if slot not in ontology[domain]:
print("slot (%s) in domain (%s) is not defined" % (slot, domain)) # bus-arriveBy not defined
continue
if value not in ontology[domain][slot] and value != '未提及':
print("%s: value (%s) in domain (%s) slot (%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = '未提及'
belief[str(domain) + '-' + str(slot)] = value
for slot in data[file_id]['log'][idx]['metadata'][domain]['book'].keys():
if slot == 'booked':
continue
if domain == '公共汽车' and slot == '人数' or domain == '列车' and slot == '票价':
continue # not defined in ontology
value = data[file_id]['log'][idx]['metadata'][domain]['book'][slot].strip()
value = value_trans.get(value, value)
if str('预订' + slot) not in ontology[domain]:
print("预订%s is not defined in domain %s" % (slot, domain))
continue
if value not in ontology[domain]['预订' + slot] and value != '未提及':
print("%s: value (%s) in domain (%s) slot (预订%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = '未提及'
belief[str(domain) + '-预订' + str(slot)] = value
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
key = str(domain) + '-' + str(slot)
if key in belief:
split_fp.write('\t' + belief[key])
else:
split_fp.write('\t未提及')
split_fp.write('\n')
split_fp.flush()
system_response = data[file_id]['log'][idx]['text']
turn_idx += 1
fp_train.close()
fp_dev.close()
fp_test.close()
print('data has been processed!')
\ No newline at end of file
This diff is collapsed.
import os
import convlab2
class DotMap():
def __init__(self):
self.max_label_length = 32
self.max_turn_length = 22
self.num_rnn_layers = 1
self.zero_init_rnn = False
self.attn_head = 4
self.do_eval = True
self.do_train = False
self.train_batch_size = 3
self.dev_batch_size = 1
self.eval_batch_size = 1
self.learning_rate = 5e-5
self.num_train_epochs = 3
self.patience = 10
self.warmup_proportion = 0.1
self.local_rank = -1
self.seed = 42
self.gradient_accumulation_steps = 1
self.fp16 = False
self.loss_scale = 0
self.do_not_use_tensorboard = False
self.fix_utterance_encoder = False
self.do_eval = True
self.num_train_epochs = 300
self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-chinese-wwm-ext")
self.do_lower_case = True
self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn'
# self.output_dir = os.path.join(path, 'ckpt/')
self.target_slot = 'all'
self.learning_rate = 5e-5
self.distance_metric = 'euclidean'
self.patience = 15
self.hidden_dim = 300
self.max_label_length = 32
self.max_seq_length = 64
self.max_turn_length = 22
self.fp16_loss_scale = 0.0
self.data_dir = 'data/multiwoz_zh/'
self.tf_dir = 'tensorboard'
self.tmp_data_dir = 'processed_data/'
self.output_dir = 'model_output/'
args = DotMap()
\ No newline at end of file
import csv
import os
import json
import collections
import logging
import re
import torch
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if len(line) > 0 and line[0][0] == '#': # ignore comments (starting with '#')
continue
lines.append(line)
return lines
class Processor(DataProcessor):
"""Processor for the belief tracking dataset (GLUE version)."""
def __init__(self, config):
super(Processor, self).__init__()
print(config)
# MultiWOZ dataset
with open(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), config.data_dir, "ontology.json"), "r") as fp_ontology:
ontology = json.load(fp_ontology)
for slot in ontology.keys():
ontology[slot].append("未提及")
if config.target_slot != 'all':
raise Exception('unsupported')
# sorting the ontology according to the alphabetic order of the slots
ontology = collections.OrderedDict(sorted(ontology.items()))
# select slots to train
nslots = len(ontology.keys())
target_slot = list(ontology.keys())
self.target_slot_idx = [*range(0, nslots)]
for idx in range(0, nslots):
if not idx in self.target_slot_idx:
del ontology[target_slot[idx]]
self.ontology = ontology
self.target_slot = list(self.ontology.keys())
for i, slot in enumerate(self.target_slot):
if slot == "pricerange":
self.target_slot[i] = "price range"
logger.info('Processor: target_slot')
logger.info(self.target_slot)
def get_train_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", accumulation)
def get_dev_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", accumulation)
def get_test_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test", accumulation)
def get_labels(self):
"""See base class."""
return [self.ontology[slot] for slot in self.target_slot]
def _create_examples(self, lines, set_type, accumulation=False):
"""Creates examples for the training and dev sets."""
prev_dialogue_index = None
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s-%s" % (set_type, line[0], line[1]) # line[0]: dialogue index, line[1]: turn index
if accumulation:
if prev_dialogue_index is None or prev_dialogue_index != line[0]:
text_a = line[2]
text_b = line[3]
prev_dialogue_index = line[0]
else:
# The symbol '#' will be replaced with '[SEP]' after tokenization.
text_a = line[2] + " # " + text_a
text_b = line[3] + " # " + text_b
else:
text_a = line[2] # line[2]: user utterance
text_b = line[3] # line[3]: system response
label = [line[4 + idx] for idx in self.target_slot_idx]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def normalize_text(text):
global replacements
# lower case every word
text = text.lower()
# replace white spaces in front and end
text = re.sub(r'^\s*|\s*$', '', text)
# hotel domain pfb30
text = re.sub(r"b&b", "bed and breakfast", text)
text = re.sub(r"b and b", "bed and breakfast", text)
# replace st.
text = text.replace(';', ',')
text = re.sub('$\/', '', text)
text = text.replace('/', ' and ')
# replace other special characters
text = text.replace('-', ' ')
text = re.sub('[\"\<>@\(\)]', '', text) # remove
# insert white space before and after tokens:
for token in ['?', '.', ',', '!']:
text = insertSpace(token, text)
# insert white space for 's
text = insertSpace('\'s', text)
# replace it's, does't, you'd ... etc
text = re.sub('^\'', '', text)
text = re.sub('\'$', '', text)
text = re.sub('\'\s', ' ', text)
text = re.sub('\s\'', ' ', text)
for fromx, tox in replacements:
text = ' ' + text + ' '
text = text.replace(fromx, tox)[1:-1]
# remove multiple spaces
text = re.sub(' +', ' ', text)
# concatenate numbers
tmp = text
tokens = text.split()
i = 1
while i < len(tokens):
if re.match(u'^\d+$', tokens[i]) and \
re.match(u'\d+$', tokens[i - 1]):
tokens[i - 1] += tokens[i]
del tokens[i]
else:
i += 1
text = ' '.join(tokens)
return text
def insertSpace(token, text):
sidx = 0
while True:
sidx = text.find(token, sidx)
if sidx == -1:
break
if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
re.match('[0-9]', text[sidx + 1]):
sidx += 1
continue
if text[sidx - 1] != ' ':
text = text[:sidx] + ' ' + text[sidx:]
sidx += 1
if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
text = text[:sidx + 1] + ' ' + text[sidx + 1:]
sidx += 1
return text
def get_label_embedding(labels, max_seq_length, tokenizer, device):
features = []
for label in labels:
label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)
label_len = len(label_token_ids)
label_padding = [0] * (max_seq_length - len(label_token_ids))
label_token_ids += label_padding
assert len(label_token_ids) == max_seq_length
features.append((label_token_ids, label_len))
all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
return all_label_token_ids, all_label_len
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x / warmup
return 1.0 - x
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_len, label_id):
self.input_ids = input_ids
self.input_len = input_len
self.label_id = label_id
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, max_turn_length):
"""Loads a data file into a list of `InputBatch`s."""
label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
slot_dim = len(label_list)
features = []
prev_dialogue_idx = None
all_padding = [0] * max_seq_length
all_padding_len = [0, 0]
max_turn = 0
for (ex_index, example) in enumerate(examples):
if max_turn < int(example.guid.split('-')[2]):
max_turn = int(example.guid.split('-')[2])
max_turn_length = min(max_turn + 1, max_turn_length)
logger.info("max_turn_length = %d" % max_turn)
for (ex_index, example) in enumerate(examples):
tokens_a = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_a)]
tokens_b = None
if example.text_b:
tokens_b = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_b)]
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
input_len = [len(tokens), 0]
if tokens_b:
tokens += tokens_b + ["[SEP]"]
input_len[1] = len(tokens_b) + 1
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == max_seq_length
FLAG_TEST = False
if example.label is not None:
label_id = []
label_info = 'label: '
for i, label in enumerate(example.label):
if label == 'dontcare':
label = 'do not care'
label_id.append(label_map[i][label])
label_info += '%s (id = %d) ' % (label, label_map[i][label])
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_len: %s" % " ".join([str(x) for x in input_len]))
logger.info("label: " + label_info)
else:
FLAG_TEST = True
label_id = None
curr_dialogue_idx = example.guid.split('-')[1]
curr_turn_idx = int(example.guid.split('-')[2])
if prev_dialogue_idx is not None and prev_dialogue_idx != curr_dialogue_idx:
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
if prev_dialogue_idx is None or prev_turn_idx < max_turn_length:
features.append(
InputFeatures(input_ids=input_ids,
input_len=input_len,
label_id=label_id))
prev_dialogue_idx = curr_dialogue_idx
prev_turn_idx = curr_turn_idx
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_len = torch.tensor([f.input_len for f in features], dtype=torch.long)
if not FLAG_TEST:
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
# reshape tensors to [#batch, #max_turn_length, #max_seq_length]
all_input_ids = all_input_ids.view(-1, max_turn_length, max_seq_length)
all_input_len = all_input_len.view(-1, max_turn_length, 2)
if not FLAG_TEST:
all_label_ids = all_label_ids.view(-1, max_turn_length, slot_dim)
else:
all_label_ids = None
return all_input_ids, all_input_len, all_label_ids
def eval_all_accs(pred_slot, labels, accuracies):
def _eval_acc(_pred_slot, _labels):
slot_dim = _labels.size(-1)
accuracy = (_pred_slot == _labels).view(-1, slot_dim)
num_turn = torch.sum(_labels[:, :, 0].view(-1) > -1, 0).float()
num_data = torch.sum(_labels > -1).float()
# joint accuracy
# joint_acc = sum(torch.sum(accuracy, 1) / slot_dim).float()
num_slots = accuracy.shape[1]
joint_acc = sum(torch.sum(accuracy, 1) == num_slots)
# slot accuracy
slot_acc = torch.sum(accuracy).float()
return joint_acc, slot_acc, num_turn, num_data
# 7 domains
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot, labels)
accuracies['joint7'] += joint_acc
accuracies['slot7'] += slot_acc
accuracies['num_turn'] += num_turn
accuracies['num_slot7'] += num_data
# restaurant domain
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot[:,:,18:25], labels[:,:,18:25])
accuracies['joint_rest'] += joint_acc
accuracies['slot_rest'] += slot_acc
accuracies['num_slot_rest'] += num_data
pred_slot5 = torch.cat((pred_slot[:,:,0:3], pred_slot[:,:,8:]), 2)
label_slot5 = torch.cat((labels[:,:,0:3], labels[:,:,8:]), 2)
# 5 domains (excluding bus and hotel domain)
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot5, label_slot5)
accuracies['joint5'] += joint_acc
accuracies['slot5'] += slot_acc
accuracies['num_slot5'] += num_data
return accuracies
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment