Select Git revision
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_BERTNLU-RuleDST-LaRL.py 2.57 KiB
# available NLU models
# from convlab.nlu.svm.multiwoz import SVMNLU
from convlab.nlu.jointBERT.multiwoz import BERTNLU
# from convlab.nlu.milu.multiwoz import MILU
# available DST models
from convlab.dst.rule.multiwoz import RuleDST
# from convlab.dst.mdbt.multiwoz import MDBT
# from convlab.dst.sumbt.multiwoz import SUMBT
# from convlab.dst.trade.multiwoz import TRADE
# from convlab.dst.comer.multiwoz import COMER
# available Policy models
from convlab.policy.rule.multiwoz import RulePolicy
# from convlab.policy.ppo.multiwoz import PPOPolicy
# from convlab.policy.pg.multiwoz import PGPolicy。
# from convlab.policy.mle.multiwoz import MLEPolicy
# from convlab.policy.gdpl.multiwoz import GDPLPolicy
# from convlab.policy.vhus.multiwoz import UserPolicyVHUS
# from convlab.policy.mdrg.multiwoz import MDRGWordPolicy
#from convlab.policy.hdsa.multiwoz import HDSA
from convlab.policy.larl.multiwoz import LaRL
# available NLG models
from convlab.nlg.template.multiwoz import TemplateNLG
from convlab.nlg.sclstm.multiwoz import SCLSTM
# available E2E models
# from convlab.e2e.sequicity.multiwoz import Sequicity
# from convlab.e2e.damd.multiwoz import Damd
from convlab.dialog_agent import PipelineAgent, BiSession
from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab.util.analysis_tool.analyzer import Analyzer
from pprint import pprint
import random
import numpy as np
import torch
import pdb
def set_seed(r_seed):
random.seed(r_seed)
np.random.seed(r_seed)
torch.manual_seed(r_seed)
def test_end2end():
# go to README.md of each model for more information
# BERT nlu
sys_nlu = BERTNLU()
# simple rule DST
sys_dst = RuleDST()
# rule policy
pdb.set_trace()
sys_policy = LaRL()
# template NLG
sys_nlg = None
# assemble
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
# BERT nlu trained on sys utterance
user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
# not use dst
user_dst = None
# rule policy
user_policy = RulePolicy(character='usr')
# template NLG
user_nlg = TemplateNLG(is_user=True)
# assemble
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
set_seed(20200202)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-LaRL', total_dialog=1000)
if __name__ == '__main__':
test_end2end()