Skip to content
Snippets Groups Projects
Unverified Commit 2746ba72 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

remove tensorflow dependency and mdbt (#164)

parent 1e3b2cfd
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,6 @@ RUN pip install scipy
RUN pip install scikit-learn==0.20.3
RUN pip install pytorch-pretrained-bert==0.6.1
RUN pip install transformers==2.3.0
RUN pip install tensorflow==1.14
RUN pip install tensorboard==1.14.0
RUN pip install tensorboardX==1.7
RUN pip install tokenizers==0.8.0
......
......@@ -46,7 +46,7 @@ Our documents are on https://thu-coai.github.io/ConvLab-2_docs/convlab2.html.
We provide following models:
- NLU: SVMNLU, MILU, BERTNLU
- DST: rule, MDBT, TRADE, SUMBT
- DST: rule, TRADE, SUMBT
- Policy: rule, Imitation, REINFORCE, PPO, GDPL, MDRG, HDSA, LaRL
- Simulator policy: Agenda, VHUS
- NLG: Template, SCLSTM
......
# Multi-domain Belief DST
The multidomain belief tracker (MDBT) is a belief tracking model that
fully utilizes semantic similarity between dialogue utterances and the
ontology terms, which is proposed by [Ramadan et al., 2018](https://www.aclweb.org/anthology/P18-2069).
## Package Structure
We adopted the original code to make it a flexible module which can be
easily imported in a pipeline dialog framework. The dataset-independent
implementation for MDBT is in ```convlab2/dst/mdbt```, and that for Multiwoz
dataset is in ```convlab2/dst/mdbt/multiwoz```.
## Run the Code
The framework will automatically download the pretrained models and data
before running. If the auto-downloading fails, you have to download the pre-trained model and data
from [here](https://drive.google.com/open?id=1k6wbabIlYju7kR0Zr4aVXwE_fsGBOtdw),
and put the ```word-vectors, models``` and ```data``` directories under
```convlab2/dst/mdbt/multiwoz/configs```.
git
## Performance
The performance of our pre-trained MDBT model is 13.9%.
You can train the model by your self for better performance.
import copy
import json
import os
import tensorflow as tf
from convlab2.dst.mdbt.mdbt_util import model_definition, \
track_dialogue, generate_batch, process_history
from convlab2.dst.rule.multiwoz import normalize_value
from convlab2.util.multiwoz.state import default_state
from convlab2.dst.dst import DST
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
from os.path import dirname
train_batch_size = 1
batches_per_eval = 10
no_epochs = 600
device = "gpu"
start_batch = 0
class MDBT(DST):
"""
A multi-domain belief tracker, adopted from https://github.com/osmanio2/multi-domain-belief-tracking.
"""
def __init__(self, ontology_vectors, ontology, slots, data_dir):
DST.__init__(self)
# data profile
self.data_dir = data_dir
self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
self.word_vectors_url = os.path.join(self.data_dir, 'word-vectors/paragram_300_sl999.txt')
self.training_url = os.path.join(self.data_dir, 'data/train.json')
self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
self.testing_url = os.path.join(self.data_dir, 'data/test.json')
self.model_url = os.path.join(self.data_dir, 'models/model-1')
self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
self.kb_url = os.path.join(self.data_dir, 'data/') # not used
self.train_model_url = os.path.join(self.data_dir, 'train_models/model-1')
self.train_graph_url = os.path.join(self.data_dir, 'train_graph/graph-1')
self.model_variables = model_definition(ontology_vectors, len(ontology), slots, num_hidden=None,
bidir=True, net_type=None, test=True, dev='cpu')
self.state = default_state()
_config = tf.ConfigProto()
_config.gpu_options.allow_growth = True
_config.allow_soft_placement = True
self.sess = tf.Session(config=_config)
self.param_restored = False
self.det_dic = {}
for domain, dic in REF_USR_DA.items():
for key, value in dic.items():
assert '-' not in key
self.det_dic[key.lower()] = key + '-' + domain
self.det_dic[value.lower()] = key + '-' + domain
def parent_dir(path, time=1):
for _ in range(time):
path = os.path.dirname(path)
return path
root_dir = parent_dir(os.path.abspath(__file__), 4)
self.value_dict = json.load(open(os.path.join(root_dir, 'data/multiwoz/value_dict.json')))
def init_session(self):
self.state = default_state()
if not self.param_restored:
self.restore()
def restore(self):
self.__restore_model(self.sess, tf.train.Saver())
def update_batch(self, batch_action):
pass
def update(self, user_act=None):
"""Update the dialog state."""
if type(user_act) is not str:
raise Exception('Expected user_act to be <class \'str\'> type, but get {}.'.format(type(user_act)))
prev_state = copy.deepcopy(self.state)
if not os.path.exists(os.path.join(self.data_dir, "results")):
os.makedirs(os.path.join(self.data_dir, "results"))
global train_batch_size
model_variables = self.model_variables
(user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy,
slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions,
true_predictions, [y, _]) = model_variables
# Note: Comment the following line since the first node is already i
# prev_state['history'] = [['sys', 'null']] if len(prev_state['history']) == 0 else prev_state['history']
assert len(prev_state['history']) > 0
first_turn = prev_state['history'][0]
if first_turn[0] != 'sys':
prev_state['history'] = [['sys', '']] + prev_state['history']
actual_history = []
assert len(prev_state['history']) % 2 == 0
for name, utt in prev_state['history']:
if not utt:
utt = 'null'
if len(actual_history)==0 or len(actual_history[-1])==2:
actual_history.append([utt])
else:
actual_history[-1].append(utt)
# actual_history[-1].append(user_act)
# actual_history = self.normalize_history(actual_history)
# if len(actual_history) == 0:
# actual_history = [['', user_act if len(user_act)>0 else 'fake user act']]
fake_dialogue = {}
turn_no = 0
for _sys, _user in actual_history:
turn = {}
turn['system'] = _sys
fake_user = {}
fake_user['text'] = _user
fake_user['belief_state'] = default_state()['belief_state']
turn['user'] = fake_user
key = str(turn_no)
fake_dialogue[key] = turn
turn_no += 1
context, actual_context = process_history([fake_dialogue], self.word_vectors, self.ontology)
batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \
batch_no_turns = generate_batch(context, 0, 1, len(self.ontology)) # old feature
# run model
[pred, y_pred] = self.sess.run(
[predictions, y],
feed_dict={user: batch_user, sys_res: batch_sys,
labels: batch_labels,
domain_labels: batch_domain_labels,
user_uttr_len: batch_user_uttr_len,
sys_uttr_len: batch_sys_uttr_len,
no_turns: batch_no_turns,
keep_prob: 1.0})
# convert to str output
dialgs, _, _ = track_dialogue(actual_context, self.ontology, pred, y_pred)
assert len(dialgs) >= 1
last_turn = dialgs[0][-1]
predictions = last_turn['prediction']
new_belief_state = copy.deepcopy(prev_state['belief_state'])
# update belief state
for item in predictions:
item = item.lower()
domain, slot, value = item.strip().split('-')
value = value[::-1].split(':', 1)[1][::-1]
if slot == 'price range':
slot = 'pricerange'
if slot not in ['name', 'book']:
if domain not in new_belief_state:
raise Exception('Error: domain <{}> not in belief state'.format(domain))
slot = REF_SYS_DA[domain.capitalize( )].get(slot, slot)
assert 'semi' in new_belief_state[domain]
assert 'book' in new_belief_state[domain]
if 'book' in slot:
assert slot.startswith('book ')
slot = slot.strip().split()[1]
if slot == 'arriveby':
slot = 'arriveBy'
elif slot == 'leaveat':
slot = 'leaveAt'
domain_dic = new_belief_state[domain]
if slot in domain_dic['semi']:
new_belief_state[domain]['semi'][slot] = normalize_value(self.value_dict, domain, slot, value)
elif slot in domain_dic['book']:
new_belief_state[domain]['book'][slot] = value
elif slot.lower() in domain_dic['book']:
new_belief_state[domain]['book'][slot.lower()] = value
else:
with open('mdbt_unknown_slot.log', 'a+') as f:
f.write('unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value,
domain, item))
new_request_state = copy.deepcopy(prev_state['request_state'])
# update request_state
user_request_slot = self.detect_requestable_slots(user_act)
for domain in user_request_slot:
for key in user_request_slot[domain]:
if domain not in new_request_state:
new_request_state[domain] = {}
if key not in new_request_state[domain]:
new_request_state[domain][key] = user_request_slot[domain][key]
# update state
new_state = copy.deepcopy(dict(prev_state))
new_state['belief_state'] = new_belief_state
new_state['request_state'] = new_request_state
self.state = new_state
return self.state
def normalize_history(self, history):
"""Replace zero-length history."""
for i in range(len(history)):
a, b = history[i]
if len(a) == 0:
history[i][0] = 'sys'
if len(b) == 0:
history[i][1] = 'user'
return history
def detect_requestable_slots(self, observation):
result = {}
observation = observation.lower()
_observation = ' {} '.format(observation)
for value in self.det_dic.keys():
_value = ' {} '.format(value.strip())
if _value in _observation:
key, domain = self.det_dic[value].split('-')
if domain not in result:
result[domain] = {}
result[domain][key] = 0
return result
def __restore_model(self, sess, saver):
saver.restore(sess, self.model_url)
print('Loading trained MDBT model from ', self.model_url)
self.param_restored = True
This diff is collapsed.
from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT as MDBT
import json
import os
import time
import tensorflow as tf
import shutil
import zipfile
from convlab2.dst.mdbt.mdbt import MDBT
from convlab2.dst.mdbt.mdbt_util import load_word_vectors, load_ontology, load_woz_data_new
from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
from convlab2.util.file_util import cached_path
from pprint import pprint
train_batch_size = 1
batches_per_eval = 10
no_epochs = 600
device = "gpu"
start_batch = 0
class MultiWozMDBT(MDBT):
def __init__(self, data_dir='configs', data=None):
"""Constructor of MultiWOzMDBT class.
Args:
data_dir (str): The path of data dir, where the root path is convlab2/dst/mdbt/multiwoz.
"""
if data is None:
loader = AgentDSTDataloader(MultiWOZDataloader())
data = loader.load_data()
self.file_url = 'https://convlab.blob.core.windows.net/convlab-2/mdbt_multiwoz_sys.zip'
local_path = os.path.dirname(os.path.abspath(__file__))
self.data_dir = os.path.join(local_path, data_dir) # abstract data path
self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
self.training_url = os.path.join(self.data_dir, 'data/train.json')
self.testing_url = os.path.join(self.data_dir, 'data/test.json')
self.word_vectors_url = os.path.join(self.data_dir, 'word-vectors/paragram_300_sl999.txt')
self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
self.model_url = os.path.join(self.data_dir, 'models/model-1')
self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
self.kb_url = os.path.join(self.data_dir, 'data/') # not used
self.train_model_url = os.path.join(self.data_dir, 'train_models/model-1')
self.train_graph_url = os.path.join(self.data_dir, 'train_graph/graph-1')
self.auto_download()
print('Configuring MDBT model...')
self.word_vectors = load_word_vectors(self.word_vectors_url)
# Load the ontology and extract the feature vectors
self.ontology, self.ontology_vectors, self.slots = load_ontology(self.ontology_url, self.word_vectors)
# Load and process the training data
self.test_dialogues, self.actual_dialogues = load_woz_data_new(data['test'], self.word_vectors,
self.ontology, url=self.testing_url)
self.no_dialogues = len(self.test_dialogues)
super(MultiWozMDBT, self).__init__(self.ontology_vectors, self.ontology, self.slots, self.data_dir)
def auto_download(self):
"""Automatically download the pretrained model and necessary data."""
if os.path.exists(os.path.join(self.data_dir, 'models')) and \
os.path.exists(os.path.join(self.data_dir, 'data')) and \
os.path.exists(os.path.join(self.data_dir, 'word-vectors')):
return
cached_path(self.file_url, self.data_dir)
files = os.listdir(self.data_dir)
target_file = ''
for name in files:
if name.endswith('.json'):
target_file = name[:-5]
try:
assert target_file in files
except Exception as e:
print('allennlp download file error: MDBT Multiwoz data download failed.')
raise e
zip_file_path = os.path.join(self.data_dir, target_file+'.zip')
shutil.copyfile(os.path.join(self.data_dir, target_file), zip_file_path)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(self.data_dir)
def test_update():
# lower case, tokenized.
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tracker = MultiWozMDBT()
tracker.init_session()
# original usage in Convlab
# tracker.state['history'] = [
# ["null", "am looking for a place to to stay that has cheap price range it should be in a type of hotel"],
# ["Okay, do you have a specific area you want to stay in?", "no, i just need to make sure it's cheap. oh, and i need parking"],
# ["I found 1 cheap hotel for you that includes parking. Do you like me to book it?", "Yes, please. 6 people 3 nights starting on tuesday."],
# ["I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay?", "how about only 2 nights."],
# ["Booking was successful.\nReference number is : 7GAWK763. Anything else I can do for you?"]
# ]
# current usage in Convlab2
tracker.state['history'] = [
['sys', ''],
['user', 'Could you book a 4 stars hotel for one night, 1 person?'],
['sys', 'If you\'d like something cheap, I recommend the Allenbell']
]
tracker.state['history'].append(['user', 'Friday and Can you book it for me and get a reference number ?'])
user_utt = 'Friday and Can you book it for me and get a reference number ?'
from timeit import default_timer as timer
start = timer()
pprint(tracker.update(user_utt))
end = timer()
print(end - start)
start = timer()
tracker.update(user_utt)
end = timer()
print(end - start)
start = timer()
tracker.update(user_utt)
end = timer()
print(end - start)
if __name__ == '__main__':
test_update()
......@@ -44,21 +44,21 @@ setup(
'scikit-learn==0.20.3',
'pytorch-pretrained-bert>=0.6.1',
'transformers>=2.3.0,<3.0.0',
'tensorflow==1.14',
'tensorboard>=1.14.0',
'tensorboardX==1.7',
'tokenizers>=0.8.0',
'allennlp==0.9.0',
'requests',
'simplejson',
'spacy',
'spacy==2.1.9',
'unidecode',
'jieba',
'embeddings',
'quadprog',
'pyyaml',
'fuzzywuzzy',
'python-Levenshtein'
'python-Levenshtein',
'json_lines'
],
extras_require={
'develop': [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment