Skip to content
Snippets Groups Projects
Unverified Commit 4b133a8d authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

Translation train on MultiWOZ (Chinese) nad CrossWOZ (English) of SUMBT (#17)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en
parent 206af602
Branches
No related tags found
No related merge requests found
Showing
with 2815 additions and 51 deletions
...@@ -66,3 +66,6 @@ convlab2/dst/trade/multiwoz_config/ ...@@ -66,3 +66,6 @@ convlab2/dst/trade/multiwoz_config/
deploy/bert_multiwoz_all.zip deploy/bert_multiwoz_all.zip
deploy/templates/dialog_eg.html deploy/templates/dialog_eg.html
test.py test.py
*.egg-info
pre-trained-models/
\ No newline at end of file
...@@ -159,6 +159,64 @@ By running `convlab2/nlg/evaluate.py MultiWOZ $model sys` ...@@ -159,6 +159,64 @@ By running `convlab2/nlg/evaluate.py MultiWOZ $model sys`
| Template | 0.3309 | | Template | 0.3309 |
| SCLSTM | 0.4884 | | SCLSTM | 0.4884 |
## translation train with SUMBT
### train
With Convlab-2, you can train SUMBT on a translated dataset like this:
```python
# train.py
import os
from sys import argv
if __name__ == "__main__":
if len(argv) != 2:
print('usage: python3 train.py [dataset]')
exit(1)
assert argv[1] in ['multiwoz', 'crosswoz']
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBT_PATH
if argv[1] == 'multiwoz':
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker as SUMBT
elif argv[1] == 'crosswoz':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker as SUMBT
sumbt = SUMBT()
sumbt.train(True)
```
### evaluate
Execute `evaluate.py` (under `convlab2/dst/`) with following command:
```bash
python3 evaluate.py [CorssWOZ-en|MultiWOZ-zh] [test|human|val]
```
`human` option will make the model evaluate on the validation set translated by human.
evaluation of our pre-trained models are:
| type | CrossWOZ-en | MultiWOZ-zh |
| ----- | ----------- | ----------- |
| test | 12.4% | 42.3% |
| human | 10.9% | 48.2% |
| val | 12.2% | 44.8% |
Note: You may want to download pre-traiend BERT models and translation-train pre-trained DST models provided by us.
Without modifying any code, you could:
- download pre-trained BERT model from:
- [CorssWOZ-en](https://huggingface.co/bert-base-uncased)
- [MultiWOZ-zh](https://huggingface.co/hfl/chinese-bert-wwm-ext)
extract it to `./pre-trained-models`.
- for a pre-trained DST model, e.g. say the DST model is SUMBT, data set is CrossWOZ (English), (after extraction) just save the pre-trained model under `./convlab2/dst/sumbt/crosswoz_en/pre-trained` and name it with `pytorch_model.bin`.
## Issues ## Issues
You are welcome to create an issue if you want to request a feature, report a bug or ask a general question. You are welcome to create an issue if you want to request a feature, report a bug or ask a general question.
......
...@@ -10,3 +10,7 @@ from os.path import abspath, dirname ...@@ -10,3 +10,7 @@ from os.path import abspath, dirname
def get_root_path(): def get_root_path():
return dirname(dirname(abspath(__file__))) return dirname(dirname(abspath(__file__)))
import os
DATA_ROOT = os.path.join(get_root_path(), 'data')
\ No newline at end of file
# -*- coding: gbk -*-
""" """
Evaluate NLU models on specified dataset Evaluate NLU models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule] Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule]
...@@ -12,8 +11,9 @@ import copy ...@@ -12,8 +11,9 @@ import copy
import jieba import jieba
multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat'] multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat']
crosswoz_slot_list = ["쒼듐-쳔튿", "쒼듐-팀롸", "꽜반-츰냔", "아듦-송목", "아듦-팀롸", "쒼듐-츰냔", "쒼듐-뒈囹", "쒼듐-踏鯤珂쇌", "꽜반-檀撚珂쇌", "꽜반-팀롸", "아듦-츰냔", "아듦-鷺긋쒼듐", "아듦-아듦嘉-싻今륩蛟", "아듦-아듦잚謹", "꽜반-훙엇句롤", "꽜반-股수꽉", "아듦-아듦嘉", "아듦-든뺐", "쒼듐-든뺐", "꽜반-鷺긋꽜반", "꽜반-든뺐", "꽜반-none", "꽜반-뒈囹", "아듦-아듦嘉-轟緊렛", "아듦-뒈囹", "쒼듐-鷺긋쒼듐", "쒼듐-鷺긋아듦", "놔理-놔랙뒈", "놔理-커돨뒈", "뒈屆-놔랙뒈", "뒈屆-커돨뒈", "쒼듐-鷺긋꽜반", "아듦-鷺긋꽜반", "놔理-났謹", "꽜반-鷺긋쒼듐", "꽜반-鷺긋아듦", "뒈屆-놔랙뒈맒쐤뒈屆籃", "뒈屆-커돨뒈맒쐤뒈屆籃", "쒼듐-none", "아듦-아듦嘉-蛟櫓懃", "꽜반-都쥴堵", "아듦-아듦嘉-櫓駕꽜戒", "아듦-아듦嘉-쌈籃륩蛟", "아듦-아듦嘉-벌셥낀槁든뺐", "아듦-아듦嘉-뉘루샙", "아듦-아듦嘉-삔累杆", "아듦-都쥴堵", "아듦-none", "아듦-아듦嘉-욱던貢", "아듦-아듦嘉-였빱鬼벚륩蛟", "아듦-아듦嘉-아듦몹뇹瓊묩wifi", "아듦-아듦嘉-킁폭", "아듦-아듦嘉-spa", "놔理-났탬", "쒼듐-都쥴堵", "아듦-아듦嘉-契쟀셍닸", "아듦-아듦嘉-鮫駕꽜戒", "아듦-아듦嘉-아걸", "아듦-아듦嘉-豆꽜륩蛟", "아듦-아듦嘉-숯렛", "아듦-아듦嘉-꽥섣훙嘉", "아듦-아듦嘉-출롤懇코든뺐", "아듦-아듦嘉-쌈덤棍깟", "아듦-아듦嘉-꼬롸렛쇌瓊묩wifi", "아듦-아듦嘉-求擄륩蛟", "아듦-아듦嘉-理났", "아듦-아듦嘉-무묾혐堵뵨꼬롸렛쇌瓊묩wifi", "아듦-아듦嘉-24鬼珂훑彊", "아듦-아듦嘉-侊홋", "아듦-아듦嘉-컬", "아듦-아듦嘉-澗롤界났貫", "아듦-鷺긋아듦", "아듦-아듦嘉-쌈샙륩蛟", "아듦-아듦嘉-杰唐렛쇌瓊묩wifi", "아듦-아듦嘉-펙탬杆", "아듦-아듦嘉-출롤벌코낀槁든뺐", "아듦-아듦嘉-杆코踏曇넥", "아듦-아듦嘉-豆꽜륩蛟출롤", "아듦-아듦嘉-무묾혐堵瓊묩wifi", "아듦-아듦嘉-杆棍踏曇넥"] crosswoz_slot_list = ["景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分", "酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话", "餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地", "出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站", "地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务", "酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网", "酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域", "酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施", "酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车", "酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位", "酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话", "酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"]
from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list
from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list
def format_history(context): def format_history(context):
history = [] history = []
...@@ -37,7 +37,7 @@ def reformat_state(state): ...@@ -37,7 +37,7 @@ def reformat_state(state):
domain_data = domain_data['semi'] domain_data = domain_data['semi']
for slot in domain_data.keys(): for slot in domain_data.keys():
val = domain_data[slot] val = domain_data[slot]
if val is not None and val != '' and val != 'not mentioned': if val is not None and val not in ['', 'not mentioned', '未提及', '未提到', '没有提到']:
new_state.append(domain + '-' + slot + '-' + val) new_state.append(domain + '-' + slot + '-' + val)
# lower # lower
new_state = [item.lower() for item in new_state] new_state = [item.lower() for item in new_state]
...@@ -47,19 +47,27 @@ def reformat_state_crosswoz(state): ...@@ -47,19 +47,27 @@ def reformat_state_crosswoz(state):
if 'belief_state' in state: if 'belief_state' in state:
state = state['belief_state'] state = state['belief_state']
new_state = [] new_state = []
# print(state)
for domain in state.keys(): for domain in state.keys():
domain_data = state[domain] domain_data = state[domain]
for slot in domain_data.keys(): for slot in domain_data.keys():
if slot == 'selectedResults': continue if slot == 'selectedResults': continue
val = domain_data[slot] val = domain_data[slot]
if val is not None and val != '': if slot == 'Hotel Facilities' and val not in ['', 'none']:
for facility in val.split(','):
new_state.append(domain + '-' + f'Hotel Facilities - {facility}' + 'yes')
else:
if val is not None and val not in ['', 'none']:
# print(domain, slot, val)
new_state.append(domain + '-' + slot + '-' + val) new_state.append(domain + '-' + slot + '-' + val)
return new_state return new_state
def compute_acc(gold, pred, slot_temp): def compute_acc(gold, pred, slot_temp):
# TODO: not mentioned in gold # TODO: not mentioned in gold
miss_gold = 0 miss_gold = 0
miss_slot = [] miss_slot = []
# print(gold, pred)
for g in gold: for g in gold:
if g not in pred: if g not in pred:
miss_gold += 1 miss_gold += 1
...@@ -124,34 +132,44 @@ if __name__ == '__main__': ...@@ -124,34 +132,44 @@ if __name__ == '__main__':
numpy.random.seed(seed) numpy.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if len(sys.argv) != 3: if len(sys.argv) != 4:
print("usage:") print("usage:")
print("\t python evaluate.py dataset model") print("\t python evaluate.py dataset model")
print("\t dataset=MultiWOZ, CrossWOZ") print("\t dataset=MultiWOZ, MultiWOZ-zh, CrossWOZ, CrossWOZ-en")
print("\t model=TRADE, mdbt, sumbt") print("\t model=TRADE, mdbt, sumbt")
print("\t val=[val|test|human]")
sys.exit() sys.exit()
## init phase ## init phase
dataset_name = sys.argv[1] dataset_name = sys.argv[1]
model_name = sys.argv[2] model_name = sys.argv[2]
if dataset_name == 'MultiWOZ': data_key = sys.argv[3]
if model_name == 'TRADE':
if dataset_name.startswith('MultiWOZ'):
if dataset_name.endswith('zh'):
if model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
raise Exception("Available models: sumbt")
else:
if model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
model = SUMBTTracker()
elif model_name == 'TRADE':
from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE
model = MultiWOZTRADE() model = MultiWOZTRADE()
elif model_name == 'mdbt': elif model_name == 'mdbt':
from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT
model = MultiWozMDBT() model = MultiWozMDBT()
elif model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
model = SUMBTTracker()
else: else:
raise Exception("Available models: TRADE/mdbt/sumbt") raise Exception("Available models: TRADE/mdbt/sumbt")
## load data ## load data
from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader()) dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader(dataset_name.endswith('zh')))
data = dataloader.load_data(data_key='test')['test'] data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['belief_state'] context, golden_truth = data['context'], data['belief_state']
all_predictions = {} all_predictions = {}
test_set = [] test_set = []
...@@ -160,7 +178,6 @@ if __name__ == '__main__': ...@@ -160,7 +178,6 @@ if __name__ == '__main__':
turn_count = 0 turn_count = 0
is_start = True is_start = True
for i in tqdm(range(len(context))): for i in tqdm(range(len(context))):
# for i in tqdm(range(200)): # for test
if len(context[i]) == 0: if len(context[i]) == 0:
turn_count = 0 turn_count = 0
if is_start: if is_start:
...@@ -181,19 +198,23 @@ if __name__ == '__main__': ...@@ -181,19 +198,23 @@ if __name__ == '__main__':
'turn_belief': reformat_state(y), 'turn_belief': reformat_state(y),
'pred_bs_ptr': reformat_state(pred) 'pred_bs_ptr': reformat_state(pred)
} }
# print('golden: ', reformat_state(y))
# print('pred :', reformat_state(pred))
turn_count += 1 turn_count += 1
# add last session # add last session
if len(curr_sess) > 0: if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess) all_predictions[session_count] = copy.deepcopy(curr_sess)
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", multiwoz_slot_list) slot_list = multiwoz_zh_slot_list if dataset_name.endswith('zh') else multiwoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr, evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr} "Joint F1": F1_score_ptr}
print(evaluation_metrics) print(evaluation_metrics)
elif dataset_name.startswith('CrossWOZ'):
elif dataset_name == 'CrossWOZ': en = dataset_name.endswith('en')
if en:
if model_name == 'sumbt':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
if model_name == 'TRADE': if model_name == 'TRADE':
from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE
model = CrossWOZTRADE() model = CrossWOZTRADE()
...@@ -210,8 +231,8 @@ if __name__ == '__main__': ...@@ -210,8 +231,8 @@ if __name__ == '__main__':
from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader
dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader()) dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader(en))
data = dataloader.load_data(data_key='test')['test'] data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['sys_state_init'] context, golden_truth = data['context'], data['sys_state_init']
all_predictions = {} all_predictions = {}
test_set = [] test_set = []
...@@ -220,7 +241,6 @@ if __name__ == '__main__': ...@@ -220,7 +241,6 @@ if __name__ == '__main__':
turn_count = 0 turn_count = 0
is_start = True is_start = True
for i in tqdm(range(len(context))): for i in tqdm(range(len(context))):
# for i in tqdm(range(10)): # for test
if len(context[i]) == 0: if len(context[i]) == 0:
turn_count = 0 turn_count = 0
if is_start: if is_start:
...@@ -229,12 +249,17 @@ if __name__ == '__main__': ...@@ -229,12 +249,17 @@ if __name__ == '__main__':
all_predictions[session_count] = copy.deepcopy(curr_sess) all_predictions[session_count] = copy.deepcopy(curr_sess)
session_count += 1 session_count += 1
curr_sess = {} curr_sess = {}
# skip usr turn
if len(context[i]) % 2 == 0: if len(context[i]) % 2 == 0:
continue continue
# add turn # add turn
x = context[i] x = context[i]
y = golden_truth[i] y = golden_truth[i]
# process y # process y
if not en:
for domain in y.keys(): for domain in y.keys():
domain_data = y[domain] domain_data = y[domain]
for slot in domain_data.keys(): for slot in domain_data.keys():
...@@ -242,9 +267,9 @@ if __name__ == '__main__': ...@@ -242,9 +267,9 @@ if __name__ == '__main__':
val = domain_data[slot] val = domain_data[slot]
if val is not None and val != '': if val is not None and val != '':
val = sentseg(val) val = sentseg(val)
y[domain][slot] = val domain_data[slot] = val
model.init_session() model.init_session()
model.state['history'] = format_history([sentseg(item) for item in context[i]]) model.state['history'] = format_history([item if en else sentseg(item) for item in context[i]])
pred = model.update(x[-1] if len(x) > 0 else '') pred = model.update(x[-1] if len(x) > 0 else '')
curr_sess[turn_count] = { curr_sess[turn_count] = {
'turn_belief': reformat_state_crosswoz(y), 'turn_belief': reformat_state_crosswoz(y),
...@@ -255,8 +280,9 @@ if __name__ == '__main__': ...@@ -255,8 +280,9 @@ if __name__ == '__main__':
if len(curr_sess) > 0: if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess) all_predictions[session_count] = copy.deepcopy(curr_sess)
slot_list = crosswoz_en_slot_list if en else crosswoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr",
crosswoz_slot_list) slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr, evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr} "Joint F1": F1_score_ptr}
print(evaluation_metrics) print(evaluation_metrics)
*/model_output/
...@@ -272,6 +272,8 @@ class BeliefTracker(nn.Module): ...@@ -272,6 +272,8 @@ class BeliefTracker(nn.Module):
# calculate joint accuracy # calculate joint accuracy
pred_slot = torch.cat(pred_slot, 2) pred_slot = torch.cat(pred_slot, 2)
# print('pred slot:', pred_slot[0][0])
# print('labels:', labels[0][0])
accuracy = (pred_slot == labels).view(-1, slot_dim) accuracy = (pred_slot == labels).view(-1, slot_dim)
acc_slot = torch.sum(accuracy, 0).float() \ acc_slot = torch.sum(accuracy, 0).float() \
/ torch.sum(labels.view(-1, slot_dim) > -1, 0).float() / torch.sum(labels.view(-1, slot_dim) > -1, 0).float()
......
model_output/
pre-trained/
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker as SUMBT
import json
import zipfile
from convlab2.dst.sumbt.crosswoz_en.sumbt_config import *
null = 'none'
def trans_value(value):
trans = {
'': 'none',
}
value = value.strip()
value = trans.get(value, value)
value = value.replace('', "'")
value = value.replace('', "'")
return value
def convert_to_glue_format(data_dir, sumbt_dir):
if not os.path.isdir(os.path.join(sumbt_dir, args.tmp_data_dir)):
os.mkdir(os.path.join(sumbt_dir, args.tmp_data_dir))
### Read ontology file
with open(os.path.join(data_dir, "ontology.json"), "r") as fp_ont:
data_ont = json.load(fp_ont)
ontology = {}
facilities = []
for domain_slot in data_ont:
domain, slot = domain_slot.split('-', 1)
if domain not in ontology:
ontology[domain] = {}
if slot.startswith('Hotel Facilities'):
facilities.append(slot.split(' - ')[1])
ontology[domain][slot] = set(map(str.lower, data_ont[domain_slot]))
### Read woz logs and write to tsv files
tsv_filename = os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")
print('tsv file: ', os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"))
if os.path.exists(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")):
print('data has been processed!')
return 0
else:
print('processing data')
with open(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"), "w") as fp_train, \
open(os.path.join(sumbt_dir, args.tmp_data_dir, "dev.tsv"), "w") as fp_dev, \
open(os.path.join(sumbt_dir, args.tmp_data_dir, "test.tsv"), "w") as fp_test:
fp_train.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_dev.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_test.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
fp_train.write(f'{str(domain)}-{str(slot)}\t')
fp_dev.write(f'{str(domain)}-{str(slot)}\t')
fp_test.write(f'{str(domain)}-{str(slot)}\t')
fp_train.write('\n')
fp_dev.write('\n')
fp_test.write('\n')
# fp_data = open(os.path.join(SELF_DATA_DIR, "data.json"), "r")
# data = json.load(fp_data)
file_split = ['train', 'val', 'test']
fp = [fp_train, fp_dev, fp_test]
for split_type, split_fp in zip(file_split, fp):
zipfile_name = "{}.json.zip".format(split_type)
zip_fp = zipfile.ZipFile(os.path.join(data_dir, zipfile_name))
data = json.loads(str(zip_fp.read(zip_fp.namelist()[0]), 'utf-8'))
for file_id in data:
user_utterance = ''
system_response = ''
turn_idx = 0
messages = data[file_id]['messages']
for idx, turn in enumerate(messages):
if idx % 2 == 0: # user turn
user_utterance = turn['content']
else: # system turn
user_utterance = user_utterance.replace('\t', ' ')
user_utterance = user_utterance.replace('\n', ' ')
user_utterance = user_utterance.replace(' ', ' ')
system_response = system_response.replace('\t', ' ')
system_response = system_response.replace('\n', ' ')
system_response = system_response.replace(' ', ' ')
split_fp.write(str(file_id)) # 0: dialogue ID
split_fp.write('\t' + str(turn_idx)) # 1: turn index
split_fp.write('\t' + str(user_utterance)) # 2: user utterance
split_fp.write('\t' + str(system_response)) # 3: system response
# hardcode the value of facilities as 'yes' and 'no'
belief = {f'Hotel-Hotel Facilities - {str(facility)}': null for facility in facilities}
sys_state_init = turn['sys_state_init']
for domain, slots in sys_state_init.items():
for slot, value in slots.items():
# skip selected results
if isinstance(value, list):
continue
if domain not in ontology:
print("domain (%s) is not defined" % domain)
continue
if slot == 'Hotel Facilities':
for facility in value.split(','):
belief[f'{str(domain)}-Hotel Facilities - {str(facility)}'] = 'yes'
else:
if slot not in ontology[domain]:
print("slot (%s) in domain (%s) is not defined" % (slot, domain)) # bus-arriveBy not defined
continue
value = trans_value(value).lower()
if value not in ontology[domain][slot] and value != null:
print("%s: value (%s) in domain (%s) slot (%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = null
belief[f'{str(domain)}-{str(slot)}'] = value
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
key = str(domain) + '-' + str(slot)
if key in belief:
val = belief[key]
split_fp.write('\t' + val)
else:
split_fp.write(f'\t{null}')
split_fp.write('\n')
split_fp.flush()
system_response = turn['content']
turn_idx += 1
print('data has been processed!')
import copy
from pprint import pprint
import random
from itertools import chain
import numpy as np
import zipfile
from matplotlib import pyplot as plt
# from tensorboardX.writer import SummaryWriter
from tqdm._tqdm import trange, tqdm
from convlab2.util.file_util import cached_path
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam
from convlab2.dst.dst import DST
from convlab2.dst.sumbt.crosswoz_en.convert_to_glue_format import convert_to_glue_format, trans_value
from convlab2.util.crosswoz_en.state import default_state
from convlab2.dst.sumbt.BeliefTrackerSlotQueryMultiSlot import BeliefTracker
from convlab2.dst.sumbt.crosswoz_en.sumbt_utils import *
from convlab2.dst.sumbt.crosswoz_en.sumbt_config import *
from convlab2.dst.sumbt.crosswoz_en.convert_to_glue_format import null
USE_CUDA = torch.cuda.is_available()
N_GPU = torch.cuda.device_count() if USE_CUDA else 1
DEVICE = "cuda" if USE_CUDA else "cpu"
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
SUMBT_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.join(ROOT_PATH, 'data/crosswoz_en')
DOWNLOAD_DIRECTORY = os.path.join(SUMBT_PATH, "downloaded_model/")
crosswoz_en_slot_list = ['Attraction-duration', 'Attraction-fee', 'Attraction-name', 'Attraction-nearby attract.', 'Attraction-nearby hotels', 'Attraction-nearby rest.', 'Attraction-rating', 'Hotel-Hotel Facilities - 24-hour Hot Water', 'Hotel-Hotel Facilities - Bar', 'Hotel-Hotel Facilities - Breakfast Service', 'Hotel-Hotel Facilities - Broadband Internet', 'Hotel-Hotel Facilities - Business Center', 'Hotel-Hotel Facilities - Car Rental', 'Hotel-Hotel Facilities - Chess-Poker Room', 'Hotel-Hotel Facilities - Childcare Services', 'Hotel-Hotel Facilities - Chinese Restaurant', 'Hotel-Hotel Facilities - Disabled Facilities', 'Hotel-Hotel Facilities - Foreign Guests Reception', 'Hotel-Hotel Facilities - Free Breakfast Service', 'Hotel-Hotel Facilities - Free Domestic Long Distance Call', 'Hotel-Hotel Facilities - Free Local Calls', 'Hotel-Hotel Facilities - Gym', 'Hotel-Hotel Facilities - Hair Dryer', 'Hotel-Hotel Facilities - Heating', 'Hotel-Hotel Facilities - Hot Spring', 'Hotel-Hotel Facilities - Indoor Swimming Pool', 'Hotel-Hotel Facilities - International Call', 'Hotel-Hotel Facilities - Laundry Service', 'Hotel-Hotel Facilities - Luggage Storage', 'Hotel-Hotel Facilities - Meeting Room', 'Hotel-Hotel Facilities - Non-smoking Room', 'Hotel-Hotel Facilities - Outdoor Swimming Pool', 'Hotel-Hotel Facilities - Pay Parking', 'Hotel-Hotel Facilities - Pick-up Service', 'Hotel-Hotel Facilities - SPA', 'Hotel-Hotel Facilities - Sauna', 'Hotel-Hotel Facilities - Wake Up Service', 'Hotel-Hotel Facilities - Western Restaurant', 'Hotel-Hotel Facilities - WiFi in All Rooms', 'Hotel-Hotel Facilities - WiFi in Public Areas', 'Hotel-Hotel Facilities - WiFi in Public Areas and Some Rooms', 'Hotel-Hotel Facilities - WiFi in Some Rooms', 'Hotel-Hotel Facilities - WiFi throughout the Hotel', 'Hotel-name', 'Hotel-nearby attract.', 'Hotel-nearby hotels', 'Hotel-nearby rest.', 'Hotel-price', 'Hotel-rating', 'Hotel-type', 'Metro-from', 'Metro-to', 'Restaurant-cost', 'Restaurant-dishes', 'Restaurant-name', 'Restaurant-nearby attract.', 'Restaurant-nearby hotels', 'Restaurant-nearby rest.', 'Restaurant-rating', 'Taxi-from', 'Taxi-to']
def plot(x, y):
a, b = [], []
for x, y in sorted(zip(x, y)):
a.append(x)
b.append(y)
plt.plot(a, b)
# def get_label_embedding(labels, max_seq_length, tokenizer, device):
# features = []
# for label in labels:
# label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
# label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)
# label_len = len(label_token_ids)
# label_padding = [0] * (max_seq_length - len(label_token_ids))
# label_token_ids += label_padding
# assert len(label_token_ids) == max_seq_length
# features.append((label_token_ids, label_len))
# all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
# all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
# return all_label_token_ids, all_label_len
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class SUMBTTracker(DST):
"""
Transferable multi-domain dialogue state tracker, adopted from https://github.com/SKTBrain/SUMBT
"""
@staticmethod
def init_data():
if not os.path.exists(os.path.join(DATA_PATH, 'train.json.zip')):
with zipfile.ZipFile(os.path.join(DATA_PATH, 'mt.zip')) as f:
f.extractall(DATA_PATH)
for split in ['train', 'test', 'val']:
with zipfile.ZipFile(os.path.join(DATA_PATH, f'{split}.json.zip'), 'w') as f:
f.write(os.path.join(DATA_PATH, f'{split}.json'), f'{split}.json')
def __init__(self, data_dir=DATA_PATH):
DST.__init__(self)
# if not os.path.exists(data_dir):
# if model_file == '':
# raise Exception(
# 'Please provide remote model file path in config')
# resp = urllib.request.urlretrieve(model_file)[0]
# temp_file = tarfile.open(resp)
# temp_file.extractall('data')
# assert os.path.exists(data_dir)
processor = Processor(args)
self.processor = processor
# values of each slot e.g. values_list
label_list = processor.get_labels()
num_labels = [len(labels) for labels in label_list] # number of slot-values in each slot-type
# tokenizer
# vocab_dir = os.path.join(data_dir, 'model', '%s-vocab.txt' % args.bert_model)
# if not os.path.exists(vocab_dir):
# raise ValueError("Can't find %s " % vocab_dir)
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
self.device = torch.device("cuda" if USE_CUDA else "cpu")
self.sumbt_model = BeliefTracker(args, num_labels, self.device)
if USE_CUDA and N_GPU > 1:
self.sumbt_model = torch.nn.DataParallel(self.sumbt_model)
if args.fp16:
self.sumbt_model.half()
self.sumbt_model.to(self.device)
## Get slot-value embeddings
self.label_token_ids, self.label_len = [], []
for labels in label_list:
# encoding values
token_ids, lens = get_label_embedding(labels, args.max_label_length, self.tokenizer, self.device)
self.label_token_ids.append(token_ids)
self.label_len.append(lens)
self.label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
self.label_map_inv = [{i: label for i, label in enumerate(labels)} for labels in label_list]
self.label_list = label_list
self.target_slot = processor.target_slot
## Get domain-slot-type embeddings
self.slot_token_ids, self.slot_len = \
get_label_embedding(processor.target_slot, args.max_label_length, self.tokenizer, self.device)
self.args = args
self.state = default_state()
self.param_restored = False
if USE_CUDA and N_GPU == 1:
self.sumbt_model.initialize_slot_value_lookup(self.label_token_ids, self.slot_token_ids)
elif USE_CUDA and N_GPU > 1:
self.sumbt_model.module.initialize_slot_value_lookup(self.label_token_ids, self.slot_token_ids)
self.cached_res = {}
convert_to_glue_format(DATA_PATH, SUMBT_PATH)
if not os.path.isdir(os.path.join(SUMBT_PATH, args.output_dir)):
os.makedirs(os.path.join(SUMBT_PATH, args.output_dir))
self.train_examples = processor.get_train_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
self.dev_examples = processor.get_dev_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
self.test_examples = processor.get_test_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
def load_weights(self, model_path=None):
if model_path is None:
model_ckpt = os.path.join(SUMBT_PATH, 'pre-trained/pytorch_model.bin')
else:
model_ckpt = model_path
model = self.sumbt_model
# in the case that slot and values are different between the training and evaluation
if not USE_CUDA:
ptr_model = torch.load(model_ckpt, map_location=torch.device('cpu'))
else:
ptr_model = torch.load(model_ckpt)
print('loading pretrained weights')
if not USE_CUDA or N_GPU == 1:
state = model.state_dict()
state.update(ptr_model)
model.load_state_dict(state)
else:
# print("Evaluate using only one device!")
model.module.load_state_dict(ptr_model)
if USE_CUDA:
model.to("cuda")
def init_session(self):
self.state = default_state()
if not self.param_restored:
if os.path.isfile(os.path.join(DOWNLOAD_DIRECTORY, 'pytorch_model.bin')):
print('loading weights from downloaded model')
self.load_weights(model_path=os.path.join(DOWNLOAD_DIRECTORY, 'pytorch_model.bin'))
elif os.path.isfile(os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin')):
print('loading weights from trained model')
self.load_weights(model_path=os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin'))
else:
raise ValueError('no availabel weights found.')
self.param_restored = True
def construct_query(self, context):
'''Construct query from context'''
ids = []
lens = []
context_len = len(context)
if context[0][0] != 'sys':
context = [['sys', '']] + context
for i in range(0, context_len, 2):
# utt_user = ''
# utt_sys = ''
# for evaluation
utt_sys = context[i][1]
utt_user = context[i + 1][1]
tokens_user = [x if x != '#' else '[SEP]' for x in self.tokenizer.tokenize(utt_user)]
tokens_sys = [x if x != '#' else '[SEP]' for x in self.tokenizer.tokenize(utt_sys)]
_truncate_seq_pair(tokens_user, tokens_sys, self.args.max_seq_length - 3)
tokens = ["[CLS]"] + tokens_user + ["[SEP]"] + tokens_sys + ["[SEP]"]
input_len = [len(tokens_user) + 2, len(tokens_sys) + 1]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
padding = [0] * (self.args.max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == self.args.max_seq_length
ids.append(input_ids)
lens.append(input_len)
return (ids, lens)
def update(self, user_act=None):
if not isinstance(user_act, str):
raise Exception(
'Expected user_act is str but found {}'.format(type(user_act))
)
prev_state = self.state
actual_history = copy.deepcopy(prev_state['history'])
# if actual_history[-1][0] == 'user':
# actual_history[-1][1] += user_act
# else:
# actual_history.append(['user', user_act])
query = self.construct_query(actual_history)
pred_states = self.predict(query)
new_belief_state = copy.deepcopy(prev_state['belief_state'])
for domain_slot, value in pred_states:
domain, slot = domain_slot.split('-', 1)
value = trans_value(value)
# print(domain, slot, value)
if domain not in new_belief_state:
raise Exception(
'Error: domain <{}> not in belief state'.format(domain))
domain_dic = new_belief_state[domain]
if slot in domain_dic:
domain_dic[slot] = value
else:
with open('sumbt_tracker_unknown_slot.log', 'a+') as f:
f.write(
'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value, domain, state)
)
new_state = copy.deepcopy(dict(prev_state))
new_state['belief_state'] = new_belief_state
self.state = new_state
return self.state
def predict(self, query):
cache_query_key = ''.join(str(list(chain.from_iterable(query[0]))))
if cache_query_key in self.cached_res.keys():
return self.cached_res[cache_query_key]
input_ids, input_len = query
input_ids = torch.tensor(input_ids).to(self.device).unsqueeze(0)
input_len = torch.tensor(input_len).to(self.device).unsqueeze(0)
labels = None
_, pred_slot = self.sumbt_model(input_ids, input_len, labels)
pred_slot_t = pred_slot[0][-1].tolist()
predict_belief = []
for idx, i in enumerate(pred_slot_t):
predict_belief.append((self.target_slot[idx], self.label_map_inv[idx][i]))
# predict_belief.append('{}-{}'.format(self.target_slot[idx], self.label_map_inv[idx][i]))
self.cached_res[cache_query_key] = predict_belief
return predict_belief
def train(self, load_model=False, model_path=None):
if load_model:
if model_path is not None:
self.load_weights(model_path)
## Training utterances
all_input_ids, all_input_len, all_label_ids = convert_examples_to_features(
self.train_examples, self.label_list, args.max_seq_length, self.tokenizer, args.max_turn_length)
print('all input ids size: ', all_input_ids.size())
num_train_batches = all_input_ids.size(0)
num_train_steps = int(
num_train_batches / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
logger.info("***** training *****")
logger.info(" Num examples = %d", len(self.train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids, all_input_len, all_label_ids = all_input_ids.to(DEVICE), all_input_len.to(
DEVICE), all_label_ids.to(DEVICE)
train_data = TensorDataset(all_input_ids, all_input_len, all_label_ids)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
all_input_ids_dev, all_input_len_dev, all_label_ids_dev = convert_examples_to_features(
self.dev_examples, self.label_list, args.max_seq_length, self.tokenizer, args.max_turn_length)
logger.info("***** validation *****")
logger.info(" Num examples = %d", len(self.dev_examples))
logger.info(" Batch size = %d", args.dev_batch_size)
all_input_ids_dev, all_input_len_dev, all_label_ids_dev = \
all_input_ids_dev.to(DEVICE), all_input_len_dev.to(DEVICE), all_label_ids_dev.to(DEVICE)
dev_data = TensorDataset(all_input_ids_dev, all_input_len_dev, all_label_ids_dev)
dev_sampler = SequentialSampler(dev_data)
dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.dev_batch_size)
logger.info("Loaded data!")
if args.fp16:
self.sumbt_model.half()
self.sumbt_model.to(DEVICE)
# ## Get domain-slot-type embeddings
# slot_token_ids, slot_len = \
# get_label_embedding(self.processor.target_slot, args.max_label_length, self.tokenizer, DEVICE)
# # for slot_idx, slot_str in zip(slot_token_ids, self.processor.target_slot):
# # self.idx2slot[slot_idx] = slot_str
# ## Get slot-value embeddings
# label_token_ids, label_len = [], []
# for slot_idx, labels in zip(slot_token_ids, self.label_list):
# # self.idx2value[slot_idx] = {}
# token_ids, lens = get_label_embedding(labels, args.max_label_length, self.tokenizer, DEVICE)
# label_token_ids.append(token_ids)
# label_len.append(lens)
# # for label, token_id in zip(labels, token_ids):
# # self.idx2value[slot_idx][token_id] = label
# logger.info('embeddings prepared')
# if USE_CUDA and N_GPU > 1:
# self.sumbt_model.module.initialize_slot_value_lookup(label_token_ids, slot_token_ids)
# else:
# self.sumbt_model.initialize_slot_value_lookup(label_token_ids, slot_token_ids)
def get_optimizer_grouped_parameters(model):
param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01,
'lr': args.learning_rate},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
'lr': args.learning_rate},
]
return optimizer_grouped_parameters
if not USE_CUDA or N_GPU == 1:
optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.sumbt_model)
else:
optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.sumbt_model.module)
t_total = num_train_steps
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.fp16_loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.fp16_loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=t_total)
logger.info(optimizer)
# Training code
###############################################################################
print(torch.cuda.memory_allocated())
logger.info("Training...")
global_step = 0
last_update = None
best_loss = None
model = self.sumbt_model
if not args.do_not_use_tensorboard:
summary_writer = None
else:
summary_writer = SummaryWriter("./tensorboard_summary/logs_1214/")
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
# Train
model.train()
tr_loss = 0
nb_tr_examples = 0
nb_tr_steps = 0
for step, batch in enumerate(tqdm(train_dataloader)):
batch = tuple(t.to(DEVICE) for t in batch)
input_ids, input_len, label_ids = batch
# print(input_ids.size())
# Forward
if N_GPU == 1:
loss, loss_slot, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
else:
loss, _, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
# average to multi-gpus
loss = loss.mean()
acc = acc.mean()
acc_slot = acc_slot.mean(0)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
# Backward
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
# tensrboard logging
if summary_writer is not None:
summary_writer.add_scalar("Epoch", epoch, global_step)
summary_writer.add_scalar("Train/Loss", loss, global_step)
summary_writer.add_scalar("Train/JointAcc", acc, global_step)
if N_GPU == 1:
for i, slot in enumerate(self.processor.target_slot):
summary_writer.add_scalar("Train/Loss_%s" % slot.replace(' ', '_'), loss_slot[i],
global_step)
summary_writer.add_scalar("Train/Acc_%s" % slot.replace(' ', '_'), acc_slot[i], global_step)
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
# modify lealrning rate with special warm up BERT uses
lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
if summary_writer is not None:
summary_writer.add_scalar("Train/LearningRate", lr_this_step, global_step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
# Perform evaluation on validation dataset
model.eval()
dev_loss = 0
dev_acc = 0
dev_loss_slot, dev_acc_slot = None, None
nb_dev_examples, nb_dev_steps = 0, 0
for step, batch in enumerate(tqdm(dev_dataloader, desc="Validation")):
batch = tuple(t.to(DEVICE) for t in batch)
input_ids, input_len, label_ids = batch
if input_ids.dim() == 2:
input_ids = input_ids.unsqueeze(0)
input_len = input_len.unsqueeze(0)
label_ids = label_ids.unsuqeeze(0)
with torch.no_grad():
if N_GPU == 1:
loss, loss_slot, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
else:
loss, _, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
# average to multi-gpus
loss = loss.mean()
acc = acc.mean()
acc_slot = acc_slot.mean(0)
num_valid_turn = torch.sum(label_ids[:, :, 0].view(-1) > -1, 0).item()
dev_loss += loss.item() * num_valid_turn
dev_acc += acc.item() * num_valid_turn
if N_GPU == 1:
if dev_loss_slot is None:
dev_loss_slot = [l * num_valid_turn for l in loss_slot]
dev_acc_slot = acc_slot * num_valid_turn
else:
for i, l in enumerate(loss_slot):
dev_loss_slot[i] = dev_loss_slot[i] + l * num_valid_turn
dev_acc_slot += acc_slot * num_valid_turn
nb_dev_examples += num_valid_turn
dev_loss = dev_loss / nb_dev_examples
dev_acc = dev_acc / nb_dev_examples
if N_GPU == 1:
dev_acc_slot = dev_acc_slot / nb_dev_examples
# tensorboard logging
if summary_writer is not None:
summary_writer.add_scalar("Validate/Loss", dev_loss, global_step)
summary_writer.add_scalar("Validate/Acc", dev_acc, global_step)
if N_GPU == 1:
for i, slot in enumerate(self.processor.target_slot):
summary_writer.add_scalar("Validate/Loss_%s" % slot.replace(' ', '_'),
dev_loss_slot[i] / nb_dev_examples, global_step)
summary_writer.add_scalar("Validate/Acc_%s" % slot.replace(' ', '_'), dev_acc_slot[i],
global_step)
dev_loss = round(dev_loss, 6)
output_model_file = os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "pytorch_model.bin")
if last_update is None or dev_loss < best_loss:
last_update = epoch
best_loss = dev_loss
best_acc = dev_acc
if not USE_CUDA or N_GPU == 1:
torch.save(model.state_dict(), output_model_file)
else:
torch.save(model.module.state_dict(), output_model_file)
logger.info(
"*** Model Updated: Epoch=%d, Validation Loss=%.6f, Validation Acc=%.6f, global_step=%d ***" % (
last_update, best_loss, best_acc, global_step))
else:
logger.info(
"*** Model NOT Updated: Epoch=%d, Validation Loss=%.6f, Validation Acc=%.6f, global_step=%d ***" % (
epoch, dev_loss, dev_acc, global_step))
if last_update + args.patience <= epoch:
break
def test(self, mode='dev', model_path=os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "pytorch_model.bin")):
'''Testing funciton of TRADE (to be added)'''
# Evaluation
self.load_weights(model_path)
if mode == 'test':
eval_examples = self.dev_examples
elif mode == 'dev':
eval_examples = self.test_examples
all_input_ids, all_input_len, all_label_ids = convert_examples_to_features(
eval_examples, self.label_list, args.max_seq_length, self.tokenizer, args.max_turn_length)
all_input_ids, all_input_len, all_label_ids = all_input_ids.to(DEVICE), all_input_len.to(
DEVICE), all_label_ids.to(DEVICE)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.dev_batch_size)
eval_data = TensorDataset(all_input_ids, all_input_len, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.dev_batch_size)
model = self.sumbt_model
eval_loss, eval_accuracy = 0, 0
eval_loss_slot, eval_acc_slot = None, None
nb_eval_steps, nb_eval_examples = 0, 0
accuracies = {'joint7': 0, 'slot7': 0, 'joint5': 0, 'slot5': 0, 'joint_rest': 0, 'slot_rest': 0,
'num_turn': 0, 'num_slot7': 0, 'num_slot5': 0, 'num_slot_rest': 0}
for input_ids, input_len, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
# if input_ids.dim() == 2:
# input_ids = input_ids.unsqueeze(0)
# input_len = input_len.unsqueeze(0)
# label_ids = label_ids.unsuqeeze(0)
with torch.no_grad():
if not USE_CUDA or N_GPU == 1:
loss, loss_slot, acc, acc_slot, pred_slot = model(input_ids, input_len, label_ids, 1)
else:
loss, _, acc, acc_slot, pred_slot = model(input_ids, input_len, label_ids, N_GPU)
nbatch = label_ids.size(0)
nslot = pred_slot.size(3)
pred_slot = pred_slot.view(nbatch, -1, nslot)
accuracies = eval_all_accs(pred_slot, label_ids, accuracies)
nb_eval_ex = (label_ids[:, :, 0].view(-1) != -1).sum().item()
nb_eval_examples += nb_eval_ex
nb_eval_steps += 1
if not USE_CUDA or N_GPU == 1:
eval_loss += loss.item() * nb_eval_ex
eval_accuracy += acc.item() * nb_eval_ex
if eval_loss_slot is None:
eval_loss_slot = [l * nb_eval_ex for l in loss_slot]
eval_acc_slot = acc_slot * nb_eval_ex
else:
for i, l in enumerate(loss_slot):
eval_loss_slot[i] = eval_loss_slot[i] + l * nb_eval_ex
eval_acc_slot += acc_slot * nb_eval_ex
else:
eval_loss += sum(loss) * nb_eval_ex
eval_accuracy += sum(acc) * nb_eval_ex
# exit(1)
eval_loss = eval_loss / nb_eval_examples
eval_accuracy = eval_accuracy / nb_eval_examples
if not USE_CUDA or N_GPU == 1:
eval_acc_slot = eval_acc_slot / nb_eval_examples
loss = None
if not USE_CUDA or N_GPU == 1:
result = {
# 'num': '\t'.join([str(x) for x in model.num_labels]),
'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'loss': loss,
'eval_loss_slot': '\t'.join([str(val / nb_eval_examples) for val in eval_loss_slot]),
'eval_acc_slot': '\t'.join([str((val).item()) for val in eval_acc_slot]),
}
else:
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'loss': loss
}
out_file_name = 'eval_results'
# if TARGET_SLOT == 'all':
# out_file_name += '_all'
output_eval_file = os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "%s.txt" % out_file_name)
if not USE_CUDA or N_GPU == 1:
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
out_file_name = 'eval_all_accuracies'
with open(os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "%s.txt" % out_file_name), 'w') as f:
s = '{:^22s}:{:^22s}:{:^22s}:{:^22s}:{:^22s}:{:^22s}'.format(
'joint acc (7 domain)',
'slot acc (7 domain)',
'joint acc (5 domain)',
'slot acc (5 domain)',
'joint restaurant',
'slot acc restaurant')
f.write(s + '\n')
print(s)
s = '{:^22.5f}:{:^22.5f}:{:^22.5f}:{:^22.5f}:{:^22.5f}:{:^22.5f}'.format(
(accuracies['joint7'] / accuracies['num_turn']).item(),
(accuracies['slot7'] / accuracies['num_slot7']).item(),
(accuracies['joint5'] / accuracies['num_turn']).item(),
(accuracies['slot5'] / accuracies['num_slot5']).item(),
(accuracies['joint_rest'] / accuracies['num_turn']).item(),
(accuracies['slot_rest'] / accuracies['num_slot_rest']).item()
)
f.write(s + '\n')
print(s)
import os
import convlab2
class DotMap():
def __init__(self):
self.max_label_length = 35
self.num_rnn_layers = 1
self.zero_init_rnn = False
self.attn_head = 4
self.do_eval = True
self.do_train = False
self.train_batch_size = 3
self.dev_batch_size = 1
self.eval_batch_size = 16
self.learning_rate = 5e-5
self.warmup_proportion = 0.1
self.local_rank = -1
self.seed = 42
self.gradient_accumulation_steps = 1
self.fp16 = False
self.loss_scale = 0
self.do_not_use_tensorboard = False
self.fix_utterance_encoder = False
self.do_eval = True
self.num_train_epochs = 300
self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-base-uncased")
self.do_lower_case = True
self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn'
self.target_slot = 'all'
self.distance_metric = 'euclidean'
self.patience = 15
self.hidden_dim = 300
self.max_seq_length = 35
self.max_turn_length = 23
self.fp16_loss_scale = 0.0
self.data_dir = 'data/crosswoz_en/'
self.tf_dir = 'tensorboard'
self.tmp_data_dir = 'processed_data/'
self.output_dir = 'model_output/'
args = DotMap()
\ No newline at end of file
import csv
import os
import json
import collections
import logging
import re
import torch
from convlab2.dst.sumbt.crosswoz_en.convert_to_glue_format import null
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if len(line) > 0 and line[0][0] == '#': # ignore comments (starting with '#')
continue
lines.append(line)
return lines
class Processor(DataProcessor):
"""Processor for the belief tracking dataset (GLUE version)."""
def __init__(self, config):
super(Processor, self).__init__()
# crosswoz dataset
with open(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), config.data_dir, "ontology.json"), "r") as fp_ontology:
ontology = json.load(fp_ontology)
for slot in ontology.keys():
ontology[slot].append(null)
assert config.target_slot == 'all'
# if not config.target_slot == 'all':
# slot_idx = {'attraction': '0:1:2', 'bus': '3:4:5:6', 'hospital': '7',
# 'hotel': '8:9:10:11:12:13:14:15:16:17', \
# 'restaurant': '18:19:20:21:22:23:24', 'taxi': '25:26:27:28', 'train': '29:30:31:32:33:34'}
# target_slot = []
# for key, value in slot_idx.items():
# if key != config.target_slot:
# target_slot.append(value)
# config.target_slot = ':'.join(target_slot)
# sorting the ontology according to the alphabetic order of the slots
ontology = collections.OrderedDict(sorted(ontology.items()))
# select slots to train
nslots = len(ontology.keys())
target_slot = list(ontology.keys())
if config.target_slot == 'all':
self.target_slot_idx = [*range(0, nslots)]
else:
self.target_slot_idx = sorted([int(x) for x in config.target_slot.split(':')])
for idx in range(0, nslots):
if not idx in self.target_slot_idx:
del ontology[target_slot[idx]]
self.ontology = ontology
self.target_slot = list(self.ontology.keys())
# for i, slot in enumerate(self.target_slot):
# if slot == "pricerange":
# self.target_slot[i] = "price range"
logger.info('Processor: target_slot')
logger.info(self.target_slot)
def get_train_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", accumulation)
def get_dev_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", accumulation)
def get_test_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test", accumulation)
def get_labels(self):
"""See base class."""
return [list(map(str.lower, self.ontology[slot])) for slot in self.target_slot]
def _create_examples(self, lines, set_type, accumulation=False):
"""Creates examples for the training and dev sets."""
prev_dialogue_index = None
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s-%s" % (set_type, line[0], line[1]) # line[0]: dialogue index, line[1]: turn index
if accumulation:
if prev_dialogue_index is None or prev_dialogue_index != line[0]:
text_a = line[2]
text_b = line[3]
prev_dialogue_index = line[0]
else:
# The symbol '#' will be replaced with '[SEP]' after tokenization.
text_a = line[2] + " # " + text_a
text_b = line[3] + " # " + text_b
else:
text_a = line[2] # line[2]: user utterance
text_b = line[3] # line[3]: system response
label = [line[4 + idx] for idx in self.target_slot_idx]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def normalize_text(text):
global replacements
# lower case every word
text = text.lower()
# replace white spaces in front and end
text = re.sub(r'^\s*|\s*$', '', text)
# hotel domain pfb30
text = re.sub(r"b&b", "bed and breakfast", text)
text = re.sub(r"b and b", "bed and breakfast", text)
# replace st.
text = text.replace(';', ',')
text = re.sub('$\/', '', text)
text = text.replace('/', ' and ')
# replace other special characters
text = text.replace('-', ' ')
text = re.sub('[\"\<>@\(\)]', '', text) # remove
# insert white space before and after tokens:
for token in ['?', '.', ',', '!']:
text = insertSpace(token, text)
# insert white space for 's
text = insertSpace('\'s', text)
# replace it's, does't, you'd ... etc
text = re.sub('^\'', '', text)
text = re.sub('\'$', '', text)
text = re.sub('\'\s', ' ', text)
text = re.sub('\s\'', ' ', text)
for fromx, tox in replacements:
text = ' ' + text + ' '
text = text.replace(fromx, tox)[1:-1]
# remove multiple spaces
text = re.sub(' +', ' ', text)
# concatenate numbers
tmp = text
tokens = text.split()
i = 1
while i < len(tokens):
if re.match(u'^\d+$', tokens[i]) and \
re.match(u'\d+$', tokens[i - 1]):
tokens[i - 1] += tokens[i]
del tokens[i]
else:
i += 1
text = ' '.join(tokens)
return text
def insertSpace(token, text):
sidx = 0
while True:
sidx = text.find(token, sidx)
if sidx == -1:
break
if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
re.match('[0-9]', text[sidx + 1]):
sidx += 1
continue
if text[sidx - 1] != ' ':
text = text[:sidx] + ' ' + text[sidx:]
sidx += 1
if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
text = text[:sidx + 1] + ' ' + text[sidx + 1:]
sidx += 1
return text
# convert tokens in labels to the identifier in vocabulary
def get_label_embedding(labels, max_seq_length, tokenizer, device):
features = []
for label in labels:
label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
# just truncate, some names are unreasonable long
label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)[:max_seq_length]
label_len = len(label_token_ids)
label_padding = [0] * (max_seq_length - len(label_token_ids))
label_token_ids += label_padding
assert len(label_token_ids) == max_seq_length
features.append((label_token_ids, label_len))
all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
return all_label_token_ids, all_label_len
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x / warmup
return 1.0 - x
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_len, label_id):
self.input_ids = input_ids
self.input_len = input_len
self.label_id = label_id
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, max_turn_length):
"""Loads a data file into a list of `InputBatch`s."""
label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
slot_dim = len(label_list)
features = []
prev_dialogue_idx = None
all_padding = [0] * max_seq_length
all_padding_len = [0, 0]
max_turn = 0
for (ex_index, example) in enumerate(examples):
if max_turn < int(example.guid.split('-')[2]):
max_turn = int(example.guid.split('-')[2])
max_turn_length = min(max_turn + 1, max_turn_length)
logger.info("max_turn_length = %d" % max_turn)
for (ex_index, example) in enumerate(examples):
tokens_a = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_a)]
tokens_b = None
if example.text_b:
tokens_b = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_b)]
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
input_len = [len(tokens), 0]
if tokens_b:
tokens += tokens_b + ["[SEP]"]
input_len[1] = len(tokens_b) + 1
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == max_seq_length
FLAG_TEST = False
if example.label is not None:
label_id = []
label_info = 'label: '
for i, label in enumerate(example.label):
if label == 'dontcare':
label = 'do not care'
label_id.append(label_map[i][label])
label_info += '%s (id = %d) ' % (label, label_map[i][label])
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_len: %s" % " ".join([str(x) for x in input_len]))
logger.info("label: " + label_info)
else:
FLAG_TEST = True
label_id = None
curr_dialogue_idx = example.guid.split('-')[1]
curr_turn_idx = int(example.guid.split('-')[2])
if prev_dialogue_idx is not None and prev_dialogue_idx != curr_dialogue_idx:
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
# print(len(features), max_turn_length)
assert len(features) % max_turn_length == 0
if prev_dialogue_idx is None or prev_turn_idx < max_turn_length:
features.append(
InputFeatures(input_ids=input_ids,
input_len=input_len,
label_id=label_id))
prev_dialogue_idx = curr_dialogue_idx
prev_turn_idx = curr_turn_idx
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_len = torch.tensor([f.input_len for f in features], dtype=torch.long)
if not FLAG_TEST:
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
# reshape tensors to [#batch, #max_turn_length, #max_seq_length]
all_input_ids = all_input_ids.view(-1, max_turn_length, max_seq_length)
all_input_len = all_input_len.view(-1, max_turn_length, 2)
if not FLAG_TEST:
all_label_ids = all_label_ids.view(-1, max_turn_length, slot_dim)
else:
all_label_ids = None
return all_input_ids, all_input_len, all_label_ids
def eval_all_accs(pred_slot, labels, accuracies):
def _eval_acc(_pred_slot, _labels):
slot_dim = _labels.size(-1)
accuracy = (_pred_slot == _labels).view(-1, slot_dim)
num_turn = torch.sum(_labels[:, :, 0].view(-1) > -1, 0).float()
num_data = torch.sum(_labels > -1).float()
# joint accuracy
joint_acc = sum(torch.sum(accuracy, 1) / slot_dim).float()
# slot accuracy
slot_acc = torch.sum(accuracy).float()
return joint_acc, slot_acc, num_turn, num_data
# 7 domains
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot, labels)
accuracies['joint7'] += joint_acc
accuracies['slot7'] += slot_acc
accuracies['num_turn'] += num_turn
accuracies['num_slot7'] += num_data
# restaurant domain
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot[:,:,18:25], labels[:,:,18:25])
accuracies['joint_rest'] += joint_acc
accuracies['slot_rest'] += slot_acc
accuracies['num_slot_rest'] += num_data
pred_slot5 = torch.cat((pred_slot[:,:,0:3], pred_slot[:,:,8:]), 2)
label_slot5 = torch.cat((labels[:,:,0:3], labels[:,:,8:]), 2)
# 5 domains (excluding bus and hotel domain)
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot5, label_slot5)
accuracies['joint5'] += joint_acc
accuracies['slot5'] += slot_acc
accuracies['num_slot5'] += num_data
return accuracies
import os import os
import convlab2
class DotMap(): class DotMap():
def __init__(self): def __init__(self):
self.max_label_length = 32 self.max_label_length = 32
...@@ -28,7 +30,7 @@ class DotMap(): ...@@ -28,7 +30,7 @@ class DotMap():
self.num_train_epochs = 300 self.num_train_epochs = 300
self.bert_model = 'bert-base-uncased' self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-base-uncased")
self.do_lower_case = True self.do_lower_case = True
self.task_name = 'bert-gru-sumbt' self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn' self.nbt = 'rnn'
......
pre-trained/
model_output/
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker as SUMBT
import json
import zipfile
from convlab2.dst.sumbt.multiwoz_zh.sumbt_config import *
def trans_value(value):
trans = {
'': '未提及',
'没有提到': '未提及',
'没有': '未提及',
'未提到': '未提及',
'一个也没有': '未提及',
'': '未提及',
'是的': '',
'不是': '没有',
'不关心': '不在意',
'不在乎': '不在意',
}
return trans.get(value, value)
def convert_to_glue_format(data_dir, sumbt_dir):
if not os.path.isdir(os.path.join(sumbt_dir, args.tmp_data_dir)):
os.mkdir(os.path.join(sumbt_dir, args.tmp_data_dir))
### Read ontology file
with open(os.path.join(data_dir, "ontology.json"), "r") as fp_ont:
data_ont = json.load(fp_ont)
ontology = {}
for domain_slot in data_ont:
domain, slot = domain_slot.split('-')
if domain not in ontology:
ontology[domain] = {}
ontology[domain][slot] = {}
for value in data_ont[domain_slot]:
ontology[domain][slot][value] = 1
### Read woz logs and write to tsv files
if os.path.exists(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")):
print('data has been processed!')
return 0
print('begin processing data')
fp_train = open(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"), "w")
fp_dev = open(os.path.join(sumbt_dir, args.tmp_data_dir, "dev.tsv"), "w")
fp_test = open(os.path.join(sumbt_dir, args.tmp_data_dir, "test.tsv"), "w")
fp_train.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_dev.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_test.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
fp_train.write(str(domain) + '-' + str(slot) + '\t')
fp_dev.write(str(domain) + '-' + str(slot) + '\t')
fp_test.write(str(domain) + '-' + str(slot) + '\t')
fp_train.write('\n')
fp_dev.write('\n')
fp_test.write('\n')
# fp_data = open(os.path.join(SELF_DATA_DIR, "data.json"), "r")
# data = json.load(fp_data)
file_split = ['train', 'val', 'test']
fp = [fp_train, fp_dev, fp_test]
for split_type, split_fp in zip(file_split, fp):
zipfile_name = "{}.json.zip".format(split_type)
zip_fp = zipfile.ZipFile(os.path.join(data_dir, zipfile_name))
data = json.loads(str(zip_fp.read(zip_fp.namelist()[0]), 'utf-8'))
for file_id in data:
user_utterance = ''
system_response = ''
turn_idx = 0
for idx, turn in enumerate(data[file_id]['log']):
if idx % 2 == 0: # user turn
user_utterance = data[file_id]['log'][idx]['text']
else: # system turn
user_utterance = user_utterance.replace('\t', ' ')
user_utterance = user_utterance.replace('\n', ' ')
user_utterance = user_utterance.replace(' ', ' ')
system_response = system_response.replace('\t', ' ')
system_response = system_response.replace('\n', ' ')
system_response = system_response.replace(' ', ' ')
split_fp.write(str(file_id)) # 0: dialogue ID
split_fp.write('\t' + str(turn_idx)) # 1: turn index
split_fp.write('\t' + str(user_utterance)) # 2: user utterance
split_fp.write('\t' + str(system_response)) # 3: system response
belief = {}
for domain in data[file_id]['log'][idx]['metadata'].keys():
for slot in data[file_id]['log'][idx]['metadata'][domain]['semi'].keys():
value = data[file_id]['log'][idx]['metadata'][domain]['semi'][slot].strip()
# value = value_trans.get(value, value)
value = trans_value(value)
if domain not in ontology:
print("domain (%s) is not defined" % domain)
continue
if slot not in ontology[domain]:
print("slot (%s) in domain (%s) is not defined" % (slot, domain)) # bus-arriveBy not defined
continue
if value not in ontology[domain][slot] and value != '未提及':
print("%s: value (%s) in domain (%s) slot (%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = '未提及'
belief[str(domain) + '-' + str(slot)] = value
for slot in data[file_id]['log'][idx]['metadata'][domain]['book'].keys():
if slot == 'booked':
continue
if domain == '公共汽车' and slot == '人数' or domain == '列车' and slot == '票价':
continue # not defined in ontology
value = data[file_id]['log'][idx]['metadata'][domain]['book'][slot].strip()
value = value_trans.get(value, value)
if str('预订' + slot) not in ontology[domain]:
print("预订%s is not defined in domain %s" % (slot, domain))
continue
if value not in ontology[domain]['预订' + slot] and value != '未提及':
print("%s: value (%s) in domain (%s) slot (预订%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = '未提及'
belief[str(domain) + '-预订' + str(slot)] = value
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
key = str(domain) + '-' + str(slot)
if key in belief:
split_fp.write('\t' + belief[key])
else:
split_fp.write('\t未提及')
split_fp.write('\n')
split_fp.flush()
system_response = data[file_id]['log'][idx]['text']
turn_idx += 1
fp_train.close()
fp_dev.close()
fp_test.close()
print('data has been processed!')
\ No newline at end of file
import os
import copy
from pprint import pprint
import random
from itertools import chain
import numpy as np
import zipfile
# from tensorboardX.writer import SummaryWriter
from tqdm._tqdm import trange, tqdm
from convlab2.util.file_util import cached_path
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam
from convlab2.dst.dst import DST
from convlab2.dst.sumbt.multiwoz_zh.convert_to_glue_format import convert_to_glue_format, trans_value
from convlab2.util.multiwoz_zh.state import default_state
from convlab2.dst.sumbt.BeliefTrackerSlotQueryMultiSlot import BeliefTracker
from convlab2.dst.sumbt.multiwoz_zh.sumbt_utils import *
from convlab2.dst.sumbt.multiwoz_zh.sumbt_config import *
USE_CUDA = torch.cuda.is_available()
N_GPU = torch.cuda.device_count() if USE_CUDA else 1
DEVICE = "cuda" if USE_CUDA else "cpu"
ROOT_PATH = convlab2.get_root_path()
SUMBT_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.join(ROOT_PATH, 'data/multiwoz_zh')
DOWNLOAD_DIRECTORY = os.path.join(SUMBT_PATH, 'pre-trianed')
multiwoz_zh_slot_list = ['公共汽车-出发地', '公共汽车-出发时间', '公共汽车-到达时间', '公共汽车-日期', '公共汽车-目的地', '出租车-出发地', '出租车-出发时间', '出租车-到达时间', '出租车-目的地', '列车-出发地', '列车-出发时间', '列车-到达时间', '列车-日期', '列车-目的地', '列车-预订人数', '医院-科室', '旅馆-互联网', '旅馆-价格范围', '旅馆-停车处', '旅馆-区域', '旅馆-名称', '旅馆-星级', '旅馆-类型', '旅馆-预订人数', '旅馆-预订停留天数', '旅馆-预订日期', '景点-区域', '景点-名称', '景点-类型', '餐厅-价格范围', '餐厅-区域', '餐厅-名称', '餐厅-预订人数', '餐厅-预订日期', '餐厅-预订时间', '餐厅-食物']
def get_label_embedding(labels, max_seq_length, tokenizer, device):
features = []
for label in labels:
label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)
label_len = len(label_token_ids)
label_padding = [0] * (max_seq_length - len(label_token_ids))
label_token_ids += label_padding
assert len(label_token_ids) == max_seq_length
features.append((label_token_ids, label_len))
all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
return all_label_token_ids, all_label_len
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class SUMBTTracker(DST):
"""
Transferable multi-domain dialogue state tracker, adopted from https://github.com/SKTBrain/SUMBT
"""
# adapt data provider
# unzip mt.zip, and zip each [train|val|test].json
@staticmethod
def init_data():
if not os.path.exists(os.path.join(DATA_PATH, 'train.json.zip')):
with zipfile.ZipFile(os.path.join(DATA_PATH, 'mt.zip')) as f:
f.extractall(DATA_PATH)
for split in ['train', 'test', 'val']:
with zipfile.ZipFile(os.path.join(DATA_PATH, f'{split}.json.zip'), 'w') as f:
f.write(os.path.join(DATA_PATH, f'{split}.json'), f'{split}.json')
def __init__(self, data_dir=DATA_PATH, eval_slots=multiwoz_zh_slot_list):
DST.__init__(self)
self.init_data()
processor = Processor(args)
self.processor = processor
label_list = processor.get_labels()
num_labels = [len(labels) for labels in label_list] # number of slot-values in each slot-type
# tokenizer
# vocab_dir = os.path.join(data_dir, 'model', '%s-vocab.txt' % args.bert_model)
# if not os.path.exists(vocab_dir):
# raise ValueError("Can't find %s " % vocab_dir)
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
self.device = torch.device("cuda" if USE_CUDA else "cpu")
self.sumbt_model = BeliefTracker(args, num_labels, self.device)
if USE_CUDA and N_GPU > 1:
self.sumbt_model = torch.nn.DataParallel(self.sumbt_model)
if args.fp16:
self.sumbt_model.half()
self.sumbt_model.to(self.device)
## Get slot-value embeddings
self.label_token_ids, self.label_len = [], []
for labels in label_list:
token_ids, lens = get_label_embedding(labels, args.max_label_length, self.tokenizer, self.device)
self.label_token_ids.append(token_ids)
self.label_len.append(lens)
self.label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
self.label_map_inv = [{i: label for i, label in enumerate(labels)} for labels in label_list]
self.label_list = label_list
self.target_slot = processor.target_slot
## Get domain-slot-type embeddings
self.slot_token_ids, self.slot_len = \
get_label_embedding(processor.target_slot, args.max_label_length, self.tokenizer, self.device)
self.args = args
self.state = default_state()
self.param_restored = False
if USE_CUDA and N_GPU == 1:
self.sumbt_model.initialize_slot_value_lookup(self.label_token_ids, self.slot_token_ids)
elif USE_CUDA and N_GPU > 1:
self.sumbt_model.module.initialize_slot_value_lookup(self.label_token_ids, self.slot_token_ids)
self.cached_res = {}
convert_to_glue_format(DATA_PATH, SUMBT_PATH)
if not os.path.isdir(os.path.join(SUMBT_PATH, args.output_dir)):
os.makedirs(os.path.join(SUMBT_PATH, args.output_dir))
self.train_examples = processor.get_train_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
self.dev_examples = processor.get_dev_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
self.test_examples = processor.get_test_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
self.eval_slots = eval_slots
def load_weights(self, model_path=None):
if model_path is None:
model_ckpt = os.path.join(SUMBT_PATH, 'pre-trained/pytorch_model.bin')
else:
model_ckpt = model_path
model = self.sumbt_model
# in the case that slot and values are different between the training and evaluation
if not USE_CUDA:
ptr_model = torch.load(model_ckpt, map_location=torch.device('cpu'))
else:
ptr_model = torch.load(model_ckpt)
print('loading pretrained weights')
if not USE_CUDA or N_GPU == 1:
state = model.state_dict()
state.update(ptr_model)
model.load_state_dict(state)
else:
# print("Evaluate using only one device!")
model.module.load_state_dict(ptr_model)
if USE_CUDA:
model.to("cuda")
def train(self, load_model=False, model_path=None):
if load_model:
if model_path is not None:
self.load_weights(model_path)
## Training utterances
all_input_ids, all_input_len, all_label_ids = convert_examples_to_features(
self.train_examples, self.label_list, args.max_seq_length, self.tokenizer, args.max_turn_length)
num_train_batches = all_input_ids.size(0)
num_train_steps = int(
num_train_batches / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
logger.info("***** training *****")
logger.info(" Num examples = %d", len(self.train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids, all_input_len, all_label_ids = all_input_ids.to(DEVICE), all_input_len.to(
DEVICE), all_label_ids.to(DEVICE)
train_data = TensorDataset(all_input_ids, all_input_len, all_label_ids)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
all_input_ids_dev, all_input_len_dev, all_label_ids_dev = convert_examples_to_features(
self.dev_examples, self.label_list, args.max_seq_length, self.tokenizer, args.max_turn_length)
logger.info("***** validation *****")
logger.info(" Num examples = %d", len(self.dev_examples))
logger.info(" Batch size = %d", args.dev_batch_size)
all_input_ids_dev, all_input_len_dev, all_label_ids_dev = \
all_input_ids_dev.to(DEVICE), all_input_len_dev.to(DEVICE), all_label_ids_dev.to(DEVICE)
dev_data = TensorDataset(all_input_ids_dev, all_input_len_dev, all_label_ids_dev)
dev_sampler = SequentialSampler(dev_data)
dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.dev_batch_size)
logger.info("Loaded data!")
if args.fp16:
self.sumbt_model.half()
self.sumbt_model.to(DEVICE)
## Get domain-slot-type embeddings
slot_token_ids, slot_len = \
get_label_embedding(self.processor.target_slot, args.max_label_length, self.tokenizer, DEVICE)
# for slot_idx, slot_str in zip(slot_token_ids, self.processor.target_slot):
# self.idx2slot[slot_idx] = slot_str
## Get slot-value embeddings
label_token_ids, label_len = [], []
for slot_idx, labels in zip(slot_token_ids, self.label_list):
# self.idx2value[slot_idx] = {}
token_ids, lens = get_label_embedding(labels, args.max_label_length, self.tokenizer, DEVICE)
label_token_ids.append(token_ids)
label_len.append(lens)
# for label, token_id in zip(labels, token_ids):
# self.idx2value[slot_idx][token_id] = label
logger.info('embeddings prepared')
if USE_CUDA and N_GPU > 1:
self.sumbt_model.module.initialize_slot_value_lookup(label_token_ids, slot_token_ids)
else:
self.sumbt_model.initialize_slot_value_lookup(label_token_ids, slot_token_ids)
def get_optimizer_grouped_parameters(model):
param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01,
'lr': args.learning_rate},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
'lr': args.learning_rate},
]
return optimizer_grouped_parameters
if not USE_CUDA or N_GPU == 1:
optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.sumbt_model)
else:
optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.sumbt_model.module)
t_total = num_train_steps
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.fp16_loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.fp16_loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=t_total)
logger.info(optimizer)
# Training code
###############################################################################
logger.info("Training...")
global_step = 0
last_update = None
best_loss = None
model = self.sumbt_model
if not args.do_not_use_tensorboard:
summary_writer = None
else:
summary_writer = SummaryWriter("./tensorboard_summary/logs_1214/")
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
# Train
model.train()
tr_loss = 0
nb_tr_examples = 0
nb_tr_steps = 0
for step, batch in enumerate(tqdm(train_dataloader)):
batch = tuple(t.to(DEVICE) for t in batch)
input_ids, input_len, label_ids = batch
# Forward
if N_GPU == 1:
loss, loss_slot, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
else:
loss, _, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
# average to multi-gpus
loss = loss.mean()
acc = acc.mean()
acc_slot = acc_slot.mean(0)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
# Backward
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
# tensrboard logging
if summary_writer is not None:
summary_writer.add_scalar("Epoch", epoch, global_step)
summary_writer.add_scalar("Train/Loss", loss, global_step)
summary_writer.add_scalar("Train/JointAcc", acc, global_step)
if N_GPU == 1:
for i, slot in enumerate(self.processor.target_slot):
summary_writer.add_scalar("Train/Loss_%s" % slot.replace(' ', '_'), loss_slot[i],
global_step)
summary_writer.add_scalar("Train/Acc_%s" % slot.replace(' ', '_'), acc_slot[i], global_step)
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
# modify lealrning rate with special warm up BERT uses
lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
if summary_writer is not None:
summary_writer.add_scalar("Train/LearningRate", lr_this_step, global_step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
# Perform evaluation on validation dataset
model.eval()
dev_loss = 0
dev_acc = 0
dev_loss_slot, dev_acc_slot = None, None
nb_dev_examples, nb_dev_steps = 0, 0
for step, batch in enumerate(tqdm(dev_dataloader, desc="Validation")):
batch = tuple(t.to(DEVICE) for t in batch)
input_ids, input_len, label_ids = batch
if input_ids.dim() == 2:
input_ids = input_ids.unsqueeze(0)
input_len = input_len.unsqueeze(0)
label_ids = label_ids.unsuqeeze(0)
with torch.no_grad():
if N_GPU == 1:
loss, loss_slot, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
else:
loss, _, acc, acc_slot, _ = model(input_ids, input_len, label_ids, N_GPU)
# average to multi-gpus
loss = loss.mean()
acc = acc.mean()
acc_slot = acc_slot.mean(0)
num_valid_turn = torch.sum(label_ids[:, :, 0].view(-1) > -1, 0).item()
dev_loss += loss.item() * num_valid_turn
dev_acc += acc.item() * num_valid_turn
if N_GPU == 1:
if dev_loss_slot is None:
dev_loss_slot = [l * num_valid_turn for l in loss_slot]
dev_acc_slot = acc_slot * num_valid_turn
else:
for i, l in enumerate(loss_slot):
dev_loss_slot[i] = dev_loss_slot[i] + l * num_valid_turn
dev_acc_slot += acc_slot * num_valid_turn
nb_dev_examples += num_valid_turn
dev_loss = dev_loss / nb_dev_examples
dev_acc = dev_acc / nb_dev_examples
if N_GPU == 1:
dev_acc_slot = dev_acc_slot / nb_dev_examples
# tensorboard logging
if summary_writer is not None:
summary_writer.add_scalar("Validate/Loss", dev_loss, global_step)
summary_writer.add_scalar("Validate/Acc", dev_acc, global_step)
if N_GPU == 1:
for i, slot in enumerate(self.processor.target_slot):
summary_writer.add_scalar("Validate/Loss_%s" % slot.replace(' ', '_'),
dev_loss_slot[i] / nb_dev_examples, global_step)
summary_writer.add_scalar("Validate/Acc_%s" % slot.replace(' ', '_'), dev_acc_slot[i],
global_step)
dev_loss = round(dev_loss, 6)
output_model_file = os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "pytorch_model.bin")
if last_update is None or dev_loss < best_loss:
if not USE_CUDA or N_GPU == 1:
torch.save(model.state_dict(), output_model_file)
else:
torch.save(model.module.state_dict(), output_model_file)
last_update = epoch
best_loss = dev_loss
best_acc = dev_acc
logger.info(
"*** Model Updated: Epoch=%d, Validation Loss=%.6f, Validation Acc=%.6f, global_step=%d ***" % (
last_update, best_loss, best_acc, global_step))
else:
logger.info(
"*** Model NOT Updated: Epoch=%d, Validation Loss=%.6f, Validation Acc=%.6f, global_step=%d ***" % (
epoch, dev_loss, dev_acc, global_step))
if last_update + args.patience <= epoch:
break
def test(self, mode='dev', model_path=None):
'''Testing funciton of TRADE (to be added)'''
# Evaluation
self.load_weights(model_path)
if mode == 'test':
eval_examples = self.dev_examples
elif mode == 'dev':
eval_examples = self.test_examples
all_input_ids, all_input_len, all_label_ids = convert_examples_to_features(
eval_examples, self.label_list, args.max_seq_length, self.tokenizer, args.max_turn_length)
all_input_ids, all_input_len, all_label_ids = all_input_ids.to(DEVICE), all_input_len.to(
DEVICE), all_label_ids.to(DEVICE)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_data = TensorDataset(all_input_ids, all_input_len, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.dev_batch_size)
model = self.sumbt_model
eval_loss, eval_accuracy = 0, 0
eval_loss_slot, eval_acc_slot = None, None
nb_eval_steps, nb_eval_examples = 0, 0
accuracies = {'joint7': 0, 'slot7': 0, 'joint5': 0, 'slot5': 0, 'joint_rest': 0, 'slot_rest': 0,
'num_turn': 0, 'num_slot7': 0, 'num_slot5': 0, 'num_slot_rest': 0}
for input_ids, input_len, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
# if input_ids.dim() == 2:
# input_ids = input_ids.unsqueeze(0)
# input_len = input_len.unsqueeze(0)
# label_ids = label_ids.unsuqeeze(0)
with torch.no_grad():
if not USE_CUDA or N_GPU == 1:
loss, loss_slot, acc, acc_slot, pred_slot = model(input_ids, input_len, label_ids, 1)
else:
loss, _, acc, acc_slot, pred_slot = model(input_ids, input_len, label_ids, N_GPU)
nbatch = label_ids.size(0)
nslot = pred_slot.size(3)
pred_slot = pred_slot.view(nbatch, -1, nslot)
accuracies = eval_all_accs(pred_slot, label_ids, accuracies)
nb_eval_ex = (label_ids[:, :, 0].view(-1) != -1).sum().item()
nb_eval_examples += nb_eval_ex
nb_eval_steps += 1
if not USE_CUDA or N_GPU == 1:
eval_loss += loss.item() * nb_eval_ex
eval_accuracy += acc.item() * nb_eval_ex
if eval_loss_slot is None:
eval_loss_slot = [l * nb_eval_ex for l in loss_slot]
eval_acc_slot = acc_slot * nb_eval_ex
else:
for i, l in enumerate(loss_slot):
eval_loss_slot[i] = eval_loss_slot[i] + l * nb_eval_ex
eval_acc_slot += acc_slot * nb_eval_ex
else:
eval_loss += sum(loss) * nb_eval_ex
eval_accuracy += sum(acc) * nb_eval_ex
eval_loss = eval_loss / nb_eval_examples
eval_accuracy = eval_accuracy / nb_eval_examples
if not USE_CUDA or N_GPU == 1:
eval_acc_slot = eval_acc_slot / nb_eval_examples
loss = None
if not USE_CUDA or N_GPU == 1:
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'loss': loss,
'eval_loss_slot': '\t'.join([str(val / nb_eval_examples) for val in eval_loss_slot]),
'eval_acc_slot': '\t'.join([str((val).item()) for val in eval_acc_slot])
}
else:
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'loss': loss
}
out_file_name = 'eval_results'
# if TARGET_SLOT == 'all':
# out_file_name += '_all'
output_eval_file = os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "%s.txt" % out_file_name)
if not USE_CUDA or N_GPU == 1:
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
out_file_name = 'eval_all_accuracies'
with open(os.path.join(os.path.join(SUMBT_PATH, args.output_dir), "%s.txt" % out_file_name), 'w') as f:
s = '{:^22s}:{:^22s}:{:^22s}:{:^22s}:{:^22s}:{:^22s}'.format(
'joint acc (7 domain)',
'slot acc (7 domain)',
'joint acc (5 domain)',
'slot acc (5 domain)',
'joint restaurant',
'slot acc restaurant')
f.write(s + '\n')
print(s)
s = '{:^22.5f}:{:^22.5f}:{:^22.5f}:{:^22.5f}:{:^22.5f}:{:^22.5f}'.format(
(accuracies['joint7'] / accuracies['num_turn']).item(),
(accuracies['slot7'] / accuracies['num_slot7']).item(),
(accuracies['joint5'] / accuracies['num_turn']).item(),
(accuracies['slot5'] / accuracies['num_slot5']).item(),
(accuracies['joint_rest'] / accuracies['num_turn']).item(),
(accuracies['slot_rest'] / accuracies['num_slot_rest']).item()
)
f.write(s + '\n')
print(s)
def init_session(self):
self.state = default_state()
if not self.param_restored:
if os.path.isfile(os.path.join(DOWNLOAD_DIRECTORY, 'pytorch_model.bin')):
print('loading weights from downloaded model')
self.load_weights(model_path=os.path.join(DOWNLOAD_DIRECTORY, 'pytorch_model.bin'))
elif os.path.isfile(os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin')):
print('loading weights from trained model')
self.load_weights(model_path=os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin'))
else:
raise ValueError('no availabel weights found.')
self.param_restored = True
def update(self, user_act=None):
"""Update the dialogue state with the generated tokens from TRADE"""
if not isinstance(user_act, str):
raise Exception(
'Expected user_act is str but found {}'.format(type(user_act))
)
prev_state = self.state
actual_history = copy.deepcopy(prev_state['history'])
query = self.construct_query(actual_history)
pred_states = self.predict(query)
new_belief_state = copy.deepcopy(prev_state['belief_state'])
for state in pred_states:
domain, slot, value = state.split('-', 2)
if slot not in ['name', 'book']:
if domain not in new_belief_state:
if domain == 'bus':
continue
else:
raise Exception(
'Error: domain <{}> not in belief state'.format(domain))
# slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
assert 'semi' in new_belief_state[domain]
assert 'book' in new_belief_state[domain]
if '预订' in slot:
assert slot.startswith('预订')
domain_dic = new_belief_state[domain]
if slot in domain_dic['semi']:
new_belief_state[domain]['semi'][slot] = value
# normalize_value(self.value_dict, domain, slot, value)
elif slot in domain_dic['book']:
new_belief_state[domain]['book'][slot] = value
elif slot.lower() in domain_dic['book']:
new_belief_state[domain]['book'][slot.lower()] = value
else:
with open('trade_tracker_unknown_slot.log', 'a+') as f:
f.write(
'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value, domain, state)
)
# new_request_state = copy.deepcopy(prev_state['request_state'])
# # update request_state
# user_request_slot = self.detect_requestable_slots(user_act)
# for domain in user_request_slot:
# for key in user_request_slot[domain]:
# if domain not in new_request_state:
# new_request_state[domain] = {}
# if key not in new_request_state[domain]:
# new_request_state[domain][key] = user_request_slot[domain][key]
new_state = copy.deepcopy(dict(prev_state))
new_state['belief_state'] = new_belief_state
# new_state['request_state'] = new_request_state
self.state = new_state
# print((pred_states, query))
return self.state
def predict(self, query):
cache_query_key = ''.join(str(list(chain.from_iterable(query[0]))))
if cache_query_key in self.cached_res.keys():
return self.cached_res[cache_query_key]
input_ids, input_len = query
input_ids = torch.tensor(input_ids).to(self.device).unsqueeze(0)
input_len = torch.tensor(input_len).to(self.device).unsqueeze(0)
labels = None
_, pred_slot = self.sumbt_model(input_ids, input_len, labels)
pred_slot_t = pred_slot[0][-1].tolist()
predict_belief = []
for idx, i in enumerate(pred_slot_t):
predict_belief.append('{}-{}'.format(self.target_slot[idx], self.label_map_inv[idx][i]))
self.cached_res[cache_query_key] = predict_belief
return predict_belief
def construct_query(self, context):
'''Construct query from context'''
ids = []
lens = []
context_len = len(context)
if context[0][0] != 'sys':
context = [['sys', '']] + context
for i in range(0, context_len, 2):
# utt_user = ''
# utt_sys = ''
# for evaluation
utt_sys = context[i][1]
utt_user = context[i + 1][1]
tokens_user = [x if x != '#' else '[SEP]' for x in self.tokenizer.tokenize(utt_user)]
tokens_sys = [x if x != '#' else '[SEP]' for x in self.tokenizer.tokenize(utt_sys)]
_truncate_seq_pair(tokens_user, tokens_sys, self.args.max_seq_length - 3)
tokens = ["[CLS]"] + tokens_user + ["[SEP]"] + tokens_sys + ["[SEP]"]
input_len = [len(tokens_user) + 2, len(tokens_sys) + 1]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
padding = [0] * (self.args.max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == self.args.max_seq_length
ids.append(input_ids)
lens.append(input_len)
return (ids, lens)
def detect_requestable_slots(self, observation):
result = {}
observation = observation.lower()
_observation = ' {} '.format(observation)
for value in self.det_dic.keys():
_value = ' {} '.format(value.strip())
if _value in _observation:
key, domain = self.det_dic[value].split('-')
if domain not in result:
result[domain] = {}
result[domain][key] = 0
return result
import os
import convlab2
class DotMap():
def __init__(self):
self.max_label_length = 32
self.max_turn_length = 22
self.num_rnn_layers = 1
self.zero_init_rnn = False
self.attn_head = 4
self.do_eval = True
self.do_train = False
self.train_batch_size = 3
self.dev_batch_size = 1
self.eval_batch_size = 1
self.learning_rate = 5e-5
self.num_train_epochs = 3
self.patience = 10
self.warmup_proportion = 0.1
self.local_rank = -1
self.seed = 42
self.gradient_accumulation_steps = 1
self.fp16 = False
self.loss_scale = 0
self.do_not_use_tensorboard = False
self.fix_utterance_encoder = False
self.do_eval = True
self.num_train_epochs = 300
self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-chinese-wwm-ext")
self.do_lower_case = True
self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn'
# self.output_dir = os.path.join(path, 'ckpt/')
self.target_slot = 'all'
self.learning_rate = 5e-5
self.distance_metric = 'euclidean'
self.patience = 15
self.hidden_dim = 300
self.max_label_length = 32
self.max_seq_length = 64
self.max_turn_length = 22
self.fp16_loss_scale = 0.0
self.data_dir = 'data/multiwoz_zh/'
self.tf_dir = 'tensorboard'
self.tmp_data_dir = 'processed_data/'
self.output_dir = 'model_output/'
args = DotMap()
\ No newline at end of file
import csv
import os
import json
import collections
import logging
import re
import torch
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if len(line) > 0 and line[0][0] == '#': # ignore comments (starting with '#')
continue
lines.append(line)
return lines
class Processor(DataProcessor):
"""Processor for the belief tracking dataset (GLUE version)."""
def __init__(self, config):
super(Processor, self).__init__()
print(config)
# MultiWOZ dataset
with open(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), config.data_dir, "ontology.json"), "r") as fp_ontology:
ontology = json.load(fp_ontology)
for slot in ontology.keys():
ontology[slot].append("未提及")
if config.target_slot != 'all':
raise Exception('unsupported')
# sorting the ontology according to the alphabetic order of the slots
ontology = collections.OrderedDict(sorted(ontology.items()))
# select slots to train
nslots = len(ontology.keys())
target_slot = list(ontology.keys())
self.target_slot_idx = [*range(0, nslots)]
for idx in range(0, nslots):
if not idx in self.target_slot_idx:
del ontology[target_slot[idx]]
self.ontology = ontology
self.target_slot = list(self.ontology.keys())
for i, slot in enumerate(self.target_slot):
if slot == "pricerange":
self.target_slot[i] = "price range"
logger.info('Processor: target_slot')
logger.info(self.target_slot)
def get_train_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", accumulation)
def get_dev_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", accumulation)
def get_test_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test", accumulation)
def get_labels(self):
"""See base class."""
return [self.ontology[slot] for slot in self.target_slot]
def _create_examples(self, lines, set_type, accumulation=False):
"""Creates examples for the training and dev sets."""
prev_dialogue_index = None
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s-%s" % (set_type, line[0], line[1]) # line[0]: dialogue index, line[1]: turn index
if accumulation:
if prev_dialogue_index is None or prev_dialogue_index != line[0]:
text_a = line[2]
text_b = line[3]
prev_dialogue_index = line[0]
else:
# The symbol '#' will be replaced with '[SEP]' after tokenization.
text_a = line[2] + " # " + text_a
text_b = line[3] + " # " + text_b
else:
text_a = line[2] # line[2]: user utterance
text_b = line[3] # line[3]: system response
label = [line[4 + idx] for idx in self.target_slot_idx]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def normalize_text(text):
global replacements
# lower case every word
text = text.lower()
# replace white spaces in front and end
text = re.sub(r'^\s*|\s*$', '', text)
# hotel domain pfb30
text = re.sub(r"b&b", "bed and breakfast", text)
text = re.sub(r"b and b", "bed and breakfast", text)
# replace st.
text = text.replace(';', ',')
text = re.sub('$\/', '', text)
text = text.replace('/', ' and ')
# replace other special characters
text = text.replace('-', ' ')
text = re.sub('[\"\<>@\(\)]', '', text) # remove
# insert white space before and after tokens:
for token in ['?', '.', ',', '!']:
text = insertSpace(token, text)
# insert white space for 's
text = insertSpace('\'s', text)
# replace it's, does't, you'd ... etc
text = re.sub('^\'', '', text)
text = re.sub('\'$', '', text)
text = re.sub('\'\s', ' ', text)
text = re.sub('\s\'', ' ', text)
for fromx, tox in replacements:
text = ' ' + text + ' '
text = text.replace(fromx, tox)[1:-1]
# remove multiple spaces
text = re.sub(' +', ' ', text)
# concatenate numbers
tmp = text
tokens = text.split()
i = 1
while i < len(tokens):
if re.match(u'^\d+$', tokens[i]) and \
re.match(u'\d+$', tokens[i - 1]):
tokens[i - 1] += tokens[i]
del tokens[i]
else:
i += 1
text = ' '.join(tokens)
return text
def insertSpace(token, text):
sidx = 0
while True:
sidx = text.find(token, sidx)
if sidx == -1:
break
if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
re.match('[0-9]', text[sidx + 1]):
sidx += 1
continue
if text[sidx - 1] != ' ':
text = text[:sidx] + ' ' + text[sidx:]
sidx += 1
if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
text = text[:sidx + 1] + ' ' + text[sidx + 1:]
sidx += 1
return text
def get_label_embedding(labels, max_seq_length, tokenizer, device):
features = []
for label in labels:
label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)
label_len = len(label_token_ids)
label_padding = [0] * (max_seq_length - len(label_token_ids))
label_token_ids += label_padding
assert len(label_token_ids) == max_seq_length
features.append((label_token_ids, label_len))
all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
return all_label_token_ids, all_label_len
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x / warmup
return 1.0 - x
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_len, label_id):
self.input_ids = input_ids
self.input_len = input_len
self.label_id = label_id
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, max_turn_length):
"""Loads a data file into a list of `InputBatch`s."""
label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
slot_dim = len(label_list)
features = []
prev_dialogue_idx = None
all_padding = [0] * max_seq_length
all_padding_len = [0, 0]
max_turn = 0
for (ex_index, example) in enumerate(examples):
if max_turn < int(example.guid.split('-')[2]):
max_turn = int(example.guid.split('-')[2])
max_turn_length = min(max_turn + 1, max_turn_length)
logger.info("max_turn_length = %d" % max_turn)
for (ex_index, example) in enumerate(examples):
tokens_a = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_a)]
tokens_b = None
if example.text_b:
tokens_b = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_b)]
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
input_len = [len(tokens), 0]
if tokens_b:
tokens += tokens_b + ["[SEP]"]
input_len[1] = len(tokens_b) + 1
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == max_seq_length
FLAG_TEST = False
if example.label is not None:
label_id = []
label_info = 'label: '
for i, label in enumerate(example.label):
if label == 'dontcare':
label = 'do not care'
label_id.append(label_map[i][label])
label_info += '%s (id = %d) ' % (label, label_map[i][label])
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_len: %s" % " ".join([str(x) for x in input_len]))
logger.info("label: " + label_info)
else:
FLAG_TEST = True
label_id = None
curr_dialogue_idx = example.guid.split('-')[1]
curr_turn_idx = int(example.guid.split('-')[2])
if prev_dialogue_idx is not None and prev_dialogue_idx != curr_dialogue_idx:
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
if prev_dialogue_idx is None or prev_turn_idx < max_turn_length:
features.append(
InputFeatures(input_ids=input_ids,
input_len=input_len,
label_id=label_id))
prev_dialogue_idx = curr_dialogue_idx
prev_turn_idx = curr_turn_idx
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_len = torch.tensor([f.input_len for f in features], dtype=torch.long)
if not FLAG_TEST:
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
# reshape tensors to [#batch, #max_turn_length, #max_seq_length]
all_input_ids = all_input_ids.view(-1, max_turn_length, max_seq_length)
all_input_len = all_input_len.view(-1, max_turn_length, 2)
if not FLAG_TEST:
all_label_ids = all_label_ids.view(-1, max_turn_length, slot_dim)
else:
all_label_ids = None
return all_input_ids, all_input_len, all_label_ids
def eval_all_accs(pred_slot, labels, accuracies):
def _eval_acc(_pred_slot, _labels):
slot_dim = _labels.size(-1)
accuracy = (_pred_slot == _labels).view(-1, slot_dim)
num_turn = torch.sum(_labels[:, :, 0].view(-1) > -1, 0).float()
num_data = torch.sum(_labels > -1).float()
# joint accuracy
# joint_acc = sum(torch.sum(accuracy, 1) / slot_dim).float()
num_slots = accuracy.shape[1]
joint_acc = sum(torch.sum(accuracy, 1) == num_slots)
# slot accuracy
slot_acc = torch.sum(accuracy).float()
return joint_acc, slot_acc, num_turn, num_data
# 7 domains
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot, labels)
accuracies['joint7'] += joint_acc
accuracies['slot7'] += slot_acc
accuracies['num_turn'] += num_turn
accuracies['num_slot7'] += num_data
# restaurant domain
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot[:,:,18:25], labels[:,:,18:25])
accuracies['joint_rest'] += joint_acc
accuracies['slot_rest'] += slot_acc
accuracies['num_slot_rest'] += num_data
pred_slot5 = torch.cat((pred_slot[:,:,0:3], pred_slot[:,:,8:]), 2)
label_slot5 = torch.cat((labels[:,:,0:3], labels[:,:,8:]), 2)
# 5 domains (excluding bus and hotel domain)
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot5, label_slot5)
accuracies['joint5'] += joint_acc
accuracies['slot5'] += slot_acc
accuracies['num_slot5'] += num_data
return accuracies
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment