From 26efc1edd7062450fb9ba25d613427a28601c938 Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Thu, 19 Jan 2023 15:56:24 +0100
Subject: [PATCH] wip

---
 convlab/nlu/jointBERT/multiwoz/nlu.py                  | 10 ++++++----
 .../policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json   |  5 +++--
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py
index 1373919e..7f707766 100755
--- a/convlab/nlu/jointBERT/multiwoz/nlu.py
+++ b/convlab/nlu/jointBERT/multiwoz/nlu.py
@@ -41,12 +41,13 @@ class BERTNLU(NLU):
         dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
                                 pretrained_weights=config['model']['pretrained_weights'])
 
-        logging.info('intent num:' +  str(len(intent_vocab)))
+        logging.info('intent num:' + str(len(intent_vocab)))
         logging.info('tag num:' + str(len(tag_vocab)))
 
         if not os.path.exists(output_dir):
             model_downloader(root_dir, model_file)
-        model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
+        model = JointBERT(config['model'], DEVICE,
+                          dataloader.tag_dim, dataloader.intent_dim)
 
         state_dict = torch.load(os.path.join(
             output_dir, 'pytorch_model.bin'), DEVICE)
@@ -74,7 +75,7 @@ class BERTNLU(NLU):
         for token in token_list:
             token = token.strip()
             self.nlp.tokenizer.add_special_case(
-                #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
+                # token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
                 token, [{ORTH: token}])
         logging.info("BERTNLU loaded")
 
@@ -97,7 +98,8 @@ class BERTNLU(NLU):
         intents = []
         da = {}
 
-        word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq)
+        word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(
+            ori_word_seq, ori_tag_seq)
         word_seq = word_seq[:510]
         tag_seq = tag_seq[:510]
         batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq,
diff --git a/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json
index 9cdbd8bb..57dcd4de 100644
--- a/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json
+++ b/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json
@@ -23,9 +23,10 @@
     },
     "nlu_sys": {
         "BertNLU": {
-            "class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU",
+            "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU",
             "ini_params": {
-                "config_file": "multiwoz_all.json",
+                "mode": "all",
+                "config_file": "multiwoz21_sys_context3.json",
                 "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip"
             }
         }
-- 
GitLab