Skip to content
Snippets Groups Projects
Unverified Commit 67529531 authored by Carrey Wang's avatar Carrey Wang Committed by GitHub
Browse files

Add CrossWoz Web support and some minor bug fix (#19)


* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: default avatarkflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: default avatarCarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
parent 9b7c0cff
Branches
Tags
No related merge requests found
......@@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/
convlab2/dst/sumbt/multiwoz/output/
convlab2/nlg/sclstm/**/generated_sens_sys.json
convlab2/nlg/template/**/generated_sens_sys.json
convlab2/nlu/jointBERT/crosswoz/**/data
# test script
*_test.py
......
......@@ -17,7 +17,8 @@ class BERTNLU(NLU):
assert mode == 'usr' or mode == 'sys' or mode == 'all'
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file))
config = json.load(open(config_file))
DEVICE = config['DEVICE']
# DEVICE = config['DEVICE']
DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
data_dir = os.path.join(root_dir, config['data_dir'])
output_dir = os.path.join(root_dir, config['output_dir'])
......
......@@ -41,4 +41,4 @@ class MLE(MLEAbstract):
if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
archive = zipfile.ZipFile(archive_file, 'r')
archive.extractall(model_dir)
self.load(archive_file, model_file, cfg['load'])
self.load_from_pretrained(archive_file, model_file, cfg['load'])
......@@ -25,7 +25,7 @@ class MLEAbstract(Policy):
"""
s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu()
action = self.vector.action_devectorize(a.numpy())
action = self.vector.action_devectorize(a.detach().numpy())
state['system_action'] = action
return action
......
......@@ -126,7 +126,6 @@ class PG(Policy):
# backprop
surrogate.backward()
for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability
......
......@@ -27,6 +27,19 @@
"preload": false,
"enable": true
},
"bert-cro": {
"class_path": "convlab2.nlu.jointBERT.crosswoz.nlu.BERTNLU",
"data_set": "crosswoz",
"ini_params": {
"mode": "all",
"config_file": "crosswoz_all.json",
"model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_crosswoz_all.zip"
},
"model_name": "bert-cro",
"max_core": 1,
"preload": false,
"enable": true
},
"bert-mul": {
"class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU",
"data_set": "multiwoz",
......@@ -60,6 +73,15 @@
"preload": true,
"enable": true
},
"rule-cro": {
"class_path": "convlab2.dst.rule.crosswoz.dst.RuleDST",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "rule-cro",
"max_core": 1,
"preload": true,
"enable": true
},
"trade-mul": {
"class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE",
"data_set": "multiwoz",
......@@ -106,6 +128,15 @@
"max_core": 1,
"preload": true,
"enable": true
},
"mle-cro": {
"class_path": "convlab2.policy.mle.crosswoz.mle.MLE",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "mle-cro",
"max_core": 1,
"preload": false,
"enable": true
}
},
"nlg": {
......@@ -143,6 +174,18 @@
"max_core": 1,
"preload": true,
"enable": true
},
"tmp-auto_manual-cro": {
"class_path": "convlab2.nlg.template.crosswoz.nlg.TemplateNLG",
"data_set": "crosswoz",
"ini_params": {
"is_user": false,
"mode": "auto_manual"
},
"model_name": "tmp-auto_manual-cro",
"max_core": 1,
"preload": true,
"enable": true
}
}
}
\ No newline at end of file
......@@ -279,7 +279,7 @@
data: {
dataset: 'MultiWoz',
dataset_short: 'mul',
dataset_list: ['MultiWoz'],
dataset_list: ['MultiWoz', 'CrossWoz'],
nlu: 'BERTNLU',
nlu_list: [],
nlu_output: {},
......@@ -317,6 +317,8 @@
dataset: function() {
if (this.dataset === 'MultiWoz') {
this.dataset_short = 'mul'
} else if (this.dataset == 'CrossWoz') {
this.dataset_short = 'cro'
} else {
this.dataset_short = 'cam'
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment