Skip to content
Snippets Groups Projects
Commit aafdb765 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Fix SUMBT import in test_SUMBT-LaRL.py

parent b5eae17f
No related branches found
No related tags found
No related merge requests found
# available NLU models # available NLU models
# from convlab.nlu.svm.multiwoz import SVMNLU # from convlab.nlu.svm.multiwoz import SVMNLU
from convlab.nlu.jointBERT.multiwoz import BERTNLU from convlab.nlu.jointBERT.unified_datasets import BERTNLU
# from convlab.nlu.milu.multiwoz import MILU # from convlab.nlu.milu.multiwoz import MILU
# available DST models # available DST models
# from convlab.dst.rule.multiwoz import RuleDST # from convlab.dst.rule.multiwoz import RuleDST
# from convlab.dst.mdbt.multiwoz import MDBT # from convlab.dst.mdbt.multiwoz import MDBT
from convlab.dst.sumbt.multiwoz import SUMBT from convlab.dst.setsumbt import SetSUMBTTracker
# from convlab.dst.trade.multiwoz import TRADE # from convlab.dst.trade.multiwoz import TRADE
# from convlab.dst.comer.multiwoz import COMER # from convlab.dst.comer.multiwoz import COMER
# available Policy models # available Policy models
...@@ -44,7 +44,7 @@ def test_end2end(): ...@@ -44,7 +44,7 @@ def test_end2end():
# BERT nlu # BERT nlu
sys_nlu = None sys_nlu = None
# simple rule DST # simple rule DST
sys_dst = SUMBT() sys_dst = SetSUMBTTracker(model_type='bert', model_path="path/to/sumbt/checkpoint")
# rule policy # rule policy
sys_policy = LaRL() sys_policy = LaRL()
# template NLG # template NLG
...@@ -53,7 +53,7 @@ def test_end2end(): ...@@ -53,7 +53,7 @@ def test_end2end():
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
# BERT nlu trained on sys utterance # BERT nlu trained on sys utterance
user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', user_nlu = BERTNLU(mode='sys', config_file='multiwoz21_sys_context.json',
model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_sys_context.zip') model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_sys_context.zip')
# not use dst # not use dst
user_dst = None user_dst = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment