From 86ed7ad3984b64ea58b18f5fa3d12ff67a672068 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com>
Date: Tue, 24 Jan 2023 11:15:27 +0100
Subject: [PATCH] Setsumbt bug fix (#123)

* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: Michael Heck <michael.heck@hhu.de>
Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de>
---
 .../setsumbt/configs/sumbt_multiwoz21.json    | 11 ++++--
 .../dst/setsumbt/dataset/unified_format.py    | 18 +++++++--
 convlab/dst/setsumbt/do/nbt.py                | 37 +++++++++----------
 3 files changed, 40 insertions(+), 26 deletions(-)

diff --git a/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json
index ebebe8a6..39aeceb2 100644
--- a/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json
+++ b/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json
@@ -2,14 +2,17 @@
   "model_type": "SetSUMBT",
   "dataset": "multiwoz21",
   "no_action_prediction": true,
-  "model_type": "bert",
   "model_name_or_path": "bert-base-uncased",
   "candidate_embedding_model_name": "bert-base-uncased",
   "transformers_local_files_only": false,
-  "no_set_similarity": false,
-  "candidate_pooling": "cls",
+  "no_set_similarity": true,
+  "candidate_pooling": "mean",
+  "loss_function": "labelsmoothing",
+  "num_train_epochs": 50,
+  "learning_rate": 5e-5,
+  "warmup_proportion": 0.2,
   "train_batch_size": 3,
   "dev_batch_size": 16,
   "test_batch_size": 16,
   "run_nbt": true
-}
\ No newline at end of file
+}
diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py
index 1c3a68c3..55483e0f 100644
--- a/convlab/dst/setsumbt/dataset/unified_format.py
+++ b/convlab/dst/setsumbt/dataset/unified_format.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Convlab3 Unified Format Dialogue Datasets"""
-
+import pdb
 from copy import deepcopy
 
 import torch
@@ -294,8 +294,20 @@ class UnifiedFormatDataset(Dataset):
         Returns:
             features (dict): All inputs and labels required to train the model
         """
-        return {label: self.features[label][index] for label in self.features
-                if self.features[label] is not None}
+        feats = dict()
+        for label in self.features:
+            if self.features[label] is not None:
+                if label == 'dialogue_ids':
+                    if type(index) == int:
+                        feat = self.features[label][index]
+                    else:
+                        feat = [self.features[label][idx] for idx in index]
+                else:
+                    feat = self.features[label][index]
+
+                feats[label] = feat
+
+        return feats
 
     def __len__(self):
         """
diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py
index d250f29e..21949e72 100644
--- a/convlab/dst/setsumbt/do/nbt.py
+++ b/convlab/dst/setsumbt/do/nbt.py
@@ -58,26 +58,25 @@ def main(args=None, config=None):
     # Set up output directory
     OUTPUT_DIR = args.output_dir
 
-    # Download model if needed
     if not os.path.exists(OUTPUT_DIR):
-        # Get path /.../convlab/dst/setsumbt/multiwoz/models
-        download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-        download_path = os.path.join(download_path, 'models')
-        if not os.path.exists(download_path):
-            os.mkdir(download_path)
-        model_downloader(download_path, OUTPUT_DIR)
-        # Downloadable model path format http://.../model_name.zip
-        OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '')
-        OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR)
-
-        args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
-        args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
-        os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
-
-    if not os.path.exists(OUTPUT_DIR):
-        os.makedirs(OUTPUT_DIR)
-        os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
-        os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
+        if "http" not in OUTPUT_DIR:
+            os.makedirs(OUTPUT_DIR)
+            os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
+            os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
+        else:
+            # Get path /.../convlab/dst/setsumbt/multiwoz/models
+            download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+            download_path = os.path.join(download_path, 'models')
+            if not os.path.exists(download_path):
+                os.mkdir(download_path)
+            model_downloader(download_path, OUTPUT_DIR)
+            # Downloadable model path format http://.../model_name.zip
+            OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '')
+            OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR)
+
+            args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
+            args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
+            os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
     args.output_dir = OUTPUT_DIR
 
     # Set pretrained model path to the trained checkpoint
-- 
GitLab