From 8603ecbc9948212549989efdeb612f8f624f8269 Mon Sep 17 00:00:00 2001
From: Carrey Wang <hrwang@se.cuhk.edu.hk>
Date: Wed, 7 Oct 2020 20:30:45 +0800
Subject: [PATCH] Add DQN Test and Change file structure  (#146)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

* Update server.py

* add test for DQN

* change server

Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>
---
 .gitignore                                    |  2 +-
 convlab2/policy/dqn/dqn.py                    | 17 ++++-
 convlab2/policy/dqn/multiwoz/__init__.py      |  1 +
 convlab2/policy/dqn/multiwoz/config.json      | 20 +++++
 convlab2/policy/dqn/multiwoz/dqn_policy.py    | 14 ++++
 ...t_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py | 73 +++++++++++++++++++
 6 files changed, 124 insertions(+), 3 deletions(-)
 create mode 100644 convlab2/policy/dqn/multiwoz/__init__.py
 create mode 100755 convlab2/policy/dqn/multiwoz/config.json
 create mode 100644 convlab2/policy/dqn/multiwoz/dqn_policy.py
 create mode 100644 tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py

diff --git a/.gitignore b/.gitignore
index cd35620..a2820f1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -84,4 +84,4 @@ deploy/templates/dialog_eg.html
 test.py
 
 *.egg-info
-pre-trained-models/
\ No newline at end of file
+pre-trained-models/
diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py
index fed3e04..39c04ac 100644
--- a/convlab2/policy/dqn/dqn.py
+++ b/convlab2/policy/dqn/dqn.py
@@ -148,12 +148,25 @@ class DQN(Policy):
     
     def load(self, filename):
         dqn_mdl_candidates = [
-            filename + '.dqn.mdl',
-            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.dqn.mdl'),
+            filename + '_dqn.pol.mdl',
+            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_dqn.pol.mdl'),
         ]
+
         for dqn_mdl in dqn_mdl_candidates:
             if os.path.exists(dqn_mdl):
                 self.net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE))
                 self.target_net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE))
                 logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(dqn_mdl))
                 break
+
+    @classmethod
+    def from_pretrained(cls,
+                        archive_file="",
+                        model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip",
+                        is_train=False,
+                        dataset='Multiwoz'):
+        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
+            cfg = json.load(f)
+        model = cls(is_train=is_train, dataset=dataset)
+        model.load(cfg['load'])
+        return model
\ No newline at end of file
diff --git a/convlab2/policy/dqn/multiwoz/__init__.py b/convlab2/policy/dqn/multiwoz/__init__.py
new file mode 100644
index 0000000..3694e4b
--- /dev/null
+++ b/convlab2/policy/dqn/multiwoz/__init__.py
@@ -0,0 +1 @@
+from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy
\ No newline at end of file
diff --git a/convlab2/policy/dqn/multiwoz/config.json b/convlab2/policy/dqn/multiwoz/config.json
new file mode 100755
index 0000000..1c3fd41
--- /dev/null
+++ b/convlab2/policy/dqn/multiwoz/config.json
@@ -0,0 +1,20 @@
+{
+	"batch_size": 16,
+	"gamma": 0.99,
+	"lr": 0.001,
+	"save_dir": "save",
+	"log_dir": "log",
+	"save_per_epoch": 5,
+	"training_iter": 10,
+	"training_batch_iter": 3,
+	"h_dim": 100,
+	"hv_dim": 50,
+	"memory_size": 5000,
+	"epsilon_spec": {
+		"start": 0.1,
+		"end": 0.0,
+		"end_epoch": 200
+	},
+	"load": "save/best",
+	"vocab_size": 500
+}
diff --git a/convlab2/policy/dqn/multiwoz/dqn_policy.py b/convlab2/policy/dqn/multiwoz/dqn_policy.py
new file mode 100644
index 0000000..2ef4312
--- /dev/null
+++ b/convlab2/policy/dqn/multiwoz/dqn_policy.py
@@ -0,0 +1,14 @@
+from convlab2.policy.dqn import DQN
+import os
+import json
+
+class DQNPolicy(DQN):
+    def __init__(self,
+                is_train=False,
+                dataset="Multiwoz",
+                archive_file="",
+                model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip"):
+        super().__init__(is_train=is_train, dataset=dataset)
+        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
+            cfg = json.load(f)
+        self.load(cfg['load'])
\ No newline at end of file
diff --git a/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py b/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py
new file mode 100644
index 0000000..cfb8c0d
--- /dev/null
+++ b/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py
@@ -0,0 +1,73 @@
+# available NLU models
+# from convlab2.nlu.svm.multiwoz import SVMNLU
+from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy
+from convlab2.nlu.jointBERT.multiwoz import BERTNLU
+# from convlab2.nlu.milu.multiwoz import MILU
+# available DST models
+from convlab2.dst.rule.multiwoz import RuleDST
+# from convlab2.dst.mdbt.multiwoz import MDBT
+# from convlab2.dst.sumbt.multiwoz import SUMBT
+# from convlab2.dst.trade.multiwoz import TRADE
+# from convlab2.dst.comer.multiwoz import COMER
+# available Policy models
+from convlab2.policy.rule.multiwoz import RulePolicy
+# from convlab2.policy.ppo.multiwoz import PPOPolicy
+# from convlab2.policy.pg.multiwoz import PGPolicy
+# from convlab2.policy.mle.multiwoz import MLEPolicy
+# from convlab2.policy.vhus.multiwoz import UserPolicyVHUS
+# from convlab2.policy.mdrg.multiwoz import MDRGWordPolicy
+# from convlab2.policy.hdsa.multiwoz import HDSA
+# from convlab2.policy.larl.multiwoz import LaRL
+# available NLG models
+from convlab2.nlg.template.multiwoz import TemplateNLG
+from convlab2.nlg.sclstm.multiwoz import SCLSTM
+# available E2E models
+# from convlab2.e2e.sequicity.multiwoz import Sequicity
+# from convlab2.e2e.damd.multiwoz import Damd
+from convlab2.dialog_agent import PipelineAgent, BiSession
+from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
+from convlab2.util.analysis_tool.analyzer import Analyzer
+from pprint import pprint
+import random
+import numpy as np
+import torch
+
+
+def set_seed(r_seed):
+    random.seed(r_seed)
+    np.random.seed(r_seed)
+    torch.manual_seed(r_seed)
+
+
+def test_end2end():
+    # go to README.md of each model for more information
+    # BERT nlu
+    sys_nlu = BERTNLU()
+    # simple rule DST
+    sys_dst = RuleDST()
+    # rule policy
+    sys_policy = DQNPolicy()
+    # template NLG
+    sys_nlg = TemplateNLG(is_user=False)
+    # assemble
+    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
+
+    # BERT nlu trained on sys utterance
+    user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
+                       model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
+    # not use dst
+    user_dst = None
+    # rule policy
+    user_policy = RulePolicy(character='usr')
+    # template NLG
+    user_nlg = TemplateNLG(is_user=True)
+    # assemble
+    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
+
+    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
+
+    set_seed(20200202)
+    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-DQNPolicy-TemplateNLG', total_dialog=1000)
+
+if __name__ == '__main__':
+    test_end2end()
-- 
GitLab