From 67529531b7bf0084e7e3c6930577085a67402ed9 Mon Sep 17 00:00:00 2001
From: Carrey Wang <cwhongru@cuc.edu.cn>
Date: Mon, 15 Jun 2020 11:06:28 +0800
Subject: [PATCH] Add CrossWoz Web support and some minor bug fix (#19)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
---
 .gitignore                             |  1 +
 convlab2/nlu/jointBERT/crosswoz/nlu.py |  3 +-
 convlab2/policy/mle/crosswoz/mle.py    |  2 +-
 convlab2/policy/mle/mle.py             |  2 +-
 convlab2/policy/pg/pg.py               |  1 -
 deploy/dep_config.json                 | 43 ++++++++++++++++++++++++++
 deploy/templates/dialog.html           |  4 ++-
 7 files changed, 51 insertions(+), 5 deletions(-)

diff --git a/.gitignore b/.gitignore
index bdb4d68..2823cc3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/
 convlab2/dst/sumbt/multiwoz/output/
 convlab2/nlg/sclstm/**/generated_sens_sys.json
 convlab2/nlg/template/**/generated_sens_sys.json
+convlab2/nlu/jointBERT/crosswoz/**/data
 # test script
 *_test.py
 
diff --git a/convlab2/nlu/jointBERT/crosswoz/nlu.py b/convlab2/nlu/jointBERT/crosswoz/nlu.py
index d9f7e4c..594f728 100755
--- a/convlab2/nlu/jointBERT/crosswoz/nlu.py
+++ b/convlab2/nlu/jointBERT/crosswoz/nlu.py
@@ -17,7 +17,8 @@ class BERTNLU(NLU):
         assert mode == 'usr' or mode == 'sys' or mode == 'all'
         config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file))
         config = json.load(open(config_file))
-        DEVICE = config['DEVICE']
+        # DEVICE = config['DEVICE']
+        DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
         root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         data_dir = os.path.join(root_dir, config['data_dir'])
         output_dir = os.path.join(root_dir, config['output_dir'])
diff --git a/convlab2/policy/mle/crosswoz/mle.py b/convlab2/policy/mle/crosswoz/mle.py
index abb1249..ef65fd7 100755
--- a/convlab2/policy/mle/crosswoz/mle.py
+++ b/convlab2/policy/mle/crosswoz/mle.py
@@ -41,4 +41,4 @@ class MLE(MLEAbstract):
         if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
             archive = zipfile.ZipFile(archive_file, 'r')
             archive.extractall(model_dir)
-        self.load(archive_file, model_file, cfg['load'])
+        self.load_from_pretrained(archive_file, model_file, cfg['load'])
diff --git a/convlab2/policy/mle/mle.py b/convlab2/policy/mle/mle.py
index 21d6436..54e2f02 100755
--- a/convlab2/policy/mle/mle.py
+++ b/convlab2/policy/mle/mle.py
@@ -25,7 +25,7 @@ class MLEAbstract(Policy):
         """
         s_vec = torch.Tensor(self.vector.state_vectorize(state))
         a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu()
-        action = self.vector.action_devectorize(a.numpy())
+        action = self.vector.action_devectorize(a.detach().numpy())
         state['system_action'] = action
         return action
 
diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py
index 82df69a..c605c80 100755
--- a/convlab2/policy/pg/pg.py
+++ b/convlab2/policy/pg/pg.py
@@ -126,7 +126,6 @@ class PG(Policy):
 
                 # backprop
                 surrogate.backward()
-                
                 for p in self.policy.parameters():
                     p.grad[p.grad != p.grad] = 0.0
                 # gradient clipping, for stability
diff --git a/deploy/dep_config.json b/deploy/dep_config.json
index b9c478b..4afd888 100755
--- a/deploy/dep_config.json
+++ b/deploy/dep_config.json
@@ -27,6 +27,19 @@
       "preload": false,
       "enable": true
     },
+    "bert-cro": {
+      "class_path": "convlab2.nlu.jointBERT.crosswoz.nlu.BERTNLU",
+      "data_set": "crosswoz",
+      "ini_params": {
+        "mode": "all",
+        "config_file": "crosswoz_all.json",
+        "model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_crosswoz_all.zip"
+      },
+      "model_name": "bert-cro",
+      "max_core": 1,
+      "preload": false,
+      "enable": true
+    },
     "bert-mul": {
       "class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU",
       "data_set": "multiwoz",
@@ -60,6 +73,15 @@
       "preload": true,
       "enable": true
     },
+    "rule-cro": {
+      "class_path": "convlab2.dst.rule.crosswoz.dst.RuleDST",
+      "data_set": "crosswoz",
+      "ini_params": {},
+      "model_name": "rule-cro",
+      "max_core": 1,
+      "preload": true,
+      "enable": true
+    },
     "trade-mul": {
       "class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE",
       "data_set": "multiwoz",
@@ -106,6 +128,15 @@
       "max_core": 1,
       "preload": true,
       "enable": true
+    },
+    "mle-cro": {
+      "class_path": "convlab2.policy.mle.crosswoz.mle.MLE",
+      "data_set": "crosswoz",
+      "ini_params": {},
+      "model_name": "mle-cro",
+      "max_core": 1,
+      "preload": false,
+      "enable": true
     }
   },
   "nlg": {
@@ -143,6 +174,18 @@
       "max_core": 1,
       "preload": true,
       "enable": true
+    },
+    "tmp-auto_manual-cro": {
+      "class_path": "convlab2.nlg.template.crosswoz.nlg.TemplateNLG",
+      "data_set": "crosswoz",
+      "ini_params": {
+        "is_user": false,
+        "mode": "auto_manual"
+      },
+      "model_name": "tmp-auto_manual-cro",
+      "max_core": 1,
+      "preload": true,
+      "enable": true
     }
   }
 }
\ No newline at end of file
diff --git a/deploy/templates/dialog.html b/deploy/templates/dialog.html
index d521809..ac3e15c 100755
--- a/deploy/templates/dialog.html
+++ b/deploy/templates/dialog.html
@@ -279,7 +279,7 @@
         data: {
             dataset: 'MultiWoz',
             dataset_short: 'mul',
-            dataset_list: ['MultiWoz'],
+            dataset_list: ['MultiWoz', 'CrossWoz'],
             nlu: 'BERTNLU',
             nlu_list: [],
             nlu_output: {},
@@ -317,6 +317,8 @@
             dataset: function() {
                 if (this.dataset === 'MultiWoz') {
                     this.dataset_short = 'mul'
+                } else if (this.dataset == 'CrossWoz') {
+                    this.dataset_short = 'cro'
                 } else {
                     this.dataset_short = 'cam'
                 }
-- 
GitLab