From aafdb765968e6d7948b5c7c38bd419686a471513 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <niekerk@hhu.de> Date: Wed, 11 Jan 2023 13:28:28 +0100 Subject: [PATCH] Fix SUMBT import in test_SUMBT-LaRL.py --- examples/agent_examples/test_SUMBT-LaRL.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/agent_examples/test_SUMBT-LaRL.py b/examples/agent_examples/test_SUMBT-LaRL.py index 3f2e9a28..91004b20 100755 --- a/examples/agent_examples/test_SUMBT-LaRL.py +++ b/examples/agent_examples/test_SUMBT-LaRL.py @@ -1,11 +1,11 @@ # available NLU models # from convlab.nlu.svm.multiwoz import SVMNLU -from convlab.nlu.jointBERT.multiwoz import BERTNLU +from convlab.nlu.jointBERT.unified_datasets import BERTNLU # from convlab.nlu.milu.multiwoz import MILU # available DST models # from convlab.dst.rule.multiwoz import RuleDST # from convlab.dst.mdbt.multiwoz import MDBT -from convlab.dst.sumbt.multiwoz import SUMBT +from convlab.dst.setsumbt import SetSUMBTTracker # from convlab.dst.trade.multiwoz import TRADE # from convlab.dst.comer.multiwoz import COMER # available Policy models @@ -44,7 +44,7 @@ def test_end2end(): # BERT nlu sys_nlu = None # simple rule DST - sys_dst = SUMBT() + sys_dst = SetSUMBTTracker(model_type='bert', model_path="path/to/sumbt/checkpoint") # rule policy sys_policy = LaRL() # template NLG @@ -53,7 +53,7 @@ def test_end2end(): sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') # BERT nlu trained on sys utterance - user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + user_nlu = BERTNLU(mode='sys', config_file='multiwoz21_sys_context.json', model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_sys_context.zip') # not use dst user_dst = None -- GitLab