From 67529531b7bf0084e7e3c6930577085a67402ed9 Mon Sep 17 00:00:00 2001 From: Carrey Wang <cwhongru@cuc.edu.cn> Date: Mon, 15 Jun 2020 11:06:28 +0800 Subject: [PATCH] Add CrossWoz Web support and some minor bug fix (#19) * Initial commit * first commit * add build * add build * add build * add recommend * add crosswoz config in deploy * add crosswoz at html * debug chinese vision * fix system bug according to convlab2 * master change * modify .gitignore * delete svm_camrest_usr.pickle Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local> Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local> --- .gitignore | 1 + convlab2/nlu/jointBERT/crosswoz/nlu.py | 3 +- convlab2/policy/mle/crosswoz/mle.py | 2 +- convlab2/policy/mle/mle.py | 2 +- convlab2/policy/pg/pg.py | 1 - deploy/dep_config.json | 43 ++++++++++++++++++++++++++ deploy/templates/dialog.html | 4 ++- 7 files changed, 51 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index bdb4d68..2823cc3 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/ convlab2/dst/sumbt/multiwoz/output/ convlab2/nlg/sclstm/**/generated_sens_sys.json convlab2/nlg/template/**/generated_sens_sys.json +convlab2/nlu/jointBERT/crosswoz/**/data # test script *_test.py diff --git a/convlab2/nlu/jointBERT/crosswoz/nlu.py b/convlab2/nlu/jointBERT/crosswoz/nlu.py index d9f7e4c..594f728 100755 --- a/convlab2/nlu/jointBERT/crosswoz/nlu.py +++ b/convlab2/nlu/jointBERT/crosswoz/nlu.py @@ -17,7 +17,8 @@ class BERTNLU(NLU): assert mode == 'usr' or mode == 'sys' or mode == 'all' config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file)) config = json.load(open(config_file)) - DEVICE = config['DEVICE'] + # DEVICE = config['DEVICE'] + DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE'] root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) data_dir = os.path.join(root_dir, config['data_dir']) output_dir = os.path.join(root_dir, config['output_dir']) diff --git a/convlab2/policy/mle/crosswoz/mle.py b/convlab2/policy/mle/crosswoz/mle.py index abb1249..ef65fd7 100755 --- a/convlab2/policy/mle/crosswoz/mle.py +++ b/convlab2/policy/mle/crosswoz/mle.py @@ -41,4 +41,4 @@ class MLE(MLEAbstract): if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')): archive = zipfile.ZipFile(archive_file, 'r') archive.extractall(model_dir) - self.load(archive_file, model_file, cfg['load']) + self.load_from_pretrained(archive_file, model_file, cfg['load']) diff --git a/convlab2/policy/mle/mle.py b/convlab2/policy/mle/mle.py index 21d6436..54e2f02 100755 --- a/convlab2/policy/mle/mle.py +++ b/convlab2/policy/mle/mle.py @@ -25,7 +25,7 @@ class MLEAbstract(Policy): """ s_vec = torch.Tensor(self.vector.state_vectorize(state)) a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu() - action = self.vector.action_devectorize(a.numpy()) + action = self.vector.action_devectorize(a.detach().numpy()) state['system_action'] = action return action diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py index 82df69a..c605c80 100755 --- a/convlab2/policy/pg/pg.py +++ b/convlab2/policy/pg/pg.py @@ -126,7 +126,6 @@ class PG(Policy): # backprop surrogate.backward() - for p in self.policy.parameters(): p.grad[p.grad != p.grad] = 0.0 # gradient clipping, for stability diff --git a/deploy/dep_config.json b/deploy/dep_config.json index b9c478b..4afd888 100755 --- a/deploy/dep_config.json +++ b/deploy/dep_config.json @@ -27,6 +27,19 @@ "preload": false, "enable": true }, + "bert-cro": { + "class_path": "convlab2.nlu.jointBERT.crosswoz.nlu.BERTNLU", + "data_set": "crosswoz", + "ini_params": { + "mode": "all", + "config_file": "crosswoz_all.json", + "model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_crosswoz_all.zip" + }, + "model_name": "bert-cro", + "max_core": 1, + "preload": false, + "enable": true + }, "bert-mul": { "class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU", "data_set": "multiwoz", @@ -60,6 +73,15 @@ "preload": true, "enable": true }, + "rule-cro": { + "class_path": "convlab2.dst.rule.crosswoz.dst.RuleDST", + "data_set": "crosswoz", + "ini_params": {}, + "model_name": "rule-cro", + "max_core": 1, + "preload": true, + "enable": true + }, "trade-mul": { "class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE", "data_set": "multiwoz", @@ -106,6 +128,15 @@ "max_core": 1, "preload": true, "enable": true + }, + "mle-cro": { + "class_path": "convlab2.policy.mle.crosswoz.mle.MLE", + "data_set": "crosswoz", + "ini_params": {}, + "model_name": "mle-cro", + "max_core": 1, + "preload": false, + "enable": true } }, "nlg": { @@ -143,6 +174,18 @@ "max_core": 1, "preload": true, "enable": true + }, + "tmp-auto_manual-cro": { + "class_path": "convlab2.nlg.template.crosswoz.nlg.TemplateNLG", + "data_set": "crosswoz", + "ini_params": { + "is_user": false, + "mode": "auto_manual" + }, + "model_name": "tmp-auto_manual-cro", + "max_core": 1, + "preload": true, + "enable": true } } } \ No newline at end of file diff --git a/deploy/templates/dialog.html b/deploy/templates/dialog.html index d521809..ac3e15c 100755 --- a/deploy/templates/dialog.html +++ b/deploy/templates/dialog.html @@ -279,7 +279,7 @@ data: { dataset: 'MultiWoz', dataset_short: 'mul', - dataset_list: ['MultiWoz'], + dataset_list: ['MultiWoz', 'CrossWoz'], nlu: 'BERTNLU', nlu_list: [], nlu_output: {}, @@ -317,6 +317,8 @@ dataset: function() { if (this.dataset === 'MultiWoz') { this.dataset_short = 'mul' + } else if (this.dataset == 'CrossWoz') { + this.dataset_short = 'cro' } else { this.dataset_short = 'cam' } -- GitLab