From 8603ecbc9948212549989efdeb612f8f624f8269 Mon Sep 17 00:00:00 2001 From: Carrey Wang <hrwang@se.cuhk.edu.hk> Date: Wed, 7 Oct 2020 20:30:45 +0800 Subject: [PATCH] Add DQN Test and Change file structure (#146) * Initial commit * first commit * add build * add build * add build * add recommend * add crosswoz config in deploy * add crosswoz at html * debug chinese vision * fix system bug according to convlab2 * master change * modify .gitignore * delete svm_camrest_usr.pickle * Update server.py * add test for DQN * change server Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn> Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local> Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local> Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com> Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk> --- .gitignore | 2 +- convlab2/policy/dqn/dqn.py | 17 ++++- convlab2/policy/dqn/multiwoz/__init__.py | 1 + convlab2/policy/dqn/multiwoz/config.json | 20 +++++ convlab2/policy/dqn/multiwoz/dqn_policy.py | 14 ++++ ...t_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py | 73 +++++++++++++++++++ 6 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 convlab2/policy/dqn/multiwoz/__init__.py create mode 100755 convlab2/policy/dqn/multiwoz/config.json create mode 100644 convlab2/policy/dqn/multiwoz/dqn_policy.py create mode 100644 tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py diff --git a/.gitignore b/.gitignore index cd35620..a2820f1 100644 --- a/.gitignore +++ b/.gitignore @@ -84,4 +84,4 @@ deploy/templates/dialog_eg.html test.py *.egg-info -pre-trained-models/ \ No newline at end of file +pre-trained-models/ diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py index fed3e04..39c04ac 100644 --- a/convlab2/policy/dqn/dqn.py +++ b/convlab2/policy/dqn/dqn.py @@ -148,12 +148,25 @@ class DQN(Policy): def load(self, filename): dqn_mdl_candidates = [ - filename + '.dqn.mdl', - os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.dqn.mdl'), + filename + '_dqn.pol.mdl', + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_dqn.pol.mdl'), ] + for dqn_mdl in dqn_mdl_candidates: if os.path.exists(dqn_mdl): self.net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE)) self.target_net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE)) logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(dqn_mdl)) break + + @classmethod + def from_pretrained(cls, + archive_file="", + model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip", + is_train=False, + dataset='Multiwoz'): + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + cfg = json.load(f) + model = cls(is_train=is_train, dataset=dataset) + model.load(cfg['load']) + return model \ No newline at end of file diff --git a/convlab2/policy/dqn/multiwoz/__init__.py b/convlab2/policy/dqn/multiwoz/__init__.py new file mode 100644 index 0000000..3694e4b --- /dev/null +++ b/convlab2/policy/dqn/multiwoz/__init__.py @@ -0,0 +1 @@ +from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy \ No newline at end of file diff --git a/convlab2/policy/dqn/multiwoz/config.json b/convlab2/policy/dqn/multiwoz/config.json new file mode 100755 index 0000000..1c3fd41 --- /dev/null +++ b/convlab2/policy/dqn/multiwoz/config.json @@ -0,0 +1,20 @@ +{ + "batch_size": 16, + "gamma": 0.99, + "lr": 0.001, + "save_dir": "save", + "log_dir": "log", + "save_per_epoch": 5, + "training_iter": 10, + "training_batch_iter": 3, + "h_dim": 100, + "hv_dim": 50, + "memory_size": 5000, + "epsilon_spec": { + "start": 0.1, + "end": 0.0, + "end_epoch": 200 + }, + "load": "save/best", + "vocab_size": 500 +} diff --git a/convlab2/policy/dqn/multiwoz/dqn_policy.py b/convlab2/policy/dqn/multiwoz/dqn_policy.py new file mode 100644 index 0000000..2ef4312 --- /dev/null +++ b/convlab2/policy/dqn/multiwoz/dqn_policy.py @@ -0,0 +1,14 @@ +from convlab2.policy.dqn import DQN +import os +import json + +class DQNPolicy(DQN): + def __init__(self, + is_train=False, + dataset="Multiwoz", + archive_file="", + model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip"): + super().__init__(is_train=is_train, dataset=dataset) + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + cfg = json.load(f) + self.load(cfg['load']) \ No newline at end of file diff --git a/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py b/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py new file mode 100644 index 0000000..cfb8c0d --- /dev/null +++ b/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py @@ -0,0 +1,73 @@ +# available NLU models +# from convlab2.nlu.svm.multiwoz import SVMNLU +from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy +from convlab2.nlu.jointBERT.multiwoz import BERTNLU +# from convlab2.nlu.milu.multiwoz import MILU +# available DST models +from convlab2.dst.rule.multiwoz import RuleDST +# from convlab2.dst.mdbt.multiwoz import MDBT +# from convlab2.dst.sumbt.multiwoz import SUMBT +# from convlab2.dst.trade.multiwoz import TRADE +# from convlab2.dst.comer.multiwoz import COMER +# available Policy models +from convlab2.policy.rule.multiwoz import RulePolicy +# from convlab2.policy.ppo.multiwoz import PPOPolicy +# from convlab2.policy.pg.multiwoz import PGPolicy +# from convlab2.policy.mle.multiwoz import MLEPolicy +# from convlab2.policy.vhus.multiwoz import UserPolicyVHUS +# from convlab2.policy.mdrg.multiwoz import MDRGWordPolicy +# from convlab2.policy.hdsa.multiwoz import HDSA +# from convlab2.policy.larl.multiwoz import LaRL +# available NLG models +from convlab2.nlg.template.multiwoz import TemplateNLG +from convlab2.nlg.sclstm.multiwoz import SCLSTM +# available E2E models +# from convlab2.e2e.sequicity.multiwoz import Sequicity +# from convlab2.e2e.damd.multiwoz import Damd +from convlab2.dialog_agent import PipelineAgent, BiSession +from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator +from convlab2.util.analysis_tool.analyzer import Analyzer +from pprint import pprint +import random +import numpy as np +import torch + + +def set_seed(r_seed): + random.seed(r_seed) + np.random.seed(r_seed) + torch.manual_seed(r_seed) + + +def test_end2end(): + # go to README.md of each model for more information + # BERT nlu + sys_nlu = BERTNLU() + # simple rule DST + sys_dst = RuleDST() + # rule policy + sys_policy = DQNPolicy() + # template NLG + sys_nlg = TemplateNLG(is_user=False) + # assemble + sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') + + # BERT nlu trained on sys utterance + user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') + # not use dst + user_dst = None + # rule policy + user_policy = RulePolicy(character='usr') + # template NLG + user_nlg = TemplateNLG(is_user=True) + # assemble + user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') + + analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') + + set_seed(20200202) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-DQNPolicy-TemplateNLG', total_dialog=1000) + +if __name__ == '__main__': + test_end2end() -- GitLab