Skip to content
Snippets Groups Projects
Select Git revision
  • 1c614c730f081f9769a4e3b2b3a9989b1ed0f1c2
  • master default protected
2 results

main.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_BERTNLU-RuleDST-LaRL.py 2.57 KiB
    # available NLU models
    # from convlab.nlu.svm.multiwoz import SVMNLU
    from convlab.nlu.jointBERT.multiwoz import BERTNLU
    # from convlab.nlu.milu.multiwoz import MILU
    # available DST models
    from convlab.dst.rule.multiwoz import RuleDST
    # from convlab.dst.mdbt.multiwoz import MDBT
    # from convlab.dst.sumbt.multiwoz import SUMBT
    # from convlab.dst.trade.multiwoz import TRADE
    # from convlab.dst.comer.multiwoz import COMER
    # available Policy models
    from convlab.policy.rule.multiwoz import RulePolicy
    # from convlab.policy.ppo.multiwoz import PPOPolicy
    # from convlab.policy.pg.multiwoz import PGPolicy。
    # from convlab.policy.mle.multiwoz import MLEPolicy
    # from convlab.policy.gdpl.multiwoz import GDPLPolicy
    # from convlab.policy.vhus.multiwoz import UserPolicyVHUS
    # from convlab.policy.mdrg.multiwoz import MDRGWordPolicy
    #from convlab.policy.hdsa.multiwoz import HDSA
    from convlab.policy.larl.multiwoz import LaRL
    # available NLG models
    from convlab.nlg.template.multiwoz import TemplateNLG
    from convlab.nlg.sclstm.multiwoz import SCLSTM
    # available E2E models
    # from convlab.e2e.sequicity.multiwoz import Sequicity
    # from convlab.e2e.damd.multiwoz import Damd
    from convlab.dialog_agent import PipelineAgent, BiSession
    from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
    from convlab.util.analysis_tool.analyzer import Analyzer
    from pprint import pprint
    import random
    import numpy as np
    import torch
    import pdb
    
    
    def set_seed(r_seed):
        random.seed(r_seed)
        np.random.seed(r_seed)
        torch.manual_seed(r_seed)
    
    
    def test_end2end():
        # go to README.md of each model for more information
        # BERT nlu
        sys_nlu = BERTNLU()
        # simple rule DST
        sys_dst = RuleDST()
        # rule policy
        pdb.set_trace()
        sys_policy = LaRL()
        # template NLG
        sys_nlg = None
        # assemble
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
    
        # BERT nlu trained on sys utterance
        user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
                           model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
        # not use dst
        user_dst = None
        # rule policy
        user_policy = RulePolicy(character='usr')
        # template NLG
        user_nlg = TemplateNLG(is_user=True)
        # assemble
        user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
    
        analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
    
        set_seed(20200202)
        analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-LaRL', total_dialog=1000)
    
    if __name__ == '__main__':
        test_end2end()