diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py
index 9492faa9c9a20d1c476819bb995900ca71d56607..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/convlab/dst/setsumbt/__init__.py
+++ b/convlab/dst/setsumbt/__init__.py
@@ -1 +0,0 @@
-from convlab.dst.setsumbt.tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py
deleted file mode 100644
index a41f280d3349164a2a67333d0ab176a37cbe50ea..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/calibration_plots.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Calibration Plot plotting script"""
-
-import os
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-import torch
-from matplotlib import pyplot as plt
-
-
-def main():
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--data_dir', help='Location of the belief states', required=True)
-    parser.add_argument('--output', help='Output image path', default='calibration_plot.png')
-    parser.add_argument('--n_bins', help='Number of bins', default=10, type=int)
-    args = parser.parse_args()
-
-    if torch.cuda.is_available():
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-    path = args.data_dir
-
-    models = os.listdir(path)
-    models = [os.path.join(path, model, 'test.predictions') for model in models]
-
-    fig = plt.figure(figsize=(14,8))
-    font=20
-    plt.tick_params(labelsize=font-2)
-    linestyle = ['-', ':', (0, (3, 5, 1, 5)), '-.', (0, (5, 10))]
-    for i, model in enumerate(models):
-        conf, acc = get_calibration(model, device, n_bins=args.n_bins)
-        name = model.split('/')[-2].strip()
-        print(name, conf, acc)
-        plt.plot(conf, acc, label=name, linestyle=linestyle[i], linewidth=3)
-
-    plt.plot(torch.tensor([0,1]), torch.tensor([0,1]), linestyle='--', color='black', linewidth=3)
-    plt.xlabel('Confidence', fontsize=font)
-    plt.ylabel('Joint Goal Accuracy', fontsize=font)
-    plt.legend(fontsize=font)
-
-    plt.savefig(args.output)
-
-
-def get_calibration(path, device, n_bins=10, temperature=1.00):
-    probs = torch.load(path, map_location=device)
-    y_true = probs['state_labels']
-    probs = probs['belief_states']
-
-    y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs}
-    goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
-    goal_acc = sum([goal_acc[slot] for slot in goal_acc])
-    goal_acc = (goal_acc == len(y_true)).int()
-
-    scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs]
-    scores = torch.cat(scores, 0).min(0)[0]
-
-    step = 1.0 / float(n_bins)
-    bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
-    bins = []
-    for b in range(n_bins):
-        lower, upper = bin_ranges[b], bin_ranges[b + 1]
-        if b == 0:
-            ids = torch.where((scores >= lower) * (scores <= upper))[0]
-        else:
-            ids = torch.where((scores > lower) * (scores <= upper))[0]
-        bins.append(ids)
-
-    conf = [0.0]
-    for b in bins:
-        if b.size(0) > 0:
-            l = scores[b]
-            conf.append(l.mean())
-        else:
-            conf.append(-1)
-    conf = torch.tensor(conf)
-
-    slot = [s for s in y_true][0]
-    acc = [0.0]
-    for b in bins:
-        if b.size(0) > 0:
-            acc_ = goal_acc[b]
-            acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
-            if acc_.size(0) >= 0:
-                acc.append(acc_.float().mean())
-            else:
-                acc.append(-1)
-        else:
-            acc.append(-1)
-    acc = torch.tensor(acc)
-
-    conf = conf[acc != -1]
-    acc = acc[acc != -1]
-
-    return conf, acc
-
-
-if __name__ == '__main__':
-    main()
diff --git a/convlab/dst/setsumbt/configs/end2_setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/end2_setsumbt_multiwoz21.json
new file mode 100644
index 0000000000000000000000000000000000000000..185d397c86a2c9cffc9abe1e826414acf84f4e1b
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/end2_setsumbt_multiwoz21.json
@@ -0,0 +1,11 @@
+{
+  "model_type": "SetSUMBT",
+  "dataset": "multiwoz21",
+  "no_action_prediction": false,
+  "loss_function": "distribution_distillation",
+  "model_name_or_path": "roberta-base",
+  "candidate_embedding_model_name": "roberta-base",
+  "train_batch_size": 3,
+  "dev_batch_size": 12,
+  "test_batch_size": 16
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/configs/ensemble_setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/ensemble_setsumbt_multiwoz21.json
new file mode 100644
index 0000000000000000000000000000000000000000..cde2a77e25e68ad48f3d1c10ae3e965fc2b1b6dd
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/ensemble_setsumbt_multiwoz21.json
@@ -0,0 +1,12 @@
+{
+  "model_type": "Ensemble-SetSUMBT",
+  "dataset": "multiwoz21",
+  "ensemble_size": 5,
+  "data_sampling_size": 7500,
+  "no_action_prediction": false,
+  "model_name_or_path": "roberta-base",
+  "candidate_embedding_model_name": "roberta-base",
+  "train_batch_size": 3,
+  "dev_batch_size": 3,
+  "test_batch_size": 3
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json
index 57a245518aae0a111f6220b1a088943b8b64ee4c..9463243b4ab22b2e096ec266f2b6922a584c5ad6 100644
--- a/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json
+++ b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json
@@ -4,9 +4,7 @@
   "no_action_prediction": true,
   "model_name_or_path": "roberta-base",
   "candidate_embedding_model_name": "roberta-base",
-  "transformers_local_files_only": false,
   "train_batch_size": 3,
-  "dev_batch_size": 16,
-  "test_batch_size": 16,
-  "run_nbt": true
+  "dev_batch_size": 12,
+  "test_batch_size": 16
 }
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py
deleted file mode 100644
index 17b1f93b3b39f95827cf6c09e8826383cd00b805..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size
-from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings
diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py
deleted file mode 100644
index ce150a61077ad61ab9d7af2ae3537971ae925f55..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/ontology.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Create Ontology Embeddings"""
-
-import json
-import os
-import random
-from copy import deepcopy
-
-import torch
-import numpy as np
-from tqdm import tqdm
-
-
-def set_seed(args):
-    """
-    Set random seeds
-
-    Args:
-        args (Arguments class): Arguments class containing seed and number of gpus to use
-    """
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor:
-    """
-    Embed candidates
-
-    Args:
-        candidates (list): List of candidate descriptions
-        args (argument class): Runtime arguments
-        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
-        embedding_model (transformer Model): Transformer model for embedding candidate descriptions
-
-    Returns:
-        feats (torch.tensor): Embeddings of the candidate descriptions
-    """
-    # Tokenize candidate descriptions
-    feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len,
-                                   padding='max_length', truncation='longest_first')
-             for val in candidates]
-
-    # Encode tokenized descriptions
-    with torch.no_grad():
-        feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]}
-        embedded_feats = embedding_model(**feats)  # [num_candidates, max_candidate_len, hidden_dim]
-
-    # Reduce/pool descriptions embeddings if required
-    if args.set_similarity:
-        feats = embedded_feats.last_hidden_state.detach().cpu()  # [num_candidates, max_candidate_len, hidden_dim]
-    elif args.candidate_pooling == 'cls':
-        feats = embedded_feats.pooler_output.detach().cpu()  # [num_candidates, hidden_dim]
-    elif args.candidate_pooling == "mean":
-        feats = embedded_feats.last_hidden_state.detach().cpu()
-        feats = feats.sum(1)
-        feats = torch.nn.functional.layer_norm(feats, feats.size())
-        feats = feats.detach().cpu()  # [num_candidates, hidden_dim]
-
-    return feats
-
-
-def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True):
-    """
-    Get embeddings for slots and candidates
-
-    Args:
-        ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets
-        set_type (str): Subset of the dataset being used (train/validation/test)
-        args (argument class): Runtime arguments
-        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
-        embedding_model (transformer Model): Transormer model for embedding candidate descriptions
-        save_to_file (bool): Indication of whether to save information to file
-
-    Returns:
-        slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
-    """
-    # Set model to eval mode
-    embedding_model.eval()
-
-    slots = dict()
-    for domain, subset in tqdm(ontology.items(), desc='Domains'):
-        for slot, slot_info in tqdm(subset.items(), desc='Slots'):
-            # Get description or use "domain-slot"
-            if args.use_descriptions:
-                desc = slot_info['description']
-            else:
-                desc = f"{domain}-{slot}"
-
-            # Encode domain-slot pair description
-            slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0]
-
-            # Obtain possible value set and discard requestable value
-            values = deepcopy(slot_info['possible_values'])
-            is_requestable = False
-            if '?' in values:
-                is_requestable = True
-                values.remove('?')
-
-            # Encode value candidates
-            if values:
-                feats = encode_candidates(values, args, tokenizer, embedding_model)
-            else:
-                feats = None
-
-            # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot
-            slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
-
-    # Dump tensors and ontology for use in training and evaluation
-    if save_to_file:
-        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
-        torch.save(slots, writer)
-
-        writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w')
-        json.dump(ontology, writer, indent=2)
-        writer.close()
-    
-    return slots
diff --git a/convlab/dst/setsumbt/datasets/__init__.py b/convlab/dst/setsumbt/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78b84c53dabfea8fac71a5d7f99ef1a7ac69dd0
--- /dev/null
+++ b/convlab/dst/setsumbt/datasets/__init__.py
@@ -0,0 +1,4 @@
+from convlab.dst.setsumbt.datasets.unified_format import get_dataloader, change_batch_size, dataloader_sample_dialogues
+from convlab.dst.setsumbt.datasets.metrics import (JointGoalAccuracy, BeliefStateUncertainty,
+                                                   ActPredictionAccuracy, Metrics)
+from convlab.dst.setsumbt.datasets.distillation import get_dataloader as get_distillation_dataloader
diff --git a/convlab/dst/setsumbt/datasets/distillation.py b/convlab/dst/setsumbt/datasets/distillation.py
new file mode 100644
index 0000000000000000000000000000000000000000..50697582cb220d6e009d21dbc57b03caa37841b3
--- /dev/null
+++ b/convlab/dst/setsumbt/datasets/distillation.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Get ensemble predictions and build distillation dataloaders"""
+
+import os
+
+import torch
+from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
+
+from convlab.dst.setsumbt.datasets.unified_format import UnifiedFormatDataset
+from convlab.dst.setsumbt.datasets.utils import IdTensor
+
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+
+def get_dataloader(ensemble_path:str, set_type: str = 'train', batch_size: int = 3,
+                   reduction: str = 'none') -> DataLoader:
+    """
+    Get dataloader for distillation of ensemble.
+
+    Args:
+        ensemble_path: Path to ensemble model and predictive distributions.
+        set_type: Dataset split to load.
+        batch_size: Batch size.
+        reduction: Reduction to apply to ensemble predictive distributions.
+
+    Returns:
+        loader: Dataloader for distillation.
+    """
+    # Load data and predictions from ensemble
+    path = os.path.join(ensemble_path, 'dataloaders', f"{set_type}.dataloader")
+    dataset = torch.load(path).dataset
+
+    path = os.path.join(ensemble_path, 'predictions', f"{set_type}.data")
+    data = torch.load(path)
+
+    dialogue_ids = data.pop('dialogue_ids')
+
+    # Preprocess data
+    data = reduce_data(data, reduction=reduction)
+    data = flatten_data(data)
+    data = do_label_padding(data)
+
+    # Build dataset and dataloader
+    data = UnifiedFormatDataset.from_datadict(set_type=set_type if set_type != 'dev' else 'validation',
+                                              data=data,
+                                              ontology=dataset.ontology,
+                                              ontology_embeddings=dataset.ontology_embeddings)
+    data.features['dialogue_ids'] = IdTensor(dialogue_ids)
+
+    if set_type == 'train':
+        sampler = RandomSampler(data)
+    else:
+        sampler = SequentialSampler(data)
+
+    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
+    return loader
+
+
+def reduce_data(data: dict, reduction: str = 'none') -> dict:
+    """
+    Reduce ensemble predictive distributions.
+
+    Args:
+        data: Dictionary of ensemble predictive distributions.
+        reduction: Reduction to apply to ensemble predictive distributions.
+
+    Returns:
+        data: Reduced ensemble predictive distributions.
+    """
+    if reduction == 'mean':
+        data['belief_state'] = {slot: probs.mean(-2) for slot, probs in data['belief_state'].items()}
+        if 'request_probabilities' in data:
+            data['request_probabilities'] = {slot: probs.mean(-1)
+                                             for slot, probs in data['request_probabilities'].items()}
+            data['active_domain_probabilities'] = {domain: probs.mean(-1)
+                                                   for domain, probs in data['active_domain_probabilities'].items()}
+            data['general_act_probabilities'] = data['general_act_probabilities'].mean(-2)
+    return data
+
+
+def do_label_padding(data: dict) -> dict:
+    """
+    Add padding to the ensemble predictions (used as labels in distillation)
+
+    Args:
+        data: Dictionary of ensemble predictions
+
+    Returns:
+        data: Padded ensemble predictions
+    """
+    if 'attention_mask' in data:
+        dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
+    else:
+        dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0)
+    
+    for key in data:
+        if key not in ['input_ids', 'attention_mask', 'token_type_ids']:
+            data[key][dialogs, turns] = -1
+    
+    return data
+
+
+def flatten_data(data: dict) -> dict:
+    """
+    Map data to flattened feature format used in training
+
+    Args:
+        data: Ensemble prediction data
+
+    Returns:
+        data: Flattened ensemble prediction data
+    """
+    data_new = dict()
+    for label, feats in data.items():
+        if type(feats) == dict:
+            for label_, feats_ in feats.items():
+                data_new[label + '-' + label_] = feats_
+        else:
+            data_new[label] = feats
+    
+    return data_new
diff --git a/convlab/dst/setsumbt/datasets/metrics.py b/convlab/dst/setsumbt/datasets/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..690d75baeaee4ef6ec5c1d963e8238db0044c57b
--- /dev/null
+++ b/convlab/dst/setsumbt/datasets/metrics.py
@@ -0,0 +1,566 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Metrics for DST models."""
+
+import json
+import os
+
+import torch
+from transformers.utils import ModelOutput
+from matplotlib import pyplot as plt
+
+from convlab.util import load_dataset
+from convlab.util import load_dst_data
+from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, QUANTITIES
+
+
+class Metrics(ModelOutput):
+    """Metrics for DST models."""
+    def __add__(self, other):
+        """Add two metrics objects."""
+        for key, itm in other.items():
+            assert key not in self
+            self[key] = itm
+        return self
+
+    def compute_score(self, **weights):
+        """
+        Compute the score for the metrics object.
+
+        Args:
+            request (float): The weight for the request F1 score.
+            active_domain (float): The weight for the active domain F1 score.
+            general_act (float): The weight for the general act F1 score.
+        """
+        assert 'joint_goal_accuracy' in self
+        self.score = 0.0
+        if 'request_f1' in self and 'request' in weights:
+            self.score += self.request_f1 * weights['request']
+        if 'active_domain_f1' in self and 'active_domain' in weights:
+            self.score += self.active_domain_f1 * weights['active_domain']
+        if 'general_act_f1' in self and 'general_act' in weights:
+            self.score += self.general_act_f1 * weights['general_act']
+        self.score += self.joint_goal_accuracy
+
+    def __gt__(self, other):
+        """Compare two metrics objects."""
+        assert isinstance(other, Metrics)
+
+        if self.joint_goal_accuracy > other.joint_goal_accuracy:
+            return True
+        elif 'score' in self and 'score' in other and self.score > other.score:
+            return True
+        elif self.training_loss < other.training_loss:
+            return True
+        else:
+            return False
+
+class JointGoalAccuracy:
+    """Joint goal accuracy metric."""
+
+    def __init__(self, dataset_names, validation_split='test'):
+        """
+        Initialize the joint goal accuracy metric.
+
+        Args:
+            dataset_names (str): The name of the dataset(s) to use for computing the metric.
+            validation_split (str): The split of the dataset to use for computing the metric.
+        """
+        self.dataset_names = [name for name in dataset_names.split('+')]
+        self.validation_split = validation_split
+        self._extract_data()
+        self._extract_states()
+        self._init_session()
+
+    def _extract_data(self):
+        """Extract the data from the dataset."""
+        dataset_dicts = [load_dataset(dataset_name=name) for name in self.dataset_names]
+        self.golden_states = dict()
+        for dataset_dict in dataset_dicts:
+            dataset = load_dst_data(dataset_dict, data_split=self.validation_split, speaker='all', dialogue_acts=True,
+                                    split_to_turn=False)
+            for dial in dataset[self.validation_split]:
+                self.golden_states[dial['dialogue_id']] = dial['turns']
+
+    @staticmethod
+    def _clean_state(state):
+        """
+        Clean the state to remove pipe separated values and map values to the standard set.
+
+        Args:
+            state (dict): The state to clean.
+
+        Returns:
+            dict: The cleaned state.
+        """
+        clean_state = dict()
+        for domain, subset in state.items():
+            clean_state[domain] = {}
+            for slot, value in subset.items():
+                value = value.split('|')
+
+                # Map values using value_map
+                for old, new in VALUE_MAP.items():
+                    value = [val.replace(old, new) for val in value]
+                value = '|'.join(value)
+
+                # Map dontcare to "do not care" and empty to 'none'
+                value = value.replace('dontcare', 'do not care')
+                value = value if value else 'none'
+
+                # Map quantity values to the integer quantity value
+                if 'people' in slot or 'duration' in slot or 'stay' in slot:
+                    try:
+                        if value not in ['do not care', 'none']:
+                            value = int(value)
+                            value = str(value) if value < 10 else QUANTITIES[-1]
+                    except:
+                        value = value
+                # Map time values to the most appropriate value in the standard time set
+                elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
+                    try:
+                        if value not in ['do not care', 'none']:
+                            # Strip after/before from time value
+                            value = value.replace('after ', '').replace('before ', '')
+                            # Extract hours and minutes from different possible formats
+                            if ':' not in value and len(value) == 4:
+                                h, m = value[:2], value[2:]
+                            elif len(value) == 1:
+                                h = int(value)
+                                m = 0
+                            elif 'pm' in value:
+                                h = int(value.replace('pm', '')) + 12
+                                m = 0
+                            elif 'am' in value:
+                                h = int(value.replace('pm', ''))
+                                m = 0
+                            elif ':' in value:
+                                h, m = value.split(':')
+                            elif ';' in value:
+                                h, m = value.split(';')
+                            # Map to closest 5 minutes
+                            if int(m) % 5 != 0:
+                                m = round(int(m) / 5) * 5
+                                h = int(h)
+                                if m == 60:
+                                    m = 0
+                                    h += 1
+                                if h >= 24:
+                                    h -= 24
+                            # Set in standard 24 hour format
+                            h, m = int(h), int(m)
+                            value = '%02i:%02i' % (h, m)
+                    except:
+                        value = value
+                # Map boolean slots to yes/no value
+                elif 'parking' in slot or 'internet' in slot:
+                    if value not in ['do not care', 'none']:
+                        if value == 'free':
+                            value = 'yes'
+                        elif True in [v in value.lower() for v in ['yes', 'no']]:
+                            value = [v for v in ['yes', 'no'] if v in value][0]
+
+                value = value if value != 'none' else ''
+
+                clean_state[domain][slot] = value
+
+        return clean_state
+
+    def _extract_states(self):
+        """Extract the states from the dataset."""
+        for dial_id, dial in self.golden_states.items():
+            states = list()
+            for turn in dial:
+                if 'state' in turn:
+                    state = self._clean_state(turn['state'])
+                    states.append(state)
+            self.golden_states[dial_id] = states
+
+    def _init_session(self):
+        """Initialize the session."""
+        self.samples = dict()
+
+    def add_dialogues(self, predictions):
+        """
+        Add dialogues to the metric.
+
+        Args:
+            predictions (dict): Dictionary of predicted dialogue belief states.
+        """
+        for dial_id, dialogue in predictions.items():
+            for turn_id, turn in enumerate(dialogue):
+                if dial_id in self.golden_states:
+                    sample = {'dialogue_id': dial_id,
+                              'turn_id': turn_id,
+                              'state': self.golden_states[dial_id][turn_id],
+                              'predictions': turn['belief_state']}
+                    self.samples[f"{dial_id}_{turn_id}"] = sample
+
+    def save_dialogues(self, path):
+        """
+        Save the dialogues and predictions to a file.
+
+        Args:
+            path (str): The path to save the dialogues to.
+        """
+        dialogues = list()
+        for idx, turn in self.samples.items():
+            predictions = dict()
+            for domain in turn['state']:
+                predictions[domain] = dict()
+                for slot in turn['state'][domain]:
+                    predictions[domain][slot] = turn['predictions'].get(domain, dict()).get(slot, '')
+            dialogues.append({'dialogue_id': turn['dialogue_id'],
+                              'turn_id': turn['turn_id'],
+                              'state': turn['state'],
+                              'predictions': {'state': predictions}})
+
+        with open(path, 'w') as writer:
+            json.dump(dialogues, writer, indent=2)
+            writer.close()
+
+    def evaluate(self):
+        """Evaluate the metric."""
+        assert len(self.samples) > 0
+        metrics = {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0, 'Correct': 0, 'N': 0}
+        for dial_id, sample in self.samples.items():
+            correct = True
+            for domain in sample['state']:
+                for slot, values in sample['state'][domain].items():
+                    metrics['N'] += 1
+                    if domain not in sample['predictions'] or slot not in sample['predictions'][domain]:
+                        predict_values = ''
+                    else:
+                        predict_values = ''.join(sample['predictions'][domain][slot].split()).lower()
+                    if len(values) > 0:
+                        if len(predict_values) > 0:
+                            values = [''.join(value.split()).lower() for value in values.split('|')]
+                            predict_values = [''.join(value.split()).lower() for value in predict_values.split('|')]
+                            if any([value in values for value in predict_values]):
+                                metrics['TP'] += 1
+                            else:
+                                correct = False
+                                metrics['FP'] += 1
+                        else:
+                            metrics['FN'] += 1
+                            correct = False
+                    else:
+                        if len(predict_values) > 0:
+                            metrics['FP'] += 1
+                            correct = False
+                        else:
+                            metrics['TN'] += 1
+
+            metrics['Correct'] += int(correct)
+
+        TP = metrics.pop('TP')
+        FP = metrics.pop('FP')
+        FN = metrics.pop('FN')
+        TN = metrics.pop('TN')
+        Correct = metrics.pop('Correct')
+        N = metrics.pop('N')
+        precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
+        recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
+        f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
+        slot_accuracy = (TP + TN) / N
+        joint_goal_accuracy = Correct / len(self.samples)
+
+        metrics = Metrics(joint_goal_accuracy=joint_goal_accuracy * 100.,
+                          slot_accuracy=slot_accuracy * 100.,
+                          slot_f1=f1 * 100.,
+                          slot_precision=precision * 100.,
+                          slot_recall=recall * 100.)
+
+        return metrics
+
+
+class BeliefStateUncertainty:
+    """Compute the uncertainty of the belief state predictions."""
+
+    def __init__(self, n_confidence_bins=10):
+        """
+        Initialize the metric.
+
+        Args:
+            n_confidence_bins (int): Number of confidence bins.
+        """
+        self._init_session()
+        self.n_confidence_bins = n_confidence_bins
+
+    def _init_session(self):
+        """Initialize the session."""
+        self.samples = {'belief_state': dict(),
+                        'golden_state': dict()}
+        self.bin_info = {'confidence': None,
+                         'accuracy': None}
+
+    def add_dialogues(self, predictions, labels):
+        """
+        Add dialogues to the metric.
+
+        Args:
+            predictions (dict): Dictionary of predicted dialogue belief states.
+            labels (dict): Dictionary of golden dialogue belief states.
+        """
+        for slot, probs in predictions.items():
+            if slot not in self.samples['belief_state']:
+                self.samples['belief_state'][slot] = probs.reshape(-1, probs.size(-1)).cpu()
+                self.samples['golden_state'][slot] = labels[slot].reshape(-1).cpu()
+            else:
+                self.samples['belief_state'][slot] = torch.cat((self.samples['belief_state'][slot],
+                                                                probs.reshape(-1, probs.size(-1)).cpu()), 0)
+                self.samples['golden_state'][slot] = torch.cat((self.samples['golden_state'][slot],
+                                                                labels[slot].reshape(-1).cpu()), 0)
+
+    def _fill_bins(self, probs: torch.Tensor) -> list:
+        """
+        Fill the bins with the relevant observation ids.
+
+        Args:
+            probs (Tensor): Predicted probabilities.
+
+        Returns:
+            list: List of bins.
+        """
+        assert probs.dim() == 2
+        probs = probs.max(-1)[0]
+
+        step = 1.0 / self.n_confidence_bins
+        bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
+        bins = []
+        # Compute the bin ranges
+        for b in range(self.n_confidence_bins):
+            lower, upper = bin_ranges[b], bin_ranges[b + 1]
+            if b == 0:
+                ids = torch.where((probs >= lower) * (probs <= upper))[0]
+            else:
+                ids = torch.where((probs > lower) * (probs <= upper))[0]
+            bins.append(ids)
+
+        return bins
+
+    @staticmethod
+    def _bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor:
+        """
+        Compute the average confidence score for each bin.
+
+        Args:
+            bins (list): List of confidence bins.
+            probs (Tensor): Predicted probabilities.
+
+        Returns:
+            scores: Confidence score for each bin.
+        """
+        probs = probs.max(-1)[0]
+
+        scores = []
+        for b in bins:
+            if b is not None:
+                scores.append(probs[b].mean())
+            else:
+                scores.append(-1)
+        scores = torch.tensor(scores)
+        return scores
+
+    def _jg_ece(self) -> float:
+        """Compute the joint goal Expected Calibration Error."""
+        y_pred = {slot: probs.argmax(-1) for slot, probs in self.samples['belief_state'].items()}
+        goal_acc = [(y_pred[slot] == y_true).int() for slot, y_true in self.samples['golden_state'].items()]
+        goal_acc = (sum(goal_acc) / len(goal_acc)).int()
+
+        # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state
+        scores = [probs.max(-1)[0].unsqueeze(1) for slot, probs in self.samples["belief_state"].items()]
+        scores = torch.cat(scores, 1).min(1)[0]
+
+        bins = self._fill_bins(scores.unsqueeze(-1))
+        conf = self._bin_confidence(bins, scores.unsqueeze(-1))
+
+        slot_0 = list(self.samples['golden_state'].keys())[0]
+        acc = []
+        for b in bins:
+            if b is not None:
+                acc_ = goal_acc[b]
+                acc_ = acc_[self.samples['golden_state'][slot_0][b] >= 0]
+                if acc_.size(0) >= 0:
+                    acc.append(acc_.float().mean())
+                else:
+                    acc.append(-1)
+            else:
+                acc.append(-1)
+        acc = torch.tensor(acc)
+
+        self.bin_info['confidence'] = conf
+        self.bin_info['accuracy'] = acc
+
+        n = self.samples["belief_state"][slot_0].size(0)
+        bk = torch.tensor([b.size(0) for b in bins])
+
+        ece = torch.abs(conf - acc) * bk / n
+        ece = ece[acc >= 0.0]
+        ece = ece.sum().item()
+
+        return ece
+
+    def draw_calibration_diagram(self, save_path: str, validation_split=None):
+        """
+        Draw the calibration diagram.
+
+        Args:
+            save_path (str): Path to save the calibration diagram.
+            validation_split (str): Validation split.
+        """
+        if self.bin_info['confidence'] is None:
+            self._jg_ece()
+
+        acc = self.bin_info['accuracy']
+        conf = self.bin_info['confidence']
+        conf = conf[acc >= 0.0]
+        acc = acc[acc >= 0.0]
+
+        fig = plt.figure(figsize=(14,8))
+        font = 20
+        plt.tick_params(labelsize=font - 2)
+        linestyle = '-'
+
+        plt.plot(torch.tensor([0, 1]), torch.tensor([0, 1]), linestyle='--', color='black', linewidth=3)
+        plt.plot(conf, acc, linestyle=linestyle, color='red', linewidth=3)
+        plt.xlabel('Confidence', fontsize=font)
+        plt.ylabel('Joint Goal Accuracy', fontsize=font)
+
+        path = validation_split + '_calibration_diagram.json' if validation_split else 'calibration_diagram.json'
+        path = os.path.join(save_path, 'predictions', path)
+        with open(path, 'w') as f:
+            json.dump({'confidence': conf.tolist(), 'accuracy': acc.tolist()}, f)
+
+        path = validation_split + '_calibration_diagram.png' if validation_split else 'calibration_diagram.png'
+        path = os.path.join(save_path, path)
+        plt.savefig(path)
+
+    def _l2_err(self, remove_belief: bool = False) -> float:
+        """
+        Compute the L2 error between the predicted and target distribution.
+
+        Args:
+            remove_belief (bool): Remove the belief state and replace it with a 1 hot prediction.
+
+        Returns:
+            l2_err: L2 error between the predicted and target distribution.
+        """
+        # Get ids used for removing padding turns.
+        slot_0 = list(self.samples['golden_state'].keys())[0]
+        padding = torch.where(self.samples['golden_state'][slot_0] != -1)[0]
+
+        distributions = []
+        labels = []
+        for slot, probs in self.samples['belief_state'].items():
+            # Replace distribution by a 1 hot prediction
+            if remove_belief:
+                probs_ = torch.zeros(probs.shape).float()
+                probs_[range(probs.size(0)), probs.argmax(-1)] = 1.0
+                probs = probs_
+                del probs_
+            # Remove padding turns
+            lab = self.samples['golden_state'][slot]
+            probs = probs[padding]
+            lab = lab[padding]
+
+            # Target distribution
+            y = torch.zeros(probs.shape)
+            y[range(y.size(0)), lab] = 1.0
+
+            distributions.append(probs)
+            labels.append(y)
+
+        # Concatenate all slots into a single belief state
+        distributions = torch.cat(distributions, -1)
+        labels = torch.cat(labels, -1)
+
+        # Calculate L2-Error for each turn
+        err = torch.sqrt(((labels - distributions) ** 2).sum(-1))
+        return err.mean().item()
+
+    def evaluate(self):
+        """Evaluate the metrics."""
+        l2_err = self._l2_err(remove_belief=False)
+        binary_l2_err = self._l2_err(remove_belief=True)
+        l2_err_ratio = (binary_l2_err - l2_err) / binary_l2_err
+        metrics = Metrics(
+            joint_goal_ece=self._jg_ece() * 100.,
+            joint_l2_error=l2_err,
+            joint_l2_error_ratio=l2_err_ratio * 100.
+        )
+        return metrics
+
+
+class ActPredictionAccuracy:
+    """Calculate the accuracy of the action predictions."""
+
+    def __init__(self, act_type, binary=False):
+        """
+        Args:
+            act_type (str): Type of action to evaluate.
+            binary (bool): Whether the action is binary or multilabel.
+        """
+        self.act_type = act_type
+        self.binary = binary
+        self._init_session()
+
+    def _init_session(self):
+        """Initialize the session."""
+        self.samples = {'predictions': dict(),
+                        'labels': dict()}
+
+    def add_dialogues(self, predictions, labels):
+        """
+        Add dialogues to the session.
+
+        Args:
+            predictions (dict): Action predictions.
+            labels (dict): Action labels.
+        """
+        for slot, probs in predictions.items():
+            if slot in labels:
+                pred = probs.cpu().argmax(-1).reshape(-1) if not self.binary else probs.cpu().round().int().reshape(-1)
+                if slot not in self.samples['predictions']:
+                    self.samples['predictions'][slot] = pred
+                    self.samples['labels'][slot] = labels[slot].reshape(-1).cpu()
+                else:
+                    self.samples['predictions'][slot] = torch.cat((self.samples['predictions'][slot], pred), 0)
+                    self.samples['labels'][slot] = torch.cat((self.samples['labels'][slot],
+                                                              labels[slot].reshape(-1).cpu()), 0)
+
+    def evaluate(self):
+        """Evaluate the metrics."""
+        metrics = {'TP': 0, 'FP': 0, 'FN': 0, 'Correct': 0, 'N': 0}
+        for slot, pred in self.samples['predictions'].items():
+            metrics['N'] += pred.size(0)
+            metrics['Correct'] += (pred == self.samples['labels'][slot]).sum()
+            tp = (pred > 0) * (self.samples['labels'][slot] > 0) * (pred == self.samples['labels'][slot])
+            metrics['TP'] += tp.sum()
+            metrics['FP'] += ((pred > 0) * (self.samples['labels'][slot] == 0)).sum()
+            metrics['FN'] += ((pred == 0) * (self.samples['labels'][slot] > 0)).sum()
+
+        TP = metrics.pop('TP')
+        FP = metrics.pop('FP')
+        FN = metrics.pop('FN')
+        Correct = metrics.pop('Correct')
+        N = metrics.pop('N')
+        precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
+        recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
+        f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
+
+        metrics = {f'{self.act_type}_f1': f1 * 100.}
+        return Metrics(**metrics)
diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/datasets/unified_format.py
similarity index 50%
rename from convlab/dst/setsumbt/dataset/unified_format.py
rename to convlab/dst/setsumbt/datasets/unified_format.py
index 55483e0f4e3404e96d817395c53dd9a6fcd57c3e..68a371f7a8f54c8258b20e584abba6a9894a9e17 100644
--- a/convlab/dst/setsumbt/dataset/unified_format.py
+++ b/convlab/dst/setsumbt/datasets/unified_format.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,260 +14,81 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Convlab3 Unified Format Dialogue Datasets"""
-import pdb
-from copy import deepcopy
 
 import torch
 import transformers
 from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
 from transformers.tokenization_utils import PreTrainedTokenizer
-from tqdm import tqdm
 
 from convlab.util import load_dataset
-from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
-                                                get_values_from_data, ontology_add_requestable_slots,
-                                                get_requestable_slots, load_dst_data, extract_dialogues,
-                                                combine_value_sets, IdTensor)
+from convlab.dst.setsumbt.datasets.utils import (get_ontology_slots, ontology_add_values,
+                                                 get_values_from_data, ontology_add_requestable_slots,
+                                                 get_requestable_slots, load_dst_data, extract_dialogues,
+                                                 combine_value_sets)
 
 transformers.logging.set_verbosity_error()
 
 
-def convert_examples_to_features(data: list,
-                                 ontology: dict,
-                                 tokenizer: PreTrainedTokenizer,
-                                 max_turns: int = 12,
-                                 max_seq_len: int = 64) -> dict:
-    """
-    Convert dialogue examples to model input features and labels
-
-    Args:
-        data (list): List of all extracted dialogues
-        ontology (dict): Ontology dictionary containing slots, slot descriptions and
-        possible value sets including requests
-        tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used
-        max_turns (int): Maximum numbers of turns in a dialogue
-        max_seq_len (int): Maximum number of tokens in a dialogue turn
-
-    Returns:
-        features (dict): All inputs and labels required to train the model
-    """
-    features = dict()
-    ontology = deepcopy(ontology)
-
-    # Get encoder input for system, user utterance pairs
-    input_feats = []
-    for dial in tqdm(data):
-        dial_feats = []
-        for turn in dial:
-            if len(turn['system_utterance']) == 0:
-                usr = turn['user_utterance']
-                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True,
-                                                        max_length=max_seq_len, padding='max_length',
-                                                        truncation='longest_first'))
-            else:
-                usr = turn['user_utterance']
-                sys = turn['system_utterance']
-                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True,
-                                                        max_length=max_seq_len, padding='max_length',
-                                                        truncation='longest_first'))
-            # Truncate
-            if len(dial_feats) >= max_turns:
-                break
-        input_feats.append(dial_feats)
-    del dial_feats
-
-    # Perform turn level padding
-    dial_ids = list()
-    for dial in data:
-        _ids = [turn['dialogue_id'] for turn in dial][:max_turns]
-        _ids += [''] * (max_turns - len(_ids))
-        dial_ids.append(_ids)
-    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
-                 for dial in input_feats]
-    if 'token_type_ids' in input_feats[0][0]:
-        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
-                          for dial in input_feats]
-    else:
-        token_type_ids = None
-    if 'attention_mask' in input_feats[0][0]:
-        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
-                          for dial in input_feats]
-    else:
-        attention_mask = None
-    del input_feats
-
-    # Create torch data tensors
-    features['dialogue_ids'] = IdTensor(dial_ids)
-    features['input_ids'] = torch.tensor(input_ids)
-    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
-    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
-    del input_ids, token_type_ids, attention_mask
-
-    # Extract all informable and requestable slots from the ontology
-    informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
-                        if ontology[domain][slot]['possible_values']
-                        and ontology[domain][slot]['possible_values'] != ['?']]
-    requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
-                         if '?' in ontology[domain][slot]['possible_values']]
-    for slot in requestable_slots:
-        domain, slot = slot.split('-', 1)
-        ontology[domain][slot]['possible_values'].remove('?')
-
-    # Extract a list of domains from the ontology slots
-    domains = list(set(informable_slots + requestable_slots))
-    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
-
-    # Create slot labels
-    for domslot in tqdm(informable_slots):
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial:
-                value = [v for d, substate in turn['state'].items() for s, v in substate.items()
-                         if f'{d}-{s}' == domslot]
-                domain, slot = domslot.split('-', 1)
-                if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
-                    value = value[0] if value else 'none'
-                else:
-                    value = -1
-                if value in ontology[domain][slot]['possible_values'] and value != -1:
-                    value = ontology[domain][slot]['possible_values'].index(value)
-                else:
-                    value = -1  # If value is not in ontology then we do not penalise the model
-                labs.append(value)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['state_labels-' + domslot] = labels
-
-    # Create requestable slot labels
-    for domslot in tqdm(requestable_slots):
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial:
-                domain, slot = domslot.split('-', 1)
-                if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
-                    acts = [act['intent'] for act in turn['dialogue_acts']
-                            if act['domain'] == domain and act['slot'] == slot]
-                    if acts:
-                        act_ = acts[0]
-                        if act_ == 'request':
-                            labs.append(1)
-                        else:
-                            labs.append(0)
-                    else:
-                        labs.append(0)
-                else:
-                    labs.append(-1)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['request_labels-' + domslot] = labels
-
-    # General act labels (1-goodbye, 2-thank you)
-    labels = []
-    for dial in tqdm(data):
-        labs = []
-        for turn in dial:
-            acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
-            if acts:
-                if 'bye' in acts:
-                    labs.append(1)
-                else:
-                    labs.append(2)
-            else:
-                labs.append(0)
-            if len(labs) >= max_turns:
-                break
-        labs = labs + [-1] * (max_turns - len(labs))
-        labels.append(labs)
-
-    labels = torch.tensor(labels)
-    features['general_act_labels'] = labels
-
-    # Create active domain labels
-    for domain in tqdm(domains):
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial:
-                possible_domains = list()
-                for dom in ontology:
-                    for slt in ontology[dom]:
-                        if turn['dataset_name'] in ontology[dom][slt]['dataset_names']:
-                            possible_domains.append(dom)
-
-                if domain in turn['active_domains']:
-                    labs.append(1)
-                elif domain in possible_domains:
-                    labs.append(0)
-                else:
-                    labs.append(-1)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['active_domain_labels-' + domain] = labels
-
-    del labels
-
-    return features
-
-
 class UnifiedFormatDataset(Dataset):
     """
     Class for preprocessing, and storing data easily from the Convlab3 unified format.
 
     Attributes:
-        dataset_dict (dict): Dictionary containing all the data in dataset
+        set_type (str): Subset of the dataset to load (train, validation or test)
+        dataset_dicts (dict): Dictionary containing all the data in dataset
         ontology (dict): Set of all domain-slot-value triplets in the ontology of the model
+        ontology_embeddings (dict): Set of all domain-slot-value triplets in the ontology of the model
         features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model
     """
     def __init__(self,
                  dataset_name: str,
                  set_type: str,
                  tokenizer: PreTrainedTokenizer,
+                 ontology_encoder,
                  max_turns: int = 12,
                  max_seq_len: int = 64,
                  train_ratio: float = 1.0,
                  seed: int = 0,
                  data: dict = None,
-                 ontology: dict = None):
+                 ontology: dict = None,
+                 ontology_embeddings: dict = None):
         """
         Args:
             dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +)
             set_type (str): Subset of the dataset to load (train, validation or test)
             tokenizer (transformers tokenizer): Tokenizer for the encoder model used
+            ontology_encoder (transformers model): Ontology encoder model
             max_turns (int): Maximum numbers of turns in a dialogue
             max_seq_len (int): Maximum number of tokens in a dialogue turn
             train_ratio (float): Fraction of training data to use during training
             seed (int): Seed governing random order of ids for subsampling
             data (dict): Dataset features for loading from dict
             ontology (dict): Ontology dict for loading from dict
+            ontology_embeddings (dict): Ontology embeddings for loading from dict
         """
+        # Load data from dict if provided
         if data is not None:
+            self.set_type = set_type
             self.ontology = ontology
+            self.ontology_embeddings = ontology_embeddings
             self.features = data
+        # Load data from dataset if data is not provided
         else:
             if '+' in dataset_name:
                 dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')]
             else:
                 dataset_args = [{"dataset_name": dataset_name}]
             self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
+            self.set_type = set_type
+
             self.ontology = get_ontology_slots(dataset_name)
             values = [get_values_from_data(dataset, set_type) for dataset in self.dataset_dicts]
             self.ontology = ontology_add_values(self.ontology, combine_value_sets(values), set_type)
             self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts))
 
+            tokenizer.set_setsumbt_ontology(self.ontology)
+            self.ontology_embeddings = ontology_encoder.get_slot_candidate_embeddings()
+
             if train_ratio != 1.0:
                 for dataset_args_ in dataset_args:
                     dataset_args_['dial_ids_order'] = seed
@@ -282,7 +103,7 @@ class UnifiedFormatDataset(Dataset):
             data = []
             for idx, data_ in enumerate(data_list):
                 data += extract_dialogues(data_, dataset_args[idx]["dataset_name"])
-            self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len)
+            self.features = tokenizer.encode(data, max_turns, max_seq_len)
 
     def __getitem__(self, index: int) -> dict:
         """
@@ -350,14 +171,15 @@ class UnifiedFormatDataset(Dataset):
                          if self.features[label] is not None}
 
     @classmethod
-    def from_datadict(cls, data: dict, ontology: dict):
-        return cls(None, None, None, data=data, ontology=ontology)
+    def from_datadict(cls, set_type: str, data: dict, ontology: dict, ontology_embeddings: dict):
+        return cls(None, set_type, None, None, data=data, ontology=ontology, ontology_embeddings=ontology_embeddings)
 
 
 def get_dataloader(dataset_name: str,
                    set_type: str,
                    batch_size: int,
                    tokenizer: PreTrainedTokenizer,
+                   ontology_encoder,
                    max_turns: int = 12,
                    max_seq_len: int = 64,
                    device='cpu',
@@ -372,6 +194,7 @@ def get_dataloader(dataset_name: str,
         set_type (str): Subset of the dataset to load (train, validation or test)
         batch_size (int): Batch size for the dataloader
         tokenizer (transformers tokenizer): Tokenizer for the encoder model used
+        ontology_encoder (OntologyEncoder): Ontology encoder object
         max_turns (int): Maximum numbers of turns in a dialogue
         max_seq_len (int): Maximum number of tokens in a dialogue turn
         device (torch device): Device to map data to
@@ -382,8 +205,8 @@ def get_dataloader(dataset_name: str,
     Returns:
         loader (torch dataloader): Dataloader to train and evaluate the setsumbt model
     '''
-    data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio,
-                                seed=seed)
+    data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, ontology_encoder, max_turns, max_seq_len,
+                                train_ratio=train_ratio, seed=seed)
     data.to(device)
 
     if resampled_size:
@@ -418,6 +241,7 @@ def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader:
 
     return loader
 
+
 def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader:
     """
     Sample a subset of the dialogues in a dataloader
diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/datasets/utils.py
similarity index 99%
rename from convlab/dst/setsumbt/dataset/utils.py
rename to convlab/dst/setsumbt/datasets/utils.py
index 96773d6b9b181925b3004e4971e440d9c7720bfb..f227a569c5cfc782dda1fbedb61b5afbffa17cf5 100644
--- a/convlab/dst/setsumbt/dataset/utils.py
+++ b/convlab/dst/setsumbt/datasets/utils.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,10 +16,9 @@
 """Convlab3 Unified dataset data processing utilities"""
 
 import numpy
-import pdb
 
 from convlab.util import load_ontology, load_dst_data, load_nlu_data
-from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
+from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
 
 
 def get_ontology_slots(dataset_name: str) -> dict:
@@ -424,6 +423,7 @@ class IdTensor:
 def extract_dialogues(data: list, dataset_name: str) -> list:
     """
     Extract all dialogues from dataset
+
     Args:
         data (list): List of all dialogues in a subset of the data
         dataset_name (str): Name of the dataset to which the dialogues belongs
diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/datasets/value_maps.py
similarity index 96%
rename from convlab/dst/setsumbt/dataset/value_maps.py
rename to convlab/dst/setsumbt/datasets/value_maps.py
index 619600a7b0a57096918058ff117aa2ca5aac864a..d4ef64a0e21839e3ecade16ded4c03aea738fa98 100644
--- a/convlab/dst/setsumbt/dataset/value_maps.py
+++ b/convlab/dst/setsumbt/datasets/value_maps.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -47,4 +47,4 @@ DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buse
 # Generic value sets for quantity and time slots
 QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
 TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
-TIME = ['%02i:%02i' % t for l in TIME for t in l]
\ No newline at end of file
+TIME = ['%02i:%02i' % t for l in TIME for t in l]
diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py
deleted file mode 100644
index 2279e22265ea417ebe9a13e63837a625f858e73d..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/distillation_setup.py
+++ /dev/null
@@ -1,277 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Get ensemble predictions and build distillation dataloaders"""
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-import os
-import json
-
-import torch
-from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
-from tqdm import tqdm
-
-from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size
-from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT
-from convlab.dst.setsumbt.modeling import training
-
-DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
-
-
-def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader:
-    """
-    Build dataloader from ensemble prediction data
-
-    Args:
-        data: Dictionary of ensemble predictions
-        ontology: Data ontology
-        set_type: Data subset (train/validation/test)
-        batch_size: Number of dialogues per batch
-
-    Returns:
-        loader: Data loader object
-    """
-    data = flatten_data(data)
-    data = do_label_padding(data)
-    data = UnifiedFormatDataset.from_datadict(data, ontology)
-    if set_type == 'train':
-        sampler = RandomSampler(data)
-    else:
-        sampler = SequentialSampler(data)
-
-    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
-    return loader
-
-
-def do_label_padding(data: dict) -> dict:
-    """
-    Add padding to the ensemble predictions (used as labels in distillation)
-
-    Args:
-        data: Dictionary of ensemble predictions
-
-    Returns:
-        data: Padded ensemble predictions
-    """
-    if 'attention_mask' in data:
-        dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
-    else:
-        dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0)
-    
-    for key in data:
-        if key not in ['input_ids', 'attention_mask', 'token_type_ids']:
-            data[key][dialogs, turns] = -1
-    
-    return data
-
-
-def flatten_data(data: dict) -> dict:
-    """
-    Map data to flattened feature format used in training
-    Args:
-        data: Ensemble prediction data
-
-    Returns:
-        data: Flattened ensemble prediction data
-    """
-    data_new = dict()
-    for label, feats in data.items():
-        if type(feats) == dict:
-            for label_, feats_ in feats.items():
-                data_new[label + '-' + label_] = feats_
-        else:
-            data_new[label] = feats
-    
-    return data_new
-
-
-def get_ensemble_distributions(args):
-    """
-    Load data and get ensemble predictions
-    Args:
-        args: Runtime arguments
-    """
-    device = DEVICE
-
-    model = EnsembleSetSUMBT.from_pretrained(args.model_path)
-    model = model.to(device)
-
-    print('Model Loaded!')
-
-    dataloader = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
-    database = os.path.join(args.model_path, 'database', f'{args.set_type}.db')
-
-    dataloader = torch.load(dataloader)
-    database = torch.load(database)
-
-    if dataloader.batch_size != args.batch_size:
-        dataloader = change_batch_size(dataloader, args.batch_size)
-
-    training.set_ontology_embeddings(model, database)
-
-    print('Environment set up.')
-
-    input_ids = []
-    token_type_ids = []
-    attention_mask = []
-    state_labels = {slot: [] for slot in model.informable_slot_ids}
-    request_labels = {slot: [] for slot in model.requestable_slot_ids}
-    active_domain_labels = {domain: [] for domain in model.domain_ids}
-    general_act_labels = []
-
-    is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None
-
-    belief_state = {slot: [] for slot in model.informable_slot_ids}
-    request_probs = {slot: [] for slot in model.requestable_slot_ids}
-    active_domain_probs = {domain: [] for domain in model.domain_ids}
-    general_act_probs = []
-    model.eval()
-    for batch in tqdm(dataloader, desc='Batch:'):
-        ids = batch['input_ids']
-        tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None
-        mask = batch['attention_mask'] if 'attention_mask' in batch else None
-
-        if 'is_noisy' in batch:
-            is_noisy.append(batch['is_noisy'])
-
-        input_ids.append(ids)
-        token_type_ids.append(tt_ids)
-        attention_mask.append(mask)
-
-        ids = ids.to(device)
-        tt_ids = tt_ids.to(device) if tt_ids is not None else None
-        mask = mask.to(device) if mask is not None else None
-
-        for slot in state_labels:
-            state_labels[slot].append(batch['state_labels-' + slot])
-        if model.config.predict_actions:
-            for slot in request_labels:
-                request_labels[slot].append(batch['request_labels-' + slot])
-            for domain in active_domain_labels:
-                active_domain_labels[domain].append(batch['active_domain_labels-' + domain])
-            general_act_labels.append(batch['general_act_labels'])
-
-        with torch.no_grad():
-            p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction)
-
-        for slot in belief_state:
-            belief_state[slot].append(p[slot].cpu())
-        if model.config.predict_actions:
-            for slot in request_probs:
-                request_probs[slot].append(p_req[slot].cpu())
-            for domain in active_domain_probs:
-                active_domain_probs[domain].append(p_dom[domain].cpu())
-            general_act_probs.append(p_gen.cpu())
-    
-    input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None
-    token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None
-    attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None
-    is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None
-
-    state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()}
-    if model.config.predict_actions:
-        request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()}
-        active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()}
-        general_act_labels = torch.cat(general_act_labels, 0)
-    
-    belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()}
-    if model.config.predict_actions:
-        request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()}
-        active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()}
-        general_act_probs = torch.cat(general_act_probs, 0)
-
-    data = {'input_ids': input_ids}
-    if token_type_ids is not None:
-        data['token_type_ids'] = token_type_ids
-    if attention_mask is not None:
-        data['attention_mask'] = attention_mask
-    if is_noisy is not None:
-        data['is_noisy'] = is_noisy
-    data['state_labels'] = state_labels
-    data['belief_state'] = belief_state
-    if model.config.predict_actions:
-        data['request_labels'] = request_labels
-        data['active_domain_labels'] = active_domain_labels
-        data['general_act_labels'] = general_act_labels
-        data['request_probs'] = request_probs
-        data['active_domain_probs'] = active_domain_probs
-        data['general_act_probs'] = general_act_probs
-
-    file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
-    torch.save(data, file)
-
-
-def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str):
-    """
-    Convert ensemble predictions to predictions file format.
-
-    Args:
-        model_path: Path to ensemble location.
-        set_type: Evaluation dataset (train/dev/test).
-    """
-    data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data"))
-
-    # Get oracle labels
-    if 'request_probs' in data:
-        data_new = {'state_labels': data['state_labels'],
-                    'request_labels': data['request_labels'],
-                    'active_domain_labels': data['active_domain_labels'],
-                    'general_act_labels': data['general_act_labels']}
-    else:
-        data_new = {'state_labels': data['state_labels']}
-
-    # Marginalising across ensemble distributions
-    data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()}
-    if 'request_probs' in data:
-        data_new['request_probs'] = {slot: distribution.mean(-1)
-                                     for slot, distribution in data['request_probs'].items()}
-        data_new['active_domain_probs'] = {domain: distribution.mean(-1)
-                                           for domain, distribution in data['active_domain_probs'].items()}
-        data_new['general_act_probs'] = data['general_act_probs'].mean(-2)
-
-    # Save predictions file
-    predictions_dir = os.path.join(model_path, 'predictions')
-    if not os.path.exists(predictions_dir):
-        os.mkdir(predictions_dir)
-    torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions"))
-
-
-if __name__ == "__main__":
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--model_path', type=str)
-    parser.add_argument('--set_type', type=str)
-    parser.add_argument('--batch_size', type=int, default=3)
-    parser.add_argument('--reduction', type=str, default='none')
-    parser.add_argument('--get_ensemble_distributions', action='store_true')
-    parser.add_argument('--convert_distributions_to_predictions', action='store_true')
-    parser.add_argument('--build_dataloaders', action='store_true')
-    args = parser.parse_args()
-
-    if args.get_ensemble_distributions:
-        get_ensemble_distributions(args)
-    if args.convert_distributions_to_predictions:
-        ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type)
-    if args.build_dataloaders:
-        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
-        data = torch.load(path)
-
-        reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
-        ontology = json.load(reader)
-        reader.close()
-
-        loader = get_loader(data, ontology, args.set_type, args.batch_size)
-
-        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
-        torch.save(loader, path)
diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py
deleted file mode 100644
index 2fe351b3d5c2af187da58ffcc46e8184013bbcdb..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/do/evaluate.py
+++ /dev/null
@@ -1,296 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Run SetSUMBT Calibration"""
-
-import logging
-import os
-
-import torch
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer)
-
-from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
-from convlab.dst.setsumbt.dataset import unified_format
-from convlab.dst.setsumbt.dataset import ontology as embeddings
-from convlab.dst.setsumbt.utils import get_args, update_args
-from convlab.dst.setsumbt.modeling import evaluation_utils
-from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc
-from convlab.dst.setsumbt.modeling import training
-
-
-# Available model
-MODELS = {
-    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
-    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
-}
-
-
-def main(args=None, config=None):
-    # Get arguments
-    if args is None:
-        args, config = get_args(MODELS)
-
-    if args.model_type in MODELS:
-        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
-    else:
-        raise NameError('NotImplemented')
-
-    # Set up output directory
-    OUTPUT_DIR = args.output_dir
-    args.output_dir = OUTPUT_DIR
-    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
-        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
-
-    # Set pretrained model path to the trained checkpoint
-    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
-    if 'pytorch_model.bin' in paths and 'config.json' in paths:
-        args.model_name_or_path = args.output_dir
-        config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
-        if paths:
-            paths = paths[0]
-            args.model_name_or_path = paths
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-
-    args = update_args(args, config)
-
-    # Create logger
-    global logger
-    logger = logging.getLogger(__name__)
-    logger.setLevel(logging.INFO)
-
-    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
-
-    fh = logging.FileHandler(args.logging_path)
-    fh.setLevel(logging.INFO)
-    fh.setFormatter(formatter)
-    logger.addHandler(fh)
-
-    # Get device
-    if torch.cuda.is_available() and args.n_gpu > 0:
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-        args.n_gpu = 0
-
-    if args.n_gpu == 0:
-        args.fp16 = False
-
-    # Set up model training/evaluation
-    evaluation_utils.set_seed(args)
-
-    # Perform tasks
-    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
-        pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
-        state_labels = pred['state_labels']
-        belief_states = pred['belief_states']
-        if 'request_labels' in pred:
-            request_labels = pred['request_labels']
-            request_probs = pred['request_probs']
-            active_domain_labels = pred['active_domain_labels']
-            active_domain_probs = pred['active_domain_probs']
-            general_act_labels = pred['general_act_labels']
-            general_act_probs = pred['general_act_probs']
-        else:
-            request_probs = None
-        del pred
-    else:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size != args.test_batch_size:
-                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
-        else:
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
-                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                            config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
-            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
-                                                                  'test', args, tokenizer, encoder)
-
-        # Initialise Model
-        model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
-        model = model.to(device)
-
-        training.set_ontology_embeddings(model, test_slots)
-
-        belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader)
-        state_labels = belief_states[1]
-        request_probs = belief_states[2]
-        request_labels = belief_states[3]
-        active_domain_probs = belief_states[4]
-        active_domain_labels = belief_states[5]
-        general_act_probs = belief_states[6]
-        general_act_labels = belief_states[7]
-        belief_states = belief_states[0]
-        out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs,
-               'request_labels': request_labels, 'active_domain_probs': active_domain_probs,
-               'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs,
-               'general_act_labels': general_act_labels}
-        torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
-
-    # Calculate calibration metrics
-    jg = jg_ece(belief_states, state_labels, 10)
-    logger.info('Joint Goal ECE: %f' % jg)
-
-    jg_acc = 0.0
-    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
-    padding = (padding == len(state_labels))
-    padding = padding.reshape(-1)
-    for slot in belief_states:
-        p_ = belief_states[slot]
-        gold = state_labels[slot]
-
-        pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
-        acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad]
-        acc = torch.tensor(acc).float()
-
-        jg_acc += acc
-
-    n_turns = jg_acc.size(0)
-    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
-
-    jg_acc /= n_turns
-
-    logger.info(f'Joint Goal Accuracy: {jg_acc}')
-
-    l2 = l2_acc(belief_states, state_labels, remove_belief=False)
-    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
-    l2 = l2_acc(belief_states, state_labels, remove_belief=True)
-    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
-
-    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
-    padding = (padding == len(state_labels))
-    padding = padding.reshape(-1)
-
-    tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0
-    for slot in belief_states:
-        p_ = belief_states[slot]
-        gold = state_labels[slot].reshape(-1)
-        p_ = p_.reshape(-1, p_.size(-1))
-
-        p_ = p_[~padding].argmax(-1)
-        gold = gold[~padding]
-
-        tp += (p_ == gold)[gold != 0].int().sum().item()
-        fp += (p_ != 0)[gold == 0].int().sum().item()
-        fp += (p_ != gold)[gold != 0].int().sum().item()
-        fp -= (p_ == 0)[gold != 0].int().sum().item()
-        fn += (p_ == 0)[gold != 0].int().sum().item()
-        tn += (p_ == 0)[gold == 0].int().sum().item()
-        n += p_.size(0)
-
-    acc = (tp + tn) / n
-    prec = tp / (tp + fp)
-    rec = tp / (tp + fn)
-    f1 = 2 * (prec * rec) / (prec + rec)
-
-    logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}")
-
-    if request_probs is not None:
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for slot in request_probs:
-            p = request_probs[slot]
-            l = request_labels[slot]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(request_probs)
-        fp /= len(request_probs)
-        fn /= len(request_probs)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Request F1 Score: %f' % f1.item())
-
-        for slot in request_probs:
-            p = request_probs[slot]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            request_probs[slot] = p
-        jg = jg_ece(request_probs, request_labels, 10)
-        logger.info('Request Joint Goal ECE: %f' % jg)
-
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for dom in active_domain_probs:
-            p = active_domain_probs[dom]
-            l = active_domain_labels[dom]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(active_domain_probs)
-        fp /= len(active_domain_probs)
-        fn /= len(active_domain_probs)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Domain F1 Score: %f' % f1.item())
-
-        for dom in active_domain_probs:
-            p = active_domain_probs[dom]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            active_domain_probs[dom] = p
-        jg = jg_ece(active_domain_probs, active_domain_labels, 10)
-        logger.info('Domain Joint Goal ECE: %f' % jg)
-
-        tp = ((general_act_probs.argmax(-1) > 0) *
-              (general_act_labels > 0)).reshape(-1).float().sum()
-        fp = ((general_act_probs.argmax(-1) > 0) *
-              (general_act_labels == 0)).reshape(-1).float().sum()
-        fn = ((general_act_probs.argmax(-1) == 0) *
-              (general_act_labels > 0)).reshape(-1).float().sum()
-        f1 = tp / (tp + 0.5 * (fp + fn))
-        logger.info('General Act F1 Score: %f' % f1.item())
-
-        err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)),
-                  general_act_labels.reshape(-1), 10)
-        logger.info('General Act ECE: %f' % err)
-
-        for slot in request_probs:
-            p = request_probs[slot].unsqueeze(-1)
-            request_probs[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(request_probs, request_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
-        l2 = l2_acc(request_probs, request_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
-
-        for slot in active_domain_probs:
-            p = active_domain_probs[slot].unsqueeze(-1)
-            active_domain_probs[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
-        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
-
-        general_act_labels = {'general': general_act_labels}
-        general_act_probs = {'general': general_act_probs}
-
-        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm General Act Accuracy: {l2}')
-        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
-        logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}')
-
-
-if __name__ == "__main__":
-    main()
diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py
deleted file mode 100644
index 21949e728aa03d261dbb901e64fbb73bfd662d13..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/do/nbt.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Run SetSUMBT training/eval"""
-
-import logging
-import os
-from shutil import copy2 as copy
-import json
-from copy import deepcopy
-import pdb
-
-import torch
-import transformers
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer)
-from tensorboardX import SummaryWriter
-from tqdm import tqdm
-
-from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
-from convlab.dst.setsumbt.dataset import unified_format
-from convlab.dst.setsumbt.modeling import training
-from convlab.dst.setsumbt.dataset import ontology as embeddings
-from convlab.dst.setsumbt.utils import get_args, update_args
-from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
-from convlab.util.custom_util import model_downloader
-
-
-# Available model
-MODELS = {
-    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
-    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
-}
-
-
-def main(args=None, config=None):
-    # Get arguments
-    if args is None:
-        args, config = get_args(MODELS)
-
-    if args.model_type in MODELS:
-        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
-    else:
-        raise NameError('NotImplemented')
-
-    # Set up output directory
-    OUTPUT_DIR = args.output_dir
-
-    if not os.path.exists(OUTPUT_DIR):
-        if "http" not in OUTPUT_DIR:
-            os.makedirs(OUTPUT_DIR)
-            os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
-            os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
-        else:
-            # Get path /.../convlab/dst/setsumbt/multiwoz/models
-            download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-            download_path = os.path.join(download_path, 'models')
-            if not os.path.exists(download_path):
-                os.mkdir(download_path)
-            model_downloader(download_path, OUTPUT_DIR)
-            # Downloadable model path format http://.../model_name.zip
-            OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '')
-            OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR)
-
-            args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
-            args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
-            os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
-    args.output_dir = OUTPUT_DIR
-
-    # Set pretrained model path to the trained checkpoint
-    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
-    if 'pytorch_model.bin' in paths and 'config.json' in paths:
-        args.model_name_or_path = args.output_dir
-        config = ConfigClass.from_pretrained(args.model_name_or_path,
-                                             local_files_only=args.transformers_local_files_only)
-    else:
-        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
-        if paths:
-            paths = paths[0]
-            args.model_name_or_path = paths
-            config = ConfigClass.from_pretrained(args.model_name_or_path,
-                                                 local_files_only=args.transformers_local_files_only)
-
-    args = update_args(args, config)
-
-    # Create TensorboardX writer
-    tb_writer = SummaryWriter(logdir=args.tensorboard_path)
-
-    # Create logger
-    global logger
-    logger = logging.getLogger(__name__)
-    logger.setLevel(logging.INFO)
-
-    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
-
-    fh = logging.FileHandler(args.logging_path)
-    fh.setLevel(logging.INFO)
-    fh.setFormatter(formatter)
-    logger.addHandler(fh)
-
-    # Get device
-    if torch.cuda.is_available() and args.n_gpu > 0:
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-        args.n_gpu = 0
-
-    if args.n_gpu == 0:
-        args.fp16 = False
-
-    # Initialise Model
-    transformers.utils.logging.set_verbosity_info()
-    model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config,
-                                          local_files_only=args.transformers_local_files_only)
-    model = model.to(device)
-
-    # Create Tokenizer and embedding model for Data Loaders and ontology
-    encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name,
-                                                    local_files_only=args.transformers_local_files_only)
-    tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config,
-                                          local_files_only=args.transformers_local_files_only)
-
-    # Set up model training/evaluation
-    training.set_logger(logger, tb_writer)
-    training.set_seed(args)
-    embeddings.set_seed(args)
-
-    transformers.utils.logging.set_verbosity_error()
-    if args.ensemble_size > 1:
-        # Build all dataloaders
-        train_dataloader = unified_format.get_dataloader(args.dataset,
-                                                         'train',
-                                                         args.train_batch_size,
-                                                         tokenizer,
-                                                         args.max_dialogue_len,
-                                                         args.max_turn_len,
-                                                         train_ratio=args.dataset_train_ratio,
-                                                         seed=args.seed)
-        torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-        dev_dataloader = unified_format.get_dataloader(args.dataset,
-                                                       'validation',
-                                                       args.dev_batch_size,
-                                                       tokenizer,
-                                                       args.max_dialogue_len,
-                                                       args.max_turn_len,
-                                                       train_ratio=args.dataset_train_ratio,
-                                                       seed=args.seed)
-        torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-        test_dataloader = unified_format.get_dataloader(args.dataset,
-                                                        'test',
-                                                        args.test_batch_size,
-                                                        tokenizer,
-                                                        args.max_dialogue_len,
-                                                        args.max_turn_len,
-                                                        train_ratio=args.dataset_train_ratio,
-                                                        seed=args.seed)
-        torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder)
-        embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder)
-        embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder)
-
-        setup_ensemble(OUTPUT_DIR, args.ensemble_size)
-
-        logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
-        dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
-                       for _ in tqdm(range(args.ensemble_size))]
-        logger.info('Dataloaders built.')
-
-        for i, loader in enumerate(dataloaders):
-            path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
-            if not os.path.exists(path):
-                os.mkdir(path)
-            path = os.path.join(path, 'dataloaders', 'train.dataloader')
-            torch.save(loader, path)
-        logger.info('Dataloaders saved.')
-
-        # Do not perform standard training after ensemble setup is created
-        return 0
-
-    # Perform tasks
-    # TRAINING
-    if args.do_train:
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
-            train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-            if train_dataloader.batch_size != args.train_batch_size:
-                train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size)
-        else:
-            if args.data_sampling_size <= 0:
-                args.data_sampling_size = None
-            train_dataloader = unified_format.get_dataloader(args.dataset,
-                                                             'train',
-                                                             args.train_batch_size,
-                                                             tokenizer,
-                                                             args.max_dialogue_len,
-                                                             config.max_turn_len,
-                                                             resampled_size=args.data_sampling_size,
-                                                             train_ratio=args.dataset_train_ratio,
-                                                             seed=args.seed)
-            torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
-            train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db'))
-        else:
-            train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology,
-                                                                   'train', args, tokenizer, encoder)
-
-        # Get development set batch loaders= and ontology embeddings
-        if args.do_eval:
-            if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-                dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-                if dev_dataloader.batch_size != args.dev_batch_size:
-                    dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
-            else:
-                dev_dataloader = unified_format.get_dataloader(args.dataset,
-                                                               'validation',
-                                                               args.dev_batch_size,
-                                                               tokenizer,
-                                                               args.max_dialogue_len,
-                                                               config.max_turn_len)
-                torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-
-            if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-                dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
-            else:
-                dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
-                                                                     'dev', args, tokenizer, encoder)
-        else:
-            dev_dataloader = None
-            dev_slots = None
-
-        # Load model ontology
-        training.set_ontology_embeddings(model, train_slots)
-
-        # TRAINING !!!!!!!!!!!!!!!!!!
-        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots)
-
-        # Copy final best model to the output dir
-        checkpoints = os.listdir(OUTPUT_DIR)
-        checkpoints = [p for p in checkpoints if 'checkpoint' in p]
-        checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
-        best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
-        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
-        copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json'))
-
-        # Load best model for evaluation
-        model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
-        model = model.to(device)
-
-    # Evaluation on the development set
-    if args.do_eval:
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-            dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-            if dev_dataloader.batch_size != args.dev_batch_size:
-                dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
-        else:
-            dev_dataloader = unified_format.get_dataloader(args.dataset,
-                                                           'validation',
-                                                           args.dev_batch_size,
-                                                           tokenizer,
-                                                           args.max_dialogue_len,
-                                                           config.max_turn_len)
-            torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-            dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
-        else:
-            dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
-                                                                 'dev', args, tokenizer, encoder)
-
-        # Load model ontology
-        training.set_ontology_embeddings(model, dev_slots)
-
-        # EVALUATION
-        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader)
-        training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
-
-    # Evaluation on the test set
-    if args.do_test:
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size != args.test_batch_size:
-                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
-        else:
-            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
-                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                            config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
-                                                                  'test', args, tokenizer, encoder)
-
-        # Load model ontology
-        training.set_ontology_embeddings(model, test_slots)
-
-        # TESTING
-        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader,
-                                                                                 return_eval_output=True)
-
-        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
-            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
-        writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w')
-        json.dump(output, writer)
-        writer.close()
-
-        training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
-
-    tb_writer.close()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/convlab/dst/setsumbt/get_golden_labels.py b/convlab/dst/setsumbt/get_golden_labels.py
deleted file mode 100644
index 7fb2841d0d503181119c791a7046fd7e0025d236..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/get_golden_labels.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import json
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-import os
-
-from tqdm import tqdm
-
-from convlab.util import load_dataset
-from convlab.util import load_dst_data
-from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
-
-
-def extract_data(dataset_names: str) -> list:
-    dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')]
-    data = []
-    for dataset_dict in dataset_dicts:
-        dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False)
-        for dial in dataset['test']:
-            data.append(dial)
-
-    return data
-
-def clean_state(state):
-    clean_state = dict()
-    for domain, subset in state.items():
-        clean_state[domain] = {}
-        for slot, value in subset.items():
-            # Remove pipe separated values
-            value = value.split('|')
-
-            # Map values using value_map
-            for old, new in VALUE_MAP.items():
-                value = [val.replace(old, new) for val in value]
-            value = '|'.join(value)
-
-            # Map dontcare to "do not care" and empty to 'none'
-            value = value.replace('dontcare', 'do not care')
-            value = value if value else 'none'
-
-            # Map quantity values to the integer quantity value
-            if 'people' in slot or 'duration' in slot or 'stay' in slot:
-                try:
-                    if value not in ['do not care', 'none']:
-                        value = int(value)
-                        value = str(value) if value < 10 else QUANTITIES[-1]
-                except:
-                    value = value
-            # Map time values to the most appropriate value in the standard time set
-            elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
-                try:
-                    if value not in ['do not care', 'none']:
-                        # Strip after/before from time value
-                        value = value.replace('after ', '').replace('before ', '')
-                        # Extract hours and minutes from different possible formats
-                        if ':' not in value and len(value) == 4:
-                            h, m = value[:2], value[2:]
-                        elif len(value) == 1:
-                            h = int(value)
-                            m = 0
-                        elif 'pm' in value:
-                            h = int(value.replace('pm', '')) + 12
-                            m = 0
-                        elif 'am' in value:
-                            h = int(value.replace('pm', ''))
-                            m = 0
-                        elif ':' in value:
-                            h, m = value.split(':')
-                        elif ';' in value:
-                            h, m = value.split(';')
-                        # Map to closest 5 minutes
-                        if int(m) % 5 != 0:
-                            m = round(int(m) / 5) * 5
-                            h = int(h)
-                            if m == 60:
-                                m = 0
-                                h += 1
-                            if h >= 24:
-                                h -= 24
-                        # Set in standard 24 hour format
-                        h, m = int(h), int(m)
-                        value = '%02i:%02i' % (h, m)
-                except:
-                    value = value
-            # Map boolean slots to yes/no value
-            elif 'parking' in slot or 'internet' in slot:
-                if value not in ['do not care', 'none']:
-                    if value == 'free':
-                        value = 'yes'
-                    elif True in [v in value.lower() for v in ['yes', 'no']]:
-                        value = [v for v in ['yes', 'no'] if v in value][0]
-
-            value = value if value != 'none' else ''
-
-            clean_state[domain][slot] = value
-
-    return clean_state
-
-def extract_states(data):
-    states_data = {}
-    for dial in data:
-        states = []
-        for turn in dial['turns']:
-            if 'state' in turn:
-                state = clean_state(turn['state'])
-                states.append(state)
-        states_data[dial['dialogue_id']] = states
-
-    return states_data
-
-
-def get_golden_state(prediction, data):
-    state = data[prediction['dial_idx']][prediction['utt_idx']]
-    pred = prediction['predictions']['state']
-    pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '')
-                     for slot in state[domain]} for domain in state}
-    prediction['state'] = state
-    prediction['predictions']['state'] = pred
-
-    return prediction
-
-
-if __name__ == "__main__":
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
-    parser.add_argument('--model_path', type=str, help='Path to model dir')
-    args = parser.parse_args()
-
-    data = extract_data(args.dataset_name)
-    data = extract_states(data)
-
-    reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r')
-    predictions = json.load(reader)
-    reader.close()
-
-    predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)]
-
-    writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w')
-    json.dump(predictions, writer)
-    writer.close()
diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py
deleted file mode 100644
index 475f7646126ea03b630efcbbc688f86c5a8ec16e..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss
-from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
-from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
-from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss
diff --git a/convlab/dst/setsumbt/loss/uncertainty_measures.py b/convlab/dst/setsumbt/loss/uncertainty_measures.py
deleted file mode 100644
index 87c89dd31c724cc7d599230c6d4a15faee9b680e..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/uncertainty_measures.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Uncertainty evaluation metrics for dialogue belief tracking"""
-
-import torch
-
-
-def fill_bins(n_bins: int, probs: torch.Tensor) -> list:
-    """
-    Function to split observations into bins based on predictive probabilities
-
-    Args:
-        n_bins (int): Number of bins
-        probs (Tensor): Predictive probabilities for the observations
-
-    Returns:
-        bins (list): List of observation ids for each bin
-    """
-    assert probs.dim() == 2
-    probs = probs.max(-1)[0]
-
-    step = 1.0 / n_bins
-    bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
-    bins = []
-    for b in range(n_bins):
-        lower, upper = bin_ranges[b], bin_ranges[b + 1]
-        if b == 0:
-            ids = torch.where((probs >= lower) * (probs <= upper))[0]
-        else:
-            ids = torch.where((probs > lower) * (probs <= upper))[0]
-        bins.append(ids)
-    return bins
-
-
-def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor:
-    """
-    Compute the confidence score within each bin
-
-    Args:
-        bins (list): List of observation ids for each bin
-        probs (Tensor): Predictive probabilities for the observations
-
-    Returns:
-        scores (Tensor): Average confidence score within each bin
-    """
-    probs = probs.max(-1)[0]
-
-    scores = []
-    for b in bins:
-        if b is not None:
-            scores.append(probs[b].mean())
-        else:
-            scores.append(-1)
-    scores = torch.tensor(scores)
-    return scores
-
-
-def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
-    """
-    Compute the accuracy score for observations in each bin
-
-    Args:
-        bins (list): List of observation ids for each bin
-        probs (Tensor): Predictive probabilities for the observations
-        y_true (Tensor): Labels for the observations
-
-    Returns:
-        acc (Tensor): Accuracies for the observations in each bin
-    """
-    y_pred = probs.argmax(-1)
-
-    acc = []
-    for b in bins:
-        if b is not None:
-            p = y_pred[b]
-            acc_ = (p == y_true[b]).float()
-            acc_ = acc_[y_true[b] >= 0]
-            if acc_.size(0) >= 0:
-                acc.append(acc_.mean())
-            else:
-                acc.append(-1)
-        else:
-            acc.append(-1)
-    acc = torch.tensor(acc)
-    return acc
-
-
-def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float:
-    """
-    Expected calibration error calculation
-
-    Args:
-        probs (Tensor): Predictive probabilities for the observations
-        y_true (Tensor): Labels for the observations
-        n_bins (int): Number of bins
-
-    Returns:
-        ece (float): Expected calibration error
-    """
-    bins = fill_bins(n_bins, probs)
-
-    scores = bin_confidence(bins, probs)
-    acc = bin_accuracy(bins, probs, y_true)
-
-    n = probs.size(0)
-    bk = torch.tensor([b.size(0) for b in bins])
-
-    ece = torch.abs(scores - acc) * bk / n
-    ece = ece[acc >= 0.0]
-    ece = ece.sum().item()
-
-    return ece
-
-
-def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
-    """
-        Joint goal expected calibration error calculation
-
-        Args:
-            belief_state (dict): Belief state probabilities for the dialogue turns
-            y_true (dict): Labels for the state in dialogue turns
-            n_bins (int): Number of bins
-
-        Returns:
-            ece (float): Joint goal expected calibration error
-        """
-    y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()}
-    goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
-    goal_acc = sum([goal_acc[slot] for slot in goal_acc])
-    goal_acc = (goal_acc == len(y_true)).int()
-
-    # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state
-    scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()]
-    scores = torch.cat(scores, 0).min(0)[0]
-
-    bins = fill_bins(n_bins, scores.unsqueeze(-1))
-
-    conf = bin_confidence(bins, scores.unsqueeze(-1))
-
-    slot = [s for s in y_true][0]
-    acc = []
-    for b in bins:
-        if b is not None:
-            acc_ = goal_acc[b]
-            acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
-            if acc_.size(0) >= 0:
-                acc.append(acc_.float().mean())
-            else:
-                acc.append(-1)
-        else:
-            acc.append(-1)
-    acc = torch.tensor(acc)
-
-    n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0)
-    bk = torch.tensor([b.size(0) for b in bins])
-
-    ece = torch.abs(conf - acc) * bk / n
-    ece = ece[acc >= 0.0]
-    ece = ece.sum().item()
-
-    return ece
-
-
-def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float:
-    """
-    Compute L2 Error of belief state prediction
-
-    Args:
-        belief_state (dict): Belief state probabilities for the dialogue turns
-        labels (dict): Labels for the state in dialogue turns
-        remove_belief (bool): Convert belief state to dialogue state
-
-    Returns:
-        err (float): L2 Error of belief state prediction
-    """
-    # Get ids used for removing padding turns.
-    padding = labels[list(labels.keys())[0]].reshape(-1)
-    padding = torch.where(padding != -1)[0]
-
-    state = []
-    labs = []
-    for slot, bs in belief_state.items():
-        # Predictive Distribution
-        bs = bs.reshape(-1, bs.size(-1)).cuda()
-        # Replace distribution by a 1 hot prediction
-        if remove_belief:
-            bs_ = torch.zeros(bs.shape).float().cuda()
-            bs_[range(bs.size(0)), bs.argmax(-1)] = 1.0
-            bs = bs_
-            del bs_
-        # Remove padding turns
-        lab = labels[slot].reshape(-1).cuda()
-        bs = bs[padding]
-        lab = lab[padding]
-
-        # Target distribution
-        y = torch.zeros(bs.shape).cuda()
-        y[range(y.size(0)), lab] = 1.0
-
-        state.append(bs)
-        labs.append(y)
-
-    # Concatenate all slots into a single belief state
-    state = torch.cat(state, -1)
-    labs = torch.cat(labs, -1)
-
-    # Calculate L2-Error for each turn
-    err = torch.sqrt(((labs - state) ** 2).sum(-1))
-    return err.mean()
diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py
index 59f1439948421ac365e4602b7800c94d3b8b32dd..502db2810b33262ab7edb40412a93dfcb7ba0786 100644
--- a/convlab/dst/setsumbt/modeling/__init__.py
+++ b/convlab/dst/setsumbt/modeling/__init__.py
@@ -1,5 +1,16 @@
-from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
-from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT
+from transformers import BertConfig, RobertaConfig
 
+from convlab.dst.setsumbt.modeling.setsumbt_nbt import BertSetSUMBT, RobertaSetSUMBT, EnsembleSetSUMBT
+from convlab.dst.setsumbt.modeling.ontology_encoder import OntologyEncoder
 from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler
+from convlab.dst.setsumbt.modeling.trainer import SetSUMBTTrainer
+from convlab.dst.setsumbt.modeling.tokenization import SetSUMBTTokenizer
+
+class BertSetSUMBTTokenizer(SetSUMBTTokenizer('bert')): pass
+class RobertaSetSUMBTTokenizer(SetSUMBTTokenizer('roberta')): pass
+
+SetSUMBTModels = {
+    'bert': (BertSetSUMBT, OntologyEncoder('bert'), BertConfig, BertSetSUMBTTokenizer),
+    'roberta': (RobertaSetSUMBT, OntologyEncoder('roberta'), RobertaConfig, RobertaSetSUMBTTokenizer),
+    'ensemble': (EnsembleSetSUMBT, None, None, None)
+}
diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py
deleted file mode 100644
index 6762fb3891b4720c3889d8c0809b8791f3bf7633..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/bert_nbt.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""BERT SetSUMBT"""
-
-import torch
-from torch.autograd import Variable
-from transformers import BertModel, BertPreTrainedModel
-
-from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
-
-
-class BertSetSUMBT(BertPreTrainedModel):
-
-    def __init__(self, config):
-        super(BertSetSUMBT, self).__init__(config)
-        self.config = config
-
-        # Turn Encoder
-        self.bert = BertModel(config)
-        if config.freeze_encoder:
-            for p in self.bert.parameters():
-                p.requires_grad = False
-
-        self.setsumbt = SetSUMBTHead(config)
-        self.add_slot_candidates = self.setsumbt.add_slot_candidates
-        self.add_value_candidates = self.setsumbt.add_value_candidates
-
-    def forward(self,
-                input_ids: torch.Tensor,
-                attention_mask: torch.Tensor,
-                token_type_ids: torch.Tensor = None,
-                hidden_state: torch.Tensor = None,
-                state_labels: torch.Tensor = None,
-                request_labels: torch.Tensor = None,
-                active_domain_labels: torch.Tensor = None,
-                general_act_labels: torch.Tensor = None,
-                get_turn_pooled_representation: bool = False,
-                calculate_state_mutual_info: bool = False):
-        """
-        Args:
-            input_ids: Input token ids
-            attention_mask: Input padding mask
-            token_type_ids: Token type indicator
-            hidden_state: Latent internal dialogue belief state
-            state_labels: Dialogue state labels
-            request_labels: User request action labels
-            active_domain_labels: Current active domain labels
-            general_act_labels: General user action labels
-            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
-            calculate_state_mutual_info: Return mutual information in the dialogue state
-
-        Returns:
-            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
-        """
-
-        # Encode Dialogues
-        batch_size, dialogue_size, turn_size = input_ids.size()
-        input_ids = input_ids.reshape(-1, turn_size)
-        token_type_ids = token_type_ids.reshape(-1, turn_size)
-        attention_mask = attention_mask.reshape(-1, turn_size)
-
-        bert_output = self.bert(input_ids, token_type_ids, attention_mask)
-
-        attention_mask = attention_mask.float().unsqueeze(2)
-        attention_mask = attention_mask.repeat((1, 1, bert_output.last_hidden_state.size(-1)))
-        turn_embeddings = bert_output.last_hidden_state * attention_mask
-        turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
-
-        if get_turn_pooled_representation:
-            return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
-                                 batch_size, dialogue_size, hidden_state, state_labels,
-                                 request_labels, active_domain_labels, general_act_labels,
-                                 calculate_state_mutual_info) + (bert_output.pooler_output,)
-        return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
-                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
-                             general_act_labels, calculate_state_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
deleted file mode 100644
index 6d3d8035a4d6f47f2ea8551050ca8da682ea0376..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Ensemble SetSUMBT"""
-
-import os
-from shutil import copy2 as copy
-
-import torch
-from torch.nn import Module
-from transformers import RobertaConfig, BertConfig
-
-from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
-from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-
-MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
-
-
-class EnsembleSetSUMBT(Module):
-    """Ensemble SetSUMBT Model for joint ensemble prediction"""
-
-    def __init__(self, config):
-        """
-        Args:
-            config (configuration): Model configuration class
-        """
-        super(EnsembleSetSUMBT, self).__init__()
-        self.config = config
-
-        # Initialise ensemble members
-        model_cls = MODELS[self.config.model_type]
-        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
-            setattr(self, attr, model_cls(config))
-
-    def _load(self, path: str):
-        """
-        Load parameters
-        Args:
-            path: Location of model parameters
-        """
-        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
-            idx = attr.split('_', 1)[-1]
-            state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin'))
-            getattr(self, attr).load_state_dict(state_dict)
-
-    def add_slot_candidates(self, slot_candidates: tuple):
-        """
-        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
-        and a request indicator, if the informable value embeddings is None the slot is not informable and if
-        the request indicator is false the slot is not requestable.
-
-        Args:
-            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
-        """
-        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
-            getattr(self, attr).add_slot_candidates(slot_candidates)
-        self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids
-        self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids
-        self.domain_ids = self.model_0.setsumbt.domain_ids
-
-    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
-        """
-        Add value candidates for a slot
-
-        Args:
-            slot: Slot name
-            value_candidates: Value candidate embeddings
-            replace: If true existing value candidates are replaced
-        """
-        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
-            getattr(self, attr).add_value_candidates(slot, value_candidates, replace)
-
-    def forward(self,
-                input_ids: torch.Tensor,
-                attention_mask: torch.Tensor,
-                token_type_ids: torch.Tensor = None,
-                reduction: str = 'mean') -> tuple:
-        """
-        Args:
-            input_ids: Input token ids
-            attention_mask: Input padding mask
-            token_type_ids: Token type indicator
-            reduction: Reduction of ensemble member predictive distributions (mean, none)
-
-        Returns:
-
-        """
-        belief_state_probs = {slot: [] for slot in self.informable_slot_ids}
-        request_probs = {slot: [] for slot in self.requestable_slot_ids}
-        active_domain_probs = {dom: [] for dom in self.domain_ids}
-        general_act_probs = []
-        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
-            # Prediction from each ensemble member
-            b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
-                                                token_type_ids=token_type_ids,
-                                                attention_mask=attention_mask)
-            for slot in belief_state_probs:
-                belief_state_probs[slot].append(b[slot].unsqueeze(-2))
-            if self.config.predict_actions:
-                for slot in request_probs:
-                    request_probs[slot].append(r[slot].unsqueeze(-1))
-                for dom in active_domain_probs:
-                    active_domain_probs[dom].append(d[dom].unsqueeze(-1))
-                general_act_probs.append(g.unsqueeze(-2))
-        
-        belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
-        if self.config.predict_actions:
-            request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()}
-            active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()}
-            general_act_probs = torch.cat(general_act_probs, -2)
-        else:
-            request_probs = {}
-            active_domain_probs = {}
-            general_act_probs = torch.tensor(0.0)
-
-        # Apply reduction of ensemble to single posterior
-        if reduction == 'mean':
-            belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()}
-            request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()}
-            active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()}
-            general_act_probs = general_act_probs.mean(-2)
-        elif reduction != 'none':
-            raise(NameError('Not Implemented!'))
-
-        return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _
-    
-
-    @classmethod
-    def from_pretrained(cls, path):
-        config_path = os.path.join(path, 'ens-0', 'config.json')
-        if not os.path.exists(config_path):
-            raise(NameError('Could not find config.json in model path.'))
-        
-        try:
-            config = RobertaConfig.from_pretrained(config_path)
-        except:
-            config = BertConfig.from_pretrained(config_path)
-
-        config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir])
-        
-        model = cls(config)
-        model._load(path)
-
-        return model
-
-
-def setup_ensemble(model_path: str, ensemble_size: int):
-    """
-    Setup ensemble model directory structure.
-
-    Args:
-        model_path: Path to ensemble model directory
-        ensemble_size: Number of ensemble members
-    """
-    for i in range(ensemble_size):
-        path = os.path.join(model_path, f'ens-{i}')
-        if not os.path.exists(path):
-            os.mkdir(path)
-            os.mkdir(os.path.join(path, 'dataloaders'))
-            os.mkdir(os.path.join(path, 'database'))
-            # Add development set dataloader to each ensemble member directory
-            for set_type in ['dev']:
-                copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'),
-                     os.path.join(path, 'dataloaders', f'{set_type}.dataloader'))
-            # Add training and development set ontologies to each ensemble member directory
-            for set_type in ['train', 'dev']:
-                copy(os.path.join(model_path, 'database', f'{set_type}.db'),
-                     os.path.join(path, 'database', f'{set_type}.db'))
diff --git a/convlab/dst/setsumbt/modeling/ensemble_utils.py b/convlab/dst/setsumbt/modeling/ensemble_utils.py
deleted file mode 100644
index 19f6abf81a4070b9498310adfab93d50f5a692f5..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/ensemble_utils.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Discriminative models calibration"""
-
-import random
-import os
-
-import torch
-import numpy as np
-from torch.distributions import Categorical
-from torch.nn.functional import kl_div
-from torch.nn import Module
-from tqdm import tqdm
-
-
-# Load logger and tensorboard summary writer
-def set_logger(logger_, tb_writer_):
-    global logger, tb_writer
-    logger = logger_
-    tb_writer = tb_writer_
-
-
-# Set seeds
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-    logger.info('Seed set to %d.' % args.seed)
-
-
-def build_train_loaders(args, tokenizer, dataset):
-    dataloaders = [dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
-                                            args.max_turn_len, resampled_size=args.data_sampling_size)
-                        for _ in range(args.ensemble_size)]
-    return dataloaders
diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py
deleted file mode 100644
index c73d4b6d32a485a2cf2b5948dbd6a9a4d7f346cb..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/evaluation_utils.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Evaluation Utilities"""
-
-import random
-
-import torch
-import numpy as np
-from tqdm import tqdm
-
-
-def set_seed(args):
-    """
-    Set random seeds
-
-    Args:
-        args (Arguments class): Arguments class containing seed and number of gpus to use
-    """
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple:
-    """
-    Get model predictions
-
-    Args:
-        args: Runtime arguments
-        model: SetSUMBT Model
-        device: Torch device
-        dataloader: Dataloader containing eval data
-    """
-    model.eval()
-    
-    belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids}
-    request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
-    active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids}
-    general_act_probs = []
-    state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids}
-    request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
-    active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids}
-    general_act_labels = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration")
-    for step, batch in enumerate(epoch_iterator):
-        with torch.no_grad():    
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-            p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids,
-                                              attention_mask=attention_mask)
-
-            for slot in belief_states:
-                p_ = p[slot]
-                labs = batch['state_labels-' + slot].to(device)
-                
-                belief_states[slot].append(p_)
-                state_labels[slot].append(labs)
-            
-            if p_req is not None:
-                for slot in request_probs:
-                    p_ = p_req[slot]
-                    labs = batch['request_labels-' + slot].to(device)
-
-                    request_probs[slot].append(p_)
-                    request_labels[slot].append(labs)
-                
-                for domain in active_domain_probs:
-                    p_ = p_dom[domain]
-                    labs = batch['active_domain_labels-' + domain].to(device)
-
-                    active_domain_probs[domain].append(p_)
-                    active_domain_labels[domain].append(labs)
-                
-                general_act_probs.append(p_gen)
-                general_act_labels.append(batch['general_act_labels'].to(device))
-    
-    for slot in belief_states:
-        belief_states[slot] = torch.cat(belief_states[slot], 0)
-        state_labels[slot] = torch.cat(state_labels[slot], 0)
-    if p_req is not None:
-        for slot in request_probs:
-            request_probs[slot] = torch.cat(request_probs[slot], 0)
-            request_labels[slot] = torch.cat(request_labels[slot], 0)
-        for domain in active_domain_probs:
-            active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0)
-            active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0)
-        general_act_probs = torch.cat(general_act_probs, 0)
-        general_act_labels = torch.cat(general_act_labels, 0)
-    else:
-        request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4
-        general_act_probs, general_act_labels = [None] * 2
-
-    out = (belief_states, state_labels, request_probs, request_labels)
-    out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels)
-    return out
diff --git a/convlab/dst/setsumbt/modeling/loss/__init__.py b/convlab/dst/setsumbt/modeling/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a5f76cdccf266f8ac702ad8a58e223d97fbc2c
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/loss/__init__.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Loss functions for SetSUMBT"""
+
+from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
+
+from convlab.dst.setsumbt.modeling.loss.bayesian_matching import (BayesianMatchingLoss,
+                                                                  BinaryBayesianMatchingLoss)
+from convlab.dst.setsumbt.modeling.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
+from convlab.dst.setsumbt.modeling.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
+from convlab.dst.setsumbt.modeling.loss.endd_loss import (RKLDirichletMediatorLoss,
+                                                          BinaryRKLDirichletMediatorLoss)
+
+LOSS_MAP = {
+    'crossentropy': {'non-binary': CrossEntropyLoss,
+                     'binary': BCEWithLogitsLoss,
+                     'args': list()},
+    'bayesianmatching': {'non-binary': BayesianMatchingLoss,
+                         'binary': BinaryBayesianMatchingLoss,
+                         'args': ['kl_scaling_factor']},
+    'labelsmoothing': {'non-binary': LabelSmoothingLoss,
+                       'binary': BinaryLabelSmoothingLoss,
+                       'args': ['label_smoothing']},
+    'distillation': {'non-binary': KLDistillationLoss,
+                     'binary': BinaryKLDistillationLoss,
+                     'args': ['ensemble_smoothing']},
+    'distribution_distillation': {'non-binary': RKLDirichletMediatorLoss,
+                                  'binary': BinaryRKLDirichletMediatorLoss,
+                                  'args': []}
+}
+
+def load(loss_function, binary=False):
+    """
+    Load loss function
+
+    Args:
+        loss_function (str): Loss function name
+        binary (bool): Whether to use binary loss function
+
+    Returns:
+        torch.nn.Module: Loss function
+    """
+    assert loss_function in LOSS_MAP
+    args_list = LOSS_MAP[loss_function]['args']
+    loss_function = LOSS_MAP[loss_function]['binary' if binary else 'non-binary']
+
+    def __init__(ignore_index=-1, **kwargs):
+        args = {'ignore_index': ignore_index} if loss_function != BCEWithLogitsLoss else dict()
+        for arg, val in kwargs.items():
+            if arg in args_list:
+                args[arg] = val
+
+        return loss_function(**args)
+
+    return __init__
diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/modeling/loss/bayesian_matching.py
similarity index 87%
rename from convlab/dst/setsumbt/loss/bayesian_matching.py
rename to convlab/dst/setsumbt/modeling/loss/bayesian_matching.py
index 3e91444d60afeeb6e2ca54192dd2283810fc5135..66e37e6eddcc99457535564a6f49b4c11190a49c 100644
--- a/convlab/dst/setsumbt/loss/bayesian_matching.py
+++ b/convlab/dst/setsumbt/modeling/loss/bayesian_matching.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,15 +23,15 @@ from torch.nn import Module
 class BayesianMatchingLoss(Module):
     """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
 
-    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
+    def __init__(self, kl_scaling_factor: float = 0.001, ignore_index: int = -1) -> Module:
         """
         Args:
-            lamb (float): Weighting factor for the KL Divergence component
+            kl_scaling_factor (float): Weighting factor for the KL Divergence component
             ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
         """
         super(BayesianMatchingLoss, self).__init__()
 
-        self.lamb = lamb
+        self.lamb = kl_scaling_factor
         self.ignore_index = ignore_index
     
     def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
@@ -46,7 +46,7 @@ class BayesianMatchingLoss(Module):
         """
         # Assert input sizes
         assert inputs.dim() == 2                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
+        assert labels.dim() == 1                 # Label for each observation
         assert labels.size(0) == inputs.size(0)  # Equal number of observation
 
         # Confirm predictive distribution dimension
@@ -88,13 +88,13 @@ class BayesianMatchingLoss(Module):
 class BinaryBayesianMatchingLoss(BayesianMatchingLoss):
     """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
 
-    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
+    def __init__(self, kl_scaling_factor: float = 0.001, ignore_index: int = -1) -> Module:
         """
         Args:
-            lamb (float): Weighting factor for the KL Divergence component
+            kl_scaling_factor (float): Weighting factor for the KL Divergence component
             ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
         """
-        super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index)
+        super(BinaryBayesianMatchingLoss, self).__init__(kl_scaling_factor, ignore_index)
 
     def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
         """
diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/modeling/loss/endd_loss.py
similarity index 93%
rename from convlab/dst/setsumbt/loss/endd_loss.py
rename to convlab/dst/setsumbt/modeling/loss/endd_loss.py
index 9bd794bf4569f54f5896e1e88ed1edeadc0fe1e2..f979cc66b668473b8b620cfae1885eef9717c049 100644
--- a/convlab/dst/setsumbt/loss/endd_loss.py
+++ b/convlab/dst/setsumbt/modeling/loss/endd_loss.py
@@ -1,3 +1,20 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Ensemble Distribution Distillation Loss Function (see https://arxiv.org/pdf/2105.06987.pdf for details)"""
+
 import torch
 from torch.nn import Module
 from torch.nn.functional import kl_div
diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/modeling/loss/kl_distillation.py
similarity index 88%
rename from convlab/dst/setsumbt/loss/kl_distillation.py
rename to convlab/dst/setsumbt/modeling/loss/kl_distillation.py
index 9aee234ab68054f2b4a83d6feb5e453384d89e94..6f3971aaad98482a7136ee2507322161cba3ffd0 100644
--- a/convlab/dst/setsumbt/loss/kl_distillation.py
+++ b/convlab/dst/setsumbt/modeling/loss/kl_distillation.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""KL Divergence Ensemble Distillation loss"""
+"""KL Divergence Ensemble Distillation loss (See https://arxiv.org/pdf/1503.02531.pdf for details)"""
 
 import torch
 from torch.nn import Module
@@ -23,7 +23,7 @@ from torch.nn.functional import kl_div
 class KLDistillationLoss(Module):
     """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
 
-    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
+    def __init__(self, ensemble_smoothing: float = 1e-4, ignore_index: int = -1) -> Module:
         """
         Args:
             lamb (float): Target smoothing parameter
@@ -31,7 +31,7 @@ class KLDistillationLoss(Module):
         """
         super(KLDistillationLoss, self).__init__()
 
-        self.lamb = lamb
+        self.lamb = ensemble_smoothing
         self.ignore_index = ignore_index
     
     def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
@@ -71,13 +71,13 @@ class KLDistillationLoss(Module):
 class BinaryKLDistillationLoss(KLDistillationLoss):
     """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
 
-    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
+    def __init__(self, ensemble_smoothing: float = 1e-4, ignore_index: int = -1) -> Module:
         """
         Args:
             lamb (float): Target smoothing parameter
             ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
         """
-        super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index)
+        super(BinaryKLDistillationLoss, self).__init__(ensemble_smoothing, ignore_index)
 
     def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
         """
@@ -101,4 +101,4 @@ class BinaryKLDistillationLoss(KLDistillationLoss):
         targets = targets.unsqueeze(-1)
         targets = torch.cat((1 - targets, targets), -1)
 
-        return super().forward(input, targets, temp)
+        return super().forward(inputs, targets, temp)
diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/modeling/loss/labelsmoothing.py
similarity index 98%
rename from convlab/dst/setsumbt/loss/labelsmoothing.py
rename to convlab/dst/setsumbt/modeling/loss/labelsmoothing.py
index 61d4b353303451eac7eb09592bdb2c5200328250..61a2eeeef3d47c9f9df5275e0316184b6048626e 100644
--- a/convlab/dst/setsumbt/loss/labelsmoothing.py
+++ b/convlab/dst/setsumbt/modeling/loss/labelsmoothing.py
@@ -17,7 +17,7 @@
 
 
 import torch
-from torch.nn import Softmax, Module, CrossEntropyLoss
+from torch.nn import Module
 from torch.nn.functional import kl_div
 
 
diff --git a/convlab/dst/setsumbt/modeling/ontology_encoder.py b/convlab/dst/setsumbt/modeling/ontology_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2c68d131b584e20a86c8b37ee43a10e7ebce0dc
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/ontology_encoder.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Ontology Encoder Model"""
+
+import random
+from copy import deepcopy
+
+import torch
+from transformers import RobertaModel, BertModel
+import numpy as np
+from tqdm import tqdm
+
+PARENT_CLASSES = {'bert': BertModel,
+                  'roberta': RobertaModel}
+
+
+def OntologyEncoder(parent_name: str):
+    """
+    Return the Ontology Encoder model based on the parent transformer model.
+
+    Args:
+        parent_name (str): Name of the parent transformer model
+
+    Returns:
+        OntologyEncoder (class): Ontology Encoder model
+    """
+    parent_class = PARENT_CLASSES.get(parent_name.lower())
+
+    class OntologyEncoder(parent_class):
+        """Ontology Encoder model based on parent transformer model"""
+        def __init__(self, config, args, tokenizer):
+            """
+            Initialize Ontology Encoder model.
+
+            Args:
+                config (transformers.configuration_utils.PretrainedConfig): Configuration of the transformer model
+                args (argparse.Namespace): Arguments
+                tokenizer (transformers.tokenization_utils_base.PreTrainedTokenizer): Tokenizer
+
+            Returns:
+                OntologyEncoder (class): Ontology Encoder model
+            """
+            super().__init__(config)
+
+            # Set random seeds
+            random.seed(args.seed)
+            np.random.seed(args.seed)
+            torch.manual_seed(args.seed)
+            if args.n_gpu > 0:
+                torch.cuda.manual_seed_all(args.seed)
+
+            self.args = args
+            self.config = config
+            self.tokenizer = tokenizer
+
+        def _encode_candidates(self, candidates: list) -> torch.tensor:
+            """
+            Embed candidates
+
+            Args:
+                candidates (list): List of candidate descriptions
+
+            Returns:
+                feats (torch.tensor): Embeddings of the candidate descriptions
+            """
+            # Tokenize candidate descriptions
+            feats = [self.tokenizer.encode_plus(val, add_special_tokens=True, max_length=self.args.max_candidate_len,
+                                                padding='max_length', truncation='longest_first')
+                     for val in candidates]
+
+            # Encode tokenized descriptions
+            with torch.no_grad():
+                feats = {key: torch.tensor([f[key] for f in feats]).to(self.device) for key in feats[0]}
+                embedded_feats = self(**feats)  # [num_candidates, max_candidate_len, hidden_dim]
+
+            # Reduce/pool descriptions embeddings if required
+            if self.args.set_similarity:
+                feats = embedded_feats.last_hidden_state.detach().cpu() #[num_candidates, max_candidate_len, hidden_dim]
+            elif self.args.candidate_pooling == 'cls':
+                feats = embedded_feats.pooler_output.detach().cpu()  # [num_candidates, hidden_dim]
+            elif self.args.candidate_pooling == "mean":
+                feats = embedded_feats.last_hidden_state.detach().cpu()
+                feats = feats.sum(1)
+                feats = torch.nn.functional.layer_norm(feats, feats.size())
+                feats = feats.detach().cpu()  # [num_candidates, hidden_dim]
+
+            return feats
+
+        def get_slot_candidate_embeddings(self):
+            """
+            Get embeddings for slots and candidates
+
+            Args:
+                set_type (str): Subset of the dataset being used (train/validation/test)
+                save_to_file (bool): Indication of whether to save information to file
+
+            Returns:
+                slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
+            """
+            # Set model to eval mode
+            self.eval()
+
+            slots = dict()
+            for domain, subset in tqdm(self.tokenizer.ontology.items(), desc='Domains'):
+                for slot, slot_info in tqdm(subset.items(), desc='Slots'):
+                    # Get description or use "domain-slot"
+                    if self.args.use_descriptions:
+                        desc = slot_info['description']
+                    else:
+                        desc = f"{domain}-{slot}"
+
+                    # Encode domain-slot pair description
+                    slot_emb = self._encode_candidates([desc])[0]
+
+                    # Obtain possible value set and discard requestable value
+                    values = deepcopy(slot_info['possible_values'])
+                    is_requestable = False
+                    if '?' in values:
+                        is_requestable = True
+                        values.remove('?')
+
+                    # Encode value candidates
+                    if values:
+                        feats = self._encode_candidates(values)
+                    else:
+                        feats = None
+
+                    # Store domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
+                    slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
+
+            return slots
+
+    return OntologyEncoder
diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py
deleted file mode 100644
index f72d17fafa50553434b6d4dcd20b8e53d143892f..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/roberta_nbt.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""RoBERTa SetSUMBT"""
-
-import torch
-from transformers import RobertaModel, RobertaPreTrainedModel
-
-from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
-
-
-class RobertaSetSUMBT(RobertaPreTrainedModel):
-    """Roberta based SetSUMBT model"""
-
-    def __init__(self, config):
-        """
-        Args:
-            config (configuration): Model configuration class
-        """
-        super(RobertaSetSUMBT, self).__init__(config)
-        self.config = config
-
-        # Turn Encoder
-        self.roberta = RobertaModel(config)
-        if config.freeze_encoder:
-            for p in self.roberta.parameters():
-                p.requires_grad = False
-
-        self.setsumbt = SetSUMBTHead(config)
-        self.add_slot_candidates = self.setsumbt.add_slot_candidates
-        self.add_value_candidates = self.setsumbt.add_value_candidates
-    
-    def forward(self,
-                input_ids: torch.Tensor,
-                attention_mask: torch.Tensor,
-                token_type_ids: torch.Tensor = None,
-                hidden_state: torch.Tensor = None,
-                state_labels: torch.Tensor = None,
-                request_labels: torch.Tensor = None,
-                active_domain_labels: torch.Tensor = None,
-                general_act_labels: torch.Tensor = None,
-                get_turn_pooled_representation: bool = False,
-                calculate_state_mutual_info: bool = False):
-        """
-        Args:
-            input_ids: Input token ids
-            attention_mask: Input padding mask
-            token_type_ids: Token type indicator
-            hidden_state: Latent internal dialogue belief state
-            state_labels: Dialogue state labels
-            request_labels: User request action labels
-            active_domain_labels: Current active domain labels
-            general_act_labels: General user action labels
-            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
-            calculate_state_mutual_info: Return mutual information in the dialogue state
-
-        Returns:
-            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
-        """
-        if token_type_ids is not None:
-            token_type_ids = None
-
-        # Encode Dialogues
-        batch_size, dialogue_size, turn_size = input_ids.size()
-        input_ids = input_ids.reshape(-1, turn_size)
-        attention_mask = attention_mask.reshape(-1, turn_size)
-
-        roberta_output = self.roberta(input_ids, attention_mask)
-
-        # Apply mask and reshape the dialogue turn token embeddings
-        attention_mask = attention_mask.float().unsqueeze(2)
-        attention_mask = attention_mask.repeat((1, 1, roberta_output.last_hidden_state.size(-1)))
-        turn_embeddings = roberta_output.last_hidden_state * attention_mask
-        turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
-        
-        if get_turn_pooled_representation:
-            return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
-                                 batch_size, dialogue_size, hidden_state, state_labels,
-                                 request_labels, active_domain_labels, general_act_labels,
-                                 calculate_state_mutual_info) + (roberta_output.pooler_output,)
-        return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
-                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
-                             general_act_labels, calculate_state_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py
index 4b67e35c3e8d2e0eaa5abcff2376d89ff67ae3cc..19f7408e95a03cec7bf76665da815eae85e216e7 100644
--- a/convlab/dst/setsumbt/modeling/setsumbt.py
+++ b/convlab/dst/setsumbt/modeling/setsumbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,26 +16,22 @@
 """SetSUMBT Prediction Head"""
 
 import torch
-from torch.autograd import Variable
 from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
-                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
-                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
+                      CosineSimilarity, PairwiseDistance, Sequential, ReLU, Conv1d, GELU, Parameter)
 from torch.nn.init import (xavier_normal_, constant_)
+from transformers.utils import ModelOutput
 
-from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss,
-                                       KLDistillationLoss, BinaryKLDistillationLoss,
-                                       LabelSmoothingLoss, BinaryLabelSmoothingLoss,
-                                       RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss)
+from convlab.dst.setsumbt.modeling import loss
 
 
 class SlotUtteranceMatching(Module):
-    """Slot Utterance matching attention based information extractor"""
+    """Slot Utterance Matching module for information extraction from utterances"""
 
     def __init__(self, hidden_size: int = 768, attention_heads: int = 12):
         """
         Args:
-            hidden_size (int): Dimension of token embeddings
-            attention_heads (int): Number of attention heads to use in attention module
+            hidden_size: Hidden size of the transformer
+            attention_heads: Number of attention heads
         """
         super(SlotUtteranceMatching, self).__init__()
 
@@ -47,12 +43,12 @@ class SlotUtteranceMatching(Module):
                 slot_embeddings: torch.Tensor) -> torch.Tensor:
         """
         Args:
-            turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size]
-            attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size]
-            slot_embeddings: Embeddings for each token in the slot descriptions
+            turn_embeddings: Turn level embeddings for the dialogue
+            attention_mask: Mask for the attention related to turn embeddings
+            slot_embeddings: Slot level embeddings for the dialogue
 
         Returns:
-            hidden: Information extracted from turn related to slot descriptions
+            hidden: Turn level embeddings for the dialogue conditioned on the slot embeddings
         """
         turn_embeddings = turn_embeddings.transpose(0, 1)
 
@@ -69,7 +65,7 @@ class SlotUtteranceMatching(Module):
 
 
 class RecurrentNeuralBeliefTracker(Module):
-    """Recurrent latent neural belief tracking module"""
+    """Recurrent Neural Belief Tracker module for tracking the latent dialogue state"""
 
     def __init__(self,
                  nbt_type: str = 'gru',
@@ -80,15 +76,16 @@ class RecurrentNeuralBeliefTracker(Module):
                  dropout_rate: float = 0.3):
         """
         Args:
-            nbt_type: Type of recurrent neural network (gru/lstm)
-            rnn_zero_init: Use zero initialised state for the RNN
-            input_size: Embedding size of the inputs
+            nbt_type: Type of recurrent neural network to use (lstm/gru)
+            rnn_zero_init: Whether to initialise the hidden state of the RNN to zero
+            input_size: Input embedding size
             hidden_size: Hidden size of the RNN
-            hidden_layers: Number of RNN Layers
+            hidden_layers: Number of hidden layers of the RNN
             dropout_rate: Dropout rate
         """
         super(RecurrentNeuralBeliefTracker, self).__init__()
 
+        # Initialise Initial Belief State Layer
         if rnn_zero_init:
             self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate))
         else:
@@ -126,12 +123,12 @@ class RecurrentNeuralBeliefTracker(Module):
     def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor:
         """
         Args:
-            inputs: Latent turn level information
-            hidden_state: Latent internal belief state
+            inputs: Input embeddings
+            hidden_state: Hidden state of the RNN
 
         Returns:
-            belief_embedding: Belief state embeddings
-            context: Latent internal belief state
+            belief_embedding: Latent belief state embeddings
+            context: Hidden state of the RNN
         """
         self.nbt.flatten_parameters()
         if hidden_state is None:
@@ -155,13 +152,13 @@ class RecurrentNeuralBeliefTracker(Module):
 
 
 class SetPooler(Module):
-    """Token set pooler"""
+    """Set Pooler module for pooling the set of token embeddings"""
 
     def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768):
         """
         Args:
-            pooling_strategy: Type of set pooler (cnn/dan/mean)
-            hidden_size: Token embedding size
+            pooling_strategy: Pooling strategy to use (mean/cnn/dan)
+            hidden_size: Hidden size of the set of token embeddings
         """
         super(SetPooler, self).__init__()
 
@@ -172,14 +169,14 @@ class SetPooler(Module):
         elif pooling_strategy == 'dan':
             self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size))
 
-    def forward(self, inputs, attention_mask):
+    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor):
         """
         Args:
-            inputs: Token set embeddings
-            attention_mask: Padding mask for the set of tokens
+            inputs: Set of token embeddings
+            attention_mask: Attention mask for the set of token embeddings
 
         Returns:
-
+            hidden: Pooled embeddings
         """
         if self.pooling_strategy == "mean":
             hidden = inputs.sum(1) / attention_mask.sum(1)
@@ -192,13 +189,25 @@ class SetPooler(Module):
         return hidden
 
 
+class SetSUMBTOutput(ModelOutput):
+    """SetSUMBT Output class"""
+    loss = None
+    belief_state = None
+    request_probabilities = None
+    active_domain_probabilities = None
+    general_act_probabilities = None
+    hidden_state = None
+    belief_state_summary = None
+    belief_state_mutual_information = None
+
+
 class SetSUMBTHead(Module):
     """SetSUMBT Prediction Head for Language Models"""
 
     def __init__(self, config):
         """
         Args:
-            config (configuration): Model configuration class
+            config: Model configuration
         """
         super(SetSUMBTHead, self).__init__()
         self.config = config
@@ -214,11 +223,24 @@ class SetSUMBTHead(Module):
             self.set_pooler = SetPooler(config.set_pooling, config.hidden_size)
 
         # Model ontology placeholders
-        self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
-        self.slot_ids = dict()
-        self.requestable_slot_ids = dict()
-        self.informable_slot_ids = dict()
-        self.domain_ids = dict()
+        if not hasattr(self.config, 'num_slots'):
+            self.config.num_slots = 1
+        self.slot_embeddings = Parameter(torch.zeros(self.config.num_slots, self.config.max_candidate_len,
+                                                     self.config.hidden_size), requires_grad=False)
+        if not hasattr(self.config, 'slot_ids'):
+            self.config.slot_ids = dict()
+            self.config.requestable_slot_ids = dict()
+            self.config.informable_slot_ids = dict()
+            self.config.domain_ids = dict()
+        if not hasattr(self.config, 'num_values'):
+            self.config.num_values = dict()
+        for slot in self.config.slot_ids:
+            if slot not in self.config.num_values:
+                self.config.num_values[slot] = 1
+            setattr(self, slot + '_value_embeddings', Parameter(torch.zeros(self.config.num_values[slot],
+                                                                            self.config.max_candidate_len,
+                                                                            self.config.hidden_size),
+                                                                requires_grad=False))
 
         # Matching network similarity measure
         if config.distance_measure == 'cosine':
@@ -229,19 +251,12 @@ class SetSUMBTHead(Module):
             raise NameError('NotImplemented')
 
         # User goal prediction loss function
-        if config.loss_function == 'crossentropy':
-            self.loss = CrossEntropyLoss(ignore_index=-1)
-        elif config.loss_function == 'bayesianmatching':
-            self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-        elif config.loss_function == 'labelsmoothing':
-            self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-        elif config.loss_function == 'distillation':
-            self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-            self.temp = 1.0
-        elif config.loss_function == 'distribution_distillation':
-            self.loss = RKLDirichletMediatorLoss(ignore_index=-1)
-        else:
-            raise NameError('NotImplemented')
+        loss_args = {'ignore_index': -1,
+                     'kl_scaling_factor': config.to_dict().get('kl_scaling_factor', 0.0),
+                     'label_smoothing': config.to_dict().get('label_smoothing', 0.0),
+                     'ensemble_smoothing': config.to_dict().get('ensemble_smoothing', 0.0)}
+        self.loss = loss.load(config.loss_function)(**loss_args)
+        self.temp = 1.0
 
         # Intent and domain prediction heads
         if config.predict_actions:
@@ -253,26 +268,10 @@ class SetSUMBTHead(Module):
             self.request_weight = float(self.config.user_request_loss_weight)
             self.general_act_weight = float(self.config.user_general_act_loss_weight)
             self.active_domain_weight = float(self.config.active_domain_loss_weight)
-            if config.loss_function == 'crossentropy':
-                self.request_loss = BCEWithLogitsLoss()
-                self.general_act_loss = CrossEntropyLoss(ignore_index=-1)
-                self.active_domain_loss = BCEWithLogitsLoss()
-            elif config.loss_function == 'labelsmoothing':
-                self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-                self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-                self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-            elif config.loss_function == 'bayesianmatching':
-                self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-                self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-                self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-            elif config.loss_function == 'distillation':
-                self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-                self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-                self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-            elif config.loss_function == 'distribution_distillation':
-                self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
-                self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1)
-                self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
+
+            self.request_loss = loss.load(config.loss_function, binary=True)(**loss_args)
+            self.general_act_loss = loss.load(config.loss_function)(**loss_args)
+            self.active_domain_loss = loss.load(config.loss_function, binary=True)(**loss_args)
 
     def add_slot_candidates(self, slot_candidates: tuple):
         """
@@ -281,7 +280,7 @@ class SetSUMBTHead(Module):
         the request indicator is false the slot is not requestable.
 
         Args:
-            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
+            slot_candidates: Tuples of slot embedding, informable value embeddings and request indicator
         """
         if self.slot_embeddings.size(0) != 0:
             embeddings = self.slot_embeddings.detach()
@@ -289,28 +288,33 @@ class SetSUMBTHead(Module):
             embeddings = torch.zeros(0)
 
         for slot in slot_candidates:
-            if slot in self.slot_ids:
-                index = self.slot_ids[slot]
+            if slot in self.config.slot_ids:
+                index = self.config.slot_ids[slot]
                 embeddings[index, :] = slot_candidates[slot][0]
             else:
                 index = embeddings.size(0)
                 emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
                 embeddings = torch.cat((embeddings, emb), 0)
-                self.slot_ids[slot] = index
-                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
+                self.config.slot_ids[slot] = index
+                self.config.num_values[slot] = 1
+                setattr(self, slot + '_value_embeddings', Parameter(torch.zeros(self.config.num_values[slot],
+                                                                                self.config.max_candidate_len,
+                                                                                self.config.hidden_size),
+                                                                    requires_grad=False))
             # Add slot to relevant requestable and informable slot lists
             if slot_candidates[slot][2]:
-                self.requestable_slot_ids[slot] = index
+                self.config.requestable_slot_ids[slot] = index
             if slot_candidates[slot][1] is not None:
-                self.informable_slot_ids[slot] = index
+                self.config.informable_slot_ids[slot] = index
 
             domain = slot.split('-', 1)[0]
-            if domain not in self.domain_ids:
-                self.domain_ids[domain] = []
-            self.domain_ids[domain].append(index)
-            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
+            if domain not in self.config.domain_ids:
+                self.config.domain_ids[domain] = []
+            self.config.domain_ids[domain].append(index)
+            self.config.domain_ids[domain] = list(set(self.config.domain_ids[domain]))
 
-        self.slot_embeddings = Variable(embeddings, requires_grad=False)
+        self.config.num_slots = embeddings.size(0)
+        self.slot_embeddings = Parameter(embeddings, requires_grad=False)
 
     def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
         """
@@ -319,7 +323,7 @@ class SetSUMBTHead(Module):
         Args:
             slot: Slot name
             value_candidates: Value candidate embeddings
-            replace: If true existing value candidates are replaced
+            replace: Replace existing value candidates
         """
         embeddings = getattr(self, slot + '_value_embeddings')
 
@@ -328,7 +332,8 @@ class SetSUMBTHead(Module):
         else:
             embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
 
-        setattr(self, slot + '_value_embeddings', embeddings)
+        self.config.num_values[slot] = embeddings.size(0)
+        setattr(self, slot + '_value_embeddings', Parameter(embeddings, requires_grad=False))
 
     def forward(self,
                 turn_embeddings: torch.Tensor,
@@ -344,20 +349,20 @@ class SetSUMBTHead(Module):
                 calculate_state_mutual_info: bool = False):
         """
         Args:
-            turn_embeddings: Token embeddings in the current turn
-            turn_pooled_representation: Pooled representation of the current dialogue turn
-            attention_mask: Padding mask for the current dialogue turn
-            batch_size: Number of dialogues in the batch
-            dialogue_size: Number of turns in each dialogue
-            hidden_state: Latent internal dialogue belief state
-            state_labels: Dialogue state labels
-            request_labels: User request action labels
-            active_domain_labels: Current active domain labels
-            general_act_labels: General user action labels
-            calculate_state_mutual_info: Return mutual information in the dialogue state
+            turn_embeddings: Turn embeddings for dialogue turns
+            turn_pooled_representation: Turn pooled representation for dialogue turns
+            attention_mask: Attention mask for dialogue turns
+            batch_size: Batch size
+            dialogue_size: Number of turns in dialogue
+            hidden_state: RNN Hidden state / Latent Belief State for dialogue turns
+            state_labels: State labels for dialogue turns
+            request_labels: Request labels for dialogue turns
+            active_domain_labels: Active domain labels for dialogue turns
+            general_act_labels: General action labels for dialogue turns
+            calculate_state_mutual_info: Calculate state mutual information
 
         Returns:
-            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+            output: Model output containing loss, state, request, active domain predictions, etc.
         """
         hidden_size = turn_embeddings.size(-1)
         # Initialise loss
@@ -432,7 +437,7 @@ class SetSUMBTHead(Module):
         if self.config.predict_actions:
             # User request prediction
             request_probs = dict()
-            for slot, slot_id in self.requestable_slot_ids.items():
+            for slot, slot_id in self.config.requestable_slot_ids.items():
                 request_logits = self.request_gate(belief_embedding[:, :, slot_id, :])
 
                 # Store output probabilities
@@ -441,7 +446,7 @@ class SetSUMBTHead(Module):
                 request_logits[batches, dialogues] = 0.0
                 request_probs[slot] = torch.sigmoid(request_logits)
 
-                if request_labels is not None:
+                if request_labels is not None and slot in request_labels:
                     # Compute request gate loss
                     request_logits = request_logits.reshape(-1)
                     if self.config.loss_function == 'distillation':
@@ -457,7 +462,7 @@ class SetSUMBTHead(Module):
 
             # Active domain prediction
             active_domain_probs = dict()
-            for domain, slot_ids in self.domain_ids.items():
+            for domain, slot_ids in self.config.domain_ids.items():
                 belief = belief_embedding[:, :, slot_ids, :]
                 if len(slot_ids) > 1:
                     # SqrtN reduction across all slots within a domain
@@ -490,7 +495,7 @@ class SetSUMBTHead(Module):
         belief_state_probs = dict()
         belief_state_mutual_info = dict()
         belief_state_stats = dict()
-        for slot, slot_id in self.informable_slot_ids.items():
+        for slot, slot_id in self.config.informable_slot_ids.items():
             # Get slot belief embedding and value candidates
             candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
             belief = belief_embedding[:, :, slot_id, :]
@@ -556,9 +561,17 @@ class SetSUMBTHead(Module):
                     loss += self.loss(logits, state_labels[slot].reshape(-1))
 
         # Return model outputs
-        out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state
+        output = SetSUMBTOutput(belief_state=belief_state_probs,
+                                request_probabilities=request_probs,
+                                active_domain_probabilities=active_domain_probs,
+                                general_act_probabilities=general_act_probs,
+                                hidden_state=hidden_state,
+                                loss=None,
+                                belief_state_summary=None,
+                                belief_state_mutual_information=None)
         if state_labels is not None or request_labels is not None:
-            out = (loss,) + out + (belief_state_stats,)
+            output.loss = loss
+            output.belief_state_summary = belief_state_stats
         if calculate_state_mutual_info:
-            out = out + (belief_state_mutual_info,)
-        return out
+            output.belief_state_mutual_information = belief_state_mutual_info
+        return output
diff --git a/convlab/dst/setsumbt/modeling/setsumbt_nbt.py b/convlab/dst/setsumbt/modeling/setsumbt_nbt.py
new file mode 100644
index 0000000000000000000000000000000000000000..a03983e75580fdadb05225b4e69f89ee5b813759
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/setsumbt_nbt.py
@@ -0,0 +1,339 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SetSUMBT Models"""
+
+import os
+from copy import deepcopy
+
+import torch
+from torch.nn import Module
+from transformers import (BertModel, BertPreTrainedModel, BertConfig,
+                          RobertaModel, RobertaPreTrainedModel, RobertaConfig)
+
+from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead, SetSUMBTOutput
+
+
+class BertSetSUMBT(BertPreTrainedModel):
+    """Bert based SetSUMBT model"""
+
+    def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
+        super(BertSetSUMBT, self).__init__(config)
+        self.config = config
+
+        # Turn Encoder
+        self.bert = BertModel(config)
+        if config.freeze_encoder:
+            for p in self.bert.parameters():
+                p.requires_grad = False
+
+        self.setsumbt = SetSUMBTHead(config)
+        self.add_slot_candidates = self.setsumbt.add_slot_candidates
+        self.add_value_candidates = self.setsumbt.add_value_candidates
+
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                get_turn_pooled_representation: bool = False,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
+            calculate_state_mutual_info: Return mutual information in the dialogue state
+
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
+
+        # Encode Dialogues
+        batch_size, dialogue_size, turn_size = input_ids.size()
+        input_ids = input_ids.reshape(-1, turn_size)
+        token_type_ids = token_type_ids.reshape(-1, turn_size)
+        attention_mask = attention_mask.reshape(-1, turn_size)
+
+        bert_output = self.bert(input_ids, token_type_ids, attention_mask)
+
+        attention_mask = attention_mask.float().unsqueeze(2)
+        attention_mask = attention_mask.repeat((1, 1, bert_output.last_hidden_state.size(-1)))
+        turn_embeddings = bert_output.last_hidden_state * attention_mask
+        turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
+
+        output = self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
+                               batch_size, dialogue_size, hidden_state, state_labels,
+                               request_labels, active_domain_labels, general_act_labels,
+                               calculate_state_mutual_info)
+        output.turn_pooled_representation = bert_output.pooler_output if get_turn_pooled_representation else None
+        return output
+
+
+class RobertaSetSUMBT(RobertaPreTrainedModel):
+    """Roberta based SetSUMBT model"""
+
+    def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
+        super(RobertaSetSUMBT, self).__init__(config)
+        self.config = config
+
+        # Turn Encoder
+        self.roberta = RobertaModel(config)
+        if config.freeze_encoder:
+            for p in self.roberta.parameters():
+                p.requires_grad = False
+
+        self.setsumbt = SetSUMBTHead(config)
+        self.add_slot_candidates = self.setsumbt.add_slot_candidates
+        self.add_value_candidates = self.setsumbt.add_value_candidates
+
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                get_turn_pooled_representation: bool = False,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
+            calculate_state_mutual_info: Return mutual information in the dialogue state
+
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
+        if token_type_ids is not None:
+            token_type_ids = None
+
+        # Encode Dialogues
+        batch_size, dialogue_size, turn_size = input_ids.size()
+        input_ids = input_ids.reshape(-1, turn_size)
+        attention_mask = attention_mask.reshape(-1, turn_size)
+
+        roberta_output = self.roberta(input_ids, attention_mask)
+
+        # Apply mask and reshape the dialogue turn token embeddings
+        attention_mask = attention_mask.float().unsqueeze(2)
+        attention_mask = attention_mask.repeat((1, 1, roberta_output.last_hidden_state.size(-1)))
+        turn_embeddings = roberta_output.last_hidden_state * attention_mask
+        turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
+
+        output = self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
+                               batch_size, dialogue_size, hidden_state, state_labels,
+                               request_labels, active_domain_labels, general_act_labels,
+                               calculate_state_mutual_info)
+        output.turn_pooled_representation = roberta_output.pooler_output if get_turn_pooled_representation else None
+        return output
+
+
+MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
+class EnsembleSetSUMBT(Module):
+    """Ensemble SetSUMBT Model for joint ensemble prediction"""
+
+    def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
+        super(EnsembleSetSUMBT, self).__init__()
+        self.config = config
+
+        # Initialise ensemble members
+        model_cls = MODELS[self.config.model_type]
+        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
+            setattr(self, attr, model_cls(self.get_clean_config(config)))
+
+    @staticmethod
+    def get_clean_config(config):
+        config = deepcopy(config)
+        config.slot_ids = dict()
+        config.requestable_slot_ids = dict()
+        config.informable_slot_ids = dict()
+        config.domain_ids = dict()
+        config.num_values = dict()
+
+        return config
+
+    def _load(self, path: str):
+        """
+        Load parameters
+        Args:
+            path: Location of model parameters
+        """
+        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
+            idx = attr.split('_', 1)[-1]
+            state_dict = torch.load(os.path.join(self._get_checkpoint_path(path, idx), 'pytorch_model.bin'))
+            state_dict = {key: itm for key, itm in state_dict.items() if '_value_embeddings' not in key}
+            getattr(self, attr).load_state_dict(state_dict)
+
+    def add_slot_candidates(self, slot_candidates: tuple):
+        """
+        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
+        and a request indicator, if the informable value embeddings is None the slot is not informable and if
+        the request indicator is false the slot is not requestable.
+
+        Args:
+            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
+        """
+        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
+            getattr(self, attr).add_slot_candidates(slot_candidates)
+        self.setsumbt = self.model_0.setsumbt
+
+    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
+        """
+        Add value candidates for a slot
+
+        Args:
+            slot: Slot name
+            value_candidates: Value candidate embeddings
+            replace: If true existing value candidates are replaced
+        """
+        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
+            getattr(self, attr).add_value_candidates(slot, value_candidates, replace)
+
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                reduction: str = 'mean',
+                **kwargs) -> tuple:
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            reduction: Reduction of ensemble member predictive distributions (mean, none)
+
+        Returns:
+
+        """
+        belief_state_probs = {slot: [] for slot in self.setsumbt.config.informable_slot_ids}
+        request_probs = {slot: [] for slot in self.setsumbt.config.requestable_slot_ids}
+        active_domain_probs = {dom: [] for dom in self.setsumbt.config.domain_ids}
+        general_act_probs = []
+        loss = 0.0 if 'state_labels' in kwargs else None
+        for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
+            # Prediction from each ensemble member
+            with torch.no_grad():
+                _out = getattr(self, attr)(input_ids=input_ids,
+                                           token_type_ids=token_type_ids,
+                                           attention_mask=attention_mask,
+                                           **kwargs)
+            if loss is not None:
+                loss += _out.loss
+            for slot in belief_state_probs:
+                belief_state_probs[slot].append(_out.belief_state[slot].unsqueeze(-2).detach().cpu())
+            if self.config.predict_actions:
+                for slot in request_probs:
+                    request_probs[slot].append(_out.request_probabilities[slot].unsqueeze(-1).detach().cpu())
+                for dom in active_domain_probs:
+                    active_domain_probs[dom].append(_out.active_domain_probabilities[dom].unsqueeze(-1).detach().cpu())
+                general_act_probs.append(_out.general_act_probabilities.unsqueeze(-2).detach().cpu())
+
+        belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
+        if self.config.predict_actions:
+            request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()}
+            active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()}
+            general_act_probs = torch.cat(general_act_probs, -2)
+        else:
+            request_probs = {}
+            active_domain_probs = {}
+            general_act_probs = torch.tensor(0.0)
+
+        # Apply reduction of ensemble to single posterior
+        if reduction == 'mean':
+            belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()}
+            request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()}
+            active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()}
+            general_act_probs = general_act_probs.mean(-2)
+        elif reduction != 'none':
+            raise (NameError('Not Implemented!'))
+
+        if loss is not None:
+            loss /= self.config.ensemble_size
+
+        output = SetSUMBTOutput(loss=loss,
+                                belief_state=belief_state_probs,
+                                request_probabilities=request_probs,
+                                active_domain_probabilities=active_domain_probs,
+                                general_act_probabilities=general_act_probs)
+
+        return output
+
+    @staticmethod
+    def _get_checkpoint_path(path: str, idx: int):
+        """
+        Get checkpoint path for ensemble member
+        Args:
+            path: Location of ensemble
+            idx: Ensemble member index
+
+        Returns:
+            Checkpoint path
+        """
+
+        checkpoints = os.listdir(os.path.join(path, f'ens-{idx}'))
+        checkpoints = [int(p.split('-', 1)[-1]) for p in checkpoints if 'checkpoint-' in p]
+        checkpoint = f"checkpoint-{max(checkpoints)}"
+        return os.path.join(path, f'ens-{idx}', checkpoint)
+
+    @classmethod
+    def from_pretrained(cls, path, config=None):
+        config_path = os.path.join(cls._get_checkpoint_path(path, 0), 'config.json')
+        if not os.path.exists(config_path):
+            raise (NameError('Could not find config.json in model path.'))
+
+        if config is None:
+            try:
+                config = RobertaConfig.from_pretrained(config_path)
+            except:
+                config = BertConfig.from_pretrained(config_path)
+
+        config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir])
+
+        model = cls(config)
+        model._load(path)
+
+        return model
diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
index 654e83c5d1ad9dc908213cca8967a84893395b04..549ae76a13af09b681eb0d1fb893bdff361fb014 100644
--- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py
+++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,7 +16,6 @@
 """Linear Temperature Scheduler Class"""
 
 
-# Temp scheduler class for ensemble distillation
 class LinearTemperatureScheduler:
     """
     Temperature scheduler object used for distribution temperature scheduling in distillation
diff --git a/convlab/dst/setsumbt/modeling/tokenization.py b/convlab/dst/setsumbt/modeling/tokenization.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbee3baa5188023a2c07dca96ae02ddfe5c7d298
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/tokenization.py
@@ -0,0 +1,401 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SetSUMBT Tokenizer"""
+
+import json
+import os
+
+import torch
+from transformers import RobertaTokenizer, BertTokenizer
+from tqdm import tqdm
+
+from convlab.dst.setsumbt.datasets.utils import IdTensor
+
+PARENT_CLASSES = {'bert': BertTokenizer,
+                  'roberta': RobertaTokenizer}
+
+
+def SetSUMBTTokenizer(parent_name):
+    """SetSUMBT Tokenizer Class Factory"""
+    parent_class = PARENT_CLASSES.get(parent_name.lower())
+
+    class SetSUMBTTokenizer(parent_class):
+        """SetSUMBT Tokenizer Class"""
+
+        def __init__(
+                self,
+                vocab_file,
+                merges_file,
+                errors="replace",
+                bos_token="<s>",
+                eos_token="</s>",
+                sep_token="</s>",
+                cls_token="<s>",
+                unk_token="<unk>",
+                pad_token="<pad>",
+                mask_token="<mask>",
+                add_prefix_space=False,
+                **kwargs,
+        ):
+            """
+            Initialize the tokenizer.
+
+            Args:
+                vocab_file (str): Path to the vocabulary file.
+                merges_file (str): Path to the merges file.
+                errors (str): Error handling for the tokenizer.
+                bos_token (str): Beginning of sentence token.
+                eos_token (str): End of sentence token.
+                sep_token (str): Separator token.
+                cls_token (str): Classification token.
+                unk_token (str): Unknown token.
+                pad_token (str): Padding token.
+                mask_token (str): Masking token.
+                add_prefix_space (bool): Whether to add a space before the first token.
+                **kwargs: Additional arguments for the tokenizer.
+            """
+
+            # Load ontology and tokenizer vocab
+            with open(vocab_file, 'r', encoding="utf-8") as vocab_handle:
+                self.encoder = json.load(vocab_handle)
+                vocab_handle.close()
+            self.ontology = self.encoder['SETSUMBT_ONTOLOGY'] if 'SETSUMBT_ONTOLOGY' in self.encoder else dict()
+            self.encoder = {k: v for k, v in self.encoder.items() if 'SETSUMBT_ONTOLOGY' not in k}
+            vocab_dir = os.path.dirname(vocab_file)
+            vocab_file = os.path.basename(vocab_file).split('.')
+            vocab_file = vocab_file[0] + "_base." + vocab_file[-1]
+            vocab_file = os.path.join(vocab_dir, vocab_file)
+            with open(vocab_file, 'w', encoding="utf-8") as vocab_handle:
+                json.dump(self.encoder, vocab_handle)
+                vocab_handle.close()
+
+            super().__init__(vocab_file, merges_file, errors, bos_token, eos_token, sep_token, cls_token, unk_token,
+                             pad_token, mask_token, add_prefix_space, **kwargs)
+
+        def set_setsumbt_ontology(self, ontology):
+            """
+            Set the ontology for the tokenizer.
+
+            Args:
+                ontology (dict): The dialogue system ontology to use.
+            """
+            self.ontology = ontology
+
+        def save_vocabulary(self, save_directory: str, filename_prefix: str = None) -> tuple:
+            """
+            Save the tokenizer vocabulary and merges files to a directory.
+
+            Args:
+                save_directory (str): Directory to which to save.
+                filename_prefix (str): Optional prefix to add to the files.
+
+            Returns:
+                vocab_file (str): Path to the saved vocabulary file.
+                merge_file (str): Path to the saved merges file.
+            """
+            self.encoder['SETSUMBT_ONTOLOGY'] = self.ontology
+            vocab_file, merge_file = super().save_vocabulary(save_directory, filename_prefix)
+            self.encoder = {k: v for k, v in self.encoder.items() if 'SETSUMBT_ONTOLOGY' not in k}
+
+            return vocab_file, merge_file
+
+        def decode_state(self, belief_state, request_probs=None, active_domain_probs=None, general_act_probs=None):
+            """
+            Decode a belief state, request, active domain and general action distributions into a dialogue state.
+
+            Args:
+                belief_state (dict): The belief state distributions.
+                request_probs (dict): The request distributions.
+                active_domain_probs (dict): The active domain distributions.
+                general_act_probs (dict): The general action distributions.
+
+            Returns:
+                dialogue_state (dict): The decoded dialogue state.
+            """
+            dialogue_state = {domain: {slot: '' for slot, slot_info in domain_info.items()
+                                       if slot_info['possible_values'] != ["?"] and slot_info['possible_values']}
+                              for domain, domain_info in self.ontology.items()}
+
+            for slot, probs in belief_state.items():
+                dom, slot = slot.split('-', 1)
+                val = self.ontology.get(dom, dict()).get(slot, dict()).get('possible_values', [])
+                val = val[probs.argmax().item()] if val else 'none'
+                if val != 'none':
+                    if dom in dialogue_state:
+                        if slot in dialogue_state[dom]:
+                            dialogue_state[dom][slot] = val
+
+            request_acts = list()
+            if request_probs is not None:
+                request_acts = [slot for slot, p in request_probs.items() if p.item() > 0.5]
+                request_acts = [slot.split('-', 1) for slot in request_acts]
+                request_acts = [[dom, slt] for dom, slt in request_acts
+                                if '?' in self.ontology.get(dom, dict()).get(slt, dict()).get('possible_values', [])]
+                request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
+
+            # Construct active domain set
+            active_domains = dict()
+            if active_domain_probs is not None:
+                active_domains = {dom: active_domain_probs.get(dom, torch.tensor(0.0)).item() > 0.5
+                                  for dom in self.ontology}
+
+            # Construct general domain action
+            general_acts = list()
+            if general_act_probs is not None:
+                general_acts = general_act_probs.argmax(-1).item()
+                general_acts = [[], ['bye'], ['thank']][general_acts]
+                general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
+
+            user_acts = request_acts + general_acts
+            dialogue_state = {'belief_state': dialogue_state,
+                              'user_action': user_acts,
+                              'active_domains': active_domains}
+
+            return dialogue_state
+
+        def decode_state_batch(self,
+                               belief_state,
+                               request_probs=None,
+                               active_domain_probs=None,
+                               general_act_probs=None,
+                               dialogue_ids=None):
+            """
+            Decode a batch of belief state, request, active domain and general action distributions.
+
+            Args:
+                belief_state (dict): The belief state distributions.
+                request_probs (dict): The request distributions.
+                active_domain_probs (dict): The active domain distributions.
+                general_act_probs (dict): The general action distributions.
+                dialogue_ids (list): The dialogue IDs.
+
+            Returns:
+                data (dict): The decoded dialogue states.
+            """
+
+            data = dict()
+            slot_0 = [key for key in belief_state.keys()][0]
+
+            if dialogue_ids is None:
+                dialogue_ids = [["{:06d}".format(i) for i in range(belief_state[slot_0].size(0))]]
+
+            for dial_idx in range(belief_state[slot_0].size(0)):
+                dialogue = list()
+                for turn_idx in range(belief_state[slot_0].size(1)):
+                    if belief_state[slot_0][dial_idx, turn_idx].sum() != 0.0:
+                        belief = {slot: p[dial_idx, turn_idx] for slot, p in belief_state.items()}
+                        req = {slot: p[dial_idx, turn_idx]
+                               for slot, p in request_probs.items()} if request_probs is not None else None
+                        dom = {dom: p[dial_idx, turn_idx]
+                               for dom, p in active_domain_probs.items()} if active_domain_probs is not None else None
+                        gen = general_act_probs[dial_idx, turn_idx] if general_act_probs is not None else None
+
+                        state = self.decode_state(belief, req, dom, gen)
+                        dialogue.append(state)
+                data[dialogue_ids[0][dial_idx]] = dialogue
+
+            return data
+
+        def encode(self, dialogues: list, max_turns: int = 12, max_seq_len: int = 64) -> dict:
+            """
+            Convert dialogue examples to model input features and labels
+
+            Args:
+                dialogues (list): List of all extracted dialogues
+                max_turns (int): Maximum numbers of turns in a dialogue
+                max_seq_len (int): Maximum number of tokens in a dialogue turn
+
+            Returns:
+                features (dict): All inputs and labels required to train the model
+            """
+            features = dict()
+
+            # Get encoder input for system, user utterance pairs
+            input_feats = []
+            if len(dialogues) > 5:
+                iterator = tqdm(dialogues)
+            else:
+                iterator = dialogues
+            for dial in iterator:
+                dial_feats = []
+                for turn in dial:
+                    if len(turn['system_utterance']) == 0:
+                        usr = turn['user_utterance']
+                        dial_feats.append(super().encode_plus(usr, add_special_tokens=True, max_length=max_seq_len,
+                                                              padding='max_length', truncation='longest_first'))
+                    else:
+                        usr = turn['user_utterance']
+                        sys = turn['system_utterance']
+                        dial_feats.append(super().encode_plus(usr, sys, add_special_tokens=True,
+                                                              max_length=max_seq_len, padding='max_length',
+                                                              truncation='longest_first'))
+                    # Truncate
+                    if len(dial_feats) >= max_turns:
+                        break
+                input_feats.append(dial_feats)
+            del dial_feats
+
+            # Perform turn level padding
+            if 'dialogue_id' in dialogues[0][0]:
+                dial_ids = list()
+                for dial in dialogues:
+                    _ids = [turn['dialogue_id'] for turn in dial][:max_turns]
+                    _ids += [''] * (max_turns - len(_ids))
+                    dial_ids.append(_ids)
+            input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                         for dial in input_feats]
+            if 'token_type_ids' in input_feats[0][0]:
+                token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                                  for dial in input_feats]
+            else:
+                token_type_ids = None
+            if 'attention_mask' in input_feats[0][0]:
+                attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                                  for dial in input_feats]
+            else:
+                attention_mask = None
+            del input_feats
+
+            # Create torch data tensors
+            if 'dialogue_id' in dialogues[0][0]:
+                features['dialogue_ids'] = IdTensor(dial_ids)
+            features['input_ids'] = torch.tensor(input_ids)
+            features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
+            features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
+            del input_ids, token_type_ids, attention_mask
+
+            # Extract all informable and requestable slots from the ontology
+            informable_slots = [f"{domain}-{slot}" for domain in self.ontology for slot in self.ontology[domain]
+                                if self.ontology[domain][slot]['possible_values']
+                                and self.ontology[domain][slot]['possible_values'] != ['?']]
+            requestable_slots = [f"{domain}-{slot}" for domain in self.ontology for slot in self.ontology[domain]
+                                 if '?' in self.ontology[domain][slot]['possible_values']]
+
+            # Extract a list of domains from the ontology slots
+            domains = [domain for domain in self.ontology]
+
+            # Create slot labels
+            if 'state' in dialogues[0][0]:
+                for domslot in tqdm(informable_slots):
+                    labels = []
+                    for dial in dialogues:
+                        labs = []
+                        for turn in dial:
+                            value = [v for d, substate in turn['state'].items() for s, v in substate.items()
+                                     if f'{d}-{s}' == domslot]
+                            domain, slot = domslot.split('-', 1)
+                            if turn['dataset_name'] in self.ontology[domain][slot]['dataset_names']:
+                                value = value[0] if value else 'none'
+                            else:
+                                value = -1
+                            if value in self.ontology[domain][slot]['possible_values'] and value != -1:
+                                value = self.ontology[domain][slot]['possible_values'].index(value)
+                            else:
+                                value = -1  # If value is not in ontology then we do not penalise the model
+                            labs.append(value)
+                            if len(labs) >= max_turns:
+                                break
+                        labs = labs + [-1] * (max_turns - len(labs))
+                        labels.append(labs)
+
+                    labels = torch.tensor(labels)
+                    features['state_labels-' + domslot] = labels
+
+            # Create requestable slot labels
+            if 'dialogue_acts' in dialogues[0][0]:
+                for domslot in tqdm(requestable_slots):
+                    labels = []
+                    for dial in dialogues:
+                        labs = []
+                        for turn in dial:
+                            domain, slot = domslot.split('-', 1)
+                            if turn['dataset_name'] in self.ontology[domain][slot]['dataset_names']:
+                                acts = [act['intent'] for act in turn['dialogue_acts']
+                                        if act['domain'] == domain and act['slot'] == slot]
+                                if acts:
+                                    act_ = acts[0]
+                                    if act_ == 'request':
+                                        labs.append(1)
+                                    else:
+                                        labs.append(0)
+                                else:
+                                    labs.append(0)
+                            else:
+                                labs.append(-1)
+                            if len(labs) >= max_turns:
+                                break
+                        labs = labs + [-1] * (max_turns - len(labs))
+                        labels.append(labs)
+
+                    labels = torch.tensor(labels)
+                    features['request_labels-' + domslot] = labels
+
+                # General act labels (1-goodbye, 2-thank you)
+                labels = []
+                for dial in tqdm(dialogues):
+                    labs = []
+                    for turn in dial:
+                        acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
+                        if acts:
+                            if 'bye' in acts:
+                                labs.append(1)
+                            else:
+                                labs.append(2)
+                        else:
+                            labs.append(0)
+                        if len(labs) >= max_turns:
+                            break
+                    labs = labs + [-1] * (max_turns - len(labs))
+                    labels.append(labs)
+
+                labels = torch.tensor(labels)
+                features['general_act_labels'] = labels
+
+            # Create active domain labels
+            if 'active_domains' in dialogues[0][0]:
+                for domain in tqdm(domains):
+                    labels = []
+                    for dial in dialogues:
+                        labs = []
+                        for turn in dial:
+                            possible_domains = list()
+                            for dom in self.ontology:
+                                for slt in self.ontology[dom]:
+                                    if turn['dataset_name'] in self.ontology[dom][slt]['dataset_names']:
+                                        possible_domains.append(dom)
+
+                            if domain in turn['active_domains']:
+                                labs.append(1)
+                            elif domain in possible_domains:
+                                labs.append(0)
+                            else:
+                                labs.append(-1)
+                            if len(labs) >= max_turns:
+                                break
+                        labs = labs + [-1] * (max_turns - len(labs))
+                        labels.append(labs)
+
+                    labels = torch.tensor(labels)
+                    features['active_domain_labels-' + domain] = labels
+
+            try:
+                del labels
+            except:
+                labels = None
+
+            return features
+
+    return SetSUMBTTokenizer
diff --git a/convlab/dst/setsumbt/modeling/trainer.py b/convlab/dst/setsumbt/modeling/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..50f0ac9d2e6acbe9e8ffa6634f47d395cece9c0d
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/trainer.py
@@ -0,0 +1,681 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SetSUMBT Trainer Class"""
+
+import random
+import os
+from copy import deepcopy
+import pdb
+
+import torch
+from torch.nn import DataParallel
+import numpy as np
+from transformers import get_linear_schedule_with_warmup
+from torch.optim import AdamW
+from tqdm import tqdm, trange
+try:
+    from apex import amp
+except ModuleNotFoundError:
+    print('Apex not used')
+
+from convlab.dst.setsumbt.utils import clear_checkpoints
+from convlab.dst.setsumbt.datasets import JointGoalAccuracy, BeliefStateUncertainty, ActPredictionAccuracy, Metrics
+from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler
+from convlab.dst.setsumbt.utils import EnsembleAggregator
+
+
+class SetSUMBTTrainer:
+    """Trainer class for SetSUMBT Model"""
+
+    def __init__(self,
+                 args,
+                 model,
+                 tokenizer,
+                 train_dataloader,
+                 validation_dataloader,
+                 logger,
+                 tb_writer,
+                 device='cpu'):
+        """
+        Initialise the trainer class.
+
+        Args:
+            args (argparse.Namespace): Arguments passed to the script
+            model (torch.nn.Module): SetSUMBT to be trained/evaluated
+            tokenizer (transformers.PreTrainedTokenizer): Tokenizer used to encode the data
+            train_dataloader (torch.utils.data.DataLoader): Dataloader for training data
+            validation_dataloader (torch.utils.data.DataLoader): Dataloader for validation data
+            logger (logging.Logger): Logger to log training progress
+            tb_writer (tensorboardX.SummaryWriter): Tensorboard writer to log training progress
+            device (str): Device to use for training
+        """
+        self.args = args
+        self.model = model
+        self.tokenizer = tokenizer
+        self.train_dataloader = train_dataloader
+        self.validation_dataloader = validation_dataloader
+        self.logger = logger
+        self.tb_writer = tb_writer
+        self.device = device
+
+        # Initialise metrics
+        if self.validation_dataloader is not None:
+            self.joint_goal_accuracy = JointGoalAccuracy(self.args.dataset, validation_dataloader.dataset.set_type)
+            self.belief_state_uncertainty_metrics = BeliefStateUncertainty()
+            self.ensemble_aggregator = EnsembleAggregator()
+            if self.args.predict_actions:
+                self.request_accuracy = ActPredictionAccuracy('request', binary=True)
+                self.active_domain_accuracy = ActPredictionAccuracy('active_domain', binary=True)
+                self.general_act_accuracy = ActPredictionAccuracy('general_act', binary=False)
+
+        self._set_seed()
+
+        if train_dataloader is not None:
+            self.training_mode(load_slots=True)
+            self._configure_optimiser()
+            self._configure_schedulers()
+
+            # Set up fp16 and multi gpu usage
+            if self.args.fp16:
+                self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
+                                                            opt_level=self.args.fp16_opt_level)
+            if self.args.n_gpu > 1:
+                self.model = DataParallel(self.model)
+
+        # Initialise training parameters
+        self.best_model = Metrics(joint_goal_accuracy=0.0,
+                                  training_loss=np.inf)
+        self.global_step = 0
+        self.epochs_trained = 0
+        self.steps_trained_in_current_epoch = 0
+
+        logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}")
+
+    def _configure_optimiser(self):
+        """Configure the optimiser for training."""
+        assert self.train_dataloader is not None
+        # Group weight decay and no decay parameters in the model
+        no_decay = ["bias", "LayerNorm.weight"]
+        optimizer_grouped_parameters = [
+            {
+                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)
+                           and 'value_embeddings' not in n],
+                "weight_decay": self.args.weight_decay,
+                "lr": self.args.learning_rate
+            },
+            {
+                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)
+                           and 'value_embeddings' not in n],
+                "weight_decay": 0.0,
+                "lr": self.args.learning_rate
+            },
+        ]
+
+        # Initialise the optimizer
+        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
+
+    def _configure_schedulers(self):
+        """Configure the learning rate and temperature schedulers for training."""
+        assert self.train_dataloader is not None
+        # Calculate the total number of training steps to be performed
+        if self.args.max_training_steps > 0:
+            self.args.num_train_epochs = (len(self.train_dataloader) // self.args.gradient_accumulation_steps) + 1
+            self.args.num_train_epochs = self.args.max_training_steps // self.args.num_train_epochs
+        else:
+            self.args.max_training_steps = len(self.train_dataloader) // self.args.gradient_accumulation_steps
+            self.args.max_training_steps *= self.args.num_train_epochs
+
+        if self.args.save_steps <= 0:
+            self.args.save_steps = len(self.train_dataloader) // self.args.gradient_accumulation_steps
+
+        # Initialise linear lr scheduler
+        self.args.num_warmup_steps = int(self.args.max_training_steps * self.args.warmup_proportion)
+        self.lr_scheduler = get_linear_schedule_with_warmup(self.optimizer,
+                                                            num_warmup_steps=self.args.num_warmup_steps,
+                                                            num_training_steps=self.args.max_training_steps)
+
+        # Initialise distillation temp scheduler
+        if self.model.config.loss_function in ['distillation']:
+            self.temp_scheduler = LinearTemperatureScheduler(total_steps=self.args.max_training_steps,
+                                                             base_temp=self.args.annealing_base_temp,
+                                                             cycle_len=self.args.annealing_cycle_len)
+        else:
+            self.temp_scheduler = None
+
+    def _set_seed(self):
+        """Set the seed for reproducibility."""
+        random.seed(self.args.seed)
+        np.random.seed(self.args.seed)
+        torch.manual_seed(self.args.seed)
+        if self.args.n_gpu > 0:
+            torch.cuda.manual_seed_all(self.args.seed)
+        self.logger.info('Seed set to %d.' % self.args.seed)
+
+    @staticmethod
+    def _set_ontology_embeddings(model, slots, load_slots=True):
+        """
+        Set the ontology embeddings for the model.
+
+        Args:
+            model (torch.nn.Module): Model to set the ontology embeddings for.
+            slots (dict): Dictionary of slot names and their corresponding information.
+            load_slots (bool): Whether to load/reload the slot embeddings.
+        """
+        # Get slot and value embeddings
+        values = {slot: slots[slot][1] for slot in slots}
+
+        # Load model ontology
+        if load_slots:
+            slots = {slot: embs for slot, embs in slots.items()}
+            model.add_slot_candidates(slots)
+        try:
+            informable_slot_ids = model.setsumbt.config.informable_slot_ids
+        except AttributeError:
+            informable_slot_ids = model.config.informable_slot_ids
+        for slot in informable_slot_ids:
+            model.add_value_candidates(slot, values[slot], replace=True)
+
+    def set_ontology_embeddings(self, slots, load_slots=True):
+        """
+        Set the ontology embeddings for the model.
+
+        Args:
+            slots (dict): Dictionary of slot names and their corresponding information.
+            load_slots (bool): Whether to load/reload the slot embeddings.
+        """
+        self._set_ontology_embeddings(self.model, slots, load_slots=load_slots)
+
+    def load_state(self):
+        """Load the model, optimiser and schedulers state from a previous run."""
+        if os.path.isfile(os.path.join(self.args.model_name_or_path, 'optimizer.pt')):
+            self.logger.info("Optimizer loaded from previous run.")
+            self.optimizer.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, 'optimizer.pt')))
+            self.lr_scheduler.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, 'lr_scheduler.pt')))
+            if self.temp_scheduler is not None:
+                self.temp_scheduler.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path,
+                                                                            'temp_scheduler.pt')))
+            if self.args.fp16 and os.path.isfile(os.path.join(self.args.model_name_or_path, 'amp.pt')):
+                self.logger.info("FP16 Apex Amp loaded from previous run.")
+                amp.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, 'amp.pt')))
+
+        # Evaluate initialised model
+        if self.args.do_eval:
+            self.eval_mode()
+            metrics = self.evaluate(is_train=True)
+            self.training_mode()
+
+            best_model = metrics
+            best_model.training_loss = np.inf
+
+    def save_state(self):
+        """Save the model, optimiser and schedulers state for future runs."""
+        output_dir = os.path.join(self.args.output_dir, f"checkpoint-{self.global_step}")
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir, exist_ok=True)
+
+        self.tokenizer.save_pretrained(output_dir)
+        if self.args.n_gpu > 1:
+            self.model.module.save_pretrained(output_dir)
+        else:
+            self.model.save_pretrained(output_dir)
+
+        torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
+        torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "lr_scheduler.pt"))
+        if self.temp_scheduler is not None:
+            torch.save(self.temp_scheduler.state_dict(), os.path.join(output_dir, 'temp_scheduler.pt'))
+        if self.args.fp16:
+            torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
+
+        # Remove older training checkpoints
+        clear_checkpoints(self.args.output_dir, self.args.keep_models)
+
+    def training_mode(self, load_slots=False):
+        """
+        Set the model and trainer to training mode.
+
+        Args:
+            load_slots (bool): Whether to load/reload the slot embeddings.
+        """
+        assert self.train_dataloader is not None
+        self.model.train()
+        self.tokenizer.set_setsumbt_ontology(self.train_dataloader.dataset.ontology)
+        self.model.zero_grad()
+        self.set_ontology_embeddings(self.train_dataloader.dataset.ontology_embeddings, load_slots=load_slots)
+
+    def eval_mode(self, load_slots=False):
+        """
+        Set the model and trainer to evaluation mode.
+
+        Args:
+            load_slots (bool): Whether to load/reload the slot embeddings.
+        """
+        self.model.eval()
+        self.model.zero_grad()
+        self.tokenizer.set_setsumbt_ontology(self.validation_dataloader.dataset.ontology)
+        self.set_ontology_embeddings(self.validation_dataloader.dataset.ontology_embeddings, load_slots=load_slots)
+
+    def log_info(self, metrics, logging_stage='update'):
+        """
+        Log information about the training/evaluation.
+
+        Args:
+            metrics (Metrics): Metrics object containing the relevant information.
+            logging_stage (str): The stage of the training/evaluation to log.
+        """
+        if logging_stage == "update":
+            info = f"{self.global_step} steps complete, "
+            info += f"Loss since last update: {metrics.training_loss}."
+            self.logger.info(info)
+            self.logger.info("Validation set statistics:")
+        elif logging_stage == 'training_complete':
+            self.logger.info("Training Complete.")
+            self.logger.info("Validation set statistics:")
+        elif logging_stage == 'dev':
+            self.logger.info("Validation set statistics:")
+            self.logger.info(f"\tLoss: {metrics.validation_loss}")
+        elif logging_stage == 'test':
+            self.logger.info("Test set statistics:")
+            self.logger.info(f"\tLoss: {metrics.validation_loss}")
+        self.logger.info(f"\tJoint Goal Accuracy: {metrics.joint_goal_accuracy}")
+        self.logger.info(f"\tGoal Slot F1 Score: {metrics.slot_f1}")
+        self.logger.info(f"\tGoal Slot Precision: {metrics.slot_precision}")
+        self.logger.info(f"\tGoal Slot Recall: {metrics.slot_recall}")
+        self.logger.info(f"\tJoint Goal ECE: {metrics.joint_goal_ece}")
+        self.logger.info(f"\tJoint Goal L2-Error: {metrics.joint_l2_error}")
+        self.logger.info(f"\tJoint Goal L2-Error Ratio: {metrics.joint_l2_error_ratio}")
+        if 'request_f1' in metrics:
+            self.logger.info(f"\tRequest Action F1 Score: {metrics.request_f1}")
+            self.logger.info(f"\tActive Domain F1 Score: {metrics.active_domain_f1}")
+            self.logger.info(f"\tGeneral Action F1 Score: {metrics.general_act_f1}")
+
+        # Log to tensorboard
+        if logging_stage == "update":
+            self.tb_writer.add_scalar('JointGoalAccuracy/Dev', metrics.joint_goal_accuracy, self.global_step)
+            self.tb_writer.add_scalar('SlotAccuracy/Dev', metrics.slot_accuracy, self.global_step)
+            self.tb_writer.add_scalar('SlotF1/Dev', metrics.slot_f1, self.global_step)
+            self.tb_writer.add_scalar('SlotPrecision/Dev', metrics.slot_precision, self.global_step)
+            self.tb_writer.add_scalar('JointGoalECE/Dev', metrics.joint_goal_ece, self.global_step)
+            self.tb_writer.add_scalar('JointGoalL2ErrorRatio/Dev', metrics.joint_l2_error_ratio, self.global_step)
+            if 'request_f1' in metrics:
+                self.tb_writer.add_scalar('RequestF1Score/Dev', metrics.request_f1, self.global_step)
+                self.tb_writer.add_scalar('ActiveDomainF1Score/Dev', metrics.active_domain_f1, self.global_step)
+                self.tb_writer.add_scalar('GeneralActionF1Score/Dev', metrics.general_act_f1, self.global_step)
+            self.tb_writer.add_scalar('Loss/Dev', metrics.validation_loss, self.global_step)
+
+            if 'belief_state_summary' in metrics:
+                for slot, stats_slot in metrics.belief_state_summary.items():
+                    for key, item in stats_slot.items():
+                        self.tb_writer.add_scalar(f'{key}_{slot}/Dev', item, self.global_step)
+
+    def get_input_dict(self, batch: dict) -> dict:
+        """
+        Get the input dictionary for the model.
+
+        Args:
+            batch (dict): The batch of data to be passed to the model.
+
+        Returns:
+            input_dict (dict): The input dictionary for the model.
+        """
+        input_dict = dict()
+
+        # Add the input ids, token type ids, and attention mask
+        input_dict['input_ids'] = batch['input_ids'].to(self.device)
+        input_dict['token_type_ids'] = batch['token_type_ids'].to(self.device) if 'token_type_ids' in batch else None
+        input_dict['attention_mask'] = batch['attention_mask'].to(self.device) if 'attention_mask' in batch else None
+
+        # Add the labels
+        if any('belief_state' in key for key in batch):
+            input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(self.device)
+                                          for slot in self.model.setsumbt.config.informable_slot_ids
+                                          if ('belief_state-' + slot) in batch}
+            if self.args.predict_actions:
+                input_dict['request_labels'] = {slot: batch['request_probabilities-' + slot].to(self.device)
+                                                for slot in self.model.setsumbt.config.requestable_slot_ids
+                                                if ('request_probabilities-' + slot) in batch}
+                input_dict['active_domain_labels'] = {domain: batch['active_domain_probabilities-' + domain].to(self.device)
+                                                      for domain in self.model.setsumbt.config.domain_ids
+                                                      if ('active_domain_probabilities-' + domain) in batch}
+                input_dict['general_act_labels'] = batch['general_act_probabilities'].to(self.device)
+        else:
+            input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(self.device)
+                                          for slot in self.model.setsumbt.config.informable_slot_ids
+                                          if ('state_labels-' + slot) in batch}
+            if self.args.predict_actions:
+                input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(self.device)
+                                                for slot in self.model.setsumbt.config.requestable_slot_ids
+                                                if ('request_labels-' + slot) in batch}
+                input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(self.device)
+                                                      for domain in self.model.setsumbt.config.domain_ids
+                                                      if ('active_domain_labels-' + domain) in batch}
+                input_dict['general_act_labels'] = batch['general_act_labels'].to(self.device)
+
+        return input_dict
+
+    def train(self):
+        """Train the SetSUMBT model."""
+        # Set the model to training mode
+        self.training_mode(load_slots=True)
+        self.load_state()
+
+        # Log training set up
+        self.logger.info("***** Running training *****")
+        self.logger.info(f"\tNum Batches = {len(self.train_dataloader)}")
+        self.logger.info(f"\tNum Epochs = {self.args.num_train_epochs}")
+        self.logger.info(f"\tGradient Accumulation steps = {self.args.gradient_accumulation_steps}")
+        self.logger.info(f"\tTotal optimization steps = {self.args.max_training_steps}")
+
+        # Check if continuing training from a checkpoint
+        if os.path.exists(self.args.model_name_or_path):
+            try:
+                # set global_step to gobal_step of last saved checkpoint from model path
+                checkpoint_suffix = self.args.model_name_or_path.split("-")[-1].split("/")[0]
+                self.global_step = int(checkpoint_suffix)
+                self.epochs_trained = len(self.train_dataloader) // self.args.gradient_accumulation_steps
+                self.steps_trained_in_current_epoch = self.global_step % self.epochs_trained
+                self.epochs_trained = self.global_step // self.epochs_trained
+
+                self.logger.info("\tContinuing training from checkpoint, will skip to saved global_step")
+                self.logger.info(f"\tContinuing training from epoch {self.epochs_trained}")
+                self.logger.info(f"\tContinuing training from global step {self.global_step}")
+                self.logger.info(f"\tWill skip the first {self.steps_trained_in_current_epoch} steps in the first epoch")
+            except ValueError:
+                self.logger.info(f"\tStarting fine-tuning.")
+
+        # Prepare iterator for training
+        tr_loss, logging_loss = 0.0, 0.0
+        train_iterator = trange(self.epochs_trained, int(self.args.num_train_epochs), desc="Epoch")
+
+        steps_since_last_update = 0
+        # Perform training
+        for e in train_iterator:
+            epoch_iterator = tqdm(self.train_dataloader, desc="Iteration")
+            # Iterate over all batches
+            for step, batch in enumerate(epoch_iterator):
+                # Skip batches already trained on
+                if step < self.steps_trained_in_current_epoch:
+                    continue
+
+                # Extract all label dictionaries from the batch
+                input_dict = self.get_input_dict(batch)
+
+                # Set up temperature scaling for the model
+                if self.temp_scheduler is not None:
+                    self.model.setsumbt.temp = self.temp_scheduler.temp()
+
+                # Forward pass to obtain loss
+                output = self.model(**input_dict)
+
+                if self.args.n_gpu > 1:
+                    output.loss = output.loss.mean()
+
+                # Update step
+                if step % self.args.gradient_accumulation_steps == 0:
+                    output.loss = output.loss / self.args.gradient_accumulation_steps
+                    if self.temp_scheduler is not None:
+                        self.tb_writer.add_scalar('Temp', self.temp_scheduler.temp(), self.global_step)
+                    self.tb_writer.add_scalar('Loss/train', output.loss, self.global_step)
+                    # Backpropogate accumulated loss
+                    if self.args.fp16:
+                        with amp.scale_loss(output.loss, self.optimizer) as scaled_loss:
+                            scaled_loss.backward()
+                            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
+                            self.tb_writer.add_scalar('Scaled_Loss/train', scaled_loss, self.global_step)
+                    else:
+                        output.loss.backward()
+                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
+
+                    # Get learning rate
+                    self.tb_writer.add_scalar('LearningRate', self.optimizer.param_groups[0]['lr'], self.global_step)
+
+                    if output.belief_state_summary:
+                        for slot, stats_slot in output.belief_state_summary.items():
+                            for key, item in stats_slot.items():
+                                self.tb_writer.add_scalar(f'{key}_{slot}/Train', item, self.global_step)
+
+                    # Update model parameters
+                    self.optimizer.step()
+                    self.lr_scheduler.step()
+                    self.model.zero_grad()
+                    if self.temp_scheduler is not None:
+                        self.temp_scheduler.step()
+
+                    tr_loss += output.loss.float().item()
+                    epoch_iterator.set_postfix(loss=output.loss.float().item())
+                    self.global_step += 1
+
+                # Save model checkpoint
+                if self.global_step % self.args.save_steps == 0:
+                    logging_loss = tr_loss - logging_loss
+
+                    # Evaluate model
+                    if self.args.do_eval:
+                        self.eval_mode()
+                        metrics = self.evaluate(is_train=True)
+                        metrics.training_loss = logging_loss / self.args.save_steps
+                        # Log model eval information
+                        self.log_info(metrics)
+                        self.training_mode()
+                    else:
+                        metrics = Metrics(training_loss=logging_loss / self.args.save_steps,
+                                          joint_goal_accuracy=0.0)
+                        self.log_info(metrics)
+
+                    logging_loss = tr_loss
+
+                    try:
+                        # Compute the score of the best model
+                        self.best_model.compute_score(request=self.model.config.user_request_loss_weight,
+                                                      active_domain=self.model.config.active_domain_loss_weight,
+                                                      general_act=self.model.config.user_general_act_loss_weight)
+
+                        # Compute the score of the current model
+                        metrics.compute_score(request=self.model.config.user_request_loss_weight,
+                                              active_domain=self.model.config.active_domain_loss_weight,
+                                              general_act=self.model.config.user_general_act_loss_weight)
+                    except AttributeError:
+                        self.best_model.compute_score()
+                        metrics.compute_score()
+
+                    metrics.training_loss = tr_loss / self.global_step
+
+                    if metrics > self.best_model:
+                        steps_since_last_update = 0
+                        self.logger.info('Model saved.')
+                        self.best_model = deepcopy(metrics)
+
+                        self.save_state()
+                    else:
+                        steps_since_last_update += 1
+                        self.logger.info('Model not saved.')
+
+                # Stop training after max training steps or if the model has not updated for too long
+                if self.args.max_training_steps > 0 and self.global_step > self.args.max_training_steps:
+                    epoch_iterator.close()
+                    break
+                if self.args.patience > 0 and steps_since_last_update >= self.args.patience:
+                    epoch_iterator.close()
+                    break
+
+            self.steps_trained_in_current_epoch = 0
+            self.logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / self.global_step}')
+
+            if self.args.max_training_steps > 0 and self.global_step > self.args.max_training_steps:
+                train_iterator.close()
+                break
+            if self.args.patience > 0 and steps_since_last_update >= self.args.patience:
+                train_iterator.close()
+                self.logger.info(f'Model has not improved for at least {self.args.patience} steps. Training stopped!')
+                break
+
+        # Evaluate final model
+        if self.args.do_eval:
+            self.eval_mode()
+            metrics = self.evaluate(is_train=True)
+            metrics.training_loss = tr_loss / self.global_step
+            self.log_info(metrics, logging_stage='training_complete')
+        else:
+            self.logger.info('Training complete!')
+
+        # Store final model
+        try:
+            self.best_model.compute_score(request=self.model.config.user_request_loss_weight,
+                                          active_domain=self.model.config.active_domain_loss_weight,
+                                          general_act=self.model.config.user_general_act_loss_weight)
+
+            metrics.compute_score(request=self.model.config.user_request_loss_weight,
+                                  active_domain=self.model.config.active_domain_loss_weight,
+                                  general_act=self.model.config.user_general_act_loss_weight)
+        except AttributeError:
+            self.best_model.compute_score()
+            metrics.compute_score()
+
+        metrics.training_loss = tr_loss / self.global_step
+
+        if metrics > self.best_model:
+            self.logger.info('Final model saved.')
+            self.save_state()
+        else:
+            self.logger.info('Final model not saved, as it is not the best performing model.')
+
+    def evaluate(self, save_eval_path=None, is_train=False, save_pred_dist_path=None, draw_calibration_diagram=False):
+        """
+        Evaluates the model on the validation set.
+
+        Args:
+            save_eval_path (str): Path to save the evaluation results.
+            is_train (bool): Whether the evaluation is performed during training.
+            save_pred_dist_path (str): Path to save the predicted distribution.
+            draw_calibration_diagram (bool): Whether to draw the calibration diagram.
+        Returns:
+            Metrics: The evaluation metrics.
+        """
+        save_eval_path = None if is_train else save_eval_path
+        save_pred_dist_path = None if is_train else save_pred_dist_path
+        draw_calibration_diagram = False if is_train else draw_calibration_diagram
+        if not is_train:
+            self.logger.info("***** Running evaluation *****")
+            self.logger.info("  Num Batches = %d", len(self.validation_dataloader))
+
+        eval_loss = 0.0
+        belief_state_summary = dict()
+        self.joint_goal_accuracy._init_session()
+        self.belief_state_uncertainty_metrics._init_session()
+        self.eval_mode(load_slots=True)
+
+        if not is_train:
+            epoch_iterator = tqdm(self.validation_dataloader, desc="Iteration")
+        else:
+            epoch_iterator = self.validation_dataloader
+        for batch in epoch_iterator:
+            with torch.no_grad():
+                input_dict = self.get_input_dict(batch)
+                if not is_train and 'distillation' in self.model.config.loss_function:
+                    input_dict = {key: input_dict[key] for key in ['input_ids', 'attention_mask', 'token_type_ids']}
+                if self.args.ensemble and save_pred_dist_path is not None:
+                    input_dict['reduction'] = 'none'
+                output = self.model(**input_dict)
+                output.loss = output.loss if output.loss is not None else 0.0
+
+            eval_loss += output.loss
+
+            if self.args.ensemble and save_pred_dist_path is not None:
+                self.ensemble_aggregator.add_batch(input_dict, output, batch['dialogue_ids'])
+                output.belief_state = {slot: probs.mean(-2) for slot, probs in output.belief_state.items()}
+                if self.args.predict_actions:
+                    output.request_probabilities = {slot: probs.mean(-1)
+                                                    for slot, probs in output.request_probabilities.items()}
+                    output.active_domain_probabilities = {domain: probs.mean(-1)
+                                                        for domain, probs in output.active_domain_probabilities.items()}
+                    output.general_act_probabilities = output.general_act_probabilities.mean(-2)
+
+            # Accumulate belief state summary across batches
+            if output.belief_state_summary is not None:
+                for slot, slot_summary in output.belief_state_summary.items():
+                    if slot not in belief_state_summary:
+                        belief_state_summary[slot] = dict()
+                    for key, item in slot_summary.items():
+                        if key not in belief_state_summary[slot]:
+                            belief_state_summary[slot][key] = item
+                        else:
+                            if 'min' in key:
+                                belief_state_summary[slot][key] = min(belief_state_summary[slot][key], item)
+                            elif 'max' in key:
+                                belief_state_summary[slot][key] = max(belief_state_summary[slot][key], item)
+                            elif 'mean' in key:
+                                belief_state_summary[slot][key] = (belief_state_summary[slot][key] + item) / 2
+
+            slot_0 = [slot for slot in input_dict['state_labels'].keys()] if 'state_labels' in input_dict else list()
+            slot_0 = slot_0[0] if slot_0 else None
+            if slot_0 is not None:
+                pad_dials, pad_turns = torch.where(input_dict['input_ids'][:, :, 0] == -1)
+                if len(input_dict['state_labels'][slot_0].size()) == 4:
+                    for slot in input_dict['state_labels']:
+                        input_dict['state_labels'][slot] = input_dict['state_labels'][slot].mean(-2).argmax(-1)
+                        input_dict['state_labels'][slot][pad_dials, pad_turns] = -1
+                    if self.args.predict_actions:
+                        for slot in input_dict['request_labels']:
+                            input_dict['request_labels'][slot] = input_dict['request_labels'][slot].mean(-1).round().int()
+                            input_dict['request_labels'][slot][pad_dials, pad_turns] = -1
+                        for domain in input_dict['active_domain_labels']:
+                            input_dict['active_domain_labels'][domain] = input_dict['active_domain_labels'][domain].mean(-1).round().int()
+                            input_dict['active_domain_labels'][domain][pad_dials, pad_turns] = -1
+                        input_dict['general_act_labels'] = input_dict['general_act_labels'].mean(-2).argmax(-1)
+                        input_dict['general_act_labels'][pad_dials, pad_turns] = -1
+            else:
+                input_dict = self.get_input_dict(batch)
+
+            # Add batch to metrics
+            self.belief_state_uncertainty_metrics.add_dialogues(output.belief_state, input_dict['state_labels'])
+
+            predicted_states = self.tokenizer.decode_state_batch(output.belief_state,
+                                                                 output.request_probabilities,
+                                                                 output.active_domain_probabilities,
+                                                                 output.general_act_probabilities,
+                                                                 batch['dialogue_ids'])
+
+            self.joint_goal_accuracy.add_dialogues(predicted_states)
+
+            if self.args.predict_actions:
+                self.request_accuracy.add_dialogues(output.request_probabilities, input_dict['request_labels'])
+                self.active_domain_accuracy.add_dialogues(output.active_domain_probabilities,
+                                                          input_dict['active_domain_labels'])
+                self.general_act_accuracy.add_dialogues({'gen': output.general_act_probabilities},
+                                                        {'gen': input_dict['general_act_labels']})
+
+        # Compute metrics
+        metrics = self.joint_goal_accuracy.evaluate()
+        metrics += self.belief_state_uncertainty_metrics.evaluate()
+        if self.args.predict_actions:
+            metrics += self.request_accuracy.evaluate()
+            metrics += self.active_domain_accuracy.evaluate()
+            metrics += self.general_act_accuracy.evaluate()
+        metrics.validation_loss = eval_loss
+        if belief_state_summary:
+            metrics.belief_state_summary = belief_state_summary
+
+        # Save model predictions
+        if save_eval_path is not None:
+            self.joint_goal_accuracy.save_dialogues(save_eval_path)
+        if save_pred_dist_path is not None:
+            self.ensemble_aggregator.save(save_pred_dist_path)
+        if draw_calibration_diagram:
+            self.belief_state_uncertainty_metrics.draw_calibration_diagram(
+                save_path=self.args.output_dir,
+                validation_split=self.joint_goal_accuracy.validation_split
+            )
+
+        return metrics
diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py
deleted file mode 100644
index 590b2ac7372b26262625d08691a8528ffddd82d2..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/training.py
+++ /dev/null
@@ -1,715 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Training and evaluation utils"""
-
-import random
-import os
-import logging
-from copy import deepcopy
-
-import torch
-from torch.nn import DataParallel
-from torch.distributions import Categorical
-import numpy as np
-from transformers import get_linear_schedule_with_warmup
-from torch.optim import AdamW
-from tqdm import tqdm, trange
-try:
-    from apex import amp
-except:
-    print('Apex not used')
-
-from convlab.dst.setsumbt.utils import clear_checkpoints
-from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler
-
-
-# Load logger and tensorboard summary writer
-def set_logger(logger_, tb_writer_):
-    global logger, tb_writer
-    logger = logger_
-    tb_writer = tb_writer_
-
-
-# Set seeds
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-    logger.info('Seed set to %d.' % args.seed)
-
-
-def set_ontology_embeddings(model, slots, load_slots=True):
-    # Get slot and value embeddings
-    values = {slot: slots[slot][1] for slot in slots}
-
-    # Load model ontology
-    if load_slots:
-        slots = {slot: embs for slot, embs in slots.items()}
-        model.add_slot_candidates(slots)
-    try:
-        informable_slot_ids = model.setsumbt.informable_slot_ids
-    except:
-        informable_slot_ids = model.informable_slot_ids
-    for slot in informable_slot_ids:
-        model.add_value_candidates(slot, values[slot], replace=True)
-
-
-def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None):
-    """
-    Log training statistics.
-
-    Args:
-        global_step: Number of global training steps completed
-        loss: Training loss
-        jg_acc: Joint goal accuracy
-        sl_acc: Slot accuracy
-        req_f1: Request prediction F1 score
-        dom_f1: Active domain prediction F1 score
-        gen_f1: General action prediction F1 score
-        stats: Uncertainty measure statistics of model
-    """
-    if type(global_step) == int:
-        info = f"{global_step} steps complete, "
-        info += f"Loss since last update: {loss}. Validation set stats: "
-    elif global_step == 'training_complete':
-        info = f"Training Complete, "
-        info += f"Validation set stats: "
-    elif global_step == 'dev':
-        info = f"Validation set stats: Loss: {loss}, "
-    elif global_step == 'test':
-        info = f"Test set stats: Loss: {loss}, "
-    info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, "
-    if req_f1 is not None:
-        info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, "
-        info += f"General Action F1 Score: {gen_f1}"
-    logger.info(info)
-
-    if type(global_step) == int:
-        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
-        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
-        if req_f1 is not None:
-            tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
-            tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step)
-            tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step)
-        tb_writer.add_scalar('Loss/Dev', loss, global_step)
-
-        if stats:
-            for slot, stats_slot in stats.items():
-                for key, item in stats_slot.items():
-                    tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
-
-
-def get_input_dict(batch: dict,
-                   predict_actions: bool,
-                   model_informable_slot_ids: list,
-                   model_requestable_slot_ids: list = None,
-                   model_domain_ids: list = None,
-                   device = 'cpu') -> dict:
-    """
-    Produce model input arguments
-
-    Args:
-        batch: Batch of data from the dataloader
-        predict_actions: Model should predict user actions if set true
-        model_informable_slot_ids: List of model dialogue state slots
-        model_requestable_slot_ids: List of model requestable slots
-        model_domain_ids: List of model domains
-        device: Current torch device in use
-
-    Returns:
-        input_dict: Dictrionary containing model inputs for the batch
-    """
-    input_dict = dict()
-
-    input_dict['input_ids'] = batch['input_ids'].to(device)
-    input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-    input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-    if any('belief_state' in key for key in batch):
-        input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device)
-                                      for slot in model_informable_slot_ids
-                                      if ('belief_state-' + slot) in batch}
-        if predict_actions:
-            input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device)
-                                            for slot in model_requestable_slot_ids
-                                            if ('request_probs-' + slot) in batch}
-            input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device)
-                                                  for domain in model_domain_ids
-                                                  if ('active_domain_probs-' + domain) in batch}
-            input_dict['general_act_labels'] = batch['general_act_probs'].to(device)
-    else:
-        input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device)
-                                      for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch}
-        if predict_actions:
-            input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device)
-                                            for slot in model_requestable_slot_ids
-                                            if ('request_labels-' + slot) in batch}
-            input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device)
-                                                  for domain in model_domain_ids
-                                                  if ('active_domain_labels-' + domain) in batch}
-            input_dict['general_act_labels'] = batch['general_act_labels'].to(device)
-
-    return input_dict
-
-
-def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict):
-    """
-    Train the SetSUMBT model.
-
-    Args:
-        args: Runtime arguments
-        model: SetSUMBT Model instance to train
-        device: Torch device to use during training
-        train_dataloader: Dataloader containing the training data
-        dev_dataloader: Dataloader containing the validation set data
-        slots: Model ontology used for training
-        slots_dev: Model ontology used for evaluating on the validation set
-    """
-
-    # Calculate the total number of training steps to be performed
-    if args.max_training_steps > 0:
-        t_total = args.max_training_steps
-        args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1
-        args.num_train_epochs = args.max_training_steps // args.num_train_epochs
-    else:
-        t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs
-
-    if args.save_steps <= 0:
-        args.save_steps = len(train_dataloader) // args.gradient_accumulation_steps
-
-    # Group weight decay and no decay parameters in the model
-    no_decay = ["bias", "LayerNorm.weight"]
-    optimizer_grouped_parameters = [
-        {
-            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
-            "weight_decay": args.weight_decay,
-            "lr": args.learning_rate
-        },
-        {
-            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
-            "weight_decay": 0.0,
-            "lr": args.learning_rate
-        },
-    ]
-
-    # Initialise the optimizer
-    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
-
-    # Initialise linear lr scheduler
-    num_warmup_steps = int(t_total * args.warmup_proportion)
-    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps,
-                                                num_training_steps=t_total)
-
-    # Initialise distillation temp scheduler
-    if model.config.loss_function in ['distillation']:
-        temp_scheduler = TemperatureScheduler(total_steps=t_total, base_temp=args.annealing_base_temp,
-                                              cycle_len=args.annealing_cycle_len)
-    else:
-        temp_scheduler = None
-
-    # Set up fp16 and multi gpu usage
-    if args.fp16:
-        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
-    if args.n_gpu > 1:
-        model = DataParallel(model)
-
-    # Load optimizer checkpoint if available
-    best_model = {'joint goal accuracy': 0.0,
-                  'request f1 score': 0.0,
-                  'active domain f1 score': 0.0,
-                  'general act f1 score': 0.0,
-                  'train loss': np.inf}
-    if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')):
-        logger.info("Optimizer loaded from previous run.")
-        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
-        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
-        if temp_scheduler is not None:
-            temp_scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'temp_scheduler.pt')))
-        if args.fp16 and os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')):
-            logger.info("FP16 Apex Amp loaded from previous run.")
-            amp.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'amp.pt')))
-
-        # Evaluate initialised model
-        if args.do_eval:
-            # Set up model for evaluation
-            model.eval()
-            set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
-
-            jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True)
-
-            # Set model back to training mode
-            model.train()
-            model.zero_grad()
-            set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
-        else:
-            jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0
-
-        best_model['joint goal accuracy'] = jg_acc
-        best_model['request f1 score'] = req_f1
-        best_model['active domain f1 score'] = dom_f1
-        best_model['general act f1 score'] = gen_f1
-
-    # Log training set up
-    logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}")
-    logger.info("***** Running training *****")
-    logger.info(f"  Num Batches = {len(train_dataloader)}")
-    logger.info(f"  Num Epochs = {args.num_train_epochs}")
-    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
-    logger.info(f"  Total optimization steps = {t_total}")
-
-    # Initialise training parameters
-    global_step = 0
-    epochs_trained = 0
-    steps_trained_in_current_epoch = 0
-
-    # Check if continuing training from a checkpoint
-    if os.path.exists(args.model_name_or_path):
-        try:
-            # set global_step to gobal_step of last saved checkpoint from model path
-            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
-            global_step = int(checkpoint_suffix)
-            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
-            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
-
-            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
-            logger.info(f"  Continuing training from epoch {epochs_trained}")
-            logger.info(f"  Continuing training from global step {global_step}")
-            logger.info(f"  Will skip the first {steps_trained_in_current_epoch} steps in the first epoch")
-        except ValueError:
-            logger.info(f"  Starting fine-tuning.")
-
-    # Prepare model for training
-    tr_loss, logging_loss = 0.0, 0.0
-    model.train()
-    model.zero_grad()
-    train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch")
-
-    steps_since_last_update = 0
-    # Perform training
-    for e in train_iterator:
-        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
-        # Iterate over all batches
-        for step, batch in enumerate(epoch_iterator):
-            # Skip batches already trained on
-            if step < steps_trained_in_current_epoch:
-                continue
-
-            # Extract all label dictionaries from the batch
-            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
-                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
-
-            # Set up temperature scaling for the model
-            if temp_scheduler is not None:
-                model.setsumbt.temp = temp_scheduler.temp()
-
-            # Forward pass to obtain loss
-            loss, _, _, _, _, _, stats = model(**input_dict)
-
-            if args.n_gpu > 1:
-                loss = loss.mean()
-
-            # Update step
-            if step % args.gradient_accumulation_steps == 0:
-                loss = loss / args.gradient_accumulation_steps
-                if temp_scheduler is not None:
-                    tb_writer.add_scalar('Temp', temp_scheduler.temp(), global_step)
-                tb_writer.add_scalar('Loss/train', loss, global_step)
-                # Backpropogate accumulated loss
-                if args.fp16:
-                    with amp.scale_loss(loss, optimizer) as scaled_loss:
-                        scaled_loss.backward()
-                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
-                        tb_writer.add_scalar('Scaled_Loss/train', scaled_loss, global_step)
-                else:
-                    loss.backward()
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-
-                # Get learning rate
-                lr = optimizer.param_groups[0]['lr']
-                tb_writer.add_scalar('LearningRate', lr, global_step)
-
-                if stats:
-                    for slot, stats_slot in stats.items():
-                        for key, item in stats_slot.items():
-                            tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step)
-
-                # Update model parameters
-                optimizer.step()
-                scheduler.step()
-                model.zero_grad()
-
-                if temp_scheduler is not None:
-                    temp_scheduler.step()
-
-                tr_loss += loss.float().item()
-                epoch_iterator.set_postfix(loss=loss.float().item())
-                global_step += 1
-
-            # Save model checkpoint
-            if global_step % args.save_steps == 0:
-                logging_loss = tr_loss - logging_loss
-
-                # Evaluate model
-                if args.do_eval:
-                    # Set up model for evaluation
-                    model.eval()
-                    set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
-
-                    jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
-                                                                                   is_train=True)
-                    # Log model eval information
-                    log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats)
-
-                    # Set model back to training mode
-                    model.train()
-                    model.zero_grad()
-                    set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
-                else:
-                    log_info(global_step, logging_loss / args.save_steps)
-
-                logging_loss = tr_loss
-
-                # Compute the score of the best model
-                try:
-                    best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
-                    best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
-                    best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
-                except AttributeError:
-                    best_score = 0.0
-                best_score += best_model['joint goal accuracy']
-
-                # Compute the score of the current model
-                try:
-                    current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
-                    current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
-                    current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
-                except AttributeError:
-                    current_score = 0.0
-                current_score += jg_acc
-
-                # Decide whether to update the model
-                if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0:
-                    update = True
-                elif current_score > best_score and current_score > 0.0:
-                    update = True
-                elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0:
-                    update = True
-                else:
-                    update = False
-
-                if update:
-                    steps_since_last_update = 0
-                    logger.info('Model saved.')
-                    best_model['joint goal accuracy'] = jg_acc
-                    if req_f1:
-                        best_model['request f1 score'] = req_f1
-                        best_model['active domain f1 score'] = dom_f1
-                        best_model['general act f1 score'] = gen_f1
-                    best_model['train loss'] = tr_loss / global_step
-
-                    output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
-                    if not os.path.exists(output_dir):
-                        os.makedirs(output_dir, exist_ok=True)
-
-                    if args.n_gpu > 1:
-                        model.module.save_pretrained(output_dir)
-                    else:
-                        model.save_pretrained(output_dir)
-
-                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
-                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
-                    if temp_scheduler is not None:
-                        torch.save(temp_scheduler.state_dict(), os.path.join(output_dir, 'temp_scheduler.pt'))
-                    if args.fp16:
-                        torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
-
-                    # Remove older training checkpoints
-                    clear_checkpoints(args.output_dir, args.keep_models)
-                else:
-                    steps_since_last_update += 1
-                    logger.info('Model not saved.')
-
-            # Stop training after max training steps or if the model has not updated for too long
-            if args.max_training_steps > 0 and global_step > args.max_training_steps:
-                epoch_iterator.close()
-                break
-            if args.patience > 0 and steps_since_last_update >= args.patience:
-                epoch_iterator.close()
-                break
-
-        steps_trained_in_current_epoch = 0
-        logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}')
-
-        if args.max_training_steps > 0 and global_step > args.max_training_steps:
-            train_iterator.close()
-            break
-        if args.patience > 0 and steps_since_last_update >= args.patience:
-            train_iterator.close()
-            logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!')
-            break
-
-    # Evaluate final model
-    if args.do_eval:
-        model.eval()
-        set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
-
-        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
-                                                                       is_train=True)
-
-        log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
-    else:
-        logger.info('Training complete!')
-
-    # Store final model
-    try:
-        best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
-        best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
-        best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
-    except AttributeError:
-        best_score = 0.0
-    best_score += best_model['joint goal accuracy']
-    try:
-        current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
-        current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
-        current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
-    except AttributeError:
-        current_score = 0.0
-    current_score += jg_acc
-    if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0:
-        update = True
-    elif current_score > best_score and current_score > 0.0:
-        update = True
-    elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0:
-        update = True
-    else:
-        update = False
-
-    if update:
-        logger.info('Final model saved.')
-        output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
-        if not os.path.exists(output_dir):
-            os.makedirs(output_dir)
-
-        if args.n_gpu > 1:
-            model.module.save_pretrained(output_dir)
-        else:
-            model.save_pretrained(output_dir)
-
-        torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
-        torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
-        if temp_scheduler is not None:
-            torch.save(temp_scheduler.state_dict(), os.path.join(output_dir, 'temp_scheduler.pt'))
-        if args.fp16:
-            torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
-        clear_checkpoints(args.output_dir)
-    else:
-        logger.info('Final model not saved, as it is not the best performing model.')
-
-
-def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False):
-    """
-    Evaluate model
-
-    Args:
-        args: Runtime arguments
-        model: SetSUMBT model instance
-        device: Torch device in use
-        dataloader: Dataloader of data to evaluate on
-        return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format
-        is_train: If true model is training and no logging is performed
-
-    Returns:
-        out: Evaluated model statistics
-    """
-    return_eval_output = False if is_train else return_eval_output
-    if not is_train:
-        logger.info("***** Running evaluation *****")
-        logger.info("  Num Batches = %d", len(dataloader))
-
-    tr_loss = 0.0
-    model.eval()
-    if return_eval_output:
-        ontology = dataloader.dataset.ontology
-
-    accuracy_jg = []
-    accuracy_sl = []
-    truepos_req, falsepos_req, falseneg_req = [], [], []
-    truepos_dom, falsepos_dom, falseneg_dom = [], [], []
-    truepos_gen, falsepos_gen, falseneg_gen = [], [], []
-    turns = []
-    if return_eval_output:
-        evaluation_output = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader
-    for batch in epoch_iterator:
-        with torch.no_grad():
-            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
-                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
-
-            loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict)
-
-        jg_acc = 0.0
-        num_inform_slots = 0.0
-        req_acc = 0.0
-        req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
-        dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
-        dom_acc = 0.0
-
-        if return_eval_output:
-            eval_output_batch = []
-            for dial_id, dial in enumerate(input_dict['input_ids']):
-                for turn_id, turn in enumerate(dial):
-                    if turn.sum() != 0:
-                        eval_output_batch.append({'dial_idx': dial_id,
-                                                  'utt_idx': turn_id,
-                                                  'state': dict(),
-                                                  'predictions': {'state': dict()}
-                                                  })
-
-        for slot in model.setsumbt.informable_slot_ids:
-            p_ = p[slot]
-            state_labels = batch['state_labels-' + slot].to(device)
-
-            if return_eval_output:
-                prediction = p_.argmax(-1)
-
-                for sample in eval_output_batch:
-                    dom, slt = slot.split('-', 1)
-                    lab = state_labels[sample['dial_idx']][sample['utt_idx']].item()
-                    lab = ontology[dom][slt]['possible_values'][lab] if lab != -1 else 'NOT_IN_ONTOLOGY'
-                    pred = prediction[sample['dial_idx']][sample['utt_idx']].item()
-                    pred = ontology[dom][slt]['possible_values'][pred]
-
-                    if dom not in sample['state']:
-                        sample['state'][dom] = dict()
-                        sample['predictions']['state'][dom] = dict()
-
-                    sample['state'][dom][slt] = lab if lab != 'none' else ''
-                    sample['predictions']['state'][dom][slt] = pred if pred != 'none' else ''
-
-            if args.temp_scaling > 0.0:
-                p_ = torch.log(p_ + 1e-10) / args.temp_scaling
-                p_ = torch.softmax(p_, -1)
-            else:
-                p_ = torch.log(p_ + 1e-10) / 1.0
-                p_ = torch.softmax(p_, -1)
-
-            acc = (p_.argmax(-1) == state_labels).reshape(-1).float()
-
-            jg_acc += acc
-            num_inform_slots += (state_labels != -1).float().reshape(-1)
-
-        if return_eval_output:
-            for sample in eval_output_batch:
-                sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']]
-                evaluation_output.append(deepcopy(sample))
-            eval_output_batch = []
-
-        if model.config.predict_actions:
-            for slot in model.setsumbt.requestable_slot_ids:
-                p_req_ = p_req[slot]
-                request_labels = batch['request_labels-' + slot].to(device)
-
-                acc = (p_req_.round().int() == request_labels).reshape(-1).float()
-                tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
-                fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float()
-                req_acc += acc
-                req_tp += tp
-                req_fp += fp
-                req_fn += fn
-
-            domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch]
-            for domain in domains:
-                p_dom_ = p_dom[domain]
-                active_domain_labels = batch['active_domain_labels-' + domain].to(device)
-
-                acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float()
-                tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float()
-                fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float()
-                dom_acc += acc
-                dom_tp += tp
-                dom_fp += fp
-                dom_fn += fn
-
-            general_act_labels = batch['general_act_labels'].to(device)
-            gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum()
-            gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum()
-            gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum()
-        else:
-            req_tp, req_fp, req_fn = None, None, None
-            dom_tp, dom_fp, dom_fn = None, None, None
-            gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
-
-        jg_acc = jg_acc[num_inform_slots > 0]
-        num_inform_slots = num_inform_slots[num_inform_slots > 0]
-        sl_acc = sum(jg_acc / num_inform_slots).float()
-        jg_acc = sum((jg_acc == num_inform_slots).int()).float()
-        if req_tp is not None and model.setsumbt.requestable_slot_ids:
-            req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float()
-            req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float()
-            req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float()
-        else:
-            req_tp, req_fp, req_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
-        dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
-        dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
-        dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
-        n_turns = num_inform_slots.size(0)
-
-        accuracy_jg.append(jg_acc.item())
-        accuracy_sl.append(sl_acc.item())
-        truepos_req.append(req_tp.item())
-        falsepos_req.append(req_fp.item())
-        falseneg_req.append(req_fn.item())
-        truepos_dom.append(dom_tp.item())
-        falsepos_dom.append(dom_fp.item())
-        falseneg_dom.append(dom_fn.item())
-        truepos_gen.append(gen_tp.item())
-        falsepos_gen.append(gen_fp.item())
-        falseneg_gen.append(gen_fn.item())
-        turns.append(n_turns)
-        tr_loss += loss.item()
-
-    # Global accuracy reduction across batches
-    turns = sum(turns)
-    jg_acc = sum(accuracy_jg) / turns
-    sl_acc = sum(accuracy_sl) / turns
-    if model.config.predict_actions:
-        req_tp = sum(truepos_req)
-        req_fp = sum(falsepos_req)
-        req_fn = sum(falseneg_req)
-        req_f1 = req_tp + 0.5 * (req_fp + req_fn)
-        req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0
-        dom_tp = sum(truepos_dom)
-        dom_fp = sum(falsepos_dom)
-        dom_fn = sum(falseneg_dom)
-        dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn)
-        dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0
-        gen_tp = sum(truepos_gen)
-        gen_fp = sum(falsepos_gen)
-        gen_fn = sum(falseneg_gen)
-        gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn)
-        gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0
-    else:
-        req_f1, dom_f1, gen_f1 = None, None, None
-
-    if return_eval_output:
-        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output
-    if is_train:
-        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats
-    return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader)
diff --git a/convlab/dst/setsumbt/predict_user_actions.py b/convlab/dst/setsumbt/predict_user_actions.py
deleted file mode 100644
index 2c304a569cb5e29920332ed21c8f862dd00c1e48..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/predict_user_actions.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Predict dataset user action using SetSUMBT Model"""
-
-from copy import deepcopy
-import os
-import json
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-from convlab.util.custom_util import flatten_acts as flatten
-from convlab.util import load_dataset, load_policy_data
-from convlab.dst.setsumbt import SetSUMBTTracker
-
-
-def flatten_acts(acts: dict) -> list:
-    """
-    Flatten dictionary actions.
-
-    Args:
-        acts: Dictionary acts
-
-    Returns:
-        flat_acts: Flattened actions
-    """
-    acts = flatten(acts)
-    flat_acts = []
-    for intent, domain, slot, value in acts:
-        flat_acts.append([intent,
-                          domain,
-                          slot if slot != 'none' else '',
-                          value.lower() if value != 'none' else ''])
-
-    return flat_acts
-
-
-def get_user_actions(context: list, system_acts: list) -> list:
-    """
-    Extract user actions from the data.
-
-    Args:
-        context: Previous dialogue turns.
-        system_acts: List of flattened system actions.
-
-    Returns:
-        user_acts: List of flattened user actions.
-    """
-    user_acts = context[-1]['dialogue_acts']
-    user_acts = flatten_acts(user_acts)
-    if len(context) == 3:
-        prev_state = context[-3]['state']
-        cur_state = context[-1]['state']
-        for domain, substate in cur_state.items():
-            for slot, value in substate.items():
-                if prev_state[domain][slot] != value:
-                    act = ['inform', domain, slot, value]
-                    if act not in user_acts and act not in system_acts:
-                        user_acts.append(act)
-
-    return user_acts
-
-
-def extract_dataset(dataset: str = 'multiwoz21') -> list:
-    """
-    Extract acts and utterances from the dataset.
-
-    Args:
-        dataset: Dataset name
-
-    Returns:
-        data: Extracted data
-    """
-    data = load_dataset(dataset_name=dataset)
-    raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test']
-
-    dialogue = list()
-    data = list()
-    for turn in raw_data:
-        state = dict()
-        state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else ''
-        state['utterance'] = turn['context'][-1]['utterance']
-        state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {}
-        state['system_actions'] = flatten_acts(state['system_actions'])
-        state['user_actions'] = get_user_actions(turn['context'], state['system_actions'])
-        dialogue.append(state)
-        if turn['terminated']:
-            data.append(dialogue)
-            dialogue = list()
-
-    return data
-
-
-def unflatten_acts(acts: list) -> dict:
-    """
-    Convert acts from flat list format to dict format.
-
-    Args:
-        acts: List of flat actions.
-
-    Returns:
-        unflat_acts: Dictionary of acts.
-    """
-    binary_acts = []
-    cat_acts = []
-    for intent, domain, slot, value in acts:
-        include = True if (domain == 'general') or (slot != 'none') else False
-        if include and (value == '' or value == 'none' or intent == 'request'):
-            binary_acts.append({'intent': intent,
-                                'domain': domain,
-                                'slot': slot if slot != 'none' else ''})
-        elif include:
-            cat_acts.append({'intent': intent,
-                             'domain': domain,
-                             'slot': slot if slot != 'none' else '',
-                             'value': value})
-
-    unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()}
-
-    return unflat_acts
-
-
-def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list:
-    """
-    Predict the user actions using the SetSUMBT Tracker.
-
-    Args:
-        data: List of dialogues.
-        tracker: SetSUMBT Tracker
-
-    Returns:
-        predict_result: List of turns containing predictions and true user actions.
-    """
-    tracker.init_session()
-    predict_result = []
-    for dial_idx, dialogue in enumerate(data):
-        for turn_idx, state in enumerate(dialogue):
-            sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx}
-
-            tracker.state['history'].append(['sys', state['system_utterance']])
-            predicted_state = deepcopy(tracker.update(state['utterance']))
-            tracker.state['history'].append(['usr', state['utterance']])
-            tracker.state['system_action'] = state['system_actions']
-
-            sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])}
-            sample['dialogue_acts'] = unflatten_acts(state['user_actions'])
-
-            predict_result.append(sample)
-
-        tracker.init_session()
-
-    return predict_result
-
-
-if __name__ =="__main__":
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
-    parser.add_argument('--model_path', type=str, help='Path to model dir')
-    args = parser.parse_args()
-
-    dataset = extract_dataset(args.dataset_name)
-    tracker = SetSUMBTTracker(args.model_path)
-    predict_results = predict_user_acts(dataset, tracker)
-
-    with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer:
-        json.dump(predict_results, writer, indent=2)
-        writer.close()
diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py
index e45bf129f0c9f2c5c1fba01d4b5eb80e29a5a1f0..d017fd8e824884df11548d4b466fdcae8f97e925 100644
--- a/convlab/dst/setsumbt/run.py
+++ b/convlab/dst/setsumbt/run.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,12 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Run"""
+"""Run SetSUMBT belief tracker training and evaluation."""
 
-from transformers import BertConfig, RobertaConfig
+import logging
+import os
+from shutil import copy2 as copy
+from copy import deepcopy
 
-from convlab.dst.setsumbt.utils import get_args
+import torch
+import transformers
+from transformers import BertConfig, RobertaConfig
+from tensorboardX import SummaryWriter
+from tqdm import tqdm
 
+from convlab.dst.setsumbt.modeling import SetSUMBTModels, SetSUMBTTrainer
+from convlab.dst.setsumbt.datasets import (get_dataloader, change_batch_size, dataloader_sample_dialogues,
+                                           get_distillation_dataloader)
+from convlab.dst.setsumbt.utils import get_args, update_args, setup_ensemble
 
 MODELS = {
     'bert': (BertConfig, "BertTokenizer"),
@@ -27,15 +38,277 @@ MODELS = {
 
 
 def main():
-    # Get arguments
     args, config = get_args(MODELS)
 
-    if args.run_nbt:
-        from convlab.dst.setsumbt.do.nbt import main
-        main(args, config)
-    if args.run_evaluation:
-        from convlab.dst.setsumbt.do.evaluate import main
-        main(args, config)
+    if args.model_type in SetSUMBTModels:
+        SetSumbtModel, OntologyEncoderModel, ConfigClass, Tokenizer = SetSUMBTModels[args.model_type]
+        if args.ensemble:
+            SetSumbtModel, _, _, _ = SetSUMBTModels['ensemble']
+    else:
+        raise NameError('NotImplemented')
+
+    # Set up output directory
+    OUTPUT_DIR = args.output_dir
+
+    if not os.path.exists(OUTPUT_DIR):
+        os.makedirs(OUTPUT_DIR)
+        os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
+    args.output_dir = OUTPUT_DIR
+
+    # Set pretrained model path to the trained checkpoint
+    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
+    if 'pytorch_model.bin' in paths and 'config.json' in paths:
+        args.model_name_or_path = args.output_dir
+        config = ConfigClass.from_pretrained(args.model_name_or_path)
+    elif 'ens-0' in paths:
+        paths = [p for p in os.listdir(os.path.join(args.output_dir, 'ens-0')) if 'checkpoint-' in p]
+        if paths:
+            args.model_name_or_path = os.path.join(args.output_dir)
+            config = ConfigClass.from_pretrained(os.path.join(args.model_name_or_path, 'ens-0', paths[0]))
+    else:
+        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
+        if paths:
+            paths = paths[0]
+            args.model_name_or_path = paths
+            config = ConfigClass.from_pretrained(args.model_name_or_path)
+
+    args = update_args(args, config)
+
+    # Create TensorboardX writer
+    tb_writer = SummaryWriter(logdir=args.tensorboard_path)
+
+    # Create logger
+    logger = logging.getLogger(__name__)
+    logger.setLevel(logging.INFO)
+
+    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
+
+    fh = logging.FileHandler(args.logging_path)
+    fh.setLevel(logging.INFO)
+    fh.setFormatter(formatter)
+    logger.addHandler(fh)
+
+    # Get device
+    if torch.cuda.is_available() and args.n_gpu > 0:
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+        args.n_gpu = 0
+
+    if args.n_gpu == 0:
+        args.fp16 = False
+
+    # Initialise Model
+    transformers.utils.logging.set_verbosity_info()
+    model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
+    model = model.to(device)
+
+    if args.ensemble:
+        args.model_name_or_path = model._get_checkpoint_path(args.model_name_or_path, 0)
+
+    # Create Tokenizer and embedding model for Data Loaders and ontology
+    tokenizer = Tokenizer.from_pretrained(args.model_name_or_path)
+    encoder = OntologyEncoderModel.from_pretrained(config.candidate_embedding_model_name,
+                                                   args=args, tokenizer=tokenizer)
+
+    transformers.utils.logging.set_verbosity_error()
+    if args.do_ensemble_setup:
+        # Build all dataloaders
+        train_dataloader = get_dataloader(args.dataset,
+                                          'train',
+                                          args.train_batch_size,
+                                          tokenizer,
+                                          encoder,
+                                          args.max_dialogue_len,
+                                          args.max_turn_len,
+                                          train_ratio=args.dataset_train_ratio,
+                                          seed=args.seed)
+        torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+        dev_dataloader = get_dataloader(args.dataset,
+                                        'validation',
+                                        args.dev_batch_size,
+                                        tokenizer,
+                                        encoder,
+                                        args.max_dialogue_len,
+                                        args.max_turn_len,
+                                        train_ratio=args.dataset_train_ratio,
+                                        seed=args.seed)
+        torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        test_dataloader = get_dataloader(args.dataset,
+                                         'test',
+                                         args.test_batch_size,
+                                         tokenizer,
+                                         encoder,
+                                         args.max_dialogue_len,
+                                         args.max_turn_len,
+                                         train_ratio=args.dataset_train_ratio,
+                                         seed=args.seed)
+        torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        setup_ensemble(OUTPUT_DIR, args.ensemble_size)
+
+        logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
+        dataloaders = [dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
+                       for _ in tqdm(range(args.ensemble_size))]
+        logger.info('Dataloaders built.')
+
+        for i, loader in enumerate(dataloaders):
+            path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
+            if not os.path.exists(path):
+                os.mkdir(path)
+            path = os.path.join(path, 'dataloaders', 'train.dataloader')
+            torch.save(loader, path)
+        logger.info('Dataloaders saved.')
+
+        # Do not perform standard training after ensemble setup is created
+        return 0
+
+    # Perform tasks
+    # TRAINING
+    if args.do_train:
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
+            train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            if train_dataloader.batch_size != args.train_batch_size:
+                train_dataloader = change_batch_size(train_dataloader, args.train_batch_size)
+        else:
+            if args.data_sampling_size <= 0:
+                args.data_sampling_size = None
+            if 'distillation' not in config.loss_function:
+                train_dataloader = get_dataloader(args.dataset,
+                                                  'train',
+                                                  args.train_batch_size,
+                                                  tokenizer,
+                                                  encoder,
+                                                  args.max_dialogue_len,
+                                                  config.max_turn_len,
+                                                  resampled_size=args.data_sampling_size,
+                                                  train_ratio=args.dataset_train_ratio,
+                                                  seed=args.seed)
+            else:
+                loader_args = {"ensemble_path": args.ensemble_model_path,
+                               "set_type": "train",
+                               "batch_size": args.train_batch_size,
+                               "reduction": "mean" if config.loss_function == 'distillation' else "none"}
+                train_dataloader = get_distillation_dataloader(**loader_args)
+            torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+
+        # Get development set batch loaders= and ontology embeddings
+        if args.do_eval:
+            if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
+                dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+                if dev_dataloader.batch_size != args.dev_batch_size:
+                    dev_dataloader = change_batch_size(dev_dataloader, args.dev_batch_size)
+            else:
+                if 'distillation' not in config.loss_function:
+                    dev_dataloader = get_dataloader(args.dataset,
+                                                    'validation',
+                                                    args.dev_batch_size,
+                                                    tokenizer,
+                                                    encoder,
+                                                    args.max_dialogue_len,
+                                                    config.max_turn_len)
+                else:
+                    loader_args = {"ensemble_path": args.ensemble_model_path,
+                                   "set_type": "dev",
+                                   "batch_size": args.dev_batch_size,
+                                   "reduction": "mean" if config.loss_function == 'distillation' else "none"}
+                    dev_dataloader = get_distillation_dataloader(**loader_args)
+                torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        else:
+            dev_dataloader = None
+
+        # TRAINING !!!!!!!!!!!!!!!!!!
+        trainer = SetSUMBTTrainer(args, model, tokenizer, train_dataloader, dev_dataloader, logger, tb_writer,
+                                  device)
+        trainer.train()
+
+        # Copy final best model to the output dir
+        checkpoints = os.listdir(OUTPUT_DIR)
+        checkpoints = [p for p in checkpoints if 'checkpoint' in p]
+        checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
+        best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
+        files = ['pytorch_model.bin', 'config.json', 'merges.txt', 'special_tokens_map.json',
+                 'tokenizer_config.json', 'vocab.json']
+        for file in files:
+            copy(os.path.join(best_checkpoint, file), os.path.join(OUTPUT_DIR, file))
+
+        # Load best model for evaluation
+        tokenizer = Tokenizer.from_pretrained(OUTPUT_DIR)
+        model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
+        model = model.to(device)
+
+    # Evaluation on the training set
+    if args.do_eval_trainset:
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
+            train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            if train_dataloader.batch_size != args.train_batch_size:
+                train_dataloader = change_batch_size(train_dataloader, args.train_batch_size)
+        else:
+            train_dataloader = get_dataloader(args.dataset, 'train', args.train_batch_size, tokenizer,
+                                              encoder, args.max_dialogue_len, config.max_turn_len)
+            torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+
+        # EVALUATION
+        trainer = SetSUMBTTrainer(args, model, tokenizer, None, train_dataloader, logger, tb_writer, device)
+        trainer.eval_mode(load_slots=True)
+
+        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+        save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'train.data') if args.ensemble else None
+        metrics = trainer.evaluate(save_pred_dist_path=save_pred_dist_path)
+        trainer.log_info(metrics, logging_stage='dev')
+
+    # Evaluation on the development set
+    if args.do_eval:
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
+            dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+            if dev_dataloader.batch_size != args.dev_batch_size:
+                dev_dataloader = change_batch_size(dev_dataloader, args.dev_batch_size)
+        else:
+            dev_dataloader = get_dataloader(args.dataset, 'validation', args.dev_batch_size, tokenizer,
+                                            encoder, args.max_dialogue_len, config.max_turn_len)
+            torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+
+        # EVALUATION
+        trainer = SetSUMBTTrainer(args, model, tokenizer, None, dev_dataloader, logger, tb_writer, device)
+        trainer.eval_mode(load_slots=True)
+
+        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+        save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'dev.data') if args.ensemble else None
+        metrics = trainer.evaluate(save_eval_path=os.path.join(OUTPUT_DIR, 'predictions', 'dev.json'),
+                                   save_pred_dist_path=save_pred_dist_path)
+        trainer.log_info(metrics, logging_stage='dev')
+
+    # Evaluation on the test set
+    if args.do_test:
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size != args.test_batch_size:
+                test_dataloader = change_batch_size(test_dataloader, args.test_batch_size)
+        else:
+            test_dataloader = get_dataloader(args.dataset, 'test', args.test_batch_size, tokenizer,
+                                             encoder, args.max_dialogue_len, config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        trainer = SetSUMBTTrainer(args, model, tokenizer, None, test_dataloader, logger, tb_writer, device)
+        trainer.eval_mode(load_slots=True)
+
+        # TESTING
+        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+
+        save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'test.data') if args.ensemble else None
+        metrics = trainer.evaluate(save_eval_path=os.path.join(OUTPUT_DIR, 'predictions', 'test.json'),
+                                   save_pred_dist_path=save_pred_dist_path, draw_calibration_diagram=True)
+        trainer.log_info(metrics, logging_stage='test')
+
+        # Save final model for inference
+        if not args.ensemble:
+            trainer.model.save_pretrained(OUTPUT_DIR)
+            trainer.tokenizer.save_pretrained(OUTPUT_DIR)
+
+    tb_writer.close()
 
 
 if __name__ == "__main__":
diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py
index f56bbadc2f4d8fdca102b2bbc996acb0ae5a4a58..5126fd3439c4dd77a2f27720bdd68f7c0f7e947a 100644
--- a/convlab/dst/setsumbt/tracker.py
+++ b/convlab/dst/setsumbt/tracker.py
@@ -1,16 +1,28 @@
-import os
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Run SetSUMBT belief tracker training and evaluation."""
+
 import copy
 import logging
 
 import torch
 import transformers
-from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer
 
-from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT
-from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings
+from convlab.dst.setsumbt.modeling import SetSUMBTModels
 from convlab.dst.dst import DST
-from convlab.util.custom_util import model_downloader
 
 USE_CUDA = torch.cuda.is_available()
 transformers.logging.set_verbosity_error()
@@ -20,7 +32,7 @@ class SetSUMBTTracker(DST):
     """SetSUMBT Tracker object for Convlab dialogue system"""
 
     def __init__(self,
-                 model_path: str = "",
+                 model_name_or_path: str = "",
                  model_type: str = "roberta",
                  return_turn_pooled_representation: bool = False,
                  return_confidence_scores: bool = False,
@@ -30,7 +42,7 @@ class SetSUMBTTracker(DST):
                  store_full_belief_state: bool = True):
         """
         Args:
-            model_path: Model path or download URL
+            model_name_or_path: Path to pretrained model or name of pretrained model
             model_type: Transformer type (roberta/bert)
             return_turn_pooled_representation: If true a turn level pooled representation is returned
             return_confidence_scores: If true act confidence scores are included in the state
@@ -42,7 +54,7 @@ class SetSUMBTTracker(DST):
         super(SetSUMBTTracker, self).__init__()
 
         self.model_type = model_type
-        self.model_path = model_path
+        self.model_name_or_path = model_name_or_path
         self.return_turn_pooled_representation = return_turn_pooled_representation
         self.return_confidence_scores = return_confidence_scores
         self.confidence_threshold = confidence_threshold
@@ -53,41 +65,24 @@ class SetSUMBTTracker(DST):
             self.full_belief_state = {}
         self.info_dict = {}
 
-        # Download model if needed
-        if not os.path.exists(self.model_path):
-            # Get path /.../convlab/dst/setsumbt/multiwoz/models
-            download_path = os.path.dirname(os.path.abspath(__file__))
-            download_path = os.path.join(download_path, 'models')
-            if not os.path.exists(download_path):
-                os.mkdir(download_path)
-            model_downloader(download_path, self.model_path)
-            # Downloadable model path format http://.../model_name.zip
-            self.model_path = self.model_path.split('/')[-1].replace('.zip', '')
-            self.model_path = os.path.join(download_path, self.model_path)
+        if self.model_type in SetSUMBTModels:
+            self.model, _, self.config, self.tokenizer = SetSUMBTModels[self.model_type]
+        else:
+            raise NameError('NotImplemented')
 
         # Select model type based on the encoder
-        if model_type == "roberta":
-            self.config = RobertaConfig.from_pretrained(self.model_path)
-            self.tokenizer = RobertaTokenizer
-            self.model = RobertaSetSUMBT
-        elif model_type == "bert":
-            self.config = BertConfig.from_pretrained(self.model_path)
-            self.tokenizer = BertTokenizer
-            self.model = BertSetSUMBT
-        else:
-            logging.debug("Name Error: Not Implemented")
+        self.config = self.config.from_pretrained(self.model_name_or_path)
 
         self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
-
         self.load_weights()
 
     def load_weights(self):
         """Load model weights and model ontology"""
         logging.info('Loading SetSUMBT pretrained model.')
-        self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name)
-        logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.')
-        self.model = self.model.from_pretrained(self.model_path, config=self.config)
-        logging.info(f'Model loaded from {self.model_path}.')
+        self.tokenizer = self.tokenizer.from_pretrained(self.model_name_or_path)
+        logging.info(f'Model tokenizer loaded from {self.model_name_or_path}.')
+        self.model = self.model.from_pretrained(self.model_name_or_path, config=self.config)
+        logging.info(f'Model loaded from {self.model_name_or_path}.')
 
         # Transfer model to compute device and setup eval environment
         self.model = self.model.to(self.device)
@@ -95,12 +90,7 @@ class SetSUMBTTracker(DST):
         logging.info(f'Model transferred to device: {self.device}')
 
         logging.info('Loading model ontology')
-        f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r')
-        self.ontology = json.load(f)
-        f.close()
-
-        db = torch.load(os.path.join(self.model_path, 'database', 'test.db'))
-        set_ontology_embeddings(self.model, db)
+        self.ontology = self.tokenizer.ontology
 
         if self.return_confidence_scores:
             logging.info('Model returns user action and belief state confidence scores.')
@@ -138,6 +128,7 @@ class SetSUMBTTracker(DST):
         return self.confidence_thresholds
 
     def init_session(self):
+        """Initialize dialogue state"""
         self.state = dict()
         self.state['belief_state'] = dict()
         self.state['booked'] = dict()
@@ -157,21 +148,21 @@ class SetSUMBTTracker(DST):
 
     def update(self, user_act: str = '') -> dict:
         """
-        Update user actions and dialogue and belief states.
+        Update dialogue state based on user utterance.
 
         Args:
-            user_act:
+            user_act: User utterance
 
         Returns:
-
+            state: Dialogue state
         """
         prev_state = self.state
-        _output = self.predict(self.get_features(user_act))
+        outputs = self.predict(self.get_features(user_act))
 
         # Format state entropy
-        if _output[5] is not None:
+        if outputs.state_entropy is not None:
             state_entropy = dict()
-            for slot, e in _output[5].items():
+            for slot, e in outputs.state_entropy.items():
                 domain, slot = slot.split('-', 1)
                 if domain not in state_entropy:
                     state_entropy[domain] = dict()
@@ -180,9 +171,9 @@ class SetSUMBTTracker(DST):
             state_entropy = None
 
         # Format state mutual information
-        if _output[6] is not None:
+        if outputs.belief_state_mutual_information is not None:
             state_mutual_info = dict()
-            for slot, mi in _output[6].items():
+            for slot, mi in outputs.belief_state_mutual_information.items():
                 domain, slot = slot.split('-', 1)
                 if domain not in state_mutual_info:
                     state_mutual_info[domain] = dict()
@@ -192,9 +183,9 @@ class SetSUMBTTracker(DST):
 
         # Format all confidence scores
         belief_state_confidence = None
-        if _output[4] is not None:
+        if outputs.confidence_scores is not None:
             belief_state_confidence = dict()
-            belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4]
+            belief_state_conf, request_probs, active_domain_probs, general_act_probs = outputs.confidence_scores
             for slot, p in belief_state_conf.items():
                 domain, slot = slot.split('-', 1)
                 if domain not in belief_state_confidence:
@@ -221,16 +212,16 @@ class SetSUMBTTracker(DST):
             belief_state_confidence['general']['none'] = general_act_probs
 
         # Get new domain activation actions
-        new_domains = [d for d, active in _output[1].items() if active]
+        new_domains = [d for d, active in outputs.state['active_domains'].items() if active]
         new_domains = [d for d in new_domains if not self.active_domains.get(d, False)]
-        self.active_domains = _output[1]
+        self.active_domains = outputs.state['active_domains']
 
-        user_acts = _output[2]
+        user_acts = outputs.state['user_action']
         for domain in new_domains:
             user_acts.append(['inform', domain, 'none', 'none'])
 
         new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        for domain, substate in _output[0].items():
+        for domain, substate in outputs.state['belief_state'].items():
             for slot, value in substate.items():
                 value = '' if value == 'none' else value
                 value = 'dontcare' if value == 'do not care' else value
@@ -268,17 +259,17 @@ class SetSUMBTTracker(DST):
         user_acts = [act for act in user_acts if act not in new_state['system_action']]
         new_state['user_action'] = user_acts
 
-        if _output[3] is not None:
-            new_state['turn_pooled_representation'] = _output[3]
+        if outputs.turn_pooled_representation is not None:
+            new_state['turn_pooled_representation'] = outputs.turn_pooled_representation.reshape(-1)
 
         self.state = new_state
-        self.info_dict = copy.deepcopy(dict(new_state))
+        self.info_dict['belief_state'] = copy.deepcopy(dict(new_state))
 
         return self.state
 
     def predict(self, features: dict) -> tuple:
         """
-        Model forward pass and prediction post processing.
+        Model forward pass and prediction post-processing.
 
         Args:
             features: Dictionary of model input features
@@ -288,96 +279,51 @@ class SetSUMBTTracker(DST):
         """
         state_mutual_info = None
         with torch.no_grad():
-            turn_pooled_representation = None
-            if self.return_turn_pooled_representation:
-                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
-                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
-                                      get_turn_pooled_representation=True)
-                belief_state = _outputs[0]
-                request_probs = _outputs[1]
-                active_domain_probs = _outputs[2]
-                general_act_probs = _outputs[3]
-                self.hidden_states = _outputs[4]
-                turn_pooled_representation = _outputs[5]
-            elif self.return_belief_state_mutual_info:
-                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
-                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
-                                      get_turn_pooled_representation=True, calculate_state_mutual_info=True)
-                belief_state = _outputs[0]
-                request_probs = _outputs[1]
-                active_domain_probs = _outputs[2]
-                general_act_probs = _outputs[3]
-                self.hidden_states = _outputs[4]
-                state_mutual_info = _outputs[5]
-            else:
-                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
-                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
-                                      get_turn_pooled_representation=False)
-                belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs
+            features['hidden_state'] = self.hidden_states
+            features['get_turn_pooled_representation'] = self.return_turn_pooled_representation
+            features['calculate_state_mutual_info'] = self.return_belief_state_mutual_info
+            outputs = self.model(**features)
+            self.hidden_states = outputs.hidden_state
 
         # Convert belief state into dialog state
-        dialogue_state = dict()
-        for slot, probs in belief_state.items():
-            dom, slot = slot.split('-', 1)
-            if dom not in dialogue_state:
-                dialogue_state[dom] = dict()
-            val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()]
-            if val != 'none':
-                dialogue_state[dom][slot] = val
+        state = self.tokenizer.decode_state_batch(outputs.belief_state, outputs.request_probabilities,
+                                                  outputs.active_domain_probabilities,
+                                                  outputs.general_act_probabilities)
+        state = state['000000'][0]
 
         if self.store_full_belief_state:
-            self.info_dict['belief_state_distributions'] = belief_state
+            self.info_dict['belief_state_distributions'] = outputs.belief_state
             if state_mutual_info is not None:
-                self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info
+                self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information
 
         # Obtain model output probabilities
         if self.return_confidence_scores:
             state_entropy = None
             if self.return_belief_state_entropy:
-                state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()}
+                state_entropy = {slot: probs[0, 0, :] for slot, probs in outputs.belief_state.items()}
                 state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()}
 
             # Confidence score is the max probability across all not "none" values candidates.
-            belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()}
-            _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()}
-            _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()}
-            _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()}
+            belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in outputs.belief_state.items()}
+            _request_probs = {slot: p[0, 0].item() for slot, p in outputs.request_probabilities.items()}
+            _active_domain_probs = {domain: p[0, 0].item() for domain, p in outputs.active_domain_probabilities.items()}
+            _general_act_probs = {'bye': outputs.general_act_probabilities[0, 0, 1].item(),
+                                  'thank': outputs.general_act_probabilities[0, 0, 2].item()}
             confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs)
         else:
             confidence_scores = None
             state_entropy = None
 
-        # Construct request action prediction
-        if request_probs is not None:
-            request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
-            request_acts = [slot.split('-', 1) for slot in request_acts]
-            request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
-        else:
-            request_acts = list()
-
-        # Construct active domain set
-        if active_domain_probs is not None:
-            active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
-        else:
-            active_domains = dict()
-
-        # Construct general domain action
-        if general_act_probs is not None:
-            general_acts = general_act_probs[0, 0, :].argmax(-1).item()
-            general_acts = [[], ['bye'], ['thank']][general_acts]
-            general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
-        else:
-            general_acts = list()
+        outputs.confidence_scores = confidence_scores
+        outputs.state_entropy = state_entropy
+        outputs.state = state
+        outputs.belief_state = None
+        return outputs
 
-        user_acts = request_acts + general_acts
-
-        out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores)
-        out += (state_entropy, state_mutual_info)
-        return out
-
-    def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor:
+    @staticmethod
+    def relative_entropy(probs: torch.Tensor) -> torch.Tensor:
         """
-        Compute relative entrop for a probability distribution
+        Compute relative entropy for a probability distribution
 
         Args:
             probs: Probability distributions
@@ -412,18 +358,17 @@ class SetSUMBTTracker(DST):
         else:
             system_act = ''
 
+        dialogue = [[{
+            'user_utterance': user_act,
+            'system_utterance': system_act
+        }]]
+
         # Tokenize dialog
-        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True,
-                                              max_length=self.config.max_turn_len, padding='max_length',
-                                              truncation='longest_first')
-
-        input_ids = torch.tensor(features['input_ids']).reshape(
-            1, 1, -1).to(self.device) if 'input_ids' in features else None
-        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
-            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
-        attention_mask = torch.tensor(features['attention_mask']).reshape(
-            1, 1, -1).to(self.device) if 'attention_mask' in features else None
-        features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
+        features = self.tokenizer.encode(dialogue, max_seq_len=self.config.max_turn_len, max_turns=1)
+
+        for key in features:
+            if features[key] is not None:
+                features[key] = features[key].to(self.device)
 
         return features
 
@@ -431,7 +376,7 @@ class SetSUMBTTracker(DST):
 # if __name__ == "__main__":
 #     from convlab.policy.vector.vector_uncertainty import VectorUncertainty
 #     # from convlab.policy.vector.vector_binary import VectorBinary
-#     tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
+#     tracker = SetSUMBTTracker(model_name_or_path='setsumbt_multiwoz21',
 #                               return_confidence_scores=True, confidence_threshold='auto',
 #                               return_belief_state_entropy=True)
 #     vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds,
diff --git a/convlab/dst/setsumbt/utils/__init__.py b/convlab/dst/setsumbt/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a96b9f8878cc81ea1e30e77deaf37366ee6348bc
--- /dev/null
+++ b/convlab/dst/setsumbt/utils/__init__.py
@@ -0,0 +1,2 @@
+from convlab.dst.setsumbt.utils.configuration import get_args, update_args, clear_checkpoints
+from convlab.dst.setsumbt.utils.ensemble import setup_ensemble, EnsembleAggregator
diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils/configuration.py
similarity index 88%
rename from convlab/dst/setsumbt/utils.py
rename to convlab/dst/setsumbt/utils/configuration.py
index ff374116a3f8e88e6219fdc8b134d40b0bee7caf..bd318a1bcbd76200d120e942f65a33294caab569 100644
--- a/convlab/dst/setsumbt/utils.py
+++ b/convlab/dst/setsumbt/utils/configuration.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""SetSUMBT utils"""
+"""SetSUMBT configuration utilities."""
 
 import os
 import json
@@ -25,6 +25,12 @@ from git import Repo
 
 
 def get_args(base_models: dict):
+    """
+    Get arguments from command line and config file.
+
+    Args:
+        base_models: Dictionary of base models to use for ensemble training
+    """
     # Get arguments
     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
 
@@ -46,7 +52,6 @@ def get_args(base_models: dict):
     parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
     parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12,
                         type=int)
-    parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.')
     parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int)
     parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings',
                         action='store_true')
@@ -56,6 +61,7 @@ def get_args(base_models: dict):
     parser.add_argument('--output_dir', help='Output storage directory', default=None)
     parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta')
     parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None)
+    parser.add_argument('--ensemble_model_path', help='Path to ensemble model', default=None)
     parser.add_argument('--candidate_embedding_model_name', default=None,
                         help='Name of the pretrained candidate embedding model.')
     parser.add_argument('--transformers_local_files_only', help='Use local files only for huggingface transformers',
@@ -140,13 +146,11 @@ def get_args(base_models: dict):
                              "See details at https://nvidia.github.io/apex/amp.html")
 
     # ACTIONS
-    parser.add_argument('--run_nbt', help='Run NBT script', action='store_true')
-    parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true')
-
-    # RUN_NBT ACTIONS
     parser.add_argument('--do_train', help='Perform training', action='store_true')
     parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true')
+    parser.add_argument('--do_eval_trainset', help='Evaluate model on training data', action='store_true')
     parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
+    parser.add_argument('--do_ensemble_setup', help='Setup the dataloaders for ensemble training', action='store_true')
     args = parser.parse_args()
 
     if args.starting_config_name:
@@ -162,10 +166,13 @@ def get_args(base_models: dict):
 
     # Setup default directories
     if not args.output_dir:
-        args.output_dir = os.path.dirname(os.path.abspath(__file__))
+        args.output_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         args.output_dir = os.path.join(args.output_dir, 'models')
 
-        name = 'SetSUMBT' if args.set_similarity else 'SUMBT'
+        name = 'Ensemble-' if args.do_ensemble_setup else ''
+        name += 'EnD-' if args.loss_function == 'distillation' else ''
+        name += 'EnD2-' if args.loss_function == 'distribution_distillation' else ''
+        name += 'SetSUMBT' if args.set_similarity else 'SUMBT'
         name += '+ActPrediction' if args.predict_actions else ''
         name += '-' + args.dataset
         name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else ''
@@ -230,13 +237,25 @@ def get_args(base_models: dict):
 
 
 def get_starting_config(args):
-    path = os.path.dirname(os.path.realpath(__file__))
+    """
+    Load a config file and update the args with the values from the config file.
+
+    Args:
+        args: The args object to update.
+    """
+    path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
     path = os.path.join(path, 'configs', f"{args.starting_config_name}.json")
     reader = open(path, 'r')
     config = json.load(reader)
     reader.close()
 
     if "model_type" in config:
+        if 'ensemble-' in config["model_type"].lower():
+            args.ensemble = True
+            config["model_type"] = config["model_type"].lower().replace('ensemble-', '')
+        else:
+            args.ensemble = False
+
         if config["model_type"].lower() == 'setsumbt':
             config["model_type"] = 'roberta'
             config["no_set_similarity"] = False
@@ -255,6 +274,7 @@ def get_starting_config(args):
 
 
 def get_git_info():
+    """Get the git info of the current branch and commit hash"""
     repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True)
     branch_name = repo.active_branch.name
     commit_hex = repo.head.object.hexsha
@@ -264,17 +284,21 @@ def get_git_info():
 
 
 def build_config(config_class, args):
+    """
+    Build a config object from the args.
+
+    Args:
+        config_class: The config class to use.
+        args: The args object to use.
+
+    Returns:
+        The config object.
+    """
     config = config_class.from_pretrained(args.model_name_or_path)
     config.code_version = get_git_info()
-    if not os.path.exists(args.model_name_or_path):
-        config.tokenizer_name = args.model_name_or_path
-    try:
-        config.tokenizer_name = config.tokenizer_name
-    except AttributeError:
-        config.tokenizer_name = args.model_name_or_path
     try:
         config.candidate_embedding_model_name = config.candidate_embedding_model_name
-    except:
+    except AttributeError:
         if args.candidate_embedding_model_name:
             config.candidate_embedding_model_name = args.candidate_embedding_model_name
     config.max_dialogue_len = args.max_dialogue_len
@@ -302,8 +326,6 @@ def build_config(config_class, args):
     if config.loss_function == 'bayesianmatching':
         config.kl_scaling_factor = args.kl_scaling_factor
         config.prior_constant = args.prior_constant
-    if config.loss_function == 'inhibitedce':
-        config.inhibiting_factor = args.inhibiting_factor
     if config.loss_function == 'labelsmoothing':
         config.label_smoothing = args.label_smoothing
     if config.loss_function == 'distillation':
@@ -320,6 +342,16 @@ def build_config(config_class, args):
 
 
 def update_args(args, config):
+    """
+    Update the args with the values from the config file.
+
+    Args:
+        args: The args object to update.
+        config: The config object to use.
+
+    Returns:
+        The updated args object.
+    """
     try:
         args.candidate_embedding_model_name = config.candidate_embedding_model_name
     except AttributeError:
@@ -342,6 +374,13 @@ def update_args(args, config):
 
 
 def clear_checkpoints(path, topn=1):
+    """
+    Clear all checkpoints except the top n.
+
+    Args:
+        path: The path to the checkpoints.
+        topn: The number of checkpoints to keep.
+    """
     checkpoints = os.listdir(path)
     checkpoints = [p for p in checkpoints if 'checkpoint' in p]
     checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
diff --git a/convlab/dst/setsumbt/utils/ensemble.py b/convlab/dst/setsumbt/utils/ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc59cbf4751c5e46b8efa4e86165e9eb4416d65
--- /dev/null
+++ b/convlab/dst/setsumbt/utils/ensemble.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Ensemble setup ad inference utils."""
+
+import os
+from shutil import copy2 as copy
+
+import torch
+import numpy as np
+
+def setup_ensemble(model_path: str, ensemble_size: int):
+    """
+    Setup ensemble model directory structure.
+
+    Args:
+        model_path: Path to ensemble model directory
+        ensemble_size: Number of ensemble members
+    """
+    for i in range(ensemble_size):
+        path = os.path.join(model_path, f'ens-{i}')
+        if not os.path.exists(path):
+            os.mkdir(path)
+            os.mkdir(os.path.join(path, 'dataloaders'))
+            # Add development set dataloader to each ensemble member directory
+            for set_type in ['dev']:
+                copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'),
+                     os.path.join(path, 'dataloaders', f'{set_type}.dataloader'))
+
+
+class EnsembleAggregator:
+    """Aggregator for ensemble model outputs."""
+
+    def __init__(self):
+        self.init_session()
+        self.input_items = ['input_ids', 'attention_mask', 'token_type_ids']
+        self.output_items = ['belief_state', 'request_probabilities', 'active_domain_probabilities',
+                             'general_act_probabilities']
+
+    def init_session(self):
+        """Initialize aggregator for new session."""
+        self.features = dict()
+
+    def add_batch(self, model_input: dict, model_output: dict, dialogue_ids=None):
+        """
+        Add batch of model outputs to aggregator.
+
+        Args:
+            model_input: Model input dictionary
+            model_output: Model output dictionary
+            dialogue_ids: List of dialogue ids
+        """
+        for key in self.input_items:
+            if key in model_input:
+                if key not in self.features:
+                    self.features[key] = list()
+                self.features[key].append(model_input[key])
+
+        for key in self.output_items:
+            if key in model_output:
+                if key not in self.features:
+                    self.features[key] = list()
+                self.features[key].append(model_output[key])
+
+        if dialogue_ids is not None:
+            if 'dialogue_ids' not in self.features:
+                self.features['dialogue_ids'] = [np.array([list(itm) for itm in dialogue_ids]).T]
+            else:
+                self.features['dialogue_ids'].append(np.array([list(itm) for itm in dialogue_ids]).T)
+
+    def _aggregate(self):
+        """Aggregate model outputs."""
+        for key in self.features:
+            self.features[key] = self._aggregate_item(self.features[key])
+
+    @staticmethod
+    def _aggregate_item(item):
+        """
+        Aggregate single model output item.
+
+        Args:
+            item: Model output item
+
+        Returns:
+            Aggregated model output item
+        """
+        if item[0] is None:
+            return None
+        elif type(item[0]) == dict:
+            return {k: EnsembleAggregator._aggregate_item([i[k] for i in item]) for k in item[0]}
+        elif type(item[0]) == np.ndarray:
+            return np.concatenate(item, 0)
+        else:
+            return torch.cat(item, 0)
+
+    def save(self, path):
+        """
+        Save aggregated model outputs to file.
+
+        Args:
+            path: Path to save file
+        """
+        self._aggregate()
+        torch.save(self.features, path)
diff --git a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json
index a80c04c9656dbdb361de1ca74e3ca24db028b1cf..03220dd3d3331b9d8324e5800b580c044b225bb2 100644
--- a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json
+++ b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json
@@ -26,9 +26,9 @@
 	"nlu_sys": {},
 	"dst_sys": {
 		"setsumbt-mul": {
-			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
+			"class_path": "convlab.dst.setsumbt.tracker.SetSUMBTTracker",
 			"ini_params": {
-				"model_path": "https://huggingface.co/ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2/resolve/main/SetSUMBT-nlu-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0.zip",
+				"model_path": "ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2",
 				"return_confidence_scores": true,
 				"return_belief_state_mutual_info": true
 			}
diff --git a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json
index bf9211006b6e2623016acfec18573768f73558fd..36f2d46d0c2ae32ca65c04e303c3967a9a56e53a 100644
--- a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json
+++ b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json
@@ -24,9 +24,9 @@
 	"nlu_sys": {},
 	"dst_sys": {
 		"setsumbt-mul": {
-			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
+			"class_path": "convlab.dst.setsumbt.tracker.SetSUMBTTracker",
 			"ini_params": {
-				"model_path": "https://huggingface.co/ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2/resolve/main/SetSUMBT-nlu-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0.zip"
+				"model_name_or_path": "ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2"
 			}
 		}
 	},