diff --git a/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json index ebebe8a6631a03aff85a08ee6608b0d757f2a33d..39aeceb27f28fe20a499d140779be8e2efa49b28 100644 --- a/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json +++ b/convlab/dst/setsumbt/configs/sumbt_multiwoz21.json @@ -2,14 +2,17 @@ "model_type": "SetSUMBT", "dataset": "multiwoz21", "no_action_prediction": true, - "model_type": "bert", "model_name_or_path": "bert-base-uncased", "candidate_embedding_model_name": "bert-base-uncased", "transformers_local_files_only": false, - "no_set_similarity": false, - "candidate_pooling": "cls", + "no_set_similarity": true, + "candidate_pooling": "mean", + "loss_function": "labelsmoothing", + "num_train_epochs": 50, + "learning_rate": 5e-5, + "warmup_proportion": 0.2, "train_batch_size": 3, "dev_batch_size": 16, "test_batch_size": 16, "run_nbt": true -} \ No newline at end of file +} diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py index 1c3a68c3b2e627ac60f555a642dfa837734249b6..55483e0f4e3404e96d817395c53dd9a6fcd57c3e 100644 --- a/convlab/dst/setsumbt/dataset/unified_format.py +++ b/convlab/dst/setsumbt/dataset/unified_format.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Convlab3 Unified Format Dialogue Datasets""" - +import pdb from copy import deepcopy import torch @@ -294,8 +294,20 @@ class UnifiedFormatDataset(Dataset): Returns: features (dict): All inputs and labels required to train the model """ - return {label: self.features[label][index] for label in self.features - if self.features[label] is not None} + feats = dict() + for label in self.features: + if self.features[label] is not None: + if label == 'dialogue_ids': + if type(index) == int: + feat = self.features[label][index] + else: + feat = [self.features[label][idx] for idx in index] + else: + feat = self.features[label][index] + + feats[label] = feat + + return feats def __len__(self): """ diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index d250f29e0186732b27fb284273d9f9ff4f166d2f..21949e728aa03d261dbb901e64fbb73bfd662d13 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -58,26 +58,25 @@ def main(args=None, config=None): # Set up output directory OUTPUT_DIR = args.output_dir - # Download model if needed if not os.path.exists(OUTPUT_DIR): - # Get path /.../convlab/dst/setsumbt/multiwoz/models - download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - download_path = os.path.join(download_path, 'models') - if not os.path.exists(download_path): - os.mkdir(download_path) - model_downloader(download_path, OUTPUT_DIR) - # Downloadable model path format http://.../model_name.zip - OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '') - OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR) - - args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) - args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) - os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) - - if not os.path.exists(OUTPUT_DIR): - os.makedirs(OUTPUT_DIR) - os.mkdir(os.path.join(OUTPUT_DIR, 'database')) - os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) + if "http" not in OUTPUT_DIR: + os.makedirs(OUTPUT_DIR) + os.mkdir(os.path.join(OUTPUT_DIR, 'database')) + os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) + else: + # Get path /.../convlab/dst/setsumbt/multiwoz/models + download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + download_path = os.path.join(download_path, 'models') + if not os.path.exists(download_path): + os.mkdir(download_path) + model_downloader(download_path, OUTPUT_DIR) + # Downloadable model path format http://.../model_name.zip + OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '') + OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR) + + args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) + args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) + os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) args.output_dir = OUTPUT_DIR # Set pretrained model path to the trained checkpoint