Skip to content
Snippets Groups Projects
Commit 9d0061d3 authored by truthless11's avatar truthless11
Browse files

Merge branch 'master' into dev

parents 8b5143ce acfef62a
Branches
No related tags found
No related merge requests found
Showing
with 2380 additions and 92 deletions
# extras
Pipfile*
results*
*.pyc
__pycache__
......@@ -13,6 +17,7 @@ __pycache__
data/**/train.json
data/**/val.json
data/**/test.json
data/**/human_val.json
data/camrest/CamRest676_v2.json
data/multiwoz/annotated_user_da_with_span_full.json
data/schema/dstc8-schema-guided-dialogue-master
......@@ -32,6 +37,9 @@ convlab2/nlu/jointBERT/**/output/
convlab2/dst/sumbt/multiwoz/output/
convlab2/nlg/sclstm/**/generated_sens_sys.json
convlab2/nlg/template/**/generated_sens_sys.json
convlab2/nlu/jointBERT/crosswoz/**/data
convlab2/nlu/jointBERT/multiwoz/**/data
# test script
*_test.py
......@@ -65,3 +73,6 @@ convlab2/dst/trade/multiwoz_config/
deploy/bert_multiwoz_all.zip
deploy/templates/dialog_eg.html
test.py
*.egg-info
pre-trained-models/
\ No newline at end of file
......@@ -9,31 +9,27 @@ install:
- pip install --upgrade pip
- pip install --progress-bar off -e .[develop]
- pip install sphinx
- pip install sphinx_rtd_theme
script:
# - python setup.py test
# - cd docs && rm source/convlab2.*.rst
# - sphinx-apidoc -o ./source ../convlab2/
# - cd source && python gen_rst.py --project convlab2 && cd ..
# - make html
# - cd source
# - python modify_py_modindex.py -d ../build/html/
# - cd ..
# - mv ./build/html ./build/docs && rm -r ./build/doctrees && mv LICENSE.txt ./build && mv README.md ./build && cd ..
- cd docs
- make html && mv ./build/html ./build/docs && rm -r ./build/doctrees
- cd ..
- gem install travis --no-document
deploy:
- provider: pages
skip-cleanup: true
github-token: $DEPLOY_KEY
keep-history: true
repo: thu-coai/convlab2_docs
repo: thu-coai/ConvLab-2_docs
target-branch: master
local-dir: ./docs/build
on:
branch: master
- provider: script
skip-cleanup: true
script: coveralls
on:
all_branches: true
condition: true
# - provider: script
# skip-cleanup: true
# script: coveralls
# on:
# all_branches: true
# condition: true
# ConvLab-2
[![Build Status](https://travis-ci.com/thu-coai/ConvLab-2.svg?branch=master)](https://travis-ci.com/thu-coai/ConvLab-2)
**ConvLab-2** is an open-source toolkit that enables researchers to build task-oriented dialogue systems with state-of-the-art models, perform an end-to-end evaluation, and diagnose the weakness of systems. As the successor of [ConvLab](https://github.com/ConvLab/ConvLab), ConvLab-2 inherits ConvLab's framework but integrates more powerful dialogue models and supports more datasets. Besides, we have developed an analysis tool and an interactive tool to assist researchers in diagnosing dialogue systems. [[paper]](https://arxiv.org/abs/2002.04793)
- [Installation](#installation)
- [Tutorials](#tutorials)
- [Documents](#documents)
- [Models](#models)
- [Supported Dataset](#Supported-Dataset)
- [Supported Datasets](#Supported-Datasets)
- [End-to-end Performance on MultiWOZ](#End-to-end-Performance-on-MultiWOZ)
- [Module Performance on MultiWOZ](#Module-Performance-on-MultiWOZ)
- [Issues](#issues)
......@@ -33,7 +36,10 @@ pip install -e .
- [Getting Started](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Getting_Started.ipynb) (Have a try on [Colab](https://colab.research.google.com/github/thu-coai/ConvLab-2/blob/master/tutorials/Getting_Started.ipynb)!)
- [Add New Model](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Add_New_Model.md)
- [Train RL Policies](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Train_RL_Policies)
- [Interactive Tool](https://github.com/thu-coai/ConvLab-2/blob/master/deploy) [[demo video]](https://drive.google.com/file/d/1HR3mjhgLL0g9IbqU443NsH2G0-PpAsog/view?usp=sharing)
- [Interactive Tool](https://github.com/thu-coai/ConvLab-2/blob/master/deploy) [[demo video]](https://youtu.be/00VWzbcx26E)
## Documents
Our documents are on https://thu-coai.github.io/ConvLab-2_docs/convlab2.html.
## Models
......@@ -69,6 +75,8 @@ For more details about these models, You can refer to `README.md` under `convla
## End-to-end Performance on MultiWOZ
*Notice*: The results are for commits before [`bdc9dba`](https://github.com/thu-coai/ConvLab-2/commit/bdc9dba72c957d97788e533f9458ed03a4b0137b) (inclusive). We will update the results after improving user policy.
We perform end-to-end evaluation (1000 dialogues) on MultiWOZ using the user simulator below (a full example on `tests/test_end2end.py`) :
```python
......@@ -141,6 +149,8 @@ By running `convlab2/dst/evaluate.py MultiWOZ $model`:
### Policy
*Notice*: The results are for commits before [`bdc9dba`](https://github.com/thu-coai/ConvLab-2/commit/bdc9dba72c957d97788e533f9458ed03a4b0137b) (inclusive). We will update the results after improving user policy.
By running `convlab2/policy/evalutate.py --model_name $model`
| | Task Success Rate |
......@@ -159,6 +169,68 @@ By running `convlab2/nlg/evaluate.py MultiWOZ $model sys`
| Template | 0.3309 |
| SCLSTM | 0.4884 |
## Translation-train SUMBT for cross-lingual DST
### Train
With Convlab-2, you can train SUMBT on a machine-translated dataset like this:
```python
# train.py
import os
from sys import argv
if __name__ == "__main__":
if len(argv) != 2:
print('usage: python3 train.py [dataset]')
exit(1)
assert argv[1] in ['multiwoz', 'crosswoz']
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBT_PATH
if argv[1] == 'multiwoz':
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker as SUMBT
elif argv[1] == 'crosswoz':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker as SUMBT
sumbt = SUMBT()
sumbt.train(True)
```
### Evaluate
Execute `evaluate.py` (under `convlab2/dst/`) with following command:
```bash
python3 evaluate.py [CrossWOZ-en|MultiWOZ-zh] [val|test|human_val]
```
evaluation of our pre-trained models are: (joint acc.)
| type | CrossWOZ-en | MultiWOZ-zh |
| ----- | ----------- | ----------- |
| val | 12.2% | 44.8% |
| test | 12.4% | 42.3% |
| human_val | 10.9% | 48.2% |
`human_val` option will make the model evaluate on the validation set translated by human.
Note: You may want to download pre-traiend BERT models and translation-train SUMBT models provided by us.
Without modifying any code, you could:
- download pre-trained BERT models from:
- [bert-base-uncased](https://huggingface.co/bert-base-uncased) for CrossWOZ-en
- [chinese-bert-wwm-ext](https://huggingface.co/hfl/chinese-bert-wwm-ext) for MultiWOZ-zh
extract it to `./pre-trained-models`.
- for translation-train SUMBT model:
- [trained on CrossWOZ-en](https://convlab.blob.core.windows.net/convlab-2/crosswoz_en-pytorch_model.bin.zip)
- [trained on MultiWOZ-zh](https://convlab.blob.core.windows.net/convlab-2/multiwoz_zh-pytorch_model.bin.zip)
- Say the data set is CrossWOZ (English), (after extraction) just save the pre-trained model under `./convlab2/dst/sumbt/crosswoz_en/pre-trained` and name it with `pytorch_model.bin`.
## Issues
You are welcome to create an issue if you want to request a feature, report a bug or ask a general question.
......@@ -177,7 +249,7 @@ We welcome contributions from community.
We would like to thank:
Yan Fang, Zhuoer Feng, Jianfeng Gao, Qihan Guo, Kaili Huang, Minlie Huang, Sungjin Lee, Bing Li, Jinchao Li, Xiang Li, Xiujun Li, Wenchang Ma, Baolin Peng, Runze Liang, Ryuichi Takanobu, Jiaxin Wen, Yaoqin Zhang, Zheng Zhang, Qi Zhu, Xiaoyan Zhu.
Yan Fang, Zhuoer Feng, Jianfeng Gao, Qihan Guo, Kaili Huang, Minlie Huang, Sungjin Lee, Bing Li, Jinchao Li, Xiang Li, Xiujun Li, Lingxiao Luo, Wenchang Ma, Mehrad Moradshahi, Baolin Peng, Runze Liang, Ryuichi Takanobu, Hongru Wang, Jiaxin Wen, Yaoqin Zhang, Zheng Zhang, Qi Zhu, Xiaoyan Zhu.
## Citing
......
......@@ -10,3 +10,7 @@ from os.path import abspath, dirname
def get_root_path():
return dirname(dirname(abspath(__file__)))
import os
DATA_ROOT = os.path.join(get_root_path(), 'data')
\ No newline at end of file
......@@ -28,6 +28,7 @@ class Environment():
self.evaluator.add_sys_da(self.usr.get_in_da())
self.evaluator.add_usr_da(self.usr.get_out_da())
dialog_act = self.sys_nlu.predict(observation) if self.sys_nlu else observation
self.sys_dst.state['user_action'] = dialog_act
state = self.sys_dst.update(dialog_act)
if self.evaluator:
......
# -*- coding: gbk -*-
# -*- coding: utf-8 -*-
"""
Evaluate NLU models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule]
Evaluate DST models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ|MultiWOZ-zh|CrossWOZ-en] [TRADE|mdbt|sumbt] [val|test|human_val]
"""
import random
import numpy
......@@ -12,8 +12,9 @@ import copy
import jieba
multiwoz_slot_list = ['attraction-area', 'attraction-name', 'attraction-type', 'hotel-day', 'hotel-people', 'hotel-stay', 'hotel-area', 'hotel-internet', 'hotel-name', 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-type', 'restaurant-day', 'restaurant-people', 'restaurant-time', 'restaurant-area', 'restaurant-food', 'restaurant-name', 'restaurant-pricerange', 'taxi-arriveby', 'taxi-departure', 'taxi-destination', 'taxi-leaveat', 'train-people', 'train-arriveby', 'train-day', 'train-departure', 'train-destination', 'train-leaveat']
crosswoz_slot_list = ["쒼듐-쳔튿", "쒼듐-팀롸", "꽜반-츰냔", "아듦-송목", "아듦-팀롸", "쒼듐-츰냔", "쒼듐-뒈囹", "쒼듐-踏鯤珂쇌", "꽜반-檀撚珂쇌", "꽜반-팀롸", "아듦-츰냔", "아듦-鷺긋쒼듐", "아듦-아듦嘉-싻今륩蛟", "아듦-아듦잚謹", "꽜반-훙엇句롤", "꽜반-股수꽉", "아듦-아듦嘉", "아듦-든뺐", "쒼듐-든뺐", "꽜반-鷺긋꽜반", "꽜반-든뺐", "꽜반-none", "꽜반-뒈囹", "아듦-아듦嘉-轟緊렛", "아듦-뒈囹", "쒼듐-鷺긋쒼듐", "쒼듐-鷺긋아듦", "놔理-놔랙뒈", "놔理-커돨뒈", "뒈屆-놔랙뒈", "뒈屆-커돨뒈", "쒼듐-鷺긋꽜반", "아듦-鷺긋꽜반", "놔理-났謹", "꽜반-鷺긋쒼듐", "꽜반-鷺긋아듦", "뒈屆-놔랙뒈맒쐤뒈屆籃", "뒈屆-커돨뒈맒쐤뒈屆籃", "쒼듐-none", "아듦-아듦嘉-蛟櫓懃", "꽜반-都쥴堵", "아듦-아듦嘉-櫓駕꽜戒", "아듦-아듦嘉-쌈籃륩蛟", "아듦-아듦嘉-벌셥낀槁든뺐", "아듦-아듦嘉-뉘루샙", "아듦-아듦嘉-삔累杆", "아듦-都쥴堵", "아듦-none", "아듦-아듦嘉-욱던貢", "아듦-아듦嘉-였빱鬼벚륩蛟", "아듦-아듦嘉-아듦몹뇹瓊묩wifi", "아듦-아듦嘉-킁폭", "아듦-아듦嘉-spa", "놔理-났탬", "쒼듐-都쥴堵", "아듦-아듦嘉-契쟀셍닸", "아듦-아듦嘉-鮫駕꽜戒", "아듦-아듦嘉-아걸", "아듦-아듦嘉-豆꽜륩蛟", "아듦-아듦嘉-숯렛", "아듦-아듦嘉-꽥섣훙嘉", "아듦-아듦嘉-출롤懇코든뺐", "아듦-아듦嘉-쌈덤棍깟", "아듦-아듦嘉-꼬롸렛쇌瓊묩wifi", "아듦-아듦嘉-求擄륩蛟", "아듦-아듦嘉-理났", "아듦-아듦嘉-무묾혐堵뵨꼬롸렛쇌瓊묩wifi", "아듦-아듦嘉-24鬼珂훑彊", "아듦-아듦嘉-侊홋", "아듦-아듦嘉-컬", "아듦-아듦嘉-澗롤界났貫", "아듦-鷺긋아듦", "아듦-아듦嘉-쌈샙륩蛟", "아듦-아듦嘉-杰唐렛쇌瓊묩wifi", "아듦-아듦嘉-펙탬杆", "아듦-아듦嘉-출롤벌코낀槁든뺐", "아듦-아듦嘉-杆코踏曇넥", "아듦-아듦嘉-豆꽜륩蛟출롤", "아듦-아듦嘉-무묾혐堵瓊묩wifi", "아듦-아듦嘉-杆棍踏曇넥"]
crosswoz_slot_list = ["景点-门票", "景点-评分", "餐馆-名称", "酒店-价格", "酒店-评分", "景点-名称", "景点-地址", "景点-游玩时间", "餐馆-营业时间", "餐馆-评分", "酒店-名称", "酒店-周边景点", "酒店-酒店设施-叫醒服务", "酒店-酒店类型", "餐馆-人均消费", "餐馆-推荐菜", "酒店-酒店设施", "酒店-电话", "景点-电话", "餐馆-周边餐馆", "餐馆-电话", "餐馆-none", "餐馆-地址", "酒店-酒店设施-无烟房", "酒店-地址", "景点-周边景点", "景点-周边酒店", "出租-出发地", "出租-目的地", "地铁-出发地", "地铁-目的地", "景点-周边餐馆", "酒店-周边餐馆", "出租-车型", "餐馆-周边景点", "餐馆-周边酒店", "地铁-出发地附近地铁站", "地铁-目的地附近地铁站", "景点-none", "酒店-酒店设施-商务中心", "餐馆-源领域", "酒店-酒店设施-中式餐厅", "酒店-酒店设施-接站服务", "酒店-酒店设施-国际长途电话", "酒店-酒店设施-吹风机", "酒店-酒店设施-会议室", "酒店-源领域", "酒店-none", "酒店-酒店设施-宽带上网", "酒店-酒店设施-看护小孩服务", "酒店-酒店设施-酒店各处提供wifi", "酒店-酒店设施-暖气", "酒店-酒店设施-spa", "出租-车牌", "景点-源领域", "酒店-酒店设施-行李寄存", "酒店-酒店设施-西式餐厅", "酒店-酒店设施-酒吧", "酒店-酒店设施-早餐服务", "酒店-酒店设施-健身房", "酒店-酒店设施-残疾人设施", "酒店-酒店设施-免费市内电话", "酒店-酒店设施-接待外宾", "酒店-酒店设施-部分房间提供wifi", "酒店-酒店设施-洗衣服务", "酒店-酒店设施-租车", "酒店-酒店设施-公共区域和部分房间提供wifi", "酒店-酒店设施-24小时热水", "酒店-酒店设施-温泉", "酒店-酒店设施-桑拿", "酒店-酒店设施-收费停车位", "酒店-周边酒店", "酒店-酒店设施-接机服务", "酒店-酒店设施-所有房间提供wifi", "酒店-酒店设施-棋牌室", "酒店-酒店设施-免费国内长途电话", "酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"]
from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list
from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list
def format_history(context):
history = []
......@@ -37,7 +38,7 @@ def reformat_state(state):
domain_data = domain_data['semi']
for slot in domain_data.keys():
val = domain_data[slot]
if val is not None and val != '' and val != 'not mentioned':
if val is not None and val not in ['', 'not mentioned', '未提及', '未提到', '没有提到']:
new_state.append(domain + '-' + slot + '-' + val)
# lower
new_state = [item.lower() for item in new_state]
......@@ -47,19 +48,27 @@ def reformat_state_crosswoz(state):
if 'belief_state' in state:
state = state['belief_state']
new_state = []
# print(state)
for domain in state.keys():
domain_data = state[domain]
for slot in domain_data.keys():
if slot == 'selectedResults': continue
val = domain_data[slot]
if val is not None and val != '':
if slot == 'Hotel Facilities' and val not in ['', 'none']:
for facility in val.split(','):
new_state.append(domain + '-' + f'Hotel Facilities - {facility}' + 'yes')
else:
if val is not None and val not in ['', 'none']:
# print(domain, slot, val)
new_state.append(domain + '-' + slot + '-' + val)
return new_state
def compute_acc(gold, pred, slot_temp):
# TODO: not mentioned in gold
miss_gold = 0
miss_slot = []
# print(gold, pred)
for g in gold:
if g not in pred:
miss_gold += 1
......@@ -124,34 +133,44 @@ if __name__ == '__main__':
numpy.random.seed(seed)
torch.manual_seed(seed)
if len(sys.argv) != 3:
if len(sys.argv) != 4:
print("usage:")
print("\t python evaluate.py dataset model")
print("\t dataset=MultiWOZ, CrossWOZ")
print("\t dataset=MultiWOZ, MultiWOZ-zh, CrossWOZ, CrossWOZ-en")
print("\t model=TRADE, mdbt, sumbt")
print("\t val=[val|test|human_val]")
sys.exit()
## init phase
dataset_name = sys.argv[1]
model_name = sys.argv[2]
if dataset_name == 'MultiWOZ':
if model_name == 'TRADE':
data_key = sys.argv[3]
if dataset_name.startswith('MultiWOZ'):
if dataset_name.endswith('zh'):
if model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
raise Exception("Available models: sumbt")
else:
if model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
model = SUMBTTracker()
elif model_name == 'TRADE':
from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE
model = MultiWOZTRADE()
elif model_name == 'mdbt':
from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT
model = MultiWozMDBT()
elif model_name == 'sumbt':
from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
raise Exception("Available models: TRADE/mdbt/sumbt")
## load data
from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader())
data = dataloader.load_data(data_key='test')['test']
dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader(dataset_name.endswith('zh')))
data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['belief_state']
all_predictions = {}
test_set = []
......@@ -160,7 +179,6 @@ if __name__ == '__main__':
turn_count = 0
is_start = True
for i in tqdm(range(len(context))):
# for i in tqdm(range(200)): # for test
if len(context[i]) == 0:
turn_count = 0
if is_start:
......@@ -181,19 +199,23 @@ if __name__ == '__main__':
'turn_belief': reformat_state(y),
'pred_bs_ptr': reformat_state(pred)
}
# print('golden: ', reformat_state(y))
# print('pred :', reformat_state(pred))
turn_count += 1
# add last session
if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess)
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", multiwoz_slot_list)
slot_list = multiwoz_zh_slot_list if dataset_name.endswith('zh') else multiwoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr}
print(evaluation_metrics)
elif dataset_name == 'CrossWOZ':
elif dataset_name.startswith('CrossWOZ'):
en = dataset_name.endswith('en')
if en:
if model_name == 'sumbt':
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker
model = SUMBTTracker()
else:
if model_name == 'TRADE':
from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE
model = CrossWOZTRADE()
......@@ -210,8 +232,8 @@ if __name__ == '__main__':
from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader
dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader())
data = dataloader.load_data(data_key='test')['test']
dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader(en))
data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['sys_state_init']
all_predictions = {}
test_set = []
......@@ -220,7 +242,6 @@ if __name__ == '__main__':
turn_count = 0
is_start = True
for i in tqdm(range(len(context))):
# for i in tqdm(range(10)): # for test
if len(context[i]) == 0:
turn_count = 0
if is_start:
......@@ -229,12 +250,17 @@ if __name__ == '__main__':
all_predictions[session_count] = copy.deepcopy(curr_sess)
session_count += 1
curr_sess = {}
# skip usr turn
if len(context[i]) % 2 == 0:
continue
# add turn
x = context[i]
y = golden_truth[i]
# process y
if not en:
for domain in y.keys():
domain_data = y[domain]
for slot in domain_data.keys():
......@@ -242,9 +268,9 @@ if __name__ == '__main__':
val = domain_data[slot]
if val is not None and val != '':
val = sentseg(val)
y[domain][slot] = val
domain_data[slot] = val
model.init_session()
model.state['history'] = format_history([sentseg(item) for item in context[i]])
model.state['history'] = format_history([item if en else sentseg(item) for item in context[i]])
pred = model.update(x[-1] if len(x) > 0 else '')
curr_sess[turn_count] = {
'turn_belief': reformat_state_crosswoz(y),
......@@ -255,8 +281,9 @@ if __name__ == '__main__':
if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess)
slot_list = crosswoz_en_slot_list if en else crosswoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr",
crosswoz_slot_list)
slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr}
print(evaluation_metrics)
*/model_output/
import os.path
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn import CosineEmbeddingLoss
from pytorch_pretrained_bert.modeling import BertModel
from pytorch_pretrained_bert.modeling import BertPreTrainedModel
from transformers import BertModel
from transformers import BertPreTrainedModel
class BertForUtteranceEncoding(BertPreTrainedModel):
......@@ -19,7 +17,7 @@ class BertForUtteranceEncoding(BertPreTrainedModel):
self.bert = BertModel(config)
def forward(self, input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False):
return self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers)
return self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, encoder_hidden_states=output_all_encoded_layers)
class MultiHeadAttention(nn.Module):
......@@ -93,7 +91,8 @@ class BeliefTracker(nn.Module):
self.device = device
### Utterance Encoder
self.utterance_encoder = BertForUtteranceEncoding.from_pretrained(args.bert_model)
self.utterance_encoder = BertForUtteranceEncoding.from_pretrained(args.bert_model_name, cache_dir=args.bert_model_cache_dir)
self.utterance_encoder.train()
self.bert_output_dim = self.utterance_encoder.config.hidden_size
self.hidden_dropout_prob = self.utterance_encoder.config.hidden_dropout_prob
if args.fix_utterance_encoder:
......@@ -101,7 +100,8 @@ class BeliefTracker(nn.Module):
p.requires_grad = False
### slot, slot-value Encoder (not trainable)
self.sv_encoder = BertForUtteranceEncoding.from_pretrained(args.bert_model)
self.sv_encoder = BertForUtteranceEncoding.from_pretrained(args.bert_model_name, cache_dir=args.bert_model_cache_dir)
self.sv_encoder.train()
for p in self.sv_encoder.bert.parameters():
p.requires_grad = False
......@@ -272,6 +272,8 @@ class BeliefTracker(nn.Module):
# calculate joint accuracy
pred_slot = torch.cat(pred_slot, 2)
# print('pred slot:', pred_slot[0][0])
# print('labels:', labels[0][0])
accuracy = (pred_slot == labels).view(-1, slot_dim)
acc_slot = torch.sum(accuracy, 0).float() \
/ torch.sum(labels.view(-1, slot_dim) > -1, 0).float()
......
model_output/
pre-trained/
from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker as SUMBT
import json
import zipfile
from convlab2.dst.sumbt.crosswoz_en.sumbt_config import *
null = 'none'
def trans_value(value):
trans = {
'': 'none',
}
value = value.strip()
value = trans.get(value, value)
value = value.replace('', "'")
value = value.replace('', "'")
return value
def convert_to_glue_format(data_dir, sumbt_dir):
if not os.path.isdir(os.path.join(sumbt_dir, args.tmp_data_dir)):
os.mkdir(os.path.join(sumbt_dir, args.tmp_data_dir))
### Read ontology file
with open(os.path.join(data_dir, "ontology.json"), "r") as fp_ont:
data_ont = json.load(fp_ont)
ontology = {}
facilities = []
for domain_slot in data_ont:
domain, slot = domain_slot.split('-', 1)
if domain not in ontology:
ontology[domain] = {}
if slot.startswith('Hotel Facilities'):
facilities.append(slot.split(' - ')[1])
ontology[domain][slot] = set(map(str.lower, data_ont[domain_slot]))
### Read woz logs and write to tsv files
tsv_filename = os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")
print('tsv file: ', os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"))
if os.path.exists(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")):
print('data has been processed!')
return 0
else:
print('processing data')
with open(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"), "w") as fp_train, \
open(os.path.join(sumbt_dir, args.tmp_data_dir, "dev.tsv"), "w") as fp_dev, \
open(os.path.join(sumbt_dir, args.tmp_data_dir, "test.tsv"), "w") as fp_test:
fp_train.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_dev.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_test.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
fp_train.write(f'{str(domain)}-{str(slot)}\t')
fp_dev.write(f'{str(domain)}-{str(slot)}\t')
fp_test.write(f'{str(domain)}-{str(slot)}\t')
fp_train.write('\n')
fp_dev.write('\n')
fp_test.write('\n')
# fp_data = open(os.path.join(SELF_DATA_DIR, "data.json"), "r")
# data = json.load(fp_data)
file_split = ['train', 'val', 'test']
fp = [fp_train, fp_dev, fp_test]
for split_type, split_fp in zip(file_split, fp):
zipfile_name = "{}.json.zip".format(split_type)
zip_fp = zipfile.ZipFile(os.path.join(data_dir, zipfile_name))
data = json.loads(str(zip_fp.read(zip_fp.namelist()[0]), 'utf-8'))
for file_id in data:
user_utterance = ''
system_response = ''
turn_idx = 0
messages = data[file_id]['messages']
for idx, turn in enumerate(messages):
if idx % 2 == 0: # user turn
user_utterance = turn['content']
else: # system turn
user_utterance = user_utterance.replace('\t', ' ')
user_utterance = user_utterance.replace('\n', ' ')
user_utterance = user_utterance.replace(' ', ' ')
system_response = system_response.replace('\t', ' ')
system_response = system_response.replace('\n', ' ')
system_response = system_response.replace(' ', ' ')
split_fp.write(str(file_id)) # 0: dialogue ID
split_fp.write('\t' + str(turn_idx)) # 1: turn index
split_fp.write('\t' + str(user_utterance)) # 2: user utterance
split_fp.write('\t' + str(system_response)) # 3: system response
# hardcode the value of facilities as 'yes' and 'no'
belief = {f'Hotel-Hotel Facilities - {str(facility)}': null for facility in facilities}
sys_state_init = turn['sys_state_init']
for domain, slots in sys_state_init.items():
for slot, value in slots.items():
# skip selected results
if isinstance(value, list):
continue
if domain not in ontology:
print("domain (%s) is not defined" % domain)
continue
if slot == 'Hotel Facilities':
for facility in value.split(','):
belief[f'{str(domain)}-Hotel Facilities - {str(facility)}'] = 'yes'
else:
if slot not in ontology[domain]:
print("slot (%s) in domain (%s) is not defined" % (slot, domain)) # bus-arriveBy not defined
continue
value = trans_value(value).lower()
if value not in ontology[domain][slot] and value != null:
print("%s: value (%s) in domain (%s) slot (%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = null
belief[f'{str(domain)}-{str(slot)}'] = value
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
key = str(domain) + '-' + str(slot)
if key in belief:
val = belief[key]
split_fp.write('\t' + val)
else:
split_fp.write(f'\t{null}')
split_fp.write('\n')
split_fp.flush()
system_response = turn['content']
turn_idx += 1
print('data has been processed!')
This diff is collapsed.
import os
import convlab2
class DotMap():
def __init__(self):
self.max_label_length = 35
self.num_rnn_layers = 1
self.zero_init_rnn = False
self.attn_head = 4
self.do_eval = True
self.do_train = False
self.train_batch_size = 3
self.dev_batch_size = 1
self.eval_batch_size = 16
self.learning_rate = 5e-5
self.warmup_proportion = 0.1
self.local_rank = -1
self.seed = 42
self.gradient_accumulation_steps = 1
self.fp16 = False
self.loss_scale = 0
self.do_not_use_tensorboard = False
self.fix_utterance_encoder = False
self.do_eval = True
self.num_train_epochs = 300
self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-base-uncased")
self.bert_model_cache_dir = os.path.join(convlab2.get_root_path(), "pre-trained-models/")
self.bert_model_name = "bert-base-uncased"
self.do_lower_case = True
self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn'
self.target_slot = 'all'
self.distance_metric = 'euclidean'
self.patience = 15
self.hidden_dim = 300
self.max_seq_length = 35
self.max_turn_length = 23
self.fp16_loss_scale = 0.0
self.data_dir = 'data/crosswoz_en/'
self.tf_dir = 'tensorboard'
self.tmp_data_dir = 'processed_data/'
self.output_dir = 'model_output/'
args = DotMap()
\ No newline at end of file
import csv
import os
import json
import collections
import logging
import re
import torch
from convlab2.dst.sumbt.crosswoz_en.convert_to_glue_format import null
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if len(line) > 0 and line[0][0] == '#': # ignore comments (starting with '#')
continue
lines.append(line)
return lines
class Processor(DataProcessor):
"""Processor for the belief tracking dataset (GLUE version)."""
def __init__(self, config):
super(Processor, self).__init__()
# crosswoz dataset
with open(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), config.data_dir, "ontology.json"), "r") as fp_ontology:
ontology = json.load(fp_ontology)
for slot in ontology.keys():
ontology[slot].append(null)
assert config.target_slot == 'all'
# if not config.target_slot == 'all':
# slot_idx = {'attraction': '0:1:2', 'bus': '3:4:5:6', 'hospital': '7',
# 'hotel': '8:9:10:11:12:13:14:15:16:17', \
# 'restaurant': '18:19:20:21:22:23:24', 'taxi': '25:26:27:28', 'train': '29:30:31:32:33:34'}
# target_slot = []
# for key, value in slot_idx.items():
# if key != config.target_slot:
# target_slot.append(value)
# config.target_slot = ':'.join(target_slot)
# sorting the ontology according to the alphabetic order of the slots
ontology = collections.OrderedDict(sorted(ontology.items()))
# select slots to train
nslots = len(ontology.keys())
target_slot = list(ontology.keys())
if config.target_slot == 'all':
self.target_slot_idx = [*range(0, nslots)]
else:
self.target_slot_idx = sorted([int(x) for x in config.target_slot.split(':')])
for idx in range(0, nslots):
if not idx in self.target_slot_idx:
del ontology[target_slot[idx]]
self.ontology = ontology
self.target_slot = list(self.ontology.keys())
# for i, slot in enumerate(self.target_slot):
# if slot == "pricerange":
# self.target_slot[i] = "price range"
logger.info('Processor: target_slot')
logger.info(self.target_slot)
def get_train_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", accumulation)
def get_dev_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", accumulation)
def get_test_examples(self, data_dir, accumulation=False):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test", accumulation)
def get_labels(self):
"""See base class."""
return [list(map(str.lower, self.ontology[slot])) for slot in self.target_slot]
def _create_examples(self, lines, set_type, accumulation=False):
"""Creates examples for the training and dev sets."""
prev_dialogue_index = None
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s-%s" % (set_type, line[0], line[1]) # line[0]: dialogue index, line[1]: turn index
if accumulation:
if prev_dialogue_index is None or prev_dialogue_index != line[0]:
text_a = line[2]
text_b = line[3]
prev_dialogue_index = line[0]
else:
# The symbol '#' will be replaced with '[SEP]' after tokenization.
text_a = line[2] + " # " + text_a
text_b = line[3] + " # " + text_b
else:
text_a = line[2] # line[2]: user utterance
text_b = line[3] # line[3]: system response
label = [line[4 + idx] for idx in self.target_slot_idx]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def normalize_text(text):
global replacements
# lower case every word
text = text.lower()
# replace white spaces in front and end
text = re.sub(r'^\s*|\s*$', '', text)
# hotel domain pfb30
text = re.sub(r"b&b", "bed and breakfast", text)
text = re.sub(r"b and b", "bed and breakfast", text)
# replace st.
text = text.replace(';', ',')
text = re.sub('$\/', '', text)
text = text.replace('/', ' and ')
# replace other special characters
text = text.replace('-', ' ')
text = re.sub('[\"\<>@\(\)]', '', text) # remove
# insert white space before and after tokens:
for token in ['?', '.', ',', '!']:
text = insertSpace(token, text)
# insert white space for 's
text = insertSpace('\'s', text)
# replace it's, does't, you'd ... etc
text = re.sub('^\'', '', text)
text = re.sub('\'$', '', text)
text = re.sub('\'\s', ' ', text)
text = re.sub('\s\'', ' ', text)
for fromx, tox in replacements:
text = ' ' + text + ' '
text = text.replace(fromx, tox)[1:-1]
# remove multiple spaces
text = re.sub(' +', ' ', text)
# concatenate numbers
tmp = text
tokens = text.split()
i = 1
while i < len(tokens):
if re.match(u'^\d+$', tokens[i]) and \
re.match(u'\d+$', tokens[i - 1]):
tokens[i - 1] += tokens[i]
del tokens[i]
else:
i += 1
text = ' '.join(tokens)
return text
def insertSpace(token, text):
sidx = 0
while True:
sidx = text.find(token, sidx)
if sidx == -1:
break
if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
re.match('[0-9]', text[sidx + 1]):
sidx += 1
continue
if text[sidx - 1] != ' ':
text = text[:sidx] + ' ' + text[sidx:]
sidx += 1
if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
text = text[:sidx + 1] + ' ' + text[sidx + 1:]
sidx += 1
return text
# convert tokens in labels to the identifier in vocabulary
def get_label_embedding(labels, max_seq_length, tokenizer, device):
features = []
for label in labels:
label_tokens = ["[CLS]"] + tokenizer.tokenize(label) + ["[SEP]"]
# just truncate, some names are unreasonable long
label_token_ids = tokenizer.convert_tokens_to_ids(label_tokens)[:max_seq_length]
label_len = len(label_token_ids)
label_padding = [0] * (max_seq_length - len(label_token_ids))
label_token_ids += label_padding
assert len(label_token_ids) == max_seq_length
features.append((label_token_ids, label_len))
all_label_token_ids = torch.tensor([f[0] for f in features], dtype=torch.long).to(device)
all_label_len = torch.tensor([f[1] for f in features], dtype=torch.long).to(device)
return all_label_token_ids, all_label_len
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x / warmup
return 1.0 - x
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_len, label_id):
self.input_ids = input_ids
self.input_len = input_len
self.label_id = label_id
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, max_turn_length):
"""Loads a data file into a list of `InputBatch`s."""
label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
slot_dim = len(label_list)
features = []
prev_dialogue_idx = None
all_padding = [0] * max_seq_length
all_padding_len = [0, 0]
max_turn = 0
for (ex_index, example) in enumerate(examples):
if max_turn < int(example.guid.split('-')[2]):
max_turn = int(example.guid.split('-')[2])
max_turn_length = min(max_turn + 1, max_turn_length)
logger.info("max_turn_length = %d" % max_turn)
for (ex_index, example) in enumerate(examples):
tokens_a = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_a)]
tokens_b = None
if example.text_b:
tokens_b = [x if x != '#' else '[SEP]' for x in tokenizer.tokenize(example.text_b)]
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
input_len = [len(tokens), 0]
if tokens_b:
tokens += tokens_b + ["[SEP]"]
input_len[1] = len(tokens_b) + 1
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
assert len(input_ids) == max_seq_length
FLAG_TEST = False
if example.label is not None:
label_id = []
label_info = 'label: '
for i, label in enumerate(example.label):
if label == 'dontcare':
label = 'do not care'
label_id.append(label_map[i][label])
label_info += '%s (id = %d) ' % (label, label_map[i][label])
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_len: %s" % " ".join([str(x) for x in input_len]))
logger.info("label: " + label_info)
else:
FLAG_TEST = True
label_id = None
curr_dialogue_idx = example.guid.split('-')[1]
curr_turn_idx = int(example.guid.split('-')[2])
if prev_dialogue_idx is not None and prev_dialogue_idx != curr_dialogue_idx:
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
# print(len(features), max_turn_length)
assert len(features) % max_turn_length == 0
if prev_dialogue_idx is None or prev_turn_idx < max_turn_length:
features.append(
InputFeatures(input_ids=input_ids,
input_len=input_len,
label_id=label_id))
prev_dialogue_idx = curr_dialogue_idx
prev_turn_idx = curr_turn_idx
if prev_turn_idx < max_turn_length:
features += [InputFeatures(input_ids=all_padding,
input_len=all_padding_len,
label_id=[-1] * slot_dim)] \
* (max_turn_length - prev_turn_idx - 1)
assert len(features) % max_turn_length == 0
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_len = torch.tensor([f.input_len for f in features], dtype=torch.long)
if not FLAG_TEST:
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
# reshape tensors to [#batch, #max_turn_length, #max_seq_length]
all_input_ids = all_input_ids.view(-1, max_turn_length, max_seq_length)
all_input_len = all_input_len.view(-1, max_turn_length, 2)
if not FLAG_TEST:
all_label_ids = all_label_ids.view(-1, max_turn_length, slot_dim)
else:
all_label_ids = None
return all_input_ids, all_input_len, all_label_ids
def eval_all_accs(pred_slot, labels, accuracies):
def _eval_acc(_pred_slot, _labels):
slot_dim = _labels.size(-1)
accuracy = (_pred_slot == _labels).view(-1, slot_dim)
num_turn = torch.sum(_labels[:, :, 0].view(-1) > -1, 0).float()
num_data = torch.sum(_labels > -1).float()
# joint accuracy
joint_acc = sum(torch.sum(accuracy, 1) / slot_dim).float()
# slot accuracy
slot_acc = torch.sum(accuracy).float()
return joint_acc, slot_acc, num_turn, num_data
# 7 domains
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot, labels)
accuracies['joint7'] += joint_acc
accuracies['slot7'] += slot_acc
accuracies['num_turn'] += num_turn
accuracies['num_slot7'] += num_data
# restaurant domain
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot[:,:,18:25], labels[:,:,18:25])
accuracies['joint_rest'] += joint_acc
accuracies['slot_rest'] += slot_acc
accuracies['num_slot_rest'] += num_data
pred_slot5 = torch.cat((pred_slot[:,:,0:3], pred_slot[:,:,8:]), 2)
label_slot5 = torch.cat((labels[:,:,0:3], labels[:,:,8:]), 2)
# 5 domains (excluding bus and hotel domain)
joint_acc, slot_acc, num_turn, num_data = _eval_acc(pred_slot5, label_slot5)
accuracies['joint5'] += joint_acc
accuracies['slot5'] += slot_acc
accuracies['num_slot5'] += num_data
return accuracies
......@@ -5,15 +5,15 @@ from itertools import chain
import numpy as np
import zipfile
# from tensorboardX.writer import SummaryWriter
from tensorboardX.writer import SummaryWriter
from tqdm._tqdm import trange, tqdm
from convlab2.util.file_util import cached_path
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam
from transformers import BertTokenizer
from transformers import get_linear_schedule_with_warmup, AdamW
from convlab2.dst.dst import DST
from convlab2.dst.sumbt.multiwoz.convert_to_glue_format import convert_to_glue_format
......@@ -94,10 +94,7 @@ class SUMBTTracker(DST):
num_labels = [len(labels) for labels in label_list] # number of slot-values in each slot-type
# tokenizer
# vocab_dir = os.path.join(data_dir, 'model', '%s-vocab.txt' % args.bert_model)
# if not os.path.exists(vocab_dir):
# raise ValueError("Can't find %s " % vocab_dir)
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model)
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_name, cache_dir=args.bert_model_cache_dir)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
......@@ -402,6 +399,7 @@ class SUMBTTracker(DST):
t_total = num_train_steps
scheduler = None
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
......@@ -420,10 +418,8 @@ class SUMBTTracker(DST):
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.fp16_loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=t_total)
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_proportion*t_total, num_training_steps=t_total)
logger.info(optimizer)
# Training code
......@@ -492,7 +488,11 @@ class SUMBTTracker(DST):
summary_writer.add_scalar("Train/LearningRate", lr_this_step, global_step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
if scheduler is not None:
torch.nn.utils.clip_grad_norm_(optimizer_grouped_parameters, 1.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
global_step += 1
......
import os
import convlab2
class DotMap():
def __init__(self):
self.max_label_length = 32
......@@ -27,8 +29,9 @@ class DotMap():
self.do_eval = True
self.num_train_epochs = 300
self.bert_model = 'bert-base-uncased'
self.bert_model = os.path.join(convlab2.get_root_path(), "pre-trained-models/bert-base-uncased")
self.bert_model_cache_dir = os.path.join(convlab2.get_root_path(), "pre-trained-models/")
self.bert_model_name = "bert-base-uncased"
self.do_lower_case = True
self.task_name = 'bert-gru-sumbt'
self.nbt = 'rnn'
......
pre-trained/
model_output/
from convlab2.dst.sumbt.multiwoz_zh.sumbt import SUMBTTracker as SUMBT
import json
import zipfile
from convlab2.dst.sumbt.multiwoz_zh.sumbt_config import *
def trans_value(value):
trans = {
'': '未提及',
'没有提到': '未提及',
'没有': '未提及',
'未提到': '未提及',
'一个也没有': '未提及',
'': '未提及',
'是的': '',
'不是': '没有',
'不关心': '不在意',
'不在乎': '不在意',
}
return trans.get(value, value)
def convert_to_glue_format(data_dir, sumbt_dir):
if not os.path.isdir(os.path.join(sumbt_dir, args.tmp_data_dir)):
os.mkdir(os.path.join(sumbt_dir, args.tmp_data_dir))
### Read ontology file
with open(os.path.join(data_dir, "ontology.json"), "r") as fp_ont:
data_ont = json.load(fp_ont)
ontology = {}
for domain_slot in data_ont:
domain, slot = domain_slot.split('-')
if domain not in ontology:
ontology[domain] = {}
ontology[domain][slot] = {}
for value in data_ont[domain_slot]:
ontology[domain][slot][value] = 1
### Read woz logs and write to tsv files
if os.path.exists(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv")):
print('data has been processed!')
return 0
print('begin processing data')
fp_train = open(os.path.join(sumbt_dir, args.tmp_data_dir, "train.tsv"), "w")
fp_dev = open(os.path.join(sumbt_dir, args.tmp_data_dir, "dev.tsv"), "w")
fp_test = open(os.path.join(sumbt_dir, args.tmp_data_dir, "test.tsv"), "w")
fp_train.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_dev.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
fp_test.write('# Dialogue ID\tTurn Index\tUser Utterance\tSystem Response\t')
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
fp_train.write(str(domain) + '-' + str(slot) + '\t')
fp_dev.write(str(domain) + '-' + str(slot) + '\t')
fp_test.write(str(domain) + '-' + str(slot) + '\t')
fp_train.write('\n')
fp_dev.write('\n')
fp_test.write('\n')
# fp_data = open(os.path.join(SELF_DATA_DIR, "data.json"), "r")
# data = json.load(fp_data)
file_split = ['train', 'val', 'test']
fp = [fp_train, fp_dev, fp_test]
for split_type, split_fp in zip(file_split, fp):
zipfile_name = "{}.json.zip".format(split_type)
zip_fp = zipfile.ZipFile(os.path.join(data_dir, zipfile_name))
data = json.loads(str(zip_fp.read(zip_fp.namelist()[0]), 'utf-8'))
for file_id in data:
user_utterance = ''
system_response = ''
turn_idx = 0
for idx, turn in enumerate(data[file_id]['log']):
if idx % 2 == 0: # user turn
user_utterance = data[file_id]['log'][idx]['text']
else: # system turn
user_utterance = user_utterance.replace('\t', ' ')
user_utterance = user_utterance.replace('\n', ' ')
user_utterance = user_utterance.replace(' ', ' ')
system_response = system_response.replace('\t', ' ')
system_response = system_response.replace('\n', ' ')
system_response = system_response.replace(' ', ' ')
split_fp.write(str(file_id)) # 0: dialogue ID
split_fp.write('\t' + str(turn_idx)) # 1: turn index
split_fp.write('\t' + str(user_utterance)) # 2: user utterance
split_fp.write('\t' + str(system_response)) # 3: system response
belief = {}
for domain in data[file_id]['log'][idx]['metadata'].keys():
for slot in data[file_id]['log'][idx]['metadata'][domain]['semi'].keys():
value = data[file_id]['log'][idx]['metadata'][domain]['semi'][slot].strip()
# value = value_trans.get(value, value)
value = trans_value(value)
if domain not in ontology:
print("domain (%s) is not defined" % domain)
continue
if slot not in ontology[domain]:
print("slot (%s) in domain (%s) is not defined" % (slot, domain)) # bus-arriveBy not defined
continue
if value not in ontology[domain][slot] and value != '未提及':
print("%s: value (%s) in domain (%s) slot (%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = '未提及'
belief[str(domain) + '-' + str(slot)] = value
for slot in data[file_id]['log'][idx]['metadata'][domain]['book'].keys():
if slot == 'booked':
continue
if domain == '公共汽车' and slot == '人数' or domain == '列车' and slot == '票价':
continue # not defined in ontology
value = data[file_id]['log'][idx]['metadata'][domain]['book'][slot].strip()
value = trans_value(value)
if str('预订' + slot) not in ontology[domain]:
print("预订%s is not defined in domain %s" % (slot, domain))
continue
if value not in ontology[domain]['预订' + slot] and value != '未提及':
print("%s: value (%s) in domain (%s) slot (预订%s) is not defined in ontology" %
(file_id, value, domain, slot))
value = '未提及'
belief[str(domain) + '-预订' + str(slot)] = value
for domain in sorted(ontology.keys()):
for slot in sorted(ontology[domain].keys()):
key = str(domain) + '-' + str(slot)
if key in belief:
split_fp.write('\t' + belief[key])
else:
split_fp.write('\t未提及')
split_fp.write('\n')
split_fp.flush()
system_response = data[file_id]['log'][idx]['text']
turn_idx += 1
fp_train.close()
fp_dev.close()
fp_test.close()
print('data has been processed!')
\ No newline at end of file
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment