Skip to content
Snippets Groups Projects
Unverified Commit 86ed7ad3 authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Setsumbt bug fix (#123)


* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

Co-authored-by: default avatarCarel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: default avatarMichael Heck <michael.heck@hhu.de>
Co-authored-by: default avatarChristian Geishauser <christian.geishauser@hhu.de>
parent a6dea527
Branches
No related tags found
No related merge requests found
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
"model_type": "SetSUMBT", "model_type": "SetSUMBT",
"dataset": "multiwoz21", "dataset": "multiwoz21",
"no_action_prediction": true, "no_action_prediction": true,
"model_type": "bert",
"model_name_or_path": "bert-base-uncased", "model_name_or_path": "bert-base-uncased",
"candidate_embedding_model_name": "bert-base-uncased", "candidate_embedding_model_name": "bert-base-uncased",
"transformers_local_files_only": false, "transformers_local_files_only": false,
"no_set_similarity": false, "no_set_similarity": true,
"candidate_pooling": "cls", "candidate_pooling": "mean",
"loss_function": "labelsmoothing",
"num_train_epochs": 50,
"learning_rate": 5e-5,
"warmup_proportion": 0.2,
"train_batch_size": 3, "train_batch_size": 3,
"dev_batch_size": 16, "dev_batch_size": 16,
"test_batch_size": 16, "test_batch_size": 16,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Convlab3 Unified Format Dialogue Datasets""" """Convlab3 Unified Format Dialogue Datasets"""
import pdb
from copy import deepcopy from copy import deepcopy
import torch import torch
...@@ -294,8 +294,20 @@ class UnifiedFormatDataset(Dataset): ...@@ -294,8 +294,20 @@ class UnifiedFormatDataset(Dataset):
Returns: Returns:
features (dict): All inputs and labels required to train the model features (dict): All inputs and labels required to train the model
""" """
return {label: self.features[label][index] for label in self.features feats = dict()
if self.features[label] is not None} for label in self.features:
if self.features[label] is not None:
if label == 'dialogue_ids':
if type(index) == int:
feat = self.features[label][index]
else:
feat = [self.features[label][idx] for idx in index]
else:
feat = self.features[label][index]
feats[label] = feat
return feats
def __len__(self): def __len__(self):
""" """
......
...@@ -58,8 +58,12 @@ def main(args=None, config=None): ...@@ -58,8 +58,12 @@ def main(args=None, config=None):
# Set up output directory # Set up output directory
OUTPUT_DIR = args.output_dir OUTPUT_DIR = args.output_dir
# Download model if needed
if not os.path.exists(OUTPUT_DIR): if not os.path.exists(OUTPUT_DIR):
if "http" not in OUTPUT_DIR:
os.makedirs(OUTPUT_DIR)
os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
else:
# Get path /.../convlab/dst/setsumbt/multiwoz/models # Get path /.../convlab/dst/setsumbt/multiwoz/models
download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
download_path = os.path.join(download_path, 'models') download_path = os.path.join(download_path, 'models')
...@@ -73,11 +77,6 @@ def main(args=None, config=None): ...@@ -73,11 +77,6 @@ def main(args=None, config=None):
args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
args.output_dir = OUTPUT_DIR args.output_dir = OUTPUT_DIR
# Set pretrained model path to the trained checkpoint # Set pretrained model path to the trained checkpoint
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment