Skip to content
Snippets Groups Projects
Unverified Commit 67529531 authored by Carrey Wang's avatar Carrey Wang Committed by GitHub
Browse files

Add CrossWoz Web support and some minor bug fix (#19)


* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: default avatarkflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: default avatarCarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
parent 9b7c0cff
No related branches found
No related tags found
No related merge requests found
...@@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/ ...@@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/
convlab2/dst/sumbt/multiwoz/output/ convlab2/dst/sumbt/multiwoz/output/
convlab2/nlg/sclstm/**/generated_sens_sys.json convlab2/nlg/sclstm/**/generated_sens_sys.json
convlab2/nlg/template/**/generated_sens_sys.json convlab2/nlg/template/**/generated_sens_sys.json
convlab2/nlu/jointBERT/crosswoz/**/data
# test script # test script
*_test.py *_test.py
......
...@@ -17,7 +17,8 @@ class BERTNLU(NLU): ...@@ -17,7 +17,8 @@ class BERTNLU(NLU):
assert mode == 'usr' or mode == 'sys' or mode == 'all' assert mode == 'usr' or mode == 'sys' or mode == 'all'
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file)) config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file))
config = json.load(open(config_file)) config = json.load(open(config_file))
DEVICE = config['DEVICE'] # DEVICE = config['DEVICE']
DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
data_dir = os.path.join(root_dir, config['data_dir']) data_dir = os.path.join(root_dir, config['data_dir'])
output_dir = os.path.join(root_dir, config['output_dir']) output_dir = os.path.join(root_dir, config['output_dir'])
......
...@@ -41,4 +41,4 @@ class MLE(MLEAbstract): ...@@ -41,4 +41,4 @@ class MLE(MLEAbstract):
if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')): if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
archive = zipfile.ZipFile(archive_file, 'r') archive = zipfile.ZipFile(archive_file, 'r')
archive.extractall(model_dir) archive.extractall(model_dir)
self.load(archive_file, model_file, cfg['load']) self.load_from_pretrained(archive_file, model_file, cfg['load'])
...@@ -25,7 +25,7 @@ class MLEAbstract(Policy): ...@@ -25,7 +25,7 @@ class MLEAbstract(Policy):
""" """
s_vec = torch.Tensor(self.vector.state_vectorize(state)) s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu() a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu()
action = self.vector.action_devectorize(a.numpy()) action = self.vector.action_devectorize(a.detach().numpy())
state['system_action'] = action state['system_action'] = action
return action return action
......
...@@ -126,7 +126,6 @@ class PG(Policy): ...@@ -126,7 +126,6 @@ class PG(Policy):
# backprop # backprop
surrogate.backward() surrogate.backward()
for p in self.policy.parameters(): for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0 p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability # gradient clipping, for stability
......
...@@ -27,6 +27,19 @@ ...@@ -27,6 +27,19 @@
"preload": false, "preload": false,
"enable": true "enable": true
}, },
"bert-cro": {
"class_path": "convlab2.nlu.jointBERT.crosswoz.nlu.BERTNLU",
"data_set": "crosswoz",
"ini_params": {
"mode": "all",
"config_file": "crosswoz_all.json",
"model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_crosswoz_all.zip"
},
"model_name": "bert-cro",
"max_core": 1,
"preload": false,
"enable": true
},
"bert-mul": { "bert-mul": {
"class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU", "class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU",
"data_set": "multiwoz", "data_set": "multiwoz",
...@@ -60,6 +73,15 @@ ...@@ -60,6 +73,15 @@
"preload": true, "preload": true,
"enable": true "enable": true
}, },
"rule-cro": {
"class_path": "convlab2.dst.rule.crosswoz.dst.RuleDST",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "rule-cro",
"max_core": 1,
"preload": true,
"enable": true
},
"trade-mul": { "trade-mul": {
"class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE", "class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE",
"data_set": "multiwoz", "data_set": "multiwoz",
...@@ -106,6 +128,15 @@ ...@@ -106,6 +128,15 @@
"max_core": 1, "max_core": 1,
"preload": true, "preload": true,
"enable": true "enable": true
},
"mle-cro": {
"class_path": "convlab2.policy.mle.crosswoz.mle.MLE",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "mle-cro",
"max_core": 1,
"preload": false,
"enable": true
} }
}, },
"nlg": { "nlg": {
...@@ -143,6 +174,18 @@ ...@@ -143,6 +174,18 @@
"max_core": 1, "max_core": 1,
"preload": true, "preload": true,
"enable": true "enable": true
},
"tmp-auto_manual-cro": {
"class_path": "convlab2.nlg.template.crosswoz.nlg.TemplateNLG",
"data_set": "crosswoz",
"ini_params": {
"is_user": false,
"mode": "auto_manual"
},
"model_name": "tmp-auto_manual-cro",
"max_core": 1,
"preload": true,
"enable": true
} }
} }
} }
\ No newline at end of file
...@@ -279,7 +279,7 @@ ...@@ -279,7 +279,7 @@
data: { data: {
dataset: 'MultiWoz', dataset: 'MultiWoz',
dataset_short: 'mul', dataset_short: 'mul',
dataset_list: ['MultiWoz'], dataset_list: ['MultiWoz', 'CrossWoz'],
nlu: 'BERTNLU', nlu: 'BERTNLU',
nlu_list: [], nlu_list: [],
nlu_output: {}, nlu_output: {},
...@@ -317,6 +317,8 @@ ...@@ -317,6 +317,8 @@
dataset: function() { dataset: function() {
if (this.dataset === 'MultiWoz') { if (this.dataset === 'MultiWoz') {
this.dataset_short = 'mul' this.dataset_short = 'mul'
} else if (this.dataset == 'CrossWoz') {
this.dataset_short = 'cro'
} else { } else {
this.dataset_short = 'cam' this.dataset_short = 'cam'
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment