From aafdb765968e6d7948b5c7c38bd419686a471513 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <niekerk@hhu.de>
Date: Wed, 11 Jan 2023 13:28:28 +0100
Subject: [PATCH] Fix SUMBT import in test_SUMBT-LaRL.py

---
 examples/agent_examples/test_SUMBT-LaRL.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/examples/agent_examples/test_SUMBT-LaRL.py b/examples/agent_examples/test_SUMBT-LaRL.py
index 3f2e9a28..91004b20 100755
--- a/examples/agent_examples/test_SUMBT-LaRL.py
+++ b/examples/agent_examples/test_SUMBT-LaRL.py
@@ -1,11 +1,11 @@
 # available NLU models
 # from convlab.nlu.svm.multiwoz import SVMNLU
-from convlab.nlu.jointBERT.multiwoz import BERTNLU
+from convlab.nlu.jointBERT.unified_datasets import BERTNLU
 # from convlab.nlu.milu.multiwoz import MILU
 # available DST models
 # from convlab.dst.rule.multiwoz import RuleDST
 # from convlab.dst.mdbt.multiwoz import MDBT
-from convlab.dst.sumbt.multiwoz import SUMBT
+from convlab.dst.setsumbt import SetSUMBTTracker
 # from convlab.dst.trade.multiwoz import TRADE
 # from convlab.dst.comer.multiwoz import COMER
 # available Policy models
@@ -44,7 +44,7 @@ def test_end2end():
     # BERT nlu
     sys_nlu = None
     # simple rule DST
-    sys_dst = SUMBT()
+    sys_dst = SetSUMBTTracker(model_type='bert', model_path="path/to/sumbt/checkpoint")
     # rule policy
     sys_policy = LaRL()
     # template NLG
@@ -53,7 +53,7 @@ def test_end2end():
     sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
 
     # BERT nlu trained on sys utterance
-    user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
+    user_nlu = BERTNLU(mode='sys', config_file='multiwoz21_sys_context.json',
                        model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_sys_context.zip')
     # not use dst
     user_dst = None
-- 
GitLab