Skip to content
Snippets Groups Projects
Unverified Commit 8603ecbc authored by Carrey Wang's avatar Carrey Wang Committed by GitHub
Browse files

Add DQN Test and Change file structure (#146)


* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

* Update server.py

* add test for DQN

* change server

Co-authored-by: default avatarCarrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: default avatarkflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: default avatarCarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: default avatarzimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: default avatarMR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>
parent 3fd04ce2
No related branches found
No related tags found
No related merge requests found
......@@ -148,12 +148,25 @@ class DQN(Policy):
def load(self, filename):
dqn_mdl_candidates = [
filename + '.dqn.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.dqn.mdl'),
filename + '_dqn.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_dqn.pol.mdl'),
]
for dqn_mdl in dqn_mdl_candidates:
if os.path.exists(dqn_mdl):
self.net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE))
self.target_net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(dqn_mdl))
break
@classmethod
def from_pretrained(cls,
archive_file="",
model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip",
is_train=False,
dataset='Multiwoz'):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
model = cls(is_train=is_train, dataset=dataset)
model.load(cfg['load'])
return model
\ No newline at end of file
from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy
\ No newline at end of file
{
"batch_size": 16,
"gamma": 0.99,
"lr": 0.001,
"save_dir": "save",
"log_dir": "log",
"save_per_epoch": 5,
"training_iter": 10,
"training_batch_iter": 3,
"h_dim": 100,
"hv_dim": 50,
"memory_size": 5000,
"epsilon_spec": {
"start": 0.1,
"end": 0.0,
"end_epoch": 200
},
"load": "save/best",
"vocab_size": 500
}
from convlab2.policy.dqn import DQN
import os
import json
class DQNPolicy(DQN):
def __init__(self,
is_train=False,
dataset="Multiwoz",
archive_file="",
model_file="https://convlab.blob.core.windows.net/convlab-2/dqn_policy_multiwoz.zip"):
super().__init__(is_train=is_train, dataset=dataset)
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
self.load(cfg['load'])
\ No newline at end of file
# available NLU models
# from convlab2.nlu.svm.multiwoz import SVMNLU
from convlab2.policy.dqn.multiwoz.dqn_policy import DQNPolicy
from convlab2.nlu.jointBERT.multiwoz import BERTNLU
# from convlab2.nlu.milu.multiwoz import MILU
# available DST models
from convlab2.dst.rule.multiwoz import RuleDST
# from convlab2.dst.mdbt.multiwoz import MDBT
# from convlab2.dst.sumbt.multiwoz import SUMBT
# from convlab2.dst.trade.multiwoz import TRADE
# from convlab2.dst.comer.multiwoz import COMER
# available Policy models
from convlab2.policy.rule.multiwoz import RulePolicy
# from convlab2.policy.ppo.multiwoz import PPOPolicy
# from convlab2.policy.pg.multiwoz import PGPolicy
# from convlab2.policy.mle.multiwoz import MLEPolicy
# from convlab2.policy.vhus.multiwoz import UserPolicyVHUS
# from convlab2.policy.mdrg.multiwoz import MDRGWordPolicy
# from convlab2.policy.hdsa.multiwoz import HDSA
# from convlab2.policy.larl.multiwoz import LaRL
# available NLG models
from convlab2.nlg.template.multiwoz import TemplateNLG
from convlab2.nlg.sclstm.multiwoz import SCLSTM
# available E2E models
# from convlab2.e2e.sequicity.multiwoz import Sequicity
# from convlab2.e2e.damd.multiwoz import Damd
from convlab2.dialog_agent import PipelineAgent, BiSession
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab2.util.analysis_tool.analyzer import Analyzer
from pprint import pprint
import random
import numpy as np
import torch
def set_seed(r_seed):
random.seed(r_seed)
np.random.seed(r_seed)
torch.manual_seed(r_seed)
def test_end2end():
# go to README.md of each model for more information
# BERT nlu
sys_nlu = BERTNLU()
# simple rule DST
sys_dst = RuleDST()
# rule policy
sys_policy = DQNPolicy()
# template NLG
sys_nlg = TemplateNLG(is_user=False)
# assemble
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
# BERT nlu trained on sys utterance
user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
# not use dst
user_dst = None
# rule policy
user_policy = RulePolicy(character='usr')
# template NLG
user_nlg = TemplateNLG(is_user=True)
# assemble
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
set_seed(20200202)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-DQNPolicy-TemplateNLG', total_dialog=1000)
if __name__ == '__main__':
test_end2end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment