Skip to content
Snippets Groups Projects
Select Git revision
  • d0dff197253673a11eb94eb65c7955fbe193ee80
  • master default protected
  • exec_auto_adjust_trace
  • let_variables
  • v1.4.1
  • v1.4.0
  • v1.3.0
  • v1.2.0
  • v1.1.0
  • v1.0.0
10 results

gradle-wrapper.properties

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_BERTNLU-RuleDST-RulePolicy-SCLSTM.py 2.60 KiB
    # available NLU models
    # from convlab2.nlu.svm.multiwoz import SVMNLU
    from convlab2.nlu.jointBERT.multiwoz import BERTNLU
    # from convlab2.nlu.milu.multiwoz import MILU
    # available DST models
    from convlab2.dst.rule.multiwoz import RuleDST
    # from convlab2.dst.mdbt.multiwoz import MDBT
    # from convlab2.dst.sumbt.multiwoz import SUMBT
    # from convlab2.dst.trade.multiwoz import TRADE
    # from convlab2.dst.comer.multiwoz import COMER
    # available Policy models
    from convlab2.policy.rule.multiwoz import RulePolicy
    # from convlab2.policy.ppo.multiwoz import PPOPolicy
    # from convlab2.policy.pg.multiwoz import PGPolicy
    # from convlab2.policy.mle.multiwoz import MLEPolicy
    # from convlab2.policy.gdpl.multiwoz import GDPLPolicy
    # from convlab2.policy.vhus.multiwoz import UserPolicyVHUS
    # from convlab2.policy.mdrg.multiwoz import MDRGWordPolicy
    # from convlab2.policy.hdsa.multiwoz import HDSA
    # from convlab2.policy.larl.multiwoz import LaRL
    # available NLG models
    from convlab2.nlg.template.multiwoz import TemplateNLG
    from convlab2.nlg.sclstm.multiwoz import SCLSTM
    # available E2E models
    # from convlab2.e2e.sequicity.multiwoz import Sequicity
    # from convlab2.e2e.damd.multiwoz import Damd
    from convlab2.dialog_agent import PipelineAgent, BiSession
    from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
    from convlab2.util.analysis_tool.analyzer import Analyzer
    from pprint import pprint
    import random
    import numpy as np
    import torch
    
    
    def set_seed(r_seed):
        random.seed(r_seed)
        np.random.seed(r_seed)
        torch.manual_seed(r_seed)
    
    
    def test_end2end():
        # go to README.md of each model for more information
        # BERT nlu
        sys_nlu = BERTNLU()
        # simple rule DST
        sys_dst = RuleDST()
        # rule policy
        sys_policy = RulePolicy()
        # template NLG
        sys_nlg = SCLSTM(use_cuda=True)
        # assemble
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
    
        # BERT nlu trained on sys utterance
        user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
                           model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
        # not use dst
        user_dst = None
        # rule policy
        user_policy = RulePolicy(character='usr')
        # template NLG
        user_nlg = TemplateNLG(is_user=True)
        # assemble
        user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
    
        analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
    
        set_seed(20200202)
        analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-RulePolicy-SCLSTM', total_dialog=1000)
    
    if __name__ == '__main__':
        test_end2end()