diff --git a/examples/agent_examples/test_SUMBT-LaRL.py b/examples/agent_examples/test_SUMBT-LaRL.py index 3f2e9a28912114ba931493b99c78e6300b80f6b4..91004b206d78fbb907d4dc09ab1fe230b51a400c 100755 --- a/examples/agent_examples/test_SUMBT-LaRL.py +++ b/examples/agent_examples/test_SUMBT-LaRL.py @@ -1,11 +1,11 @@ # available NLU models # from convlab.nlu.svm.multiwoz import SVMNLU -from convlab.nlu.jointBERT.multiwoz import BERTNLU +from convlab.nlu.jointBERT.unified_datasets import BERTNLU # from convlab.nlu.milu.multiwoz import MILU # available DST models # from convlab.dst.rule.multiwoz import RuleDST # from convlab.dst.mdbt.multiwoz import MDBT -from convlab.dst.sumbt.multiwoz import SUMBT +from convlab.dst.setsumbt import SetSUMBTTracker # from convlab.dst.trade.multiwoz import TRADE # from convlab.dst.comer.multiwoz import COMER # available Policy models @@ -44,7 +44,7 @@ def test_end2end(): # BERT nlu sys_nlu = None # simple rule DST - sys_dst = SUMBT() + sys_dst = SetSUMBTTracker(model_type='bert', model_path="path/to/sumbt/checkpoint") # rule policy sys_policy = LaRL() # template NLG @@ -53,7 +53,7 @@ def test_end2end(): sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') # BERT nlu trained on sys utterance - user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + user_nlu = BERTNLU(mode='sys', config_file='multiwoz21_sys_context.json', model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_sys_context.zip') # not use dst user_dst = None