diff --git a/.gitignore b/.gitignore index cd356208b42f43779d592bd7dd043a4bcbf306b6..a2820f1eab2c138e99371195495312a75e1bcb0e 100644 --- a/.gitignore +++ b/.gitignore @@ -84,4 +84,4 @@ deploy/templates/dialog_eg.html test.py *.egg-info -pre-trained-models/ \ No newline at end of file +pre-trained-models/ diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py index fed3e04bc8f9ce928c92b3f68cacc3368d1d2da6..39c04ac13c6e3e1cee03b591c058b86e15e0de47 100644 --- a/convlab2/policy/dqn/dqn.py +++ b/convlab2/policy/dqn/dqn.py @@ -148,12 +148,25 @@ class DQN(Policy): def load(self, filename): dqn_mdl_candidates = [ - filename + '.dqn.mdl', - os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.dqn.mdl'), + filename + '_dqn.pol.mdl', + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_dqn.pol.mdl'), ] + for dqn_mdl in dqn_mdl_candidates: if os.path.exists(dqn_mdl): self.net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE)) self.target_net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE)) logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(dqn_mdl)) break + + @classmethod + def from_pretrained(cls, + archive_file="", + model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip", + is_train=False, + dataset='Multiwoz'): + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + cfg = json.load(f) + model = cls(is_train=is_train, dataset=dataset) + model.load(cfg['load']) + return model \ No newline at end of file diff --git a/convlab2/policy/dqn/multiwoz/__init__.py b/convlab2/policy/dqn/multiwoz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3694e4bc165bbbab401e6b54388f4fefedcba2fa --- /dev/null +++ b/convlab2/policy/dqn/multiwoz/__init__.py @@ -0,0 +1 @@ +from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy \ No newline at end of file diff --git a/convlab2/policy/dqn/multiwoz/config.json b/convlab2/policy/dqn/multiwoz/config.json new file mode 100755 index 0000000000000000000000000000000000000000..1c3fd41fa2ee68ae36c68e8e769c298edd53177e --- /dev/null +++ b/convlab2/policy/dqn/multiwoz/config.json @@ -0,0 +1,20 @@ +{ + "batch_size": 16, + "gamma": 0.99, + "lr": 0.001, + "save_dir": "save", + "log_dir": "log", + "save_per_epoch": 5, + "training_iter": 10, + "training_batch_iter": 3, + "h_dim": 100, + "hv_dim": 50, + "memory_size": 5000, + "epsilon_spec": { + "start": 0.1, + "end": 0.0, + "end_epoch": 200 + }, + "load": "save/best", + "vocab_size": 500 +} diff --git a/convlab2/policy/dqn/multiwoz/dqn_policy.py b/convlab2/policy/dqn/multiwoz/dqn_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef43121df9785f106caf65a9dbff02beb291383 --- /dev/null +++ b/convlab2/policy/dqn/multiwoz/dqn_policy.py @@ -0,0 +1,14 @@ +from convlab2.policy.dqn import DQN +import os +import json + +class DQNPolicy(DQN): + def __init__(self, + is_train=False, + dataset="Multiwoz", + archive_file="", + model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip"): + super().__init__(is_train=is_train, dataset=dataset) + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + cfg = json.load(f) + self.load(cfg['load']) \ No newline at end of file diff --git a/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py b/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb8c0db0905f20c03525200ce0dfcd9e5818b93 --- /dev/null +++ b/tests/test_BERTNLU-RuleDST-DQNPolicy-TemplateNLG.py @@ -0,0 +1,73 @@ +# available NLU models +# from convlab2.nlu.svm.multiwoz import SVMNLU +from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy +from convlab2.nlu.jointBERT.multiwoz import BERTNLU +# from convlab2.nlu.milu.multiwoz import MILU +# available DST models +from convlab2.dst.rule.multiwoz import RuleDST +# from convlab2.dst.mdbt.multiwoz import MDBT +# from convlab2.dst.sumbt.multiwoz import SUMBT +# from convlab2.dst.trade.multiwoz import TRADE +# from convlab2.dst.comer.multiwoz import COMER +# available Policy models +from convlab2.policy.rule.multiwoz import RulePolicy +# from convlab2.policy.ppo.multiwoz import PPOPolicy +# from convlab2.policy.pg.multiwoz import PGPolicy +# from convlab2.policy.mle.multiwoz import MLEPolicy +# from convlab2.policy.vhus.multiwoz import UserPolicyVHUS +# from convlab2.policy.mdrg.multiwoz import MDRGWordPolicy +# from convlab2.policy.hdsa.multiwoz import HDSA +# from convlab2.policy.larl.multiwoz import LaRL +# available NLG models +from convlab2.nlg.template.multiwoz import TemplateNLG +from convlab2.nlg.sclstm.multiwoz import SCLSTM +# available E2E models +# from convlab2.e2e.sequicity.multiwoz import Sequicity +# from convlab2.e2e.damd.multiwoz import Damd +from convlab2.dialog_agent import PipelineAgent, BiSession +from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator +from convlab2.util.analysis_tool.analyzer import Analyzer +from pprint import pprint +import random +import numpy as np +import torch + + +def set_seed(r_seed): + random.seed(r_seed) + np.random.seed(r_seed) + torch.manual_seed(r_seed) + + +def test_end2end(): + # go to README.md of each model for more information + # BERT nlu + sys_nlu = BERTNLU() + # simple rule DST + sys_dst = RuleDST() + # rule policy + sys_policy = DQNPolicy() + # template NLG + sys_nlg = TemplateNLG(is_user=False) + # assemble + sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') + + # BERT nlu trained on sys utterance + user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') + # not use dst + user_dst = None + # rule policy + user_policy = RulePolicy(character='usr') + # template NLG + user_nlg = TemplateNLG(is_user=True) + # assemble + user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') + + analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') + + set_seed(20200202) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-DQNPolicy-TemplateNLG', total_dialog=1000) + +if __name__ == '__main__': + test_end2end()