From e4d716f52c9df83a47c63c2d70b0ee7b9d5ace45 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <vniekerk.carel@gmail.com>
Date: Wed, 16 Nov 2022 11:41:47 +0100
Subject: [PATCH] Merge new setsumbt code into github repo copy

---
 convlab/dst/setsumbt/__init__.py              |    1 +
 convlab/dst/setsumbt/calibration_plots.py     |   12 +-
 convlab/dst/setsumbt/dataset/__init__.py      |    2 +
 convlab/dst/setsumbt/dataset/ontology.py      |  133 +
 .../dst/setsumbt/dataset/unified_format.py    |  423 +++
 convlab/dst/setsumbt/dataset/utils.py         |  409 +++
 convlab/dst/setsumbt/dataset/value_maps.py    |   50 +
 convlab/dst/setsumbt/distillation_setup.py    |  253 +-
 convlab/dst/setsumbt/do/calibration.py        |  481 ---
 convlab/dst/setsumbt/do/evaluate.py           |  296 ++
 convlab/dst/setsumbt/do/nbt.py                |  352 +-
 convlab/dst/setsumbt/loss/__init__.py         |    4 +
 convlab/dst/setsumbt/loss/bayesian.py         |  144 -
 .../dst/setsumbt/loss/bayesian_matching.py    |  115 +
 convlab/dst/setsumbt/loss/distillation.py     |  201 --
 convlab/dst/setsumbt/loss/endd_loss.py        |  314 +-
 convlab/dst/setsumbt/loss/kl_distillation.py  |  104 +
 convlab/dst/setsumbt/loss/labelsmoothing.py   |  115 +-
 .../loss/{ece.py => uncertainty_measures.py}  |  136 +-
 convlab/dst/setsumbt/modeling/__init__.py     |    4 +-
 convlab/dst/setsumbt/modeling/bert_nbt.py     |  102 +-
 .../setsumbt/modeling/calibration_utils.py    |  134 -
 convlab/dst/setsumbt/modeling/ensemble_nbt.py |  242 +-
 .../dst/setsumbt/modeling/evaluation_utils.py |  112 +
 convlab/dst/setsumbt/modeling/functional.py   |  456 ---
 convlab/dst/setsumbt/modeling/roberta_nbt.py  |  105 +-
 convlab/dst/setsumbt/modeling/setsumbt.py     |  564 +++
 .../modeling/temperature_scheduler.py         |   68 +-
 convlab/dst/setsumbt/modeling/training.py     |  663 ++--
 convlab/dst/setsumbt/multiwoz/Tracker.py      |  455 ---
 convlab/dst/setsumbt/multiwoz/__init__.py     |    2 -
 .../setsumbt/multiwoz/dataset/mapping.pair    |   83 -
 .../setsumbt/multiwoz/dataset/multiwoz21.py   |  502 ---
 .../setsumbt/multiwoz/dataset/mwoz21_ont.json | 2990 ----------------
 .../multiwoz/dataset/mwoz21_ont_request.json  | 3128 -----------------
 .../dataset/mwoz21_slot_descriptions.json     |   57 -
 .../dst/setsumbt/multiwoz/dataset/ontology.py |  168 -
 .../dst/setsumbt/multiwoz/dataset/utils.py    |  446 ---
 convlab/dst/setsumbt/predict_user_actions.py  |  178 +
 convlab/dst/setsumbt/process_mwoz_data.py     |   99 -
 convlab/dst/setsumbt/run.py                   |    4 +-
 convlab/dst/setsumbt/tracker.py               |  446 +++
 convlab/dst/setsumbt/utils.py                 |  234 +-
 convlab/policy/mle/loader.py                  |   37 +-
 convlab/policy/mle/train.py                   |   37 +-
 ...eline_config.json => setsumbt_config.json} |   30 +-
 convlab/policy/ppo/setsumbt_unc_config.json   |   65 +
 convlab/policy/ppo/train.py                   |   12 +-
 convlab/policy/vector/dataset.py              |   20 -
 convlab/policy/vector/vector_base.py          |   13 +-
 convlab/policy/vector/vector_binary.py        |    2 +-
 .../vector/vector_multiwoz_uncertainty.py     |  238 --
 convlab/policy/vector/vector_nodes.py         |   56 +-
 convlab/policy/vector/vector_uncertainty.py   |  166 +
 convlab/util/custom_util.py                   |  105 +-
 55 files changed, 4524 insertions(+), 11044 deletions(-)
 create mode 100644 convlab/dst/setsumbt/dataset/__init__.py
 create mode 100644 convlab/dst/setsumbt/dataset/ontology.py
 create mode 100644 convlab/dst/setsumbt/dataset/unified_format.py
 create mode 100644 convlab/dst/setsumbt/dataset/utils.py
 create mode 100644 convlab/dst/setsumbt/dataset/value_maps.py
 delete mode 100644 convlab/dst/setsumbt/do/calibration.py
 create mode 100644 convlab/dst/setsumbt/do/evaluate.py
 create mode 100644 convlab/dst/setsumbt/loss/__init__.py
 delete mode 100644 convlab/dst/setsumbt/loss/bayesian.py
 create mode 100644 convlab/dst/setsumbt/loss/bayesian_matching.py
 delete mode 100644 convlab/dst/setsumbt/loss/distillation.py
 create mode 100644 convlab/dst/setsumbt/loss/kl_distillation.py
 rename convlab/dst/setsumbt/loss/{ece.py => uncertainty_measures.py} (50%)
 delete mode 100644 convlab/dst/setsumbt/modeling/calibration_utils.py
 create mode 100644 convlab/dst/setsumbt/modeling/evaluation_utils.py
 delete mode 100644 convlab/dst/setsumbt/modeling/functional.py
 create mode 100644 convlab/dst/setsumbt/modeling/setsumbt.py
 delete mode 100644 convlab/dst/setsumbt/multiwoz/Tracker.py
 delete mode 100644 convlab/dst/setsumbt/multiwoz/__init__.py
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/ontology.py
 delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/utils.py
 create mode 100644 convlab/dst/setsumbt/predict_user_actions.py
 delete mode 100755 convlab/dst/setsumbt/process_mwoz_data.py
 create mode 100644 convlab/dst/setsumbt/tracker.py
 rename convlab/policy/ppo/{setsumbt_end_baseline_config.json => setsumbt_config.json} (53%)
 create mode 100644 convlab/policy/ppo/setsumbt_unc_config.json
 delete mode 100644 convlab/policy/vector/vector_multiwoz_uncertainty.py
 create mode 100644 convlab/policy/vector/vector_uncertainty.py

diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py
index e69de29b..9492faa9 100644
--- a/convlab/dst/setsumbt/__init__.py
+++ b/convlab/dst/setsumbt/__init__.py
@@ -0,0 +1 @@
+from convlab.dst.setsumbt.tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py
index 379057e6..a41f280d 100644
--- a/convlab/dst/setsumbt/calibration_plots.py
+++ b/convlab/dst/setsumbt/calibration_plots.py
@@ -35,7 +35,7 @@ def main():
     path = args.data_dir
 
     models = os.listdir(path)
-    models = [os.path.join(path, model, 'test.belief') for model in models]
+    models = [os.path.join(path, model, 'test.predictions') for model in models]
 
     fig = plt.figure(figsize=(14,8))
     font=20
@@ -56,16 +56,16 @@ def main():
 
 
 def get_calibration(path, device, n_bins=10, temperature=1.00):
-    logits = torch.load(path, map_location=device)
-    y_true = logits['labels']
-    logits = logits['belief_states']
+    probs = torch.load(path, map_location=device)
+    y_true = probs['state_labels']
+    probs = probs['belief_states']
 
-    y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
+    y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs}
     goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
     goal_acc = sum([goal_acc[slot] for slot in goal_acc])
     goal_acc = (goal_acc == len(y_true)).int()
 
-    scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
+    scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs]
     scores = torch.cat(scores, 0).min(0)[0]
 
     step = 1.0 / float(n_bins)
diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py
new file mode 100644
index 00000000..17b1f93b
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/__init__.py
@@ -0,0 +1,2 @@
+from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size
+from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings
diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py
new file mode 100644
index 00000000..81e20780
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/ontology.py
@@ -0,0 +1,133 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Create Ontology Embeddings"""
+
+import json
+import os
+import random
+from copy import deepcopy
+
+import torch
+import numpy as np
+
+
+def set_seed(args):
+    """
+    Set random seeds
+
+    Args:
+        args (Arguments class): Arguments class containing seed and number of gpus to use
+    """
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+
+def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor:
+    """
+    Embed candidates
+
+    Args:
+        candidates (list): List of candidate descriptions
+        args (argument class): Runtime arguments
+        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
+        embedding_model (transformer Model): Transformer model for embedding candidate descriptions
+
+    Returns:
+        feats (torch.tensor): Embeddings of the candidate descriptions
+    """
+    # Tokenize candidate descriptions
+    feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len,
+                                   padding='max_length', truncation='longest_first')
+             for val in candidates]
+
+    # Encode tokenized descriptions
+    with torch.no_grad():
+        feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]}
+        embedded_feats = embedding_model(**feats)  # [num_candidates, max_candidate_len, hidden_dim]
+
+    # Reduce/pool descriptions embeddings if required
+    if args.set_similarity:
+        feats = embedded_feats.last_hidden_state.detach().cpu()  # [num_candidates, max_candidate_len, hidden_dim]
+    elif args.candidate_pooling == 'cls':
+        feats = embedded_feats.pooler_output.detach().cpu()  # [num_candidates, hidden_dim]
+    elif args.candidate_pooling == "mean":
+        feats = embedded_feats.last_hidden_state.detach().cpu()
+        feats = feats.sum(1)
+        feats = torch.nn.functional.layer_norm(feats, feats.size())
+        feats = feats.detach().cpu()  # [num_candidates, hidden_dim]
+
+    return feats
+
+
+def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True):
+    """
+    Get embeddings for slots and candidates
+
+    Args:
+        ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets
+        set_type (str): Subset of the dataset being used (train/validation/test)
+        args (argument class): Runtime arguments
+        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
+        embedding_model (transformer Model): Transormer model for embedding candidate descriptions
+        save_to_file (bool): Indication of whether to save information to file
+
+    Returns:
+        slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
+    """
+    # Set model to eval mode
+    embedding_model.eval()
+
+    slots = dict()
+    for domain, subset in ontology.items():
+        for slot, slot_info in subset.items():
+            # Get description or use "domain-slot"
+            if args.use_descriptions:
+                desc = slot_info['description']
+            else:
+                desc = f"{domain}-{slot}"
+
+            # Encode domain-slot pair description
+            slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0]
+
+            # Obtain possible value set and discard requestable value
+            values = deepcopy(slot_info['possible_values'])
+            is_requestable = False
+            if '?' in values:
+                is_requestable = True
+                values.remove('?')
+
+            # Encode value candidates
+            if values:
+                feats = encode_candidates(values, args, tokenizer, embedding_model)
+            else:
+                feats = None
+
+            # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot
+            slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
+
+    # Dump tensors and ontology for use in training and evaluation
+    if save_to_file:
+        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
+        torch.save(slots, writer)
+
+        writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w')
+        json.dump(ontology, writer, indent=2)
+        writer.close()
+    
+    return slots
diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py
new file mode 100644
index 00000000..26b67268
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/unified_format.py
@@ -0,0 +1,423 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convlab3 Unified Format Dialogue Datasets"""
+
+from copy import deepcopy
+
+import torch
+import transformers
+from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
+from transformers.tokenization_utils import PreTrainedTokenizer
+from tqdm import tqdm
+
+from convlab.util import load_dataset
+from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
+                                                get_values_from_data, ontology_add_requestable_slots,
+                                                get_requestable_slots, load_dst_data, extract_dialogues,
+                                                combine_value_sets)
+
+transformers.logging.set_verbosity_error()
+
+
+def convert_examples_to_features(data: list,
+                                 ontology: dict,
+                                 tokenizer: PreTrainedTokenizer,
+                                 max_turns: int = 12,
+                                 max_seq_len: int = 64) -> dict:
+    """
+    Convert dialogue examples to model input features and labels
+
+    Args:
+        data (list): List of all extracted dialogues
+        ontology (dict): Ontology dictionary containing slots, slot descriptions and
+        possible value sets including requests
+        tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used
+        max_turns (int): Maximum numbers of turns in a dialogue
+        max_seq_len (int): Maximum number of tokens in a dialogue turn
+
+    Returns:
+        features (dict): All inputs and labels required to train the model
+    """
+    features = dict()
+    ontology = deepcopy(ontology)
+
+    # Get encoder input for system, user utterance pairs
+    input_feats = []
+    for dial in tqdm(data):
+        dial_feats = []
+        for turn in dial:
+            if len(turn['system_utterance']) == 0:
+                usr = turn['user_utterance']
+                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True,
+                                                        max_length=max_seq_len, padding='max_length',
+                                                        truncation='longest_first'))
+            else:
+                usr = turn['user_utterance']
+                sys = turn['system_utterance']
+                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True,
+                                                        max_length=max_seq_len, padding='max_length',
+                                                        truncation='longest_first'))
+            # Truncate
+            if len(dial_feats) >= max_turns:
+                break
+        input_feats.append(dial_feats)
+    del dial_feats
+
+    # Perform turn level padding
+    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                 for dial in input_feats]
+    if 'token_type_ids' in input_feats[0][0]:
+        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                          for dial in input_feats]
+    else:
+        token_type_ids = None
+    if 'attention_mask' in input_feats[0][0]:
+        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                          for dial in input_feats]
+    else:
+        attention_mask = None
+    del input_feats
+
+    # Create torch data tensors
+    features['input_ids'] = torch.tensor(input_ids)
+    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
+    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
+    del input_ids, token_type_ids, attention_mask
+
+    # Extract all informable and requestable slots from the ontology
+    informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
+                        if ontology[domain][slot]['possible_values']
+                        and ontology[domain][slot]['possible_values'] != ['?']]
+    requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
+                         if '?' in ontology[domain][slot]['possible_values']]
+    for slot in requestable_slots:
+        domain, slot = slot.split('-', 1)
+        ontology[domain][slot]['possible_values'].remove('?')
+
+    # Extract a list of domains from the ontology slots
+    domains = list(set(informable_slots + requestable_slots))
+    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
+
+    # Create slot labels
+    for domslot in tqdm(informable_slots):
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial:
+                value = [v for d, substate in turn['state'].items() for s, v in substate.items()
+                         if f'{d}-{s}' == domslot]
+                domain, slot = domslot.split('-', 1)
+                if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
+                    value = value[0] if value else 'none'
+                else:
+                    value = -1
+                if value in ontology[domain][slot]['possible_values'] and value != -1:
+                    value = ontology[domain][slot]['possible_values'].index(value)
+                else:
+                    value = -1  # If value is not in ontology then we do not penalise the model
+                labs.append(value)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['state_labels-' + domslot] = labels
+
+    # Create requestable slot labels
+    for domslot in tqdm(requestable_slots):
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial:
+                domain, slot = domslot.split('-', 1)
+                if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
+                    acts = [act['intent'] for act in turn['dialogue_acts']
+                            if act['domain'] == domain and act['slot'] == slot]
+                    if acts:
+                        act_ = acts[0]
+                        if act_ == 'request':
+                            labs.append(1)
+                        else:
+                            labs.append(0)
+                    else:
+                        labs.append(0)
+                else:
+                    labs.append(-1)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['request_labels-' + domslot] = labels
+
+    # General act labels (1-goodbye, 2-thank you)
+    labels = []
+    for dial in tqdm(data):
+        labs = []
+        for turn in dial:
+            acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
+            if acts:
+                if 'bye' in acts:
+                    labs.append(1)
+                else:
+                    labs.append(2)
+            else:
+                labs.append(0)
+            if len(labs) >= max_turns:
+                break
+        labs = labs + [-1] * (max_turns - len(labs))
+        labels.append(labs)
+
+    labels = torch.tensor(labels)
+    features['general_act_labels'] = labels
+
+    # Create active domain labels
+    for domain in tqdm(domains):
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial:
+                possible_domains = list()
+                for dom in ontology:
+                    for slt in ontology[dom]:
+                        if turn['dataset_name'] in ontology[dom][slt]['dataset_names']:
+                            possible_domains.append(dom)
+
+                if domain in turn['active_domains']:
+                    labs.append(1)
+                elif domain in possible_domains:
+                    labs.append(0)
+                else:
+                    labs.append(-1)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['active_domain_labels-' + domain] = labels
+
+    del labels
+
+    return features
+
+
+class UnifiedFormatDataset(Dataset):
+    """
+    Class for preprocessing, and storing data easily from the Convlab3 unified format.
+
+    Attributes:
+        dataset_dict (dict): Dictionary containing all the data in dataset
+        ontology (dict): Set of all domain-slot-value triplets in the ontology of the model
+        features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model
+    """
+    def __init__(self,
+                 dataset_name: str,
+                 set_type: str,
+                 tokenizer: PreTrainedTokenizer,
+                 max_turns: int = 12,
+                 max_seq_len: int = 64,
+                 train_ratio: float = 1.0,
+                 seed: int = 0,
+                 data: dict = None,
+                 ontology: dict = None):
+        """
+        Args:
+            dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +)
+            set_type (str): Subset of the dataset to load (train, validation or test)
+            tokenizer (transformers tokenizer): Tokenizer for the encoder model used
+            max_turns (int): Maximum numbers of turns in a dialogue
+            max_seq_len (int): Maximum number of tokens in a dialogue turn
+            train_ratio (float): Fraction of training data to use during training
+            seed (int): Seed governing random order of ids for subsampling
+            data (dict): Dataset features for loading from dict
+            ontology (dict): Ontology dict for loading from dict
+        """
+        if data is not None:
+            self.ontology = ontology
+            self.features = data
+        else:
+            if '+' in dataset_name:
+                dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')]
+            else:
+                dataset_args = [{"dataset_name": dataset_name}]
+            self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
+            self.ontology = get_ontology_slots(dataset_name)
+            values = [get_values_from_data(dataset) for dataset in self.dataset_dicts]
+            self.ontology = ontology_add_values(self.ontology, combine_value_sets(values))
+            self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts))
+
+            if train_ratio != 1.0:
+                for dataset_args_ in dataset_args:
+                    dataset_args_['dial_ids_order'] = seed
+                    dataset_args_['split2ratio'] = {'train': train_ratio, 'validation': train_ratio}
+            self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
+
+            data = [load_dst_data(dataset_dict, data_split=set_type, speaker='all',
+                                  dialogue_acts=True, split_to_turn=False)
+                    for dataset_dict in self.dataset_dicts]
+            data_list = [data_[set_type] for data_ in data]
+
+            data = []
+            for idx, data_ in enumerate(data_list):
+                data += extract_dialogues(data_, dataset_args[idx]["dataset_name"])
+            self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len)
+
+    def __getitem__(self, index: int) -> dict:
+        """
+        Obtain dialogues with specific ids from dataset
+
+        Args:
+            index (int/list/tensor): Index/indices of dialogues to get
+
+        Returns:
+            features (dict): All inputs and labels required to train the model
+        """
+        return {label: self.features[label][index] for label in self.features
+                if self.features[label] is not None}
+
+    def __len__(self):
+        """
+        Get number of dialogues in the dataset
+
+        Returns:
+            len (int): Number of dialogues in the dataset object
+        """
+        return self.features['input_ids'].size(0)
+
+    def resample(self, size: int = None) -> Dataset:
+        """
+        Resample subset of the dataset
+
+        Args:
+            size (int): Number of dialogues to sample
+
+        Returns:
+            self (Dataset): Dataset object
+        """
+        # If no subset size is specified we resample a set with the same size as the full dataset
+        n_dialogues = self.__len__()
+        if not size:
+            size = n_dialogues
+
+        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
+        self.features = self.__getitem__(dialogues)
+        
+        return self
+
+    def to(self, device):
+        """
+        Map all data to a device
+
+        Args:
+            device (torch device): Device to map data to
+        """
+        self.device = device
+        self.features = {label: self.features[label].to(device) for label in self.features
+                         if self.features[label] is not None}
+
+    @classmethod
+    def from_datadict(cls, data: dict, ontology: dict):
+        return cls(None, None, None, data=data, ontology=ontology)
+
+
+def get_dataloader(dataset_name: str,
+                   set_type: str,
+                   batch_size: int,
+                   tokenizer: PreTrainedTokenizer,
+                   max_turns: int = 12,
+                   max_seq_len: int = 64,
+                   device='cpu',
+                   resampled_size: int = None,
+                   train_ratio: float = 1.0,
+                   seed: int = 0) -> DataLoader:
+    '''
+    Module to create torch dataloaders
+
+    Args:
+        dataset_name (str): Name of the dataset to load
+        set_type (str): Subset of the dataset to load (train, validation or test)
+        batch_size (int): Batch size for the dataloader
+        tokenizer (transformers tokenizer): Tokenizer for the encoder model used
+        max_turns (int): Maximum numbers of turns in a dialogue
+        max_seq_len (int): Maximum number of tokens in a dialogue turn
+        device (torch device): Device to map data to
+        resampled_size (int): Number of dialogues to sample
+        train_ratio (float): Ratio of training data to use for training
+        seed (int): Seed governing random order of ids for subsampling
+
+    Returns:
+        loader (torch dataloader): Dataloader to train and evaluate the setsumbt model
+    '''
+    data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio,
+                                seed=seed)
+    data.to(device)
+
+    if resampled_size:
+        data = data.resample(resampled_size)
+
+    if set_type in ['test', 'validation']:
+        sampler = SequentialSampler(data)
+    else:
+        sampler = RandomSampler(data)
+    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
+
+    return loader
+
+
+def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader:
+    """
+    Change the batch size of a preloaded loader
+
+    Args:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+        batch_size (int): Batch size for the dataloader
+
+    Returns:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+    """
+
+    if 'SequentialSampler' in str(loader.sampler):
+        sampler = SequentialSampler(loader.dataset)
+    else:
+        sampler = RandomSampler(loader.dataset)
+    loader = DataLoader(loader.dataset, sampler=sampler, batch_size=batch_size)
+
+    return loader
+
+def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader:
+    """
+    Sample a subset of the dialogues in a dataloader
+
+    Args:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+        sample_size (int): Number of dialogues to sample
+
+    Returns:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+    """
+
+    dataset = loader.dataset.resample(sample_size)
+
+    if 'SequentialSampler' in str(loader.sampler):
+        sampler = SequentialSampler(dataset)
+    else:
+        sampler = RandomSampler(dataset)
+    loader = DataLoader(loader.dataset, sampler=sampler, batch_size=loader.batch_size)
+
+    return loader
diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/dataset/utils.py
new file mode 100644
index 00000000..088480c4
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/utils.py
@@ -0,0 +1,409 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convlab3 Unified dataset data processing utilities"""
+
+from convlab.util import load_ontology, load_dst_data, load_nlu_data
+from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
+
+
+def get_ontology_slots(dataset_name: str) -> dict:
+    """
+    Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
+
+    Args:
+        dataset_name (str): Dataset name
+
+    Returns:
+        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
+    """
+    dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name]
+    ontology_slots = dict()
+    for dataset_name in dataset_names:
+        ontology = load_ontology(dataset_name)
+        domains = [domain for domain in ontology['domains'] if domain not in ['booking', 'general']]
+        for domain in domains:
+            domain_name = DOMAINS_MAP.get(domain, domain.lower())
+            if domain_name not in ontology_slots:
+                ontology_slots[domain_name] = dict()
+            for slot, slot_info in ontology['domains'][domain]['slots'].items():
+                if slot not in ontology_slots[domain_name]:
+                    ontology_slots[domain_name][slot] = {'description': slot_info['description'],
+                                                         'possible_values': list(),
+                                                         'dataset_names': list()}
+                if slot_info['is_categorical']:
+                    ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values']
+
+                ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values']))
+                ontology_slots[domain_name][slot]['dataset_names'].append(dataset_name)
+
+    return ontology_slots
+
+
+def get_values_from_data(dataset: dict) -> dict:
+    """
+    Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
+
+    Args:
+        dataset (dict): Dataset dictionary obtained using the load_dataset function
+
+    Returns:
+        value_sets (dict): Dictionary containing possible values obtained from dataset
+    """
+    data = load_dst_data(dataset, data_split='all', speaker='user')
+    value_sets = {}
+    for set_type, dataset in data.items():
+        for turn in dataset:
+            for domain, substate in turn['state'].items():
+                domain_name = DOMAINS_MAP.get(domain, domain.lower())
+                if domain not in value_sets:
+                    value_sets[domain_name] = {}
+                for slot, value in substate.items():
+                    if slot not in value_sets[domain_name]:
+                        value_sets[domain_name][slot] = []
+                    if value and value not in value_sets[domain_name][slot]:
+                        value_sets[domain_name][slot].append(value)
+
+    return clean_values(value_sets)
+
+
+def combine_value_sets(value_sets: list) -> dict:
+    """
+    Function to combine value sets extracted from different datasets
+
+    Args:
+        value_sets (list): List of value sets extracted using the get_values_from_data function
+
+    Returns:
+        value_set (dict): Dictionary containing possible values obtained from datasets
+    """
+    value_set = value_sets[0]
+    for _value_set in value_sets[1:]:
+        for domain, domain_info in _value_set.items():
+            for slot, possible_values in domain_info.items():
+                if domain not in value_set:
+                    value_set[domain] = dict()
+                if slot not in value_set[domain]:
+                    value_set[domain][slot] = list()
+                value_set[domain][slot] += _value_set[domain][slot]
+                value_set[domain][slot] = list(set(value_set[domain][slot]))
+
+    return value_set
+
+
+def clean_values(value_sets: dict, value_map: dict = VALUE_MAP) -> dict:
+    """
+    Function to clean up the possible value sets extracted from the states in the dataset
+
+    Args:
+        value_sets (dict): Dictionary containing possible values obtained from dataset
+        value_map (dict): Label map to avoid duplication and typos in values
+
+    Returns:
+        clean_vals (dict): Cleaned Dictionary containing possible values obtained from dataset
+    """
+    clean_vals = {}
+    for domain, subset in value_sets.items():
+        clean_vals[domain] = {}
+        for slot, values in subset.items():
+            # Remove pipe separated values
+            values = list(set([val.split('|', 1)[0] for val in values]))
+
+            # Map values using value_map
+            for old, new in value_map.items():
+                values = list(set([val.replace(old, new) for val in values]))
+
+            # Remove empty and dontcare from possible value sets
+            values = [val for val in values if val not in ['', 'dontcare']]
+
+            # MultiWOZ specific value sets for quantity, time and boolean slots
+            if 'people' in slot or 'duration' in slot or 'stay' in slot:
+                values = QUANTITIES
+            elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
+                values = TIME
+            elif 'parking' in slot or 'internet' in slot:
+                values = ['yes', 'no']
+
+            clean_vals[domain][slot] = values
+
+    return clean_vals
+
+
+def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict:
+    """
+    Add value sets obtained from the dataset to the ontology
+    Args:
+        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
+        value_sets (dict): Cleaned Dictionary containing possible values obtained from dataset
+
+    Returns:
+        ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and possible value sets
+    """
+    ontology = {}
+    for domain in sorted(ontology_slots):
+        ontology[domain] = {}
+        for slot in sorted(ontology_slots[domain]):
+            if not ontology_slots[domain][slot]['possible_values']:
+                if domain in value_sets:
+                    if slot in value_sets[domain]:
+                        ontology_slots[domain][slot]['possible_values'] = value_sets[domain][slot]
+            if ontology_slots[domain][slot]['possible_values']:
+                values = sorted(ontology_slots[domain][slot]['possible_values'])
+                ontology_slots[domain][slot]['possible_values'] = ['none', 'do not care'] + values
+
+            ontology[domain][slot] = ontology_slots[domain][slot]
+
+    return ontology
+
+
+def get_requestable_slots(datasets: list) -> dict:
+    """
+    Function to get set of requestable slots from the dataset action labels.
+    Args:
+        dataset (dict): Dataset dictionary obtained using the load_dataset function
+
+    Returns:
+        slots (dict): Dictionary containing requestable domain-slot pairs
+    """
+    datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets]
+
+    slots = {}
+    for data in datasets:
+        for set_type, subset in data.items():
+            for turn in subset:
+                requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request']
+                requests += [act for act in turn['dialogue_acts']['non-categorical'] if act['intent'] == 'request']
+                requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request']
+                requests = [(act['domain'], act['slot']) for act in requests]
+                for domain, slot in requests:
+                    domain_name = DOMAINS_MAP.get(domain, domain.lower())
+                    if domain_name not in slots:
+                        slots[domain_name] = []
+                    slots[domain_name].append(slot)
+
+    slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()}
+
+    return slots
+
+
+def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict) -> dict:
+    """
+    Add requestable slots obtained from the dataset to the ontology
+    Args:
+        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
+        requestable_slots (dict): Dictionary containing requestable domain-slot pairs
+
+    Returns:
+        ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and
+        possible value sets including requests
+    """
+    for domain in ontology_slots:
+        for slot in ontology_slots[domain]:
+            if domain in requestable_slots:
+                if slot in requestable_slots[domain]:
+                    ontology_slots[domain][slot]['possible_values'].append('?')
+
+    return ontology_slots
+
+
+def extract_turns(dialogue: list, dataset_name: str) -> list:
+    """
+    Extract the required information from the data provided by unified loader
+    Args:
+        dialogue (list): List of turns within a dialogue
+        dataset_name (str): Name of the dataset to which the dialogue belongs
+
+    Returns:
+        turns (list): List of turns within a dialogue
+    """
+    turns = []
+    turn_info = {}
+    for turn in dialogue:
+        if turn['speaker'] == 'system':
+            turn_info['system_utterance'] = turn['utterance']
+
+        # System utterance in the first turn is always empty as conversation is initiated by the user
+        if turn['utt_idx'] == 1:
+            turn_info['system_utterance'] = ''
+
+        if turn['speaker'] == 'user':
+            turn_info['user_utterance'] = turn['utterance']
+
+            # Inform acts not required by model
+            turn_info['dialogue_acts'] = [act for act in turn['dialogue_acts']['categorical']
+                                          if act['intent'] not in ['inform']]
+            turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['non-categorical']
+                                           if act['intent'] not in ['inform']]
+            turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['binary']
+                                           if act['intent'] not in ['inform']]
+
+            turn_info['state'] = turn['state']
+            turn_info['dataset_name'] = dataset_name
+
+        if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
+            turns.append(turn_info)
+            turn_info = {}
+
+    return turns
+
+
+def clean_states(turns: list) -> list:
+    """
+    Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology)
+    Args:
+        turns (list): List of turns within a dialogue
+
+    Returns:
+        clean_turns (list): List of turns within a dialogue
+    """
+    clean_turns = []
+    for turn in turns:
+        clean_state = {}
+        clean_acts = []
+        for act in turn['dialogue_acts']:
+            domain = act['domain']
+            act['domain'] = DOMAINS_MAP.get(domain, domain.lower())
+            clean_acts.append(act)
+        for domain, subset in turn['state'].items():
+            domain_name = DOMAINS_MAP.get(domain, domain.lower())
+            clean_state[domain_name] = {}
+            for slot, value in subset.items():
+                # Remove pipe separated values
+                value = value.split('|', 1)[0]
+
+                # Map values using value_map
+                for old, new in VALUE_MAP.items():
+                    value = value.replace(old, new)
+
+                # Map dontcare to "do not care" and empty to 'none'
+                value = value.replace('dontcare', 'do not care')
+                value = value if value else 'none'
+
+                # Map quantity values to the integer quantity value
+                if 'people' in slot or 'duration' in slot or 'stay' in slot:
+                    try:
+                        if value not in ['do not care', 'none']:
+                            value = int(value)
+                            value = str(value) if value < 10 else QUANTITIES[-1]
+                    except:
+                        value = value
+                # Map time values to the most appropriate value in the standard time set
+                elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
+                    try:
+                        if value not in ['do not care', 'none']:
+                            # Strip after/before from time value
+                            value = value.replace('after ', '').replace('before ', '')
+                            # Extract hours and minutes from different possible formats
+                            if ':' not in value and len(value) == 4:
+                                h, m = value[:2], value[2:]
+                            elif len(value) == 1:
+                                h = int(value)
+                                m = 0
+                            elif 'pm' in value:
+                                h = int(value.replace('pm', '')) + 12
+                                m = 0
+                            elif 'am' in value:
+                                h = int(value.replace('pm', ''))
+                                m = 0
+                            elif ':' in value:
+                                h, m = value.split(':')
+                            elif ';' in value:
+                                h, m = value.split(';')
+                            # Map to closest 5 minutes
+                            if int(m) % 5 != 0:
+                                m = round(int(m) / 5) * 5
+                                h = int(h)
+                                if m == 60:
+                                    m = 0
+                                    h += 1
+                                if h >= 24:
+                                    h -= 24
+                            # Set in standard 24 hour format
+                            h, m = int(h), int(m)
+                            value = '%02i:%02i' % (h, m)
+                    except:
+                        value = value
+                # Map boolean slots to yes/no value
+                elif 'parking' in slot or 'internet' in slot:
+                    if value not in ['do not care', 'none']:
+                        if value == 'free':
+                            value = 'yes'
+                        elif True in [v in value.lower() for v in ['yes', 'no']]:
+                            value = [v for v in ['yes', 'no'] if v in value][0]
+
+                clean_state[domain_name][slot] = value
+        turn['state'] = clean_state
+        turn['dialogue_acts'] = clean_acts
+        clean_turns.append(turn)
+
+    return clean_turns
+
+
+def get_active_domains(turns: list) -> list:
+    """
+    Get active domains at each turn in a dialogue
+    Args:
+        turns (list): List of turns within a dialogue
+
+    Returns:
+        turns (list): List of turns within a dialogue
+    """
+    for turn_id in range(len(turns)):
+        # At first turn all domains with not none values in the state are active
+        if turn_id == 0:
+            domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none']
+            domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']]
+            domains = [DOMAINS_MAP.get(domain, domain.lower()) for domain in domains]
+            turns[turn_id]['active_domains'] = list(set(domains))
+        else:
+            # Use changes in domains to identify active domains
+            domains = []
+            for domain, substate in turns[turn_id]['state'].items():
+                domain_name = DOMAINS_MAP.get(domain, domain.lower())
+                for slot, value in substate.items():
+                    if value != turns[turn_id - 1]['state'][domain][slot]:
+                        val = value
+                    else:
+                        val = 'none'
+                    if value == 'none':
+                        val = 'none'
+                    if val != 'none':
+                        domains.append(domain_name)
+            # Add all domains activated by a user action
+            domains += [act['domain'] for act in turns[turn_id]['dialogue_acts']
+                        if act['domain'] in turns[turn_id]['state']]
+            turns[turn_id]['active_domains'] = list(set(domains))
+
+    return turns
+
+
+def extract_dialogues(data: list, dataset_name: str) -> list:
+    """
+    Extract all dialogues from dataset
+    Args:
+        data (list): List of all dialogues in a subset of the data
+        dataset_name (str): Name of the dataset to which the dialogues belongs
+
+    Returns:
+        dialogues (list): List of all extracted dialogues
+    """
+    dialogues = []
+    for dial in data:
+        turns = extract_turns(dial['turns'], dataset_name)
+        turns = clean_states(turns)
+        turns = get_active_domains(turns)
+        dialogues.append(turns)
+
+    return dialogues
diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/dataset/value_maps.py
new file mode 100644
index 00000000..619600a7
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/value_maps.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convlab3 Unified dataset value maps"""
+
+
+# MultiWOZ specific label map to avoid duplication and typos in values
+VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
+             'cityroomz': 'city roomz', '  ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
+             'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
+             'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
+             'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
+             'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
+
+
+# Domain map for SGD and TM Data
+DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus',
+               'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events',
+               'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights',
+               'Flights_4': 'flights', 'Homes_1': 'homes', 'Homes_2': 'homes', 'Hotels_1': 'hotel',
+               'Hotels_2': 'hotel', 'Hotels_3': 'hotel', 'Hotels_4': 'hotel', 'Media_1': 'media',
+               'Media_2': 'media', 'Media_3': 'media', 'Messaging_1': 'messaging', 'Movies_1': 'movies',
+               'Movies_2': 'movies', 'Movies_3': 'movies', 'Music_1': 'music', 'Music_2': 'music', 'Music_3': 'music',
+               'Payment_1': 'payment', 'RentalCars_1': 'rentalcars', 'RentalCars_2': 'rentalcars',
+               'RentalCars_3': 'rentalcars', 'Restaurants_1': 'restaurant', 'Restaurants_2': 'restaurant',
+               'RideSharing_1': 'ridesharing', 'RideSharing_2': 'ridesharing', 'Services_1': 'services',
+               'Services_2': 'services', 'Services_3': 'services', 'Services_4': 'services', 'Trains_1': 'train',
+               'Travel_1': 'travel', 'Weather_1': 'weather', 'movie_ticket': 'movies',
+               'restaurant_reservation': 'restaurant', 'coffee_ordering': 'coffee', 'pizza_ordering': 'takeout',
+               'auto_repair': 'car_repairs', 'flights': 'flights', 'food-ordering': 'takeout', 'hotels': 'hotel',
+               'movies': 'movies', 'music': 'music', 'restaurant-search': 'restaurant', 'sports': 'sports',
+               'movie': 'movies'}
+
+
+# Generic value sets for quantity and time slots
+QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
+TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
+TIME = ['%02i:%02i' % t for l in TIME for t in l]
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py
index e0d87bb9..2279e222 100644
--- a/convlab/dst/setsumbt/distillation_setup.py
+++ b/convlab/dst/setsumbt/distillation_setup.py
@@ -1,53 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Get ensemble predictions and build distillation dataloaders"""
+
 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
 import os
+import json
 
 import torch
-import transformers
-from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
-from transformers import RobertaConfig, BertConfig
+from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
 from tqdm import tqdm
 
-import convlab
-from convlab.dst.setsumbt.multiwoz.dataset.multiwoz21 import EnsembleMultiWoz21
+from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size
 from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT
+from convlab.dst.setsumbt.modeling import training
 
-DEVICE = 'cuda'
-
-
-def args_parser():
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--model_path', type=str)
-    parser.add_argument('--model_type', type=str)
-    parser.add_argument('--set_type', type=str)
-    parser.add_argument('--batch_size', type=int)
-    parser.add_argument('--ensemble_size', type=int)
-    parser.add_argument('--reduction', type=str, default='mean')
-    parser.add_argument('--get_ensemble_distributions', action='store_true')
-    parser.add_argument('--build_dataloaders', action='store_true')
-    
-    return parser.parse_args()
-
-
-def main():
-    args = args_parser()
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
-    if args.get_ensemble_distributions:
-        get_ensemble_distributions(args)
-    elif args.build_dataloaders:
-        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
-        data = torch.load(path)
-        loader = get_loader(data, args.set_type, args.batch_size)
 
-        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
-        torch.save(loader, path)
-    else:
-        raise NameError("NotImplemented")
+def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader:
+    """
+    Build dataloader from ensemble prediction data
 
+    Args:
+        data: Dictionary of ensemble predictions
+        ontology: Data ontology
+        set_type: Data subset (train/validation/test)
+        batch_size: Number of dialogues per batch
 
-def get_loader(data, set_type='train', batch_size=3):
+    Returns:
+        loader: Data loader object
+    """
     data = flatten_data(data)
     data = do_label_padding(data)
-    data = EnsembleMultiWoz21(data)
+    data = UnifiedFormatDataset.from_datadict(data, ontology)
     if set_type == 'train':
         sampler = RandomSampler(data)
     else:
@@ -57,7 +55,16 @@ def get_loader(data, set_type='train', batch_size=3):
     return loader
 
 
-def do_label_padding(data):
+def do_label_padding(data: dict) -> dict:
+    """
+    Add padding to the ensemble predictions (used as labels in distillation)
+
+    Args:
+        data: Dictionary of ensemble predictions
+
+    Returns:
+        data: Padded ensemble predictions
+    """
     if 'attention_mask' in data:
         dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
     else:
@@ -70,13 +77,17 @@ def do_label_padding(data):
     return data
 
 
-map_dict = {'belief_state': 'belief', 'greeting_act_belief': 'goodbye_belief',
-            'state_labels': 'labels', 'request_labels': 'request',
-            'domain_labels': 'active', 'greeting_labels': 'goodbye'}
-def flatten_data(data):
+def flatten_data(data: dict) -> dict:
+    """
+    Map data to flattened feature format used in training
+    Args:
+        data: Ensemble prediction data
+
+    Returns:
+        data: Flattened ensemble prediction data
+    """
     data_new = dict()
     for label, feats in data.items():
-        label = map_dict.get(label, label)
         if type(feats) == dict:
             for label_, feats_ in feats.items():
                 data_new[label + '-' + label_] = feats_
@@ -87,13 +98,11 @@ def flatten_data(data):
 
 
 def get_ensemble_distributions(args):
-    if args.model_type == 'roberta':
-        config = RobertaConfig
-    elif args.model_type == 'bert':
-        config = BertConfig
-    config = config.from_pretrained(args.model_path)
-    config.ensemble_size = args.ensemble_size
-
+    """
+    Load data and get ensemble predictions
+    Args:
+        args: Runtime arguments
+    """
     device = DEVICE
 
     model = EnsembleSetSUMBT.from_pretrained(args.model_path)
@@ -107,16 +116,10 @@ def get_ensemble_distributions(args):
     dataloader = torch.load(dataloader)
     database = torch.load(database)
 
-    # Get slot and value embeddings
-    slots = {slot: val for slot, val in database.items()}
-    values = {slot: val[1] for slot, val in database.items()}
-    del database
+    if dataloader.batch_size != args.batch_size:
+        dataloader = change_batch_size(dataloader, args.batch_size)
 
-    # Load model ontology
-    model.add_slot_candidates(slots)
-    for slot in model.informable_slot_ids:
-        model.add_value_candidates(slot, values[slot], replace=True)
-    del slots, values
+    training.set_ontology_embeddings(model, database)
 
     print('Environment set up.')
 
@@ -125,18 +128,24 @@ def get_ensemble_distributions(args):
     attention_mask = []
     state_labels = {slot: [] for slot in model.informable_slot_ids}
     request_labels = {slot: [] for slot in model.requestable_slot_ids}
-    domain_labels = {domain: [] for domain in model.domain_ids}
-    greeting_labels = []
+    active_domain_labels = {domain: [] for domain in model.domain_ids}
+    general_act_labels = []
+
+    is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None
+
     belief_state = {slot: [] for slot in model.informable_slot_ids}
-    request_belief = {slot: [] for slot in model.requestable_slot_ids}
-    domain_belief = {domain: [] for domain in model.domain_ids}
-    greeting_act_belief = []
+    request_probs = {slot: [] for slot in model.requestable_slot_ids}
+    active_domain_probs = {domain: [] for domain in model.domain_ids}
+    general_act_probs = []
     model.eval()
     for batch in tqdm(dataloader, desc='Batch:'):
         ids = batch['input_ids']
         tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None
         mask = batch['attention_mask'] if 'attention_mask' in batch else None
 
+        if 'is_noisy' in batch:
+            is_noisy.append(batch['is_noisy'])
+
         input_ids.append(ids)
         token_type_ids.append(tt_ids)
         attention_mask.append(mask)
@@ -146,61 +155,123 @@ def get_ensemble_distributions(args):
         mask = mask.to(device) if mask is not None else None
 
         for slot in state_labels:
-            state_labels[slot].append(batch['labels-' + slot])
-        if model.config.predict_intents:
+            state_labels[slot].append(batch['state_labels-' + slot])
+        if model.config.predict_actions:
             for slot in request_labels:
-                request_labels[slot].append(batch['request-' + slot])
-            for domain in domain_labels:
-                domain_labels[domain].append(batch['active-' + domain])
-            greeting_labels.append(batch['goodbye'])
+                request_labels[slot].append(batch['request_labels-' + slot])
+            for domain in active_domain_labels:
+                active_domain_labels[domain].append(batch['active_domain_labels-' + domain])
+            general_act_labels.append(batch['general_act_labels'])
 
         with torch.no_grad():
-            p, p_req, p_dom, p_bye, _ = model(ids, mask, tt_ids,
-                                            reduction=args.reduction)
+            p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction)
 
         for slot in belief_state:
             belief_state[slot].append(p[slot].cpu())
-        if model.config.predict_intents:
-            for slot in request_belief:
-                request_belief[slot].append(p_req[slot].cpu())
-            for domain in domain_belief:
-                domain_belief[domain].append(p_dom[domain].cpu())
-            greeting_act_belief.append(p_bye.cpu())
+        if model.config.predict_actions:
+            for slot in request_probs:
+                request_probs[slot].append(p_req[slot].cpu())
+            for domain in active_domain_probs:
+                active_domain_probs[domain].append(p_dom[domain].cpu())
+            general_act_probs.append(p_gen.cpu())
     
     input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None
     token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None
     attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None
+    is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None
 
     state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()}
-    if model.config.predict_intents:
+    if model.config.predict_actions:
         request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()}
-        domain_labels = {domain: torch.cat(l, 0) for domain, l in domain_labels.items()}
-        greeting_labels = torch.cat(greeting_labels, 0)
+        active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()}
+        general_act_labels = torch.cat(general_act_labels, 0)
     
     belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()}
-    if model.config.predict_intents:
-        request_belief = {slot: torch.cat(p, 0) for slot, p in request_belief.items()}
-        domain_belief = {domain: torch.cat(p, 0) for domain, p in domain_belief.items()}
-        greeting_act_belief = torch.cat(greeting_act_belief, 0)
+    if model.config.predict_actions:
+        request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()}
+        active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()}
+        general_act_probs = torch.cat(general_act_probs, 0)
 
     data = {'input_ids': input_ids}
     if token_type_ids is not None:
         data['token_type_ids'] = token_type_ids
     if attention_mask is not None:
         data['attention_mask'] = attention_mask
+    if is_noisy is not None:
+        data['is_noisy'] = is_noisy
     data['state_labels'] = state_labels
     data['belief_state'] = belief_state
-    if model.config.predict_intents:
+    if model.config.predict_actions:
         data['request_labels'] = request_labels
-        data['domain_labels'] = domain_labels
-        data['greeting_labels'] = greeting_labels
-        data['request_belief'] = request_belief
-        data['domain_belief'] = domain_belief
-        data['greeting_act_belief'] = greeting_act_belief
+        data['active_domain_labels'] = active_domain_labels
+        data['general_act_labels'] = general_act_labels
+        data['request_probs'] = request_probs
+        data['active_domain_probs'] = active_domain_probs
+        data['general_act_probs'] = general_act_probs
 
     file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
     torch.save(data, file)
 
 
+def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str):
+    """
+    Convert ensemble predictions to predictions file format.
+
+    Args:
+        model_path: Path to ensemble location.
+        set_type: Evaluation dataset (train/dev/test).
+    """
+    data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data"))
+
+    # Get oracle labels
+    if 'request_probs' in data:
+        data_new = {'state_labels': data['state_labels'],
+                    'request_labels': data['request_labels'],
+                    'active_domain_labels': data['active_domain_labels'],
+                    'general_act_labels': data['general_act_labels']}
+    else:
+        data_new = {'state_labels': data['state_labels']}
+
+    # Marginalising across ensemble distributions
+    data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()}
+    if 'request_probs' in data:
+        data_new['request_probs'] = {slot: distribution.mean(-1)
+                                     for slot, distribution in data['request_probs'].items()}
+        data_new['active_domain_probs'] = {domain: distribution.mean(-1)
+                                           for domain, distribution in data['active_domain_probs'].items()}
+        data_new['general_act_probs'] = data['general_act_probs'].mean(-2)
+
+    # Save predictions file
+    predictions_dir = os.path.join(model_path, 'predictions')
+    if not os.path.exists(predictions_dir):
+        os.mkdir(predictions_dir)
+    torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions"))
+
+
 if __name__ == "__main__":
-    main()
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--model_path', type=str)
+    parser.add_argument('--set_type', type=str)
+    parser.add_argument('--batch_size', type=int, default=3)
+    parser.add_argument('--reduction', type=str, default='none')
+    parser.add_argument('--get_ensemble_distributions', action='store_true')
+    parser.add_argument('--convert_distributions_to_predictions', action='store_true')
+    parser.add_argument('--build_dataloaders', action='store_true')
+    args = parser.parse_args()
+
+    if args.get_ensemble_distributions:
+        get_ensemble_distributions(args)
+    if args.convert_distributions_to_predictions:
+        ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type)
+    if args.build_dataloaders:
+        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
+        data = torch.load(path)
+
+        reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
+        ontology = json.load(reader)
+        reader.close()
+
+        loader = get_loader(data, ontology, args.set_type, args.batch_size)
+
+        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
+        torch.save(loader, path)
diff --git a/convlab/dst/setsumbt/do/calibration.py b/convlab/dst/setsumbt/do/calibration.py
deleted file mode 100644
index 27ee058e..00000000
--- a/convlab/dst/setsumbt/do/calibration.py
+++ /dev/null
@@ -1,481 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Run SetSUMBT Calibration"""
-
-import logging
-import random
-import os
-from shutil import copy2 as copy
-
-import torch
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer,
-                          AdamW, get_linear_schedule_with_warmup)
-from tqdm import tqdm, trange
-from tensorboardX import SummaryWriter
-from torch.distributions import Categorical
-
-from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
-from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.multiwoz import multiwoz21
-from convlab.dst.setsumbt.multiwoz import ontology as embeddings
-from convlab.dst.setsumbt.utils import get_args, upload_local_directory_to_gcs, update_args
-from convlab.dst.setsumbt.modeling import calibration_utils
-from convlab.dst.setsumbt.modeling import ensemble_utils
-from convlab.dst.setsumbt.loss.ece import ece, jg_ece, l2_acc
-
-
-# Datasets
-DATASETS = {
-    'multiwoz21': multiwoz21
-}
-
-MODELS = {
-    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
-    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
-}
-
-
-def main(args=None, config=None):
-    # Get arguments
-    if args is None:
-        args, config = get_args(MODELS)
-
-    # Select Dataset object
-    if args.dataset in DATASETS:
-        Dataset = DATASETS[args.dataset]
-    else:
-        raise NameError('NotImplemented')
-
-    if args.model_type in MODELS:
-        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
-    else:
-        raise NameError('NotImplemented')
-
-    # Set up output directory
-    OUTPUT_DIR = args.output_dir
-    if not os.path.exists(OUTPUT_DIR):
-        os.mkdir(OUTPUT_DIR)
-    args.output_dir = OUTPUT_DIR
-    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
-        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
-
-    paths = os.listdir(args.output_dir) if os.path.exists(
-        args.output_dir) else []
-    if 'pytorch_model.bin' in paths and 'config.json' in paths:
-        args.model_name_or_path = args.output_dir
-        config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        paths = [os.path.join(args.output_dir, p)
-                 for p in paths if 'checkpoint-' in p]
-        if paths:
-            paths = paths[0]
-            args.model_name_or_path = paths
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-
-    if args.ensemble_size > 0:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        paths = [os.path.join(args.output_dir, p)
-                 for p in paths if 'ensemble_' in p]
-        if paths:
-            args.model_name_or_path = args.output_dir
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-
-    args = update_args(args, config)
-
-    # Set up data directory
-    DATA_DIR = args.data_dir
-    Dataset.set_datadir(DATA_DIR)
-    embeddings.set_datadir(DATA_DIR)
-
-    if args.shrink_active_domains and args.dataset == 'multiwoz21':
-        Dataset.set_active_domains(
-            ['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
-
-    # Download and preprocess
-    Dataset.create_examples(
-        args.max_turn_len, args.predict_intents, args.force_processing)
-
-    # Create logger
-    global logger
-    logger = logging.getLogger(__name__)
-    logger.setLevel(logging.INFO)
-
-    formatter = logging.Formatter(
-        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
-    if 'stream' not in args.logging_path:
-        fh = logging.FileHandler(args.logging_path)
-        fh.setLevel(logging.INFO)
-        fh.setFormatter(formatter)
-        logger.addHandler(fh)
-    else:
-        ch = logging.StreamHandler()
-        ch.setLevel(level=logging.INFO)
-        ch.setFormatter(formatter)
-        logger.addHandler(ch)
-
-    # Get device
-    if torch.cuda.is_available() and args.n_gpu > 0:
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-        args.n_gpu = 0
-
-    if args.n_gpu == 0:
-        args.fp16 = False
-
-    # Set up model training/evaluation
-    calibration.set_logger(logger, None)
-    calibration.set_seed(args)
-
-    if args.ensemble_size > 0:
-        ensemble.set_logger(logger, tb_writer)
-        ensemble_utils.set_seed(args)
-
-    # Perform tasks
-
-    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
-        pred = torch.load(os.path.join(
-            OUTPUT_DIR, 'predictions', 'test.predictions'))
-        labels = pred['labels']
-        belief_states = pred['belief_states']
-        if 'request_labels' in pred:
-            request_labels = pred['request_labels']
-            request_belief = pred['request_belief']
-            domain_labels = pred['domain_labels']
-            domain_belief = pred['domain_belief']
-            greeting_labels = pred['greeting_labels']
-            greeting_belief = pred['greeting_belief']
-        else:
-            request_belief = None
-        del pred
-    elif args.ensemble_size > 0:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            # Create Tokenizer and embedding model for Data Loaders and ontology
-            encoder = CandidateEncoderModel.from_pretrained(
-                config.candidate_embedding_model_name)
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            embeddings.get_slot_candidate_embeddings(
-                'test', args, tokenizer, encoder)
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size == args.test_batch_size:
-                exists = True
-        if not exists:
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                     config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        config, models = ensemble.get_models(
-            args.model_name_or_path, device, ConfigClass, SetSumbtModel)
-
-        belief_states, labels = ensemble_utils.get_predictions(
-            args, models, device, test_dataloader, test_slots)
-        torch.save({'belief_states': belief_states, 'labels': labels},
-                   os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
-    else:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            # Create Tokenizer and embedding model for Data Loaders and ontology
-            encoder = CandidateEncoderModel.from_pretrained(
-                config.candidate_embedding_model_name)
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            embeddings.get_slot_candidate_embeddings(
-                'test', args, tokenizer, encoder)
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size == args.test_batch_size:
-                exists = True
-        if not exists:
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                     config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        # Initialise Model
-        model = SetSumbtModel.from_pretrained(
-            args.model_name_or_path, config=config)
-        model = model.to(device)
-
-        # Get slot and value embeddings
-        slots = {slot: test_slots[slot] for slot in test_slots}
-        values = {slot: test_slots[slot][1] for slot in test_slots}
-
-        # Load model ontology
-        model.add_slot_candidates(slots)
-        for slot in model.informable_slot_ids:
-            model.add_value_candidates(slot, values[slot], replace=True)
-
-        belief_states = calibration.get_predictions(
-            args, model, device, test_dataloader)
-        belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states
-        out = {'belief_states': belief_states, 'labels': labels,
-               'request_belief': request_belief, 'request_labels': request_labels,
-               'domain_belief': domain_belief, 'domain_labels': domain_labels,
-               'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels}
-        torch.save(out, os.path.join(
-            OUTPUT_DIR, 'predictions', 'test.predictions'))
-
-    # err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10)
-    #         for slot in belief_states]
-    # err = max(err)
-    # logger.info('ECE: %f' % err)
-
-    # Calculate calibration metrics
-
-    jg = jg_ece(belief_states, labels, 10)
-    logger.info('Joint Goal ECE: %f' % jg)
-
-    binary_states = {}
-    for slot, p in belief_states.items():
-        shp = p.shape
-        p = p.reshape(-1, p.size(-1))
-        p_ = torch.ones(p.shape).to(p.device) * 1e-8
-        p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-        binary_states[slot] = p_.reshape(shp)
-    jg = jg_ece(binary_states, labels, 10)
-    logger.info('Joint Goal Binary ECE: %f' % jg)
-
-    bs = {slot: torch.cat((p[:, :, 0].unsqueeze(-1), p[:, :, 1:].max(-1)
-                          [0].unsqueeze(-1)), -1) for slot, p in belief_states.items()}
-    ls = {}
-    for slot, l in labels.items():
-        y = torch.zeros((l.size(0), l.size(1))).to(l.device)
-        dials, turns = torch.where(l > 0)
-        y[dials, turns] = 1.0
-        dials, turns = torch.where(l < 0)
-        y[dials, turns] = -1.0
-        ls[slot] = y
-
-    jg = jg_ece(bs, ls, 10)
-    logger.info('Slot presence ECE: %f' % jg)
-
-    binary_states = {}
-    for slot, p in bs.items():
-        shp = p.shape
-        p = p.reshape(-1, p.size(-1))
-        p_ = torch.ones(p.shape).to(p.device) * 1e-8
-        p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-        binary_states[slot] = p_.reshape(shp)
-    jg = jg_ece(binary_states, ls, 10)
-    logger.info('Slot presence Binary ECE: %f' % jg)
-
-    jg_acc = 0.0
-    padding = torch.cat([item.unsqueeze(-1)
-                        for _, item in labels.items()], -1).sum(-1) * -1.0
-    padding = (padding == len(labels))
-    padding = padding.reshape(-1)
-    for slot in belief_states:
-        topn = args.accuracy_topn
-        p_ = belief_states[slot]
-        gold = labels[slot]
-
-        if p_.size(-1) <= topn:
-            topn = p_.size(-1) - 1
-        if topn <= 0:
-            topn = 1
-
-        if topn > 1:
-            labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
-            labs = labs[:, :topn]
-        else:
-            labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
-        acc = [lab in s for lab, s, pad in zip(
-            gold.reshape(-1), labs, padding) if not pad]
-        acc = torch.tensor(acc).float()
-
-        jg_acc += acc
-
-    n_turns = jg_acc.size(0)
-    sl_acc = sum(jg_acc / len(belief_states)).float()
-    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
-
-    sl_acc /= n_turns
-    jg_acc /= n_turns
-
-    logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc))
-
-    l2 = l2_acc(belief_states, labels, remove_belief=False)
-    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
-    l2 = l2_acc(belief_states, labels, remove_belief=True)
-    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
-
-    for slot in belief_states:
-        p = belief_states[slot]
-        p = p.reshape(-1, p.size(-1))
-        p = torch.cat(
-            (p[:, 0].unsqueeze(-1), p[:, 1:].max(-1)[0].unsqueeze(-1)), -1)
-        belief_states[slot] = p
-
-        l = labels[slot].reshape(-1)
-        l[l > 0] = 1
-        labels[slot] = l
-
-    f1 = 0.0
-    for slot in belief_states:
-        prd = belief_states[slot].argmax(-1)
-        tp = ((prd == 1) * (labels[slot] == 1)).sum()
-        fp = ((prd == 1) * (labels[slot] == 0)).sum()
-        fn = ((prd == 0) * (labels[slot] == 1)).sum()
-        if tp > 0:
-            f1 += tp / (tp + 0.5 * (fp + fn))
-    f1 /= len(belief_states)
-    logger.info(f'Trucated Goal F1 Score: {f1}')
-
-    l2 = l2_acc(belief_states, labels, remove_belief=False)
-    logger.info(f'Model L2 Norm Trucated Goal Accuracy: {l2}')
-    l2 = l2_acc(belief_states, labels, remove_belief=True)
-    logger.info(f'Binary Model L2 Norm Trucated Goal Accuracy: {l2}')
-
-    if request_belief is not None:
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for slot in request_belief:
-            p = request_belief[slot]
-            l = request_labels[slot]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(request_belief)
-        fp /= len(request_belief)
-        fn /= len(request_belief)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Request F1 Score: %f' % f1.item())
-
-        for slot in request_belief:
-            p = request_belief[slot]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            request_belief[slot] = p
-        jg = jg_ece(request_belief, request_labels, 10)
-        logger.info('Request Joint Goal ECE: %f' % jg)
-
-        binary_states = {}
-        for slot, p in request_belief.items():
-            shp = p.shape
-            p = p.reshape(-1, p.size(-1))
-            p_ = torch.ones(p.shape).to(p.device) * 1e-8
-            p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-            binary_states[slot] = p_.reshape(shp)
-        jg = jg_ece(binary_states, request_labels, 10)
-        logger.info('Request Joint Goal Binary ECE: %f' % jg)
-
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for dom in domain_belief:
-            p = domain_belief[dom]
-            l = domain_labels[dom]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(domain_belief)
-        fp /= len(domain_belief)
-        fn /= len(domain_belief)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Domain F1 Score: %f' % f1.item())
-
-        for dom in domain_belief:
-            p = domain_belief[dom]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            domain_belief[dom] = p
-        jg = jg_ece(domain_belief, domain_labels, 10)
-        logger.info('Domain Joint Goal ECE: %f' % jg)
-
-        binary_states = {}
-        for slot, p in domain_belief.items():
-            shp = p.shape
-            p = p.reshape(-1, p.size(-1))
-            p_ = torch.ones(p.shape).to(p.device) * 1e-8
-            p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-            binary_states[slot] = p_.reshape(shp)
-        jg = jg_ece(binary_states, domain_labels, 10)
-        logger.info('Domain Joint Goal Binary ECE: %f' % jg)
-
-        tp = ((greeting_belief.argmax(-1) > 0) *
-              (greeting_labels > 0)).reshape(-1).float().sum()
-        fp = ((greeting_belief.argmax(-1) > 0) *
-              (greeting_labels == 0)).reshape(-1).float().sum()
-        fn = ((greeting_belief.argmax(-1) == 0) *
-              (greeting_labels > 0)).reshape(-1).float().sum()
-        f1 = tp / (tp + 0.5 * (fp + fn))
-        logger.info('Greeting F1 Score: %f' % f1.item())
-
-        err = ece(greeting_belief.reshape(-1, greeting_belief.size(-1)),
-                  greeting_labels.reshape(-1), 10)
-        logger.info('Greetings ECE: %f' % err)
-
-        greeting_belief = greeting_belief.reshape(-1, greeting_belief.size(-1))
-        binary_states = torch.ones(greeting_belief.shape).to(
-            greeting_belief.device) * 1e-8
-        binary_states[range(greeting_belief.size(0)),
-                      greeting_belief.argmax(-1)] = 1.0 - 1e-8
-        err = ece(binary_states, greeting_labels.reshape(-1), 10)
-        logger.info('Greetings Binary ECE: %f' % err)
-
-        for slot in request_belief:
-            p = request_belief[slot].unsqueeze(-1)
-            request_belief[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(request_belief, request_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
-        l2 = l2_acc(request_belief, request_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
-
-        for slot in domain_belief:
-            p = domain_belief[slot].unsqueeze(-1)
-            domain_belief[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(domain_belief, domain_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
-        l2 = l2_acc(domain_belief, domain_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
-
-        greeting_labels = {'bye': greeting_labels}
-        greeting_belief = {'bye': greeting_belief}
-
-        l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Greeting Accuracy: {l2}')
-        l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False)
-        logger.info(f'Binary Model L2 Norm Greeting Accuracy: {l2}')
-
-
-if __name__ == "__main__":
-    main()
diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py
new file mode 100644
index 00000000..2fe351b3
--- /dev/null
+++ b/convlab/dst/setsumbt/do/evaluate.py
@@ -0,0 +1,296 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Run SetSUMBT Calibration"""
+
+import logging
+import os
+
+import torch
+from transformers import (BertModel, BertConfig, BertTokenizer,
+                          RobertaModel, RobertaConfig, RobertaTokenizer)
+
+from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
+from convlab.dst.setsumbt.dataset import unified_format
+from convlab.dst.setsumbt.dataset import ontology as embeddings
+from convlab.dst.setsumbt.utils import get_args, update_args
+from convlab.dst.setsumbt.modeling import evaluation_utils
+from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc
+from convlab.dst.setsumbt.modeling import training
+
+
+# Available model
+MODELS = {
+    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
+    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
+}
+
+
+def main(args=None, config=None):
+    # Get arguments
+    if args is None:
+        args, config = get_args(MODELS)
+
+    if args.model_type in MODELS:
+        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
+    else:
+        raise NameError('NotImplemented')
+
+    # Set up output directory
+    OUTPUT_DIR = args.output_dir
+    args.output_dir = OUTPUT_DIR
+    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+
+    # Set pretrained model path to the trained checkpoint
+    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
+    if 'pytorch_model.bin' in paths and 'config.json' in paths:
+        args.model_name_or_path = args.output_dir
+        config = ConfigClass.from_pretrained(args.model_name_or_path)
+    else:
+        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
+        if paths:
+            paths = paths[0]
+            args.model_name_or_path = paths
+            config = ConfigClass.from_pretrained(args.model_name_or_path)
+
+    args = update_args(args, config)
+
+    # Create logger
+    global logger
+    logger = logging.getLogger(__name__)
+    logger.setLevel(logging.INFO)
+
+    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
+
+    fh = logging.FileHandler(args.logging_path)
+    fh.setLevel(logging.INFO)
+    fh.setFormatter(formatter)
+    logger.addHandler(fh)
+
+    # Get device
+    if torch.cuda.is_available() and args.n_gpu > 0:
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+        args.n_gpu = 0
+
+    if args.n_gpu == 0:
+        args.fp16 = False
+
+    # Set up model training/evaluation
+    evaluation_utils.set_seed(args)
+
+    # Perform tasks
+    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
+        pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
+        state_labels = pred['state_labels']
+        belief_states = pred['belief_states']
+        if 'request_labels' in pred:
+            request_labels = pred['request_labels']
+            request_probs = pred['request_probs']
+            active_domain_labels = pred['active_domain_labels']
+            active_domain_probs = pred['active_domain_probs']
+            general_act_labels = pred['general_act_labels']
+            general_act_probs = pred['general_act_probs']
+        else:
+            request_probs = None
+        del pred
+    else:
+        # Get training batch loaders and ontology embeddings
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size != args.test_batch_size:
+                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
+        else:
+            tokenizer = Tokenizer(config.candidate_embedding_model_name)
+            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
+                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                            config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
+            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
+        else:
+            encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
+            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
+                                                                  'test', args, tokenizer, encoder)
+
+        # Initialise Model
+        model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
+        model = model.to(device)
+
+        training.set_ontology_embeddings(model, test_slots)
+
+        belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader)
+        state_labels = belief_states[1]
+        request_probs = belief_states[2]
+        request_labels = belief_states[3]
+        active_domain_probs = belief_states[4]
+        active_domain_labels = belief_states[5]
+        general_act_probs = belief_states[6]
+        general_act_labels = belief_states[7]
+        belief_states = belief_states[0]
+        out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs,
+               'request_labels': request_labels, 'active_domain_probs': active_domain_probs,
+               'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs,
+               'general_act_labels': general_act_labels}
+        torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
+
+    # Calculate calibration metrics
+    jg = jg_ece(belief_states, state_labels, 10)
+    logger.info('Joint Goal ECE: %f' % jg)
+
+    jg_acc = 0.0
+    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
+    padding = (padding == len(state_labels))
+    padding = padding.reshape(-1)
+    for slot in belief_states:
+        p_ = belief_states[slot]
+        gold = state_labels[slot]
+
+        pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
+        acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad]
+        acc = torch.tensor(acc).float()
+
+        jg_acc += acc
+
+    n_turns = jg_acc.size(0)
+    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
+
+    jg_acc /= n_turns
+
+    logger.info(f'Joint Goal Accuracy: {jg_acc}')
+
+    l2 = l2_acc(belief_states, state_labels, remove_belief=False)
+    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
+    l2 = l2_acc(belief_states, state_labels, remove_belief=True)
+    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
+
+    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
+    padding = (padding == len(state_labels))
+    padding = padding.reshape(-1)
+
+    tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0
+    for slot in belief_states:
+        p_ = belief_states[slot]
+        gold = state_labels[slot].reshape(-1)
+        p_ = p_.reshape(-1, p_.size(-1))
+
+        p_ = p_[~padding].argmax(-1)
+        gold = gold[~padding]
+
+        tp += (p_ == gold)[gold != 0].int().sum().item()
+        fp += (p_ != 0)[gold == 0].int().sum().item()
+        fp += (p_ != gold)[gold != 0].int().sum().item()
+        fp -= (p_ == 0)[gold != 0].int().sum().item()
+        fn += (p_ == 0)[gold != 0].int().sum().item()
+        tn += (p_ == 0)[gold == 0].int().sum().item()
+        n += p_.size(0)
+
+    acc = (tp + tn) / n
+    prec = tp / (tp + fp)
+    rec = tp / (tp + fn)
+    f1 = 2 * (prec * rec) / (prec + rec)
+
+    logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}")
+
+    if request_probs is not None:
+        tp, fp, fn = 0.0, 0.0, 0.0
+        for slot in request_probs:
+            p = request_probs[slot]
+            l = request_labels[slot]
+
+            tp += (p.round().int() * (l == 1)).reshape(-1).float()
+            fp += (p.round().int() * (l == 0)).reshape(-1).float()
+            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
+        tp /= len(request_probs)
+        fp /= len(request_probs)
+        fn /= len(request_probs)
+        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
+        logger.info('Request F1 Score: %f' % f1.item())
+
+        for slot in request_probs:
+            p = request_probs[slot]
+            p = p.unsqueeze(-1)
+            p = torch.cat((1 - p, p), -1)
+            request_probs[slot] = p
+        jg = jg_ece(request_probs, request_labels, 10)
+        logger.info('Request Joint Goal ECE: %f' % jg)
+
+        tp, fp, fn = 0.0, 0.0, 0.0
+        for dom in active_domain_probs:
+            p = active_domain_probs[dom]
+            l = active_domain_labels[dom]
+
+            tp += (p.round().int() * (l == 1)).reshape(-1).float()
+            fp += (p.round().int() * (l == 0)).reshape(-1).float()
+            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
+        tp /= len(active_domain_probs)
+        fp /= len(active_domain_probs)
+        fn /= len(active_domain_probs)
+        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
+        logger.info('Domain F1 Score: %f' % f1.item())
+
+        for dom in active_domain_probs:
+            p = active_domain_probs[dom]
+            p = p.unsqueeze(-1)
+            p = torch.cat((1 - p, p), -1)
+            active_domain_probs[dom] = p
+        jg = jg_ece(active_domain_probs, active_domain_labels, 10)
+        logger.info('Domain Joint Goal ECE: %f' % jg)
+
+        tp = ((general_act_probs.argmax(-1) > 0) *
+              (general_act_labels > 0)).reshape(-1).float().sum()
+        fp = ((general_act_probs.argmax(-1) > 0) *
+              (general_act_labels == 0)).reshape(-1).float().sum()
+        fn = ((general_act_probs.argmax(-1) == 0) *
+              (general_act_labels > 0)).reshape(-1).float().sum()
+        f1 = tp / (tp + 0.5 * (fp + fn))
+        logger.info('General Act F1 Score: %f' % f1.item())
+
+        err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)),
+                  general_act_labels.reshape(-1), 10)
+        logger.info('General Act ECE: %f' % err)
+
+        for slot in request_probs:
+            p = request_probs[slot].unsqueeze(-1)
+            request_probs[slot] = torch.cat((1 - p, p), -1)
+
+        l2 = l2_acc(request_probs, request_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
+        l2 = l2_acc(request_probs, request_labels, remove_belief=True)
+        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
+
+        for slot in active_domain_probs:
+            p = active_domain_probs[slot].unsqueeze(-1)
+            active_domain_probs[slot] = torch.cat((1 - p, p), -1)
+
+        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
+        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True)
+        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
+
+        general_act_labels = {'general': general_act_labels}
+        general_act_probs = {'general': general_act_probs}
+
+        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm General Act Accuracy: {l2}')
+        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
+        logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}')
+
+
+if __name__ == "__main__":
+    main()
diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py
index 821dca59..276d13f2 100644
--- a/convlab/dst/setsumbt/do/nbt.py
+++ b/convlab/dst/setsumbt/do/nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,33 +16,27 @@
 """Run SetSUMBT training/eval"""
 
 import logging
-import random
 import os
 from shutil import copy2 as copy
+import json
+from copy import deepcopy
 
 import torch
-from torch.nn import DataParallel
+import transformers
 from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer,
-                          AdamW, get_linear_schedule_with_warmup)
-from tqdm import tqdm, trange
-import numpy as np
+                          RobertaModel, RobertaConfig, RobertaTokenizer)
 from tensorboardX import SummaryWriter
+from tqdm import tqdm
 
-from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
-from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.multiwoz import multiwoz21
+from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
+from convlab.dst.setsumbt.dataset import unified_format
 from convlab.dst.setsumbt.modeling import training
-from convlab.dst.setsumbt.multiwoz import ontology as embeddings
+from convlab.dst.setsumbt.dataset import ontology as embeddings
 from convlab.dst.setsumbt.utils import get_args, update_args
-from convlab.dst.setsumbt.modeling import ensemble_utils
+from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
 
 
-# Datasets
-DATASETS = {
-    'multiwoz21': multiwoz21
-}
-
+# Available model
 MODELS = {
     'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
     'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
@@ -54,12 +48,6 @@ def main(args=None, config=None):
     if args is None:
         args, config = get_args(MODELS)
 
-    # Select Dataset object
-    if args.dataset in DATASETS:
-        Dataset = DATASETS[args.dataset]
-    else:
-        raise NameError('NotImplemented')
-
     if args.model_type in MODELS:
         SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
     else:
@@ -74,53 +62,19 @@ def main(args=None, config=None):
     args.output_dir = OUTPUT_DIR
 
     # Set pretrained model path to the trained checkpoint
-    if args.do_train:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        paths = [os.path.join(args.output_dir, p)
-                 for p in paths if 'checkpoint-' in p]
+    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
+    if 'pytorch_model.bin' in paths and 'config.json' in paths:
+        args.model_name_or_path = args.output_dir
+        config = ConfigClass.from_pretrained(args.model_name_or_path)
+    else:
+        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
         if paths:
             paths = paths[0]
             args.model_name_or_path = paths
             config = ConfigClass.from_pretrained(args.model_name_or_path)
-        else:
-            paths = os.listdir(args.output_dir) if os.path.exists(
-                args.output_dir) else []
-            if 'pytorch_model.bin' in paths and 'config.json' in paths:
-                args.model_name_or_path = args.output_dir
-                config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        if 'pytorch_model.bin' in paths and 'config.json' in paths:
-            args.model_name_or_path = args.output_dir
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-        else:
-            paths = os.listdir(args.output_dir) if os.path.exists(
-                args.output_dir) else []
-            paths = [os.path.join(args.output_dir, p)
-                     for p in paths if 'checkpoint-' in p]
-            if paths:
-                paths = paths[0]
-                args.model_name_or_path = paths
-                config = ConfigClass.from_pretrained(args.model_name_or_path)
 
     args = update_args(args, config)
 
-    # Set up data directory
-    DATA_DIR = args.data_dir
-    Dataset.set_datadir(DATA_DIR)
-    embeddings.set_datadir(DATA_DIR)
-
-    # If use shrinked domains, remove bus and hospital domains from the training data and model ontology
-    if args.shrink_active_domains and args.dataset == 'multiwoz21':
-        Dataset.set_active_domains(
-            ['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
-
-    # Download and preprocess
-    Dataset.create_examples(
-        args.max_turn_len, args.predict_actions, args.force_processing)
-
     # Create TensorboardX writer
     tb_writer = SummaryWriter(logdir=args.tensorboard_path)
 
@@ -129,19 +83,12 @@ def main(args=None, config=None):
     logger = logging.getLogger(__name__)
     logger.setLevel(logging.INFO)
 
-    formatter = logging.Formatter(
-        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
 
-    if 'stream' not in args.logging_path:
-        fh = logging.FileHandler(args.logging_path)
-        fh.setLevel(logging.INFO)
-        fh.setFormatter(formatter)
-        logger.addHandler(fh)
-    else:
-        ch = logging.StreamHandler()
-        ch.setLevel(level=logging.INFO)
-        ch.setFormatter(formatter)
-        logger.addHandler(ch)
+    fh = logging.FileHandler(args.logging_path)
+    fh.setLevel(logging.INFO)
+    fh.setFormatter(formatter)
+    logger.addHandler(fh)
 
     # Get device
     if torch.cuda.is_available() and args.n_gpu > 0:
@@ -154,14 +101,12 @@ def main(args=None, config=None):
         args.fp16 = False
 
     # Initialise Model
-    model = SetSumbtModel.from_pretrained(
-        args.model_name_or_path, config=config)
+    transformers.utils.logging.set_verbosity_info()
+    model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
     model = model.to(device)
 
     # Create Tokenizer and embedding model for Data Loaders and ontology
-    encoder = model.roberta if args.model_type == 'roberta' else None
-    encoder = model.bert if args.model_type == 'bert' else encoder
-
+    encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
     tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config)
 
     # Set up model training/evaluation
@@ -169,88 +114,107 @@ def main(args=None, config=None):
     training.set_seed(args)
     embeddings.set_seed(args)
 
+    transformers.utils.logging.set_verbosity_error()
     if args.ensemble_size > 1:
-        ensemble_utils.set_logger(logger, tb_writer)
-        ensemble.set_seed(args)
-        logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
-                                                                           args.data_sampling_size))
-        dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset)
+        # Build all dataloaders
+        train_dataloader = unified_format.get_dataloader(args.dataset,
+                                                         'train',
+                                                         args.train_batch_size,
+                                                         tokenizer,
+                                                         args.max_dialogue_len,
+                                                         args.max_turn_len,
+                                                         train_ratio=args.dataset_train_ratio,
+                                                         seed=args.seed)
+        torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+        dev_dataloader = unified_format.get_dataloader(args.dataset,
+                                                       'validation',
+                                                       args.dev_batch_size,
+                                                       tokenizer,
+                                                       args.max_dialogue_len,
+                                                       args.max_turn_len,
+                                                       train_ratio=args.dataset_train_ratio,
+                                                       seed=args.seed)
+        torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        test_dataloader = unified_format.get_dataloader(args.dataset,
+                                                        'test',
+                                                        args.test_batch_size,
+                                                        tokenizer,
+                                                        args.max_dialogue_len,
+                                                        args.max_turn_len,
+                                                        train_ratio=args.dataset_train_ratio,
+                                                        seed=args.seed)
+        torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder)
+        embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder)
+        embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder)
+
+        setup_ensemble(OUTPUT_DIR, args.ensemble_size)
+
+        logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
+        dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
+                       for _ in tqdm(range(args.ensemble_size))]
         logger.info('Dataloaders built.')
+
         for i, loader in enumerate(dataloaders):
-            path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i)
+            path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
             if not os.path.exists(path):
                 os.mkdir(path)
-            path = os.path.join(path, 'train.dataloader')
+            path = os.path.join(path, 'dataloaders', 'train.dataloader')
             torch.save(loader, path)
         logger.info('Dataloaders saved.')
 
-        train_slots = embeddings.get_slot_candidate_embeddings(
-            'train', args, tokenizer, encoder)
-        dev_slots = embeddings.get_slot_candidate_embeddings(
-            'dev', args, tokenizer, encoder)
-        test_slots = embeddings.get_slot_candidate_embeddings(
-            'test', args, tokenizer, encoder)
-
-        train_dataloader = Dataset.get_dataloader(
-            'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
-        torch.save(dev_dataloader, os.path.join(
-            OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-        dev_dataloader = Dataset.get_dataloader(
-            'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
-        torch.save(dev_dataloader, os.path.join(
-            OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-        test_dataloader = Dataset.get_dataloader(
-            'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
-        torch.save(test_dataloader, os.path.join(
-            OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
         # Do not perform standard training after ensemble setup is created
         return 0
 
     # Perform tasks
     # TRAINING
     if args.do_train:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
-            train_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'train.db'))
-        else:
-            train_slots = embeddings.get_slot_candidate_embeddings(
-                'train', args, tokenizer, encoder)
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-            dev_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'dev.db'))
-        else:
-            dev_slots = embeddings.get_slot_candidate_embeddings(
-                'dev', args, tokenizer, encoder)
-
-        exists = False
         if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
-            train_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-            if train_dataloader.batch_size == args.train_batch_size:
-                exists = True
-        if not exists:
+            train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            if train_dataloader.batch_size != args.train_batch_size:
+                train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size)
+        else:
             if args.data_sampling_size <= 0:
                 args.data_sampling_size = None
-            train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
-                                                      config.max_turn_len, resampled_size=args.data_sampling_size)
-            torch.save(train_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            train_dataloader = unified_format.get_dataloader(args.dataset,
+                                                             'train',
+                                                             args.train_batch_size,
+                                                             tokenizer,
+                                                             args.max_dialogue_len,
+                                                             config.max_turn_len,
+                                                             resampled_size=args.data_sampling_size,
+                                                             train_ratio=args.dataset_train_ratio,
+                                                             seed=args.seed)
+            torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+
+        # Get training batch loaders and ontology embeddings
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
+            train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db'))
+        else:
+            train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology,
+                                                                   'train', args, tokenizer, encoder)
 
         # Get development set batch loaders= and ontology embeddings
         if args.do_eval:
-            exists = False
             if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-                dev_dataloader = torch.load(os.path.join(
-                    OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-                if dev_dataloader.batch_size == args.dev_batch_size:
-                    exists = True
-            if not exists:
-                dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
-                                                        config.max_turn_len)
-                torch.save(dev_dataloader, os.path.join(
-                    OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+                dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+                if dev_dataloader.batch_size != args.dev_batch_size:
+                    dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
+            else:
+                dev_dataloader = unified_format.get_dataloader(args.dataset,
+                                                               'validation',
+                                                               args.dev_batch_size,
+                                                               tokenizer,
+                                                               args.max_dialogue_len,
+                                                               config.max_turn_len)
+                torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+
+            if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
+                dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
+            else:
+                dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
+                                                                     'dev', args, tokenizer, encoder)
         else:
             dev_dataloader = None
             dev_slots = None
@@ -259,94 +223,80 @@ def main(args=None, config=None):
         training.set_ontology_embeddings(model, train_slots)
 
         # TRAINING !!!!!!!!!!!!!!!!!!
-        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots,
-                       embeddings=embeddings, tokenizer=tokenizer)
+        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots)
 
         # Copy final best model to the output dir
         checkpoints = os.listdir(OUTPUT_DIR)
         checkpoints = [p for p in checkpoints if 'checkpoint' in p]
         checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
-        best_checkpoint = checkpoints[-1]
-        best_checkpoint = os.path.join(
-            OUTPUT_DIR, f'checkpoint-{best_checkpoint}')
-        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'),
-             os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
-        copy(os.path.join(best_checkpoint, 'config.json'),
-             os.path.join(OUTPUT_DIR, 'config.json'))
+        best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
+        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
+        copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json'))
 
         # Load best model for evaluation
-        model = SumbtModel.from_pretrained(OUTPUT_DIR)
+        model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
         model = model.to(device)
 
     # Evaluation on the development set
     if args.do_eval:
-        # Get development set batch loaders= and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-            dev_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'dev.db'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
+            dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+            if dev_dataloader.batch_size != args.dev_batch_size:
+                dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
         else:
-            dev_slots = embeddings.get_slot_candidate_embeddings(
-                'dev', args, tokenizer, encoder)
+            dev_dataloader = unified_format.get_dataloader(args.dataset,
+                                                           'validation',
+                                                           args.dev_batch_size,
+                                                           tokenizer,
+                                                           args.max_dialogue_len,
+                                                           config.max_turn_len)
+            torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
 
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-            dev_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-            if dev_dataloader.batch_size == args.dev_batch_size:
-                exists = True
-        if not exists:
-            dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
-                                                    config.max_turn_len)
-            torch.save(dev_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
+            dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
+        else:
+            dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
+                                                                 'dev', args, tokenizer, encoder)
 
         # Load model ontology
         training.set_ontology_embeddings(model, dev_slots)
 
         # EVALUATION
-        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
-            args, model, device, dev_dataloader)
-        if req_f1:
-            logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
-                        % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-        else:
-            logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f'
-                        % (loss, jg_acc, sl_acc))
+        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader)
+        training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
 
     # Evaluation on the test set
     if args.do_test:
-        # Get test set batch loaders= and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size != args.test_batch_size:
+                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
         else:
-            test_slots = embeddings.get_slot_candidate_embeddings(
-                'test', args, tokenizer, encoder)
+            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
+                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                            config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
 
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size == args.test_batch_size:
-                exists = True
-        if not exists:
-            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                     config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
+            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
+        else:
+            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
+                                                                  'test', args, tokenizer, encoder)
 
         # Load model ontology
         training.set_ontology_embeddings(model, test_slots)
 
         # TESTING
-        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
-            args, model, device, test_dataloader)
-        if req_f1:
-            logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
-                        % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-        else:
-            logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f'
-                        % (loss, jg_acc, sl_acc))
+        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader,
+                                                                                 return_eval_output=True)
+
+        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+        writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w')
+        json.dump(output, writer)
+        writer.close()
+
+        training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
 
     tb_writer.close()
 
diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py
new file mode 100644
index 00000000..475f7646
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/__init__.py
@@ -0,0 +1,4 @@
+from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss
+from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
+from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
+from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss
diff --git a/convlab/dst/setsumbt/loss/bayesian.py b/convlab/dst/setsumbt/loss/bayesian.py
deleted file mode 100644
index e52d8d07..00000000
--- a/convlab/dst/setsumbt/loss/bayesian.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Bayesian Matching Activation and Loss Functions"""
-
-import torch
-from torch import digamma, lgamma
-from torch.nn import Module
-
-
-# Inverse Linear activation function
-def invlinear(x):
-    z = (1.0 / (1.0 - x)) * (x < 0)
-    z += (1.0 + x) * (x >= 0)
-    return z
-
-# Exponential activation function
-def exponential(x):
-    return torch.exp(x)
-
-
-# Dirichlet activation function for the model
-def dirichlet(a):
-    p = exponential(a)
-    repeat_dim = (1,)*(len(p.shape)-1) + (p.size(-1),)
-    p = p / p.sum(-1).unsqueeze(-1).repeat(repeat_dim)
-    return p
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BayesianMatchingLoss(Module):
-
-    def __init__(self, lamb=0.01, ignore_index=-1):
-        super(BayesianMatchingLoss, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, prior=None):
-        # Assert input sizes
-        assert alpha.dim() == 2                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.max() <= alpha.size(-1):
-            dimension = alpha.size(-1)
-        else:
-            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
-        
-        # Remove observations with no labels
-        if prior is not None:
-            prior = prior[labels != self.ignore_index]
-        alpha = exponential(alpha[labels != self.ignore_index])
-        labels = labels[labels != self.ignore_index]
-        
-        # Initialise and reshape prior parameters
-        if prior is None:
-            prior = torch.ones(dimension)
-        prior = prior.to(alpha.device)
-
-        # KL divergence term
-        lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1)
-        e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1)))
-        e = ((alpha - prior) * e).sum(-1)
-        kl = lb + e
-        kl *= self.lamb
-        del lb, e, prior
-
-        # Expected log likelihood
-        expected_likelihood = digamma(alpha[range(labels.size(0)), labels]) - digamma(alpha.sum(1))
-        del alpha, labels
-
-        # Apply ELBO loss and mean reduction
-        loss = (kl - expected_likelihood).mean()
-        del kl, expected_likelihood
-
-        return loss
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BinaryBayesianMatchingLoss(Module):
-
-    def __init__(self, lamb=0.01, ignore_index=-1):
-        super(BinaryBayesianMatchingLoss, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, prior=None):
-        # Assert input sizes
-        assert alpha.dim() == 1                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.max() <= 2:
-            dimension = 2
-        else:
-            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
-        
-        # Remove observations with no labels
-        if prior is not None:
-            prior = prior[labels != self.ignore_index]
-        alpha = alpha[labels != self.ignore_index]
-        alpha_sum = 1 + (1 / self.lamb)
-        alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1)
-        alpha = torch.cat((alpha_sum - alpha, alpha), 1)
-        labels = labels[labels != self.ignore_index]
-        
-        # Initialise and reshape prior parameters
-        if prior is None:
-            prior = torch.ones(dimension)
-        prior = prior.to(alpha.device)
-
-        # KL divergence term
-        lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1)
-        e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1)))
-        e = ((alpha - prior) * e).sum(-1)
-        kl = lb + e
-        kl *= self.lamb
-        del lb, e, prior
-
-        # Expected log likelihood
-        expected_likelihood = digamma(alpha[range(labels.size(0)), labels.long()]) - digamma(alpha.sum(1))
-        del alpha, labels
-
-        # Apply ELBO loss and mean reduction
-        loss = (kl - expected_likelihood).mean()
-        del kl, expected_likelihood
-
-        return loss
diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/loss/bayesian_matching.py
new file mode 100644
index 00000000..3e91444d
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/bayesian_matching.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Bayesian Matching Activation and Loss Functions (see https://arxiv.org/pdf/2002.07965.pdf for details)"""
+
+import torch
+from torch import digamma, lgamma
+from torch.nn import Module
+
+
+class BayesianMatchingLoss(Module):
+    """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
+
+    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Weighting factor for the KL Divergence component
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BayesianMatchingLoss, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+            prior (Tensor): Prior distribution over label classes
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        # Assert input sizes
+        assert inputs.dim() == 2                 # Observations, predictive distribution
+        assert labels.dim() == 1                # Label for each observation
+        assert labels.size(0) == inputs.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if labels.max() <= inputs.size(-1):
+            dimension = inputs.size(-1)
+        else:
+            raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.')
+        
+        # Remove observations to be ignored in loss calculation
+        if prior is not None:
+            prior = prior[labels != self.ignore_index]
+        inputs = torch.exp(inputs[labels != self.ignore_index])
+        labels = labels[labels != self.ignore_index]
+        
+        # Initialise and reshape prior parameters
+        if prior is None:
+            prior = torch.ones(dimension).to(inputs.device)
+        prior = prior.to(inputs.device)
+
+        # KL divergence term (divergence of predictive distribution from prior over label classes - regularisation term)
+        log_gamma_term = lgamma(inputs.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(inputs)).sum(-1)
+        div_term = digamma(inputs) - digamma(inputs.sum(-1)).unsqueeze(-1).repeat((1, inputs.size(-1)))
+        div_term = ((inputs - prior) * div_term).sum(-1)
+        kl_term = log_gamma_term + div_term
+        kl_term *= self.lamb
+        del log_gamma_term, div_term, prior
+
+        # Expected log likelihood
+        expected_likelihood = digamma(inputs[range(labels.size(0)), labels]) - digamma(inputs.sum(-1))
+        del inputs, labels
+
+        # Apply ELBO loss and mean reduction
+        loss = (kl_term - expected_likelihood).mean()
+        del kl_term, expected_likelihood
+
+        return loss
+
+
+class BinaryBayesianMatchingLoss(BayesianMatchingLoss):
+    """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
+
+    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Weighting factor for the KL Divergence component
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index)
+
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+            prior (Tensor): Prior distribution over label classes
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        
+        # Create 2D input dirichlet distribution
+        input_sum = 1 + (1 / self.lamb)
+        inputs = (torch.sigmoid(inputs) * input_sum).reshape(-1, 1)
+        inputs = torch.cat((input_sum - inputs, inputs), 1)
+
+        return super().forward(inputs, labels, prior=prior)
diff --git a/convlab/dst/setsumbt/loss/distillation.py b/convlab/dst/setsumbt/loss/distillation.py
deleted file mode 100644
index 3cf13f10..00000000
--- a/convlab/dst/setsumbt/loss/distillation.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Bayesian Matching Activation and Loss Functions"""
-
-import torch
-from torch import lgamma, log
-from torch.nn import Module
-from torch.nn.functional import kl_div
-
-from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class DistillationKL(Module):
-
-    def __init__(self, lamb=1e-4, ignore_index=-1):
-        super(DistillationKL, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, temp=1.0):
-        # Assert input sizes
-        assert alpha.dim() == 2                 # Observations, predictive distribution
-        assert labels.dim() == 2                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.size(-1) == alpha.size(-1):
-            dimension = alpha.size(-1)
-        else:
-            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-        alpha = torch.log(torch.softmax(alpha / temp, -1))
-        ids = torch.where(labels[:, 0] != self.ignore_index)[0]
-        alpha = alpha[ids]
-        labels = labels[ids]
-
-        labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))
-
-        kl = kl_div(alpha, labels, reduction='none').sum(-1).mean()
-        return kl    
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BinaryDistillationKL(Module):
-
-    def __init__(self, lamb=1e-4, ignore_index=-1):
-        super(BinaryDistillationKL, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, temp=0.0):
-        # Assert input sizes
-        assert alpha.dim() == 1                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        # if labels.size(-1) == alpha.size(-1):
-        #     dimension = alpha.size(-1)
-        # else:
-        #     raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-        alpha = torch.sigmoid(alpha / temp).unsqueeze(-1)
-        ids = torch.where(labels != self.ignore_index)[0]
-        alpha = alpha[ids]
-        labels = labels[ids]
-
-        alpha = torch.log(torch.cat((1 - alpha, alpha), 1))
-        
-        labels = labels.unsqueeze(-1)
-        labels = torch.cat((1 - labels, labels), -1)
-        labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))
-
-        kl = kl_div(alpha, labels, reduction='none').sum(-1).mean()
-        return kl  
-
-
-# def smart_sort(x, permutation):
-#     assert x.dim() == permutation.dim()
-#     if x.dim() == 3:
-#         d1, d2, d3 = x.size()
-#         ret = x[torch.arange(d1).unsqueeze(-1).unsqueeze(-1).repeat((1, d2, d3)).flatten(),
-#                 torch.arange(d2).unsqueeze(0).unsqueeze(-1).repeat((d1, 1, d3)).flatten(),
-#                 permutation.flatten()].view(d1, d2, d3)
-#         return ret
-#     elif x.dim() == 2:
-#         d1, d2 = x.size()
-#         ret = x[torch.arange(d1).unsqueeze(-1).repeat((1, d2)).flatten(),
-#                 permutation.flatten()].view(d1, d2)
-#         return ret
-
-
-# # Pytorch BayesianMatchingLoss nn.Module
-# class DistillationNLL(Module):
-
-#     def __init__(self, lamb=1e-4, ignore_index=-1):
-#         super(DistillationNLL, self).__init__()
-
-#         self.lamb = lamb
-#         self.ignore_index = ignore_index
-#         self.loss_add = BayesianMatchingLoss(lamb=0.001)
-    
-#     def forward(self, alpha, labels, temp=1.0):
-#         # Assert input sizes
-#         assert alpha.dim() == 2                 # Observations, predictive distribution
-#         assert labels.dim() == 3                # Label for each observation
-#         assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-#         # Confirm predictive distribution dimension
-#         if labels.size(-1) == alpha.size(-1):
-#             dimension = alpha.size(-1)
-#         else:
-#             raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-#         alpha = torch.exp(alpha / temp)
-#         ids = torch.where(labels[:, 0, 0] != self.ignore_index)[0]
-#         alpha = alpha[ids]
-#         labels = labels[ids]
-
-#         best_labels = labels.mean(-2).argmax(-1)
-#         loss2 = self.loss_add(alpha, best_labels)
-
-#         topn = labels.mean(-2).argsort(-1, descending=True)
-#         n = 10
-#         alpha = smart_sort(alpha, topn)[:, :n]
-#         labels = smart_sort(labels, topn.unsqueeze(-2).repeat((1, labels.size(-2), 1)))
-#         labels = labels[:, :, :n]
-#         labels = labels / labels.sum(-1).unsqueeze(-1).repeat((1, 1, labels.size(-1)))
-
-#         labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))))
-
-#         loss = (alpha - 1) * labels.mean(-2)
-#         # loss = (alpha - 1) * labels
-#         loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) 
-#         loss = -1.0 * loss.mean()
-#         # loss = -1.0 * loss.mean() / alpha.size(-1)
-
-#         return loss      
-
-
-# # Pytorch BayesianMatchingLoss nn.Module
-# class BinaryDistillationNLL(Module):
-
-#     def __init__(self, lamb=1e-4, ignore_index=-1):
-#         super(BinaryDistillationNLL, self).__init__()
-
-#         self.lamb = lamb
-#         self.ignore_index = ignore_index
-    
-#     def forward(self, alpha, labels, temp=0.0):
-#         # Assert input sizes
-#         assert alpha.dim() == 1                 # Observations, predictive distribution
-#         assert labels.dim() == 2                # Label for each observation
-#         assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-#         # Confirm predictive distribution dimension
-#         # if labels.size(-1) == alpha.size(-1):
-#         #     dimension = alpha.size(-1)
-#         # else:
-#         #     raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-#         # Remove observations with no labels
-#         ids = torch.where(labels[:, 0] != self.ignore_index)[0]
-#         # alpha_sum = 1 + (1 / self.lamb)
-#         alpha_sum = 10.0
-#         alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1)
-#         alpha = alpha[ids]
-#         labels = labels[ids]
-
-#         if temp != 1.0:
-#             alpha = torch.log(alpha + 1e-4)
-#             alpha = torch.exp(alpha / temp)
-
-#         alpha = torch.cat((alpha_sum - alpha, alpha), 1)
-        
-#         labels = labels.unsqueeze(-1)
-#         labels = torch.cat((1 - labels, labels), -1)
-#         # labels[labels[:, 0, 0] == self.ignore_index] = 1
-#         labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))))
-
-#         loss = (alpha - 1) * labels.mean(-2)
-#         loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1)
-#         loss = -1.0 * loss.mean()
-
-#         return loss    
diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/loss/endd_loss.py
index d84c3f72..9bd794bf 100644
--- a/convlab/dst/setsumbt/loss/endd_loss.py
+++ b/convlab/dst/setsumbt/loss/endd_loss.py
@@ -1,30 +1,46 @@
 import torch
+from torch.nn import Module
+from torch.nn.functional import kl_div
 
 EPS = torch.finfo(torch.float32).eps
 
+
 @torch.no_grad()
-def compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs):
-    mkl = torch.nn.functional.kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_probs),
-                                     reduction='none').sum(-1).mean(1)
-    return mkl
+def compute_mkl(ensemble_mean_probs: torch.Tensor, ensemble_logprobs: torch.Tensor) -> torch.Tensor:
+    """
+    Computing MKL in ensemble.
+
+    Args:
+        ensemble_mean_probs (Tensor): Marginal predictive distribution of the ensemble
+        ensemble_logprobs (Tensor): Log predictive distributions of individual ensemble members
+
+    Returns:
+        mkl (Tensor): MKL
+    """
+    mkl = kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_logprobs),reduction='none')
+    return mkl.sum(-1).mean(1)
+
 
 @torch.no_grad()
-def compute_ensemble_stats(ensemble_logits):
-    # ensemble_probs = torch.softmax(ensemble_logits, dim=-1)
-    # ensemble_mean_probs = ensemble_probs.mean(dim=1)
-    # ensemble_logprobs = torch.log_softmax(ensemble_logits, dim=-1)
-    ensemble_probs = ensemble_logits
+def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict:
+    """
+    Compute a range of ensemble uncertainty measures
+
+    Args:
+        ensemble_probs (Tensor): Predictive distributions of the ensemble members
+
+    Returns:
+        stats (dict): Dictionary of ensemble uncertainty measures
+    """
     ensemble_mean_probs = ensemble_probs.mean(dim=1)
-    num_classes = ensemble_logits.size(-1)
-    ensemble_logprobs = torch.log(ensemble_logits + (1e-4 / num_classes))
+    num_classes = ensemble_probs.size(-1)
+    ensemble_logprobs = torch.log(ensemble_probs + (1e-4 / num_classes))
 
     entropy_of_expected = torch.distributions.Categorical(probs=ensemble_mean_probs).entropy()
     expected_entropy = torch.distributions.Categorical(probs=ensemble_probs).entropy().mean(dim=1)
     mutual_info = entropy_of_expected - expected_entropy
 
-    mkl = compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs)
-
-    # num_classes = ensemble_logits.size(-1)
+    mkl = compute_mkl(ensemble_mean_probs, ensemble_logprobs)
 
     ensemble_precision = (num_classes - 1) / (2 * mkl.unsqueeze(1) + EPS)
 
@@ -39,108 +55,226 @@ def compute_ensemble_stats(ensemble_logits):
     }
     return stats
 
-def entropy(probs, dim: int = -1):
+
+def entropy(probs: torch.Tensor, dim: int = -1) -> torch.Tensor:
+    """
+    Compute entropy in a predictive distribution
+
+    Args:
+        probs (Tensor): Predictive distributions
+        dim (int): Dimension representing the predictive probabilities for a single prediction
+
+    Returns:
+        entropy (Tensor): Entropy
+    """
     return -(probs * (probs + EPS).log()).sum(dim=dim)
 
 
-def compute_dirichlet_uncertainties(dirichlet_params, precisions, expected_dirichlet):
+def compute_dirichlet_uncertainties(dirichlet_params: torch.Tensor,
+                                    precisions: torch.Tensor,
+                                    expected_dirichlet: torch.Tensor) -> tuple:
     """
     Function which computes measures of uncertainty for Dirichlet model.
-    :param dirichlet_params:  Tensor of size [batch_size, n_classes] of Dirichlet concentration parameters.
-    :param precisions: Tensor of size [batch_size, 1] of Dirichlet Precisions
-    :param expected_dirichlet: Tensor of size [batch_size, n_classes] of probablities of expected categorical under Dirichlet.
-    :return: Tensors of token level uncertainties of size [batch_size]
+
+    Args:
+        dirichlet_params (Tensor): Dirichlet concentration parameters.
+        precisions (Tensor): Dirichlet Precisions
+        expected_dirichlet (Tensor): Probabities of expected categorical under Dirichlet.
+
+    Returns:
+        stats (tuple): Token level uncertainties
     """
     batch_size, n_classes = dirichlet_params.size()
 
     entropy_of_expected = entropy(expected_dirichlet)
 
-    expected_entropy = (
-            -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))).sum(dim=-1)
+    expected_entropy = -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))
+    expected_entropy = expected_entropy.sum(dim=-1)
 
-    mutual_information = -((expected_dirichlet + EPS) * (
-            torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) + torch.digamma(
-        precisions + 1 + EPS))).sum(dim=-1)
-    # assert torch.allclose(mutual_information, entropy_of_expected - expected_entropy, atol=1e-4, rtol=0)
+    mutual_information = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS)
+    mutual_information += torch.digamma(precisions + 1 + EPS)
+    mutual_information *= -(expected_dirichlet + EPS)
+    mutual_information = mutual_information.sum(dim=-1)
 
     epkl = (n_classes - 1) / precisions.squeeze(-1)
 
-    mkl = (expected_dirichlet * (
-            torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) + torch.digamma(
-        precisions + EPS))).sum(dim=-1)
+    mkl = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS)
+    mkl += torch.digamma(precisions + EPS)
+    mkl *= expected_dirichlet
+    mkl = mkl.sum(dim=-1)
+
+    stats = (entropy_of_expected.clamp(min=0), expected_entropy.clamp(min=0), mutual_information.clamp(min=0))
+    stats += (epkl.clamp(min=0), mkl.clamp(min=0))
+
+    return stats
+
+
+def get_dirichlet_parameters(logits: torch.Tensor,
+                             parametrization,
+                             add_to_alphas: float = 0,
+                             dtype=torch.double) -> tuple:
+    """
+    Get dirichlet parameters from model logits
 
-    return entropy_of_expected.clamp(min=0), \
-           expected_entropy.clamp(min=0), \
-           mutual_information.clamp(min=0), \
-           epkl.clamp(min=0), \
-           mkl.clamp(min=0)
+    Args:
+        logits (Tensor): Model logits
+        parametrization (function): Mapping from logits to concentration parameters
+        add_to_alphas (float): Addition constant for stability
+        dtype (data type): Data type of the parameters
 
-def get_dirichlet_parameters(logits, parametrization, add_to_alphas=0, dtype=torch.double):
+    Return:
+        params (tuple): Concentration and precision parameters of the model Dirichlet
+    """
     max_val = torch.finfo(dtype).max / logits.size(-1) - 1
     alphas = torch.clip(parametrization(logits.to(dtype=dtype)) + add_to_alphas, max=max_val)
     precision = torch.sum(alphas, dim=-1, dtype=dtype)
     return alphas, precision
 
 
-def logits_to_mutual_info(logits):
-    alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0)
+def logits_to_mutual_info(logits: torch.Tensor) -> torch.Tensor:
+    """
+    Map modfel logits to mutual information of model Dirichlet
 
-    unsqueezed_precision = precision.unsqueeze(1)
-    normalized_probs = alphas / unsqueezed_precision
+    Args:
+        logits (Tensor): Model logits
 
-    entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas,
-                                                                                                           unsqueezed_precision,
-                                                                                                           normalized_probs)
-    
-    # Max entropy is log(K) for K classes. Hence relative MI is calculated as MI/log(K)
-    # mutual_information /= torch.log(torch.tensor(logits.size(-1)))
-    
-    return mutual_information
+    Returns:
+        mutual_information (Tensor): Mutual information of the model Dirichlet
+    """
+    alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0)
 
+    normalized_probs = alphas / precision.unsqueeze(1)
 
-def rkl_dirichlet_mediator_loss(logits, ensemble_stats, model_offset, target_offset, parametrization=torch.exp):
-    turns = torch.where(ensemble_stats[:, 0, 0] != -1)[0]
-    logits = logits[turns]
-    ensemble_stats = ensemble_stats[turns]
+    _, _, mutual_information, _, _ = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs)
     
-    ensemble_stats = compute_ensemble_stats(ensemble_stats)
-
-    alphas, precision = get_dirichlet_parameters(logits, parametrization, model_offset)
-
-    unsqueezed_precision = precision.unsqueeze(1)
-    normalized_probs = alphas / unsqueezed_precision
-
-    entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas,
-                                                                                                           unsqueezed_precision,
-                                                                                                           normalized_probs)
-
-    stats = {
-        'alpha_min': alphas.min(),
-        'alpha_mean': alphas.mean(),
-        'precision': precision,
-        'entropy_of_expected': entropy_of_expected,
-        'mutual_info': mutual_information,
-        'mkl': mkl,
-    }
-
-    num_classes = alphas.size(-1)
-
-    ensemble_precision = ensemble_stats['precision']
-
-    ensemble_precision += target_offset * num_classes
-    ensemble_probs = ensemble_stats['mean_probs']
-
-    expected_KL_term = -1.0 * torch.sum(ensemble_probs * (torch.digamma(alphas + EPS)
-                                                          - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1)
-    assert torch.isfinite(expected_KL_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype)
-
-    differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) \
-                                   - torch.sum(
-        (alphas - 1) * (torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1)
-    assert torch.isfinite(differential_negentropy_term).all()
-
-    cost = expected_KL_term - differential_negentropy_term / ensemble_precision.squeeze(-1)
+    return mutual_information
 
-    assert torch.isfinite(cost).all()
-    return torch.mean(cost), stats, ensemble_stats
 
+class RKLDirichletMediatorLoss(Module):
+    """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)"""
+
+    def __init__(self,
+                 model_offset: float = 1.0,
+                 target_offset: float = 1,
+                 ignore_index: int = -1,
+                 parameterization=torch.exp):
+        """
+        Args:
+            model_offset (float): Offset of model Dirichlet for stability
+            target_offset (float): Offset of target Dirichlet for stability
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+            parameterization (function): Mapping from logits to concentration parameters
+        """
+        super(RKLDirichletMediatorLoss, self).__init__()
+
+        self.model_offset = model_offset
+        self.target_offset = target_offset
+        self.ignore_index = ignore_index
+        self.parameterization = parameterization
+
+    def logits_to_mutual_info(self, logits: torch.Tensor) -> torch.Tensor:
+        """
+        Map modfel logits to mutual information of model Dirichlet
+
+        Args:
+            logits (Tensor): Model logits
+
+        Returns:
+            mutual_information (Tensor): Mutual information of the model Dirichlet
+        """
+        return logits_to_mutual_info(logits)
+
+    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            logits (Tensor): Model logits
+            targets (Tensor): Ensemble predictive distributions
+
+        Returns:
+            loss (Tensor): RKL dirichlet mediator loss value
+        """
+
+        # Remove padding
+        turns = torch.where(targets[:, 0, 0] != self.ignore_index)[0]
+        logits = logits[turns]
+        targets = targets[turns]
+
+        ensemble_stats = compute_ensemble_stats(targets)
+
+        alphas, precision = get_dirichlet_parameters(logits, self.parameterization, self.model_offset)
+
+        normalized_probs = alphas / precision.unsqueeze(1)
+
+        stats = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs)
+        entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = stats
+
+        stats = {
+            'alpha_min': alphas.min(),
+            'alpha_mean': alphas.mean(),
+            'precision': precision,
+            'entropy_of_expected': entropy_of_expected,
+            'mutual_info': mutual_information,
+            'mkl': mkl,
+        }
+
+        num_classes = alphas.size(-1)
+
+        ensemble_precision = ensemble_stats['precision']
+
+        ensemble_precision += self.target_offset * num_classes
+        ensemble_probs = ensemble_stats['mean_probs']
+
+        expected_kl_term = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)
+        expected_kl_term = -1.0 * torch.sum(ensemble_probs * expected_kl_term, dim=-1)
+        assert torch.isfinite(expected_kl_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype)
+
+        differential_negentropy_term_ = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)
+        differential_negentropy_term_ *= alphas - 1.0
+        differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS)
+        differential_negentropy_term -= torch.sum(differential_negentropy_term_, dim=-1)
+        assert torch.isfinite(differential_negentropy_term).all()
+
+        loss = expected_kl_term - differential_negentropy_term / ensemble_precision.squeeze(-1)
+        assert torch.isfinite(loss).all()
+
+        return torch.mean(loss), stats, ensemble_stats
+
+
+class BinaryRKLDirichletMediatorLoss(RKLDirichletMediatorLoss):
+    """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)"""
+
+    def __init__(self,
+                 model_offset: float = 1.0,
+                 target_offset: float = 1,
+                 ignore_index: int = -1,
+                 parameterization=torch.exp):
+        """
+        Args:
+            model_offset (float): Offset of model Dirichlet for stability
+            target_offset (float): Offset of target Dirichlet for stability
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+            parameterization (function): Mapping from logits to concentration parameters
+        """
+        super(BinaryRKLDirichletMediatorLoss, self).__init__(model_offset, target_offset,
+                                                             ignore_index, parameterization)
+
+    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            logits (Tensor): Model logits
+            targets (Tensor): Ensemble predictive distributions
+
+        Returns:
+            loss (Tensor): RKL dirichlet mediator loss value
+        """
+        # Convert single target probability p to distribution [1-p, p]
+        targets = targets.reshape(-1, targets.size(-1), 1)
+        targets = torch.cat([1 - targets, targets], -1)
+        targets[targets[:, 0, 1] == self.ignore_index] = self.ignore_index
+
+        # Convert input logits into predictive distribution [1-z, z]
+        logits = torch.sigmoid(logits).unsqueeze(1)
+        logits = torch.cat((1 - logits, logits), 1)
+        logits = -1.0 * torch.log((1 / (logits + 1e-8)) - 1)  # Inverse sigmoid
+
+        return super().forward(logits, targets)
diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/loss/kl_distillation.py
new file mode 100644
index 00000000..9aee234a
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/kl_distillation.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""KL Divergence Ensemble Distillation loss"""
+
+import torch
+from torch.nn import Module
+from torch.nn.functional import kl_div
+
+
+class KLDistillationLoss(Module):
+    """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
+
+    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Target smoothing parameter
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(KLDistillationLoss, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            targets (Tensor): Target distribution (ensemble marginal)
+            temp (float): Temperature scaling coefficient for predictive distribution
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        # Assert input sizes
+        assert inputs.dim() == 2                  # Observations, predictive distribution
+        assert targets.dim() == 2                # Label for each observation
+        assert targets.size(0) == inputs.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if targets.size(-1) != inputs.size(-1):
+            name_error = f'Target dimension {targets.size(-1)} is not the same as the prediction dimension '
+            name_error += f'{inputs.size(-1)}.'
+            raise NameError(name_error)
+
+        # Remove observations to be ignored in loss calculation
+        inputs = torch.log(torch.softmax(inputs / temp, -1))
+        ids = torch.where(targets[:, 0] != self.ignore_index)[0]
+        inputs = inputs[ids]
+        targets = targets[ids]
+
+        # Target smoothing
+        targets = ((1 - self.lamb) * targets) + (self.lamb / targets.size(-1))
+
+        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
+
+
+# Pytorch BayesianMatchingLoss nn.Module
+class BinaryKLDistillationLoss(KLDistillationLoss):
+    """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
+
+    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Target smoothing parameter
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index)
+
+    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            targets (Tensor): Target distribution (ensemble marginal)
+            temp (float): Temperature scaling coefficient for predictive distribution
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        # Assert input sizes
+        assert inputs.dim() == 1                 # Observations, predictive distribution
+        assert targets.dim() == 1                # Label for each observation
+        assert targets.size(0) == inputs.size(0)  # Equal number of observation
+        
+        # Convert input and target to 2D binary distribution for KL divergence computation
+        inputs = torch.sigmoid(inputs / temp).unsqueeze(-1)
+        inputs = torch.log(torch.cat((1 - inputs, inputs), 1))
+
+        targets = targets.unsqueeze(-1)
+        targets = torch.cat((1 - targets, targets), -1)
+
+        return super().forward(input, targets, temp)
diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/loss/labelsmoothing.py
index 8fcc60af..61d4b353 100644
--- a/convlab/dst/setsumbt/loss/labelsmoothing.py
+++ b/convlab/dst/setsumbt/loss/labelsmoothing.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inhibited Softmax Activation and Loss Functions"""
+"""Label smoothing loss function"""
 
 
 import torch
@@ -23,66 +23,97 @@ from torch.nn.functional import kl_div
 
 class LabelSmoothingLoss(Module):
     """
-    With label smoothing,
-    KL-divergence between q_{smoothed ground truth prob.}(w)
-    and p_{prob. computed by model}(w) is minimized.
+    Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w)
+    and p_{prob. computed by model}(w).
     """
-    def __init__(self, label_smoothing=0.05, ignore_index=-1):
+
+    def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            label_smoothing (float): Label smoothing constant
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
         super(LabelSmoothingLoss, self).__init__()
 
         assert 0.0 < label_smoothing <= 1.0
         self.ignore_index = ignore_index
         self.label_smoothing = float(label_smoothing)
 
-    def forward(self, logits, targets):
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
         """
-        output (FloatTensor): batch_size x n_classes
-        target (LongTensor): batch_size
+        Args:
+            input (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+
+        Returns:
+            loss (Tensor): Loss value
         """
-        assert logits.dim() == 2
-        assert targets.dim() == 1
-        assert self.label_smoothing <= ((logits.size(-1) - 1) / logits.size(-1))
+        # Assert input sizes
+        assert inputs.dim() == 2
+        assert labels.dim() == 1
+        assert self.label_smoothing <= ((inputs.size(-1) - 1) / inputs.size(-1))
 
-        logits = logits[targets != self.ignore_index]
-        targets = targets[targets != self.ignore_index]
+        # Confirm predictive distribution dimension
+        if labels.max() <= inputs.size(-1):
+            dimension = inputs.size(-1)
+        else:
+            raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.')
 
-        logits = torch.log(torch.softmax(logits, -1))
-        labels = torch.ones(logits.size()).float().to(logits.device)
-        labels *= self.label_smoothing / (logits.size(-1) - 1)
-        labels[range(labels.size(0)), targets] = 1.0 - self.label_smoothing
+        # Remove observations to be ignored in loss calculation
+        inputs = inputs[labels != self.ignore_index]
+        labels = labels[labels != self.ignore_index]
 
-        kl = kl_div(logits, labels, reduction='none').sum(-1).mean()
-        del logits, targets, labels
-        return kl
+        if labels.size(0) == 0.0:
+            return torch.zeros(1).float().to(labels.device).mean()
 
+        # Create target distribution
+        inputs = torch.log(torch.softmax(inputs, -1))
+        targets = torch.ones(inputs.size()).float().to(inputs.device)
+        targets *= self.label_smoothing / (dimension - 1)
+        targets[range(labels.size(0)), labels] = 1.0 - self.label_smoothing
 
-class BinaryLabelSmoothingLoss(Module):
+        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
+
+
+class BinaryLabelSmoothingLoss(LabelSmoothingLoss):
     """
-    With label smoothing,
-    KL-divergence between q_{smoothed ground truth prob.}(w)
-    and p_{prob. computed by model}(w) is minimized.
+    Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w)
+    and p_{prob. computed by model}(w).
     """
-    def __init__(self, label_smoothing=0.05):
-        super(BinaryLabelSmoothingLoss, self).__init__()
 
-        assert 0.0 < label_smoothing <= 1.0
-        self.label_smoothing = float(label_smoothing)
+    def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            label_smoothing (float): Label smoothing constant
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BinaryLabelSmoothingLoss, self).__init__(label_smoothing, ignore_index)
 
-    def forward(self, logits, targets):
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
         """
-        output (FloatTensor): batch_size x n_classes
-        target (LongTensor): batch_size
+        Args:
+            input (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+
+        Returns:
+            loss (Tensor): Loss value
         """
-        assert logits.dim() == 1
-        assert targets.dim() == 1
+        # Assert input sizes
+        assert inputs.dim() == 1
+        assert labels.dim() == 1
         assert self.label_smoothing <= 0.5
 
-        logits = torch.sigmoid(logits).reshape(-1, 1)
-        logits = torch.log(torch.cat((1 - logits, logits), 1))
-        labels = torch.ones(logits.size()).float().to(logits.device)
-        labels *= self.label_smoothing
-        labels[range(labels.size(0)), targets.long()] = 1.0 - self.label_smoothing
+        # Remove observations to be ignored in loss calculation
+        inputs = inputs[labels != self.ignore_index]
+        labels = labels[labels != self.ignore_index]
+
+        if labels.size(0) == 0.0:
+            return torch.zeros(1).float().to(labels.device).mean()
+
+        inputs = torch.sigmoid(inputs).reshape(-1, 1)
+        inputs = torch.log(torch.cat((1 - inputs, inputs), 1))
+        targets = torch.ones(inputs.size()).float().to(inputs.device)
+        targets *= self.label_smoothing
+        targets[range(labels.size(0)), labels.long()] = 1.0 - self.label_smoothing
 
-        kl = kl_div(logits, labels, reduction='none').sum(-1).mean()
-        del logits, targets
-        return kl
+        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
diff --git a/convlab/dst/setsumbt/loss/ece.py b/convlab/dst/setsumbt/loss/uncertainty_measures.py
similarity index 50%
rename from convlab/dst/setsumbt/loss/ece.py
rename to convlab/dst/setsumbt/loss/uncertainty_measures.py
index 034b9aa0..87c89dd3 100644
--- a/convlab/dst/setsumbt/loss/ece.py
+++ b/convlab/dst/setsumbt/loss/uncertainty_measures.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,14 +13,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Expected calibration error"""
+"""Uncertainty evaluation metrics for dialogue belief tracking"""
 
 import torch
 
 
-def fill_bins(n_bins, logits):
-    assert logits.dim() == 2
-    logits = logits.max(-1)[0]
+def fill_bins(n_bins: int, probs: torch.Tensor) -> list:
+    """
+    Function to split observations into bins based on predictive probabilities
+
+    Args:
+        n_bins (int): Number of bins
+        probs (Tensor): Predictive probabilities for the observations
+
+    Returns:
+        bins (list): List of observation ids for each bin
+    """
+    assert probs.dim() == 2
+    probs = probs.max(-1)[0]
 
     step = 1.0 / n_bins
     bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
@@ -28,29 +38,49 @@ def fill_bins(n_bins, logits):
     for b in range(n_bins):
         lower, upper = bin_ranges[b], bin_ranges[b + 1]
         if b == 0:
-            ids = torch.where((logits >= lower) * (logits <= upper))[0]
+            ids = torch.where((probs >= lower) * (probs <= upper))[0]
         else:
-            ids = torch.where((logits > lower) * (logits <= upper))[0]
+            ids = torch.where((probs > lower) * (probs <= upper))[0]
         bins.append(ids)
     return bins
 
 
-def bin_confidence(bins, logits):
-    logits = logits.max(-1)[0]
+def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor:
+    """
+    Compute the confidence score within each bin
+
+    Args:
+        bins (list): List of observation ids for each bin
+        probs (Tensor): Predictive probabilities for the observations
+
+    Returns:
+        scores (Tensor): Average confidence score within each bin
+    """
+    probs = probs.max(-1)[0]
 
     scores = []
     for b in bins:
         if b is not None:
-            l = logits[b]
-            scores.append(l.mean())
+            scores.append(probs[b].mean())
         else:
             scores.append(-1)
     scores = torch.tensor(scores)
     return scores
 
 
-def bin_accuracy(bins, logits, y_true):
-    y_pred = logits.argmax(-1)
+def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
+    """
+    Compute the accuracy score for observations in each bin
+
+    Args:
+        bins (list): List of observation ids for each bin
+        probs (Tensor): Predictive probabilities for the observations
+        y_true (Tensor): Labels for the observations
+
+    Returns:
+        acc (Tensor): Accuracies for the observations in each bin
+    """
+    y_pred = probs.argmax(-1)
 
     acc = []
     for b in bins:
@@ -68,13 +98,24 @@ def bin_accuracy(bins, logits, y_true):
     return acc
 
 
-def ece(logits, y_true, n_bins):
-    bins = fill_bins(n_bins, logits)
+def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float:
+    """
+    Expected calibration error calculation
 
-    scores = bin_confidence(bins, logits)
-    acc = bin_accuracy(bins, logits, y_true)
+    Args:
+        probs (Tensor): Predictive probabilities for the observations
+        y_true (Tensor): Labels for the observations
+        n_bins (int): Number of bins
 
-    n = logits.size(0)
+    Returns:
+        ece (float): Expected calibration error
+    """
+    bins = fill_bins(n_bins, probs)
+
+    scores = bin_confidence(bins, probs)
+    acc = bin_accuracy(bins, probs, y_true)
+
+    n = probs.size(0)
     bk = torch.tensor([b.size(0) for b in bins])
 
     ece = torch.abs(scores - acc) * bk / n
@@ -84,34 +125,30 @@ def ece(logits, y_true, n_bins):
     return ece
 
 
-def jg_ece(logits, y_true, n_bins):
-    y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
+def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
+    """
+        Joint goal expected calibration error calculation
+
+        Args:
+            belief_state (dict): Belief state probabilities for the dialogue turns
+            y_true (dict): Labels for the state in dialogue turns
+            n_bins (int): Number of bins
+
+        Returns:
+            ece (float): Joint goal expected calibration error
+        """
+    y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()}
     goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
     goal_acc = sum([goal_acc[slot] for slot in goal_acc])
     goal_acc = (goal_acc == len(y_true)).int()
 
-    scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
+    # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state
+    scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()]
     scores = torch.cat(scores, 0).min(0)[0]
 
-    step = 1.0 / n_bins
-    bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
-    bins = []
-    for b in range(n_bins):
-        lower, upper = bin_ranges[b], bin_ranges[b + 1]
-        if b == 0:
-            ids = torch.where((scores >= lower) * (scores <= upper))[0]
-        else:
-            ids = torch.where((scores > lower) * (scores <= upper))[0]
-        bins.append(ids)
+    bins = fill_bins(n_bins, scores.unsqueeze(-1))
 
-    conf = []
-    for b in bins:
-        if b is not None:
-            l = scores[b]
-            conf.append(l.mean())
-        else:
-            conf.append(-1)
-    conf = torch.tensor(conf)
+    conf = bin_confidence(bins, scores.unsqueeze(-1))
 
     slot = [s for s in y_true][0]
     acc = []
@@ -127,7 +164,7 @@ def jg_ece(logits, y_true, n_bins):
             acc.append(-1)
     acc = torch.tensor(acc)
 
-    n = logits[slot].reshape(-1, logits[slot].size(-1)).size(0)
+    n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0)
     bk = torch.tensor([b.size(0) for b in bins])
 
     ece = torch.abs(conf - acc) * bk / n
@@ -137,12 +174,22 @@ def jg_ece(logits, y_true, n_bins):
     return ece
 
 
-def l2_acc(belief_state, labels, remove_belief=False):
+def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float:
+    """
+    Compute L2 Error of belief state prediction
+
+    Args:
+        belief_state (dict): Belief state probabilities for the dialogue turns
+        labels (dict): Labels for the state in dialogue turns
+        remove_belief (bool): Convert belief state to dialogue state
+
+    Returns:
+        err (float): L2 Error of belief state prediction
+    """
     # Get ids used for removing padding turns.
     padding = labels[list(labels.keys())[0]].reshape(-1)
     padding = torch.where(padding != -1)[0]
 
-    # l2 = []
     state = []
     labs = []
     for slot, bs in belief_state.items():
@@ -163,13 +210,8 @@ def l2_acc(belief_state, labels, remove_belief=False):
         y = torch.zeros(bs.shape).cuda()
         y[range(y.size(0)), lab] = 1.0
 
-        # err = torch.sqrt(((y - bs) ** 2).sum(-1))
-        # l2.append(err.unsqueeze(-1))
-
         state.append(bs)
         labs.append(y)
-    
-    # err = torch.cat(l2, -1).max(-1)[0]
 
     # Concatenate all slots into a single belief state
     state = torch.cat(state, -1)
diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py
index 011a1a77..59f14399 100644
--- a/convlab/dst/setsumbt/modeling/__init__.py
+++ b/convlab/dst/setsumbt/modeling/__init__.py
@@ -1,3 +1,5 @@
 from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
 from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT
+from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT
+
+from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler
diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py
index 8b402b6b..6762fb38 100644
--- a/convlab/dst/setsumbt/modeling/bert_nbt.py
+++ b/convlab/dst/setsumbt/modeling/bert_nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,11 +16,10 @@
 """BERT SetSUMBT"""
 
 import torch
-import transformers
 from torch.autograd import Variable
 from transformers import BertModel, BertPreTrainedModel
 
-from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward
+from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
 
 
 class BertSetSUMBT(BertPreTrainedModel):
@@ -35,59 +34,37 @@ class BertSetSUMBT(BertPreTrainedModel):
             for p in self.bert.parameters():
                 p.requires_grad = False
 
-        _initialise(self, config)
-
-    # Add new slot candidates to the model
-    def add_slot_candidates(self, slot_candidates):
-        """slot_candidates is a list of tuples for each slot.
-        - The tuples contains the slot embedding, informable value embeddings and a request indicator.
-        - If the informable value embeddings is None the slot is not informable
-        - If the request indicator is false the slot is not requestable"""
-        if self.slot_embeddings.size(0) != 0:
-            embeddings = self.slot_embeddings.detach()
-        else:
-            embeddings = torch.zeros(0)
-
-        for slot in slot_candidates:
-            if slot in self.slot_ids:
-                index = self.slot_ids[slot]
-                embeddings[index, :] = slot_candidates[slot][0]
-            else:
-                index = embeddings.size(0)
-                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
-                embeddings = torch.cat((embeddings, emb), 0)
-                self.slot_ids[slot] = index
-                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
-            # Add slot to relevant requestable and informable slot lists
-            if slot_candidates[slot][2]:
-                self.requestable_slot_ids[slot] = index
-            if slot_candidates[slot][1] is not None:
-                self.informable_slot_ids[slot] = index
-            
-            domain = slot.split('-', 1)[0]
-            if domain not in self.domain_ids:
-                self.domain_ids[domain] = []
-            self.domain_ids[domain].append(index)
-            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
-        
-        self.slot_embeddings = Variable(embeddings, requires_grad=False)
-
-
-    # Add new value candidates to the model
-    def add_value_candidates(self, slot, value_candidates, replace=False):
-        embeddings = getattr(self, slot + '_value_embeddings')
-
-        if embeddings.size(0) == 0 or replace:
-            embeddings = value_candidates
-        else:
-            embeddings = torch.cat((embeddings, value_candidates), 0)
-        
-        setattr(self, slot + '_value_embeddings', embeddings)
-
-    
-    def forward(self, input_ids, token_type_ids, attention_mask, hidden_state=None, inform_labels=None,
-                request_labels=None, domain_labels=None, goodbye_labels=None,
-                get_turn_pooled_representation=False, calculate_inform_mutual_info=False):
+        self.setsumbt = SetSUMBTHead(config)
+        self.add_slot_candidates = self.setsumbt.add_slot_candidates
+        self.add_value_candidates = self.setsumbt.add_value_candidates
+
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                get_turn_pooled_representation: bool = False,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
+            calculate_state_mutual_info: Return mutual information in the dialogue state
+
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
 
         # Encode Dialogues
         batch_size, dialogue_size, turn_size = input_ids.size()
@@ -103,9 +80,10 @@ class BertSetSUMBT(BertPreTrainedModel):
         turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
 
         if get_turn_pooled_representation:
-            return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
-                                dialogue_size, turn_size, hidden_state, inform_labels, request_labels, domain_labels,
-                                goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,)
-        return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size,
-                            turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
-                            calculate_inform_mutual_info)
+            return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
+                                 batch_size, dialogue_size, hidden_state, state_labels,
+                                 request_labels, active_domain_labels, general_act_labels,
+                                 calculate_state_mutual_info) + (bert_output.pooler_output,)
+        return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
+                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
+                             general_act_labels, calculate_state_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/calibration_utils.py b/convlab/dst/setsumbt/modeling/calibration_utils.py
deleted file mode 100644
index 8514ac8d..00000000
--- a/convlab/dst/setsumbt/modeling/calibration_utils.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Discriminative models calibration"""
-
-import random
-
-import torch
-import numpy as np
-from tqdm import tqdm
-
-
-# Load logger and tensorboard summary writer
-def set_logger(logger_, tb_writer_):
-    global logger, tb_writer
-    logger = logger_
-    tb_writer = tb_writer_
-
-
-# Set seeds
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-    logger.info('Seed set to %d.' % args.seed)
-
-
-def get_predictions(args, model, device, dataloader):
-    logger.info("  Num Batches = %d", len(dataloader))
-
-    model.eval()
-    if args.dropout_iterations > 1:
-        model.train()
-    
-    belief_states = {slot: [] for slot in model.informable_slot_ids}
-    request_belief = {slot: [] for slot in model.requestable_slot_ids}
-    domain_belief = {dom: [] for dom in model.domain_ids}
-    greeting_belief = []
-    labels = {slot: [] for slot in model.informable_slot_ids}
-    request_labels = {slot: [] for slot in model.requestable_slot_ids}
-    domain_labels = {dom: [] for dom in model.domain_ids}
-    greeting_labels = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration")
-    for step, batch in enumerate(epoch_iterator):
-        with torch.no_grad():    
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-            if args.dropout_iterations > 1:
-                p = {slot: [] for slot in model.informable_slot_ids}
-                for _ in range(args.dropout_iterations):
-                    p_, p_req_, p_dom_, p_bye_, _ = model(input_ids=input_ids,
-                                                        token_type_ids=token_type_ids,
-                                                        attention_mask=attention_mask)
-                    for slot in model.informable_slot_ids:
-                        p[slot].append(p_[slot].unsqueeze(0))
-                
-                mu = {slot: torch.cat(p[slot], 0).mean(0) for slot in model.informable_slot_ids}
-                sig = {slot: torch.cat(p[slot], 0).var(0) for slot in model.informable_slot_ids}
-                p = {slot: mu[slot] / torch.sqrt(1 + sig[slot]) for slot in model.informable_slot_ids}
-                p = {slot: normalise(p[slot]) for slot in model.informable_slot_ids}
-            else:
-                p, p_req, p_dom, p_bye, _ = model(input_ids=input_ids,
-                                                token_type_ids=token_type_ids,
-                                                attention_mask=attention_mask)
-            
-            for slot in model.informable_slot_ids:
-                p_ = p[slot]
-                labs = batch['labels-' + slot].to(device)
-                
-                belief_states[slot].append(p_)
-                labels[slot].append(labs)
-            
-            if p_req is not None:
-                for slot in model.requestable_slot_ids:
-                    p_ = p_req[slot]
-                    labs = batch['request-' + slot].to(device)
-
-                    request_belief[slot].append(p_)
-                    request_labels[slot].append(labs)
-                
-                for domain in model.domain_ids:
-                    p_ = p_dom[domain]
-                    labs = batch['active-' + domain].to(device)
-
-                    domain_belief[domain].append(p_)
-                    domain_labels[domain].append(labs)
-                
-                greeting_belief.append(p_bye)
-                greeting_labels.append(batch['goodbye'].to(device))
-    
-    for slot in belief_states:
-        belief_states[slot] = torch.cat(belief_states[slot], 0)
-        labels[slot] = torch.cat(labels[slot], 0)
-    if p_req is not None:
-        for slot in request_belief:
-            request_belief[slot] = torch.cat(request_belief[slot], 0)
-            request_labels[slot] = torch.cat(request_labels[slot], 0)
-        for domain in domain_belief:
-            domain_belief[domain] = torch.cat(domain_belief[domain], 0)
-            domain_labels[domain] = torch.cat(domain_labels[domain], 0)
-        greeting_belief = torch.cat(greeting_belief, 0)
-        greeting_labels = torch.cat(greeting_labels, 0)
-    else:
-        request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6
-
-    return belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels
-
-
-def normalise(p):
-    p_shape = p.size()
-
-    p = p.reshape(-1, p_shape[-1]) + 1e-10
-    p_sum = p.sum(-1).unsqueeze(1).repeat((1, p_shape[-1]))
-    p /= p_sum
-
-    p = p.reshape(p_shape)
-
-    return p
diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
index 9f101d12..6d3d8035 100644
--- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py
+++ b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
@@ -16,9 +16,9 @@
 """Ensemble SetSUMBT"""
 
 import os
+from shutil import copy2 as copy
 
 import torch
-import transformers
 from torch.nn import Module
 from transformers import RobertaConfig, BertConfig
 
@@ -29,8 +29,13 @@ MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
 
 
 class EnsembleSetSUMBT(Module):
+    """Ensemble SetSUMBT Model for joint ensemble prediction"""
 
     def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
         super(EnsembleSetSUMBT, self).__init__()
         self.config = config
 
@@ -38,175 +43,138 @@ class EnsembleSetSUMBT(Module):
         model_cls = MODELS[self.config.model_type]
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             setattr(self, attr, model_cls(config))
-    
 
-    # Load all ensemble memeber parameters
-    def load(self, path, config=None):
-        if config is None:
-            config = self.config
-        
+    def _load(self, path: str):
+        """
+        Load parameters
+        Args:
+            path: Location of model parameters
+        """
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             idx = attr.split('_', 1)[-1]
-            state_dict = torch.load(os.path.join(path, f'pytorch_model_{idx}.bin'))
+            state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin'))
             getattr(self, attr).load_state_dict(state_dict)
-    
 
-    # Add new slot candidates to the ensemble members
-    def add_slot_candidates(self, slot_candidates):
+    def add_slot_candidates(self, slot_candidates: tuple):
+        """
+        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
+        and a request indicator, if the informable value embeddings is None the slot is not informable and if
+        the request indicator is false the slot is not requestable.
+
+        Args:
+            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
+        """
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             getattr(self, attr).add_slot_candidates(slot_candidates)
-        self.requestable_slot_ids = self.model_0.requestable_slot_ids
-        self.informable_slot_ids = self.model_0.informable_slot_ids
-        self.domain_ids = self.model_0.domain_ids
-
-
-    # Add new value candidates to the ensemble members
-    def add_value_candidates(self, slot, value_candidates, replace=False):
+        self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids
+        self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids
+        self.domain_ids = self.model_0.setsumbt.domain_ids
+
+    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
+        """
+        Add value candidates for a slot
+
+        Args:
+            slot: Slot name
+            value_candidates: Value candidate embeddings
+            replace: If true existing value candidates are replaced
+        """
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             getattr(self, attr).add_value_candidates(slot, value_candidates, replace)
-        
 
-    # Forward pass of full ensemble
-    def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'):
-        logits, request_logits, domain_logits, goodbye_scores = [], [], [], []
-        logits = {slot: [] for slot in self.model_0.informable_slot_ids}
-        request_logits = {slot: [] for slot in self.model_0.requestable_slot_ids}
-        domain_logits = {dom: [] for dom in self.model_0.domain_ids}
-        goodbye_scores = []
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                reduction: str = 'mean') -> tuple:
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            reduction: Reduction of ensemble member predictive distributions (mean, none)
+
+        Returns:
+
+        """
+        belief_state_probs = {slot: [] for slot in self.informable_slot_ids}
+        request_probs = {slot: [] for slot in self.requestable_slot_ids}
+        active_domain_probs = {dom: [] for dom in self.domain_ids}
+        general_act_probs = []
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             # Prediction from each ensemble member
-            l, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
+            b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
                                                 token_type_ids=token_type_ids,
                                                 attention_mask=attention_mask)
-            for slot in logits:
-                logits[slot].append(l[slot].unsqueeze(-2))
-            if self.config.predict_intents:
-                for slot in request_logits:
-                    request_logits[slot].append(r[slot].unsqueeze(-1))
-                for dom in domain_logits:
-                    domain_logits[dom].append(d[dom].unsqueeze(-1))
-                goodbye_scores.append(g.unsqueeze(-2))
+            for slot in belief_state_probs:
+                belief_state_probs[slot].append(b[slot].unsqueeze(-2))
+            if self.config.predict_actions:
+                for slot in request_probs:
+                    request_probs[slot].append(r[slot].unsqueeze(-1))
+                for dom in active_domain_probs:
+                    active_domain_probs[dom].append(d[dom].unsqueeze(-1))
+                general_act_probs.append(g.unsqueeze(-2))
         
-        logits = {slot: torch.cat(l, -2) for slot, l in logits.items()}
-        if self.config.predict_intents:
-            request_logits = {slot: torch.cat(l, -1) for slot, l in request_logits.items()}
-            domain_logits = {dom: torch.cat(l, -1) for dom, l in domain_logits.items()}
-            goodbye_scores = torch.cat(goodbye_scores, -2)
+        belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
+        if self.config.predict_actions:
+            request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()}
+            active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()}
+            general_act_probs = torch.cat(general_act_probs, -2)
         else:
-            request_logits = {}
-            domain_logits = {}
-            goodbye_scores = torch.tensor(0.0)
+            request_probs = {}
+            active_domain_probs = {}
+            general_act_probs = torch.tensor(0.0)
 
         # Apply reduction of ensemble to single posterior
         if reduction == 'mean':
-            logits = {slot: l.mean(-2) for slot, l in logits.items()}
-            request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()}
-            domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()}
-            goodbye_scores = goodbye_scores.mean(-2)
+            belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()}
+            request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()}
+            active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()}
+            general_act_probs = general_act_probs.mean(-2)
         elif reduction != 'none':
             raise(NameError('Not Implemented!'))
 
-        return logits, request_logits, domain_logits, goodbye_scores, _
+        return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _
     
 
     @classmethod
     def from_pretrained(cls, path):
-        if not os.path.exists(os.path.join(path, 'config.json')):
+        config_path = os.path.join(path, 'ens-0', 'config.json')
+        if not os.path.exists(config_path):
             raise(NameError('Could not find config.json in model path.'))
-        if not os.path.exists(os.path.join(path, 'pytorch_model_0.bin')):
-            raise(NameError('Could not find a model binary in the model path.'))
         
         try:
-            config = RobertaConfig.from_pretrained(path)
+            config = RobertaConfig.from_pretrained(config_path)
         except:
-            config = BertConfig.from_pretrained(path)
+            config = BertConfig.from_pretrained(config_path)
+
+        config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir])
         
         model = cls(config)
-        model.load(path)
+        model._load(path)
 
         return model
 
 
-class DropoutEnsembleSetSUMBT(Module):
-
-    def __init__(self, config):
-        super(DropoutEnsembleBeliefTracker, self).__init__()
-        self.config = config
-
-        model_cls = MODELS[self.config.model_type]
-        self.model = model_cls(config)
-        self.model.train()
-    
-
-    def load(self, path, config=None):
-        if config is None:
-            config = self.config
-        state_dict = torch.load(os.path.join(path, f'pytorch_model.bin'))
-        self.model.load_state_dict(state_dict)
-    
-
-    # Add new slot candidates to the model
-    def add_slot_candidates(self, slot_candidates):
-        self.model.add_slot_candidates(slot_candidates)
-        self.requestable_slot_ids = self.model.requestable_slot_ids
-        self.informable_slot_ids = self.model.informable_slot_ids
-        self.domain_ids = self.model.domain_ids
-
-
-    # Add new value candidates to the model
-    def add_value_candidates(self, slot, value_candidates, replace=False):
-        self.model.add_value_candidates(slot, value_candidates, replace)
-        
-    
-    def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'):
-
-        input_ids = input_ids.unsqueeze(0).repeat((self.config.ensemble_size, 1, 1, 1))
-        input_ids = input_ids.reshape(-1, input_ids.size(-2), input_ids.size(-1))
-        if attention_mask is not None:
-            attention_mask = attention_mask.unsqueeze(0).repeat((10, 1, 1, 1))
-            attention_mask = attention_mask.reshape(-1, attention_mask.size(-2), attention_mask.size(-1))
-        if token_type_ids is not None:
-            token_type_ids = token_type_ids.unsqueeze(0).repeat((10, 1, 1, 1))
-            token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-2), token_type_ids.size(-1))
-        
-        self.model.train()
-        logits, request_logits, domain_logits, goodbye_scores, _ = self.model(input_ids=input_ids,
-                                                                            attention_mask=attention_mask,
-                                                                            token_type_ids=token_type_ids)
-        
-        logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-2), l.size(-1)).transpose(0, 1).transpose(1, 2)
-                for s, l in logits.items()}
-        request_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2)
-                        for s, l in request_logits.items()}
-        domain_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2)
-                        for s, l in domain_logits.items()}
-        goodbye_scores = goodbye_scores.reshape(self.config.ensemble_size, -1, goodbye_scores.size(-2), goodbye_scores.size(-1))
-        goodbye_scores = goodbye_scores.transpose(0, 1).transpose(1, 2)
-
-        if reduction == 'mean':
-            logits = {slot: l.mean(-2) for slot, l in logits.items()}
-            request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()}
-            domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()}
-            goodbye_scores = goodbye_scores.mean(-2)
-        elif reduction != 'none':
-            raise(NameError('Not Implemented!'))
-
-        return logits, request_logits, domain_logits, goodbye_scores, _
-    
-
-    @classmethod
-    def from_pretrained(cls, path):
-        if not os.path.exists(os.path.join(path, 'config.json')):
-            raise(NameError('Could not find config.json in model path.'))
-        if not os.path.exists(os.path.join(path, 'pytorch_model.bin')):
-            raise(NameError('Could not find a model binary in the model path.'))
-        
-        try:
-            config = RobertaConfig.from_pretrained(path)
-        except:
-            config = BertConfig.from_pretrained(path)
-        
-        model = cls(config)
-        model.load(path)
-
-        return model
+def setup_ensemble(model_path: str, ensemble_size: int):
+    """
+    Setup ensemble model directory structure.
+
+    Args:
+        model_path: Path to ensemble model directory
+        ensemble_size: Number of ensemble members
+    """
+    for i in range(ensemble_size):
+        path = os.path.join(model_path, f'ens-{i}')
+        if not os.path.exists(path):
+            os.mkdir(path)
+            os.mkdir(os.path.join(path, 'dataloaders'))
+            os.mkdir(os.path.join(path, 'database'))
+            # Add development set dataloader to each ensemble member directory
+            for set_type in ['dev']:
+                copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'),
+                     os.path.join(path, 'dataloaders', f'{set_type}.dataloader'))
+            # Add training and development set ontologies to each ensemble member directory
+            for set_type in ['train', 'dev']:
+                copy(os.path.join(model_path, 'database', f'{set_type}.db'),
+                     os.path.join(path, 'database', f'{set_type}.db'))
diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py
new file mode 100644
index 00000000..c73d4b6d
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/evaluation_utils.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Evaluation Utilities"""
+
+import random
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+
+def set_seed(args):
+    """
+    Set random seeds
+
+    Args:
+        args (Arguments class): Arguments class containing seed and number of gpus to use
+    """
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+
+def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple:
+    """
+    Get model predictions
+
+    Args:
+        args: Runtime arguments
+        model: SetSUMBT Model
+        device: Torch device
+        dataloader: Dataloader containing eval data
+    """
+    model.eval()
+    
+    belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids}
+    request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
+    active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids}
+    general_act_probs = []
+    state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids}
+    request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
+    active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids}
+    general_act_labels = []
+    epoch_iterator = tqdm(dataloader, desc="Iteration")
+    for step, batch in enumerate(epoch_iterator):
+        with torch.no_grad():    
+            input_ids = batch['input_ids'].to(device)
+            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+
+            p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids,
+                                              attention_mask=attention_mask)
+
+            for slot in belief_states:
+                p_ = p[slot]
+                labs = batch['state_labels-' + slot].to(device)
+                
+                belief_states[slot].append(p_)
+                state_labels[slot].append(labs)
+            
+            if p_req is not None:
+                for slot in request_probs:
+                    p_ = p_req[slot]
+                    labs = batch['request_labels-' + slot].to(device)
+
+                    request_probs[slot].append(p_)
+                    request_labels[slot].append(labs)
+                
+                for domain in active_domain_probs:
+                    p_ = p_dom[domain]
+                    labs = batch['active_domain_labels-' + domain].to(device)
+
+                    active_domain_probs[domain].append(p_)
+                    active_domain_labels[domain].append(labs)
+                
+                general_act_probs.append(p_gen)
+                general_act_labels.append(batch['general_act_labels'].to(device))
+    
+    for slot in belief_states:
+        belief_states[slot] = torch.cat(belief_states[slot], 0)
+        state_labels[slot] = torch.cat(state_labels[slot], 0)
+    if p_req is not None:
+        for slot in request_probs:
+            request_probs[slot] = torch.cat(request_probs[slot], 0)
+            request_labels[slot] = torch.cat(request_labels[slot], 0)
+        for domain in active_domain_probs:
+            active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0)
+            active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0)
+        general_act_probs = torch.cat(general_act_probs, 0)
+        general_act_labels = torch.cat(general_act_labels, 0)
+    else:
+        request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4
+        general_act_probs, general_act_labels = [None] * 2
+
+    out = (belief_states, state_labels, request_probs, request_labels)
+    out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels)
+    return out
diff --git a/convlab/dst/setsumbt/modeling/functional.py b/convlab/dst/setsumbt/modeling/functional.py
deleted file mode 100644
index 0dd083d0..00000000
--- a/convlab/dst/setsumbt/modeling/functional.py
+++ /dev/null
@@ -1,456 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""SetSUMBT functionals"""
-
-import torch
-import transformers
-from torch.autograd import Variable
-from torch.nn import (MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
-                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
-                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
-from torch.nn.init import (xavier_normal_, constant_)
-
-from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss, BinaryBayesianMatchingLoss, dirichlet
-from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
-from convlab.dst.setsumbt.loss.distillation import DistillationKL, BinaryDistillationKL
-from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info
-
-
-# Default belief tracker model intialisation function
-def _initialise(self, config):
-    # Slot Utterance matching attention
-    self.slot_attention = MultiheadAttention(
-        config.hidden_size, config.slot_attention_heads)
-
-    # Latent context tracker
-    # Initial state prediction
-    if not config.rnn_zero_init and config.nbt_type in ['gru', 'lstm']:
-        self.belief_init = Sequential(Linear(config.hidden_size, config.nbt_hidden_size),
-                                      ReLU(), Dropout(config.dropout_rate))
-
-    # Recurrent context tracker setup
-    if config.nbt_type == 'gru':
-        self.nbt = GRU(input_size=config.hidden_size,
-                       hidden_size=config.nbt_hidden_size,
-                       num_layers=config.nbt_layers,
-                       dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate,
-                       batch_first=True)
-        # Initialise Parameters
-        xavier_normal_(self.nbt.weight_ih_l0)
-        xavier_normal_(self.nbt.weight_hh_l0)
-        constant_(self.nbt.bias_ih_l0, 0.0)
-        constant_(self.nbt.bias_hh_l0, 0.0)
-    elif config.nbt_type == 'lstm':
-        self.nbt = LSTM(input_size=config.hidden_size,
-                        hidden_size=config.nbt_hidden_size,
-                        num_layers=config.nbt_layers,
-                        dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate,
-                        batch_first=True)
-        # Initialise Parameters
-        xavier_normal_(self.nbt.weight_ih_l0)
-        xavier_normal_(self.nbt.weight_hh_l0)
-        constant_(self.nbt.bias_ih_l0, 0.0)
-        constant_(self.nbt.bias_hh_l0, 0.0)
-    else:
-        raise NameError('Not Implemented')
-
-    # Feature decoder and layer norm
-    self.intermediate = Linear(config.nbt_hidden_size, config.hidden_size)
-    self.layer_norm = LayerNorm(config.hidden_size)
-
-    # Dropout
-    self.dropout = Dropout(config.dropout_rate)
-
-    # Set pooler for set similarity model
-    if self.config.set_similarity:
-        # 1D convolutional set pooler
-        if self.config.set_pooling == 'cnn':
-            self.conv_pooler = Conv1d(
-                self.config.hidden_size, self.config.hidden_size, 3)
-        # Deep averaging network set pooler
-        elif self.config.set_pooling == 'dan':
-            self.avg_net = Sequential(Linear(self.config.hidden_size, 2 * self.config.hidden_size), GELU(),
-                                      Linear(2 * self.config.hidden_size, self.config.hidden_size))
-
-    # Model ontology placeholders
-    self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
-    self.slot_ids = dict()
-    self.requestable_slot_ids = dict()
-    self.informable_slot_ids = dict()
-    self.domain_ids = dict()
-
-    # Matching network similarity measure
-    if config.distance_measure == 'cosine':
-        self.distance = CosineSimilarity(dim=-1, eps=1e-8)
-    elif config.distance_measure == 'euclidean':
-        self.distance = PairwiseDistance(p=2.0, eps=1e-06, keepdim=False)
-    else:
-        raise NameError('NotImplemented')
-
-    # Belief state loss function
-    if config.loss_function == 'crossentropy':
-        self.loss = CrossEntropyLoss(ignore_index=-1)
-    elif config.loss_function == 'bayesianmatching':
-        self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-    elif config.loss_function == 'labelsmoothing':
-        self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-    elif config.loss_function == 'distillation':
-        self.loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-        self.temp = 1.0
-    elif config.loss_function == 'distribution_distillation':
-        self.loss = rkl_dirichlet_mediator_loss
-        self.temp = 1.0
-    else:
-        raise NameError('NotImplemented')
-
-    # Intent and domain prediction heads
-    if config.predict_actions:
-        self.request_gate = Linear(config.hidden_size, 1)
-        self.goodbye_gate = Linear(config.hidden_size, 3)
-        self.domain_gate = Linear(config.hidden_size, 1)
-
-        # Intent and domain loss function
-        self.request_weight = float(self.config.user_request_loss_weight)
-        self.goodbye_weight = float(self.config.user_general_act_loss_weight)
-        self.domain_weight = float(self.config.active_domain_loss_weight)
-        if config.loss_function == 'crossentropy':
-            self.request_loss = BCEWithLogitsLoss()
-            self.goodbye_loss = CrossEntropyLoss(ignore_index=-1)
-            self.domain_loss = BCEWithLogitsLoss()
-        elif config.loss_function == 'labelsmoothing':
-            self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-            self.goodbye_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-            self.domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-        elif config.loss_function == 'bayesianmatching':
-            self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-            self.goodbye_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-            self.domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-        elif config.loss_function == 'distillation':
-            self.request_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-            self.goodbye_loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-            self.domain_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-
-
-# Default belief tracker forward pass.
-def _nbt_forward(self, turn_embeddings,
-                 turn_pooled_representation,
-                 attention_mask,
-                 batch_size,
-                 dialogue_size,
-                 turn_size,
-                 hidden_state,
-                 inform_labels,
-                 request_labels,
-                 domain_labels,
-                 goodbye_labels,
-                 calculate_inform_mutual_info):
-    hidden_size = turn_embeddings.size(-1)
-    # Initialise loss
-    loss = 0.0
-
-    # Goodbye predictions
-    goodbye_probs = None
-    if self.config.predict_actions:
-        # General action prediction
-        goodbye_scores = self.goodbye_gate(
-            turn_pooled_representation.reshape(batch_size * dialogue_size, hidden_size))
-
-        # Compute loss for general action predictions (weighted loss)
-        if goodbye_labels is not None:
-            if self.config.loss_function == 'distillation':
-                goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-1))
-                loss += self.goodbye_loss(goodbye_scores, goodbye_labels, self.temp) * self.goodbye_weight
-            elif self.config.loss_function == 'distribution_distillation':
-                goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-2), goodbye_labels.size(-1))
-                loss += self.loss(goodbye_scores, goodbye_labels, 1.0, 1.0)[0] * self.goodbye_weight
-            else:
-                goodbye_labels = goodbye_labels.reshape(-1)
-                loss += self.goodbye_loss(goodbye_scores, goodbye_labels) * self.request_weight
-
-        # Compute general action probabilities
-        if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'distillation', 'distribution_distillation']:
-            goodbye_probs = torch.softmax(goodbye_scores, -1).reshape(batch_size, dialogue_size, -1)
-        elif self.config.loss_function in ['bayesianmatching']:
-            goodbye_probs = dirichlet(goodbye_scores.reshape(batch_size, dialogue_size, -1))
-
-    # Slot utterance matching
-    num_slots = self.slot_embeddings.size(0)
-    slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
-    slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)).to(turn_embeddings.device)
-
-    if self.config.set_similarity:
-        # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
-        slot_mask = (slot_embeddings != 0.0).float()
-
-    # Turn embeddings shape [turn_size, batch_size * dialogue_size, 768]
-    turn_embeddings = turn_embeddings.transpose(0, 1)
-    # Compute key padding mask
-    key_padding_mask = (attention_mask[:, :, 0] == 0.0)
-    key_padding_mask[key_padding_mask[:, 0] == True, :] = False
-    # Multi head attention of slot over tokens
-    hidden, _ = self.slot_attention(query=slot_embeddings,
-                                    key=turn_embeddings,
-                                    value=turn_embeddings,
-                                    key_padding_mask=key_padding_mask)  # [num_slots, batch_size * dialogue_size, 768]
-
-    # Set embeddings for all masked tokens to 0
-    attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
-    hidden = hidden * attention_mask
-    if self.config.set_similarity:
-        hidden = hidden * slot_mask
-    # Hidden layer shape [num_dials, num_slots, num_turns, 768]
-    hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
-
-    # Latent context tracking
-    # [batch_size * num_slots, dialogue_size, 768]
-    hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
-
-    if self.config.nbt_type == 'gru':
-        self.nbt.flatten_parameters()
-        if hidden_state is None:
-            if self.config.rnn_zero_init:
-                context = torch.zeros(self.config.nbt_layers, batch_size * slot_embeddings.size(0),
-                                      self.config.nbt_hidden_size)
-                context = context.to(turn_embeddings.device)
-            else:
-                context = self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1))
-        else:
-            context = hidden_state.to(hidden.device)
-
-        # [batch_size, dialogue_size, nbt_hidden_size]
-        belief_embedding, context = self.nbt(hidden, context)
-    elif self.config.nbt_type == 'lstm':
-        self.nbt.flatten_parameters()
-        if self.config.rnn_zero_init:
-            context = (torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size),
-                       torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size))
-            context = (context[0].to(turn_embeddings.device),
-                       context[1].to(turn_embeddings.device))
-        else:
-            context = (self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)),
-                       torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size))
-            context = (context[0], context[1].to(turn_embeddings.device))
-
-        # [batch_size, dialogue_size, nbt_hidden_size]
-        belief_embedding, context = self.nbt(hidden, context)
-
-    # Decode features
-    belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), dialogue_size, -1).transpose(1, 2)
-    if self.config.set_similarity:
-        belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
-                                                    self.config.nbt_hidden_size)
-    # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
-    # Normalisation and regularisation
-    belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
-    belief_embedding = self.dropout(belief_embedding)
-
-    # Pooling of the set of latent context representation
-    if self.config.set_similarity:
-        slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
-        belief_embedding = belief_embedding * slot_mask
-
-        # Apply pooler to latent context sequence
-        if self.config.set_pooling == 'mean':
-            belief_embedding = belief_embedding.sum(-2) / slot_mask.sum(-2)
-            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
-        elif self.config.set_pooling == 'cnn':
-            belief_embedding = belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size).transpose(1, 2)
-            belief_embedding = self.conv_pooler(belief_embedding)
-            # Mean pooling after CNN
-            belief_embedding = belief_embedding.mean(-1).reshape(batch_size, dialogue_size, num_slots, -1)
-        elif self.config.set_pooling == 'dan':
-            # sqrt N reduction
-            belief_embedding = belief_embedding.sum(-2) / torch.sqrt(torch.tensor(slot_mask.sum(-2)))
-            # Deep averaging feature extractor
-            belief_embedding = self.avg_net(belief_embedding)
-            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
-
-    # Perform classification
-    if self.config.predict_actions:
-        # User request prediction
-        request_probs = dict()
-        for slot, slot_id in self.requestable_slot_ids.items():
-            request_scores = self.request_gate(belief_embedding[:, :, slot_id, :])
-
-            # Store output probabilities
-            request_scores = request_scores.reshape(batch_size, dialogue_size)
-            mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
-            batches, dialogues = torch.where(mask == 0.0)
-            # Set request scores to 0.0 for padded turns
-            request_scores[batches, dialogues] = 0.0
-            if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching',
-                                             'distillation', 'distribution_distillation']:
-                request_probs[slot] = torch.sigmoid(request_scores)
-
-            if request_labels is not None:
-                # Compute request gate loss
-                request_scores = request_scores.reshape(-1)
-                if self.config.loss_function == 'distillation':
-                    loss += self.request_loss(request_scores, request_labels[slot].reshape(-1),
-                                              self.temp) * self.request_weight
-                elif self.config.loss_function == 'distribution_distillation':
-                    scores, labs = convert_probs_to_logits(request_scores, request_labels[slot])
-                    loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight
-                else:
-                    labs = request_labels[slot].reshape(-1)
-                    request_scores = request_scores[labs != -1]
-                    labs = labs[labs != -1].float()
-                    loss += self.request_loss(request_scores, labs) * self.request_weight
-
-        # Active domain prediction
-        domain_probs = dict()
-        for domain, slot_ids in self.domain_ids.items():
-            belief = belief_embedding[:, :, slot_ids, :]
-            if len(slot_ids) > 1:
-                # SqrtN reduction across all slots within a domain
-                belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
-            domain_scores = self.domain_gate(belief)
-
-            # Store output probabilities
-            domain_scores = domain_scores.reshape(batch_size, dialogue_size)
-            mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
-            batches, dialogues = torch.where(mask == 0.0)
-            domain_scores[batches, dialogues] = 0.0
-            if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', 'distillation',
-                                             'distribution_distillation']:
-                domain_probs[domain] = torch.sigmoid(domain_scores)
-
-            if domain_labels is not None:
-                # Compute domain prediction loss
-                domain_scores = domain_scores.reshape(-1)
-                if self.config.loss_function == 'distillation':
-                    loss += self.domain_loss(domain_scores, domain_labels[domain].reshape(-1),
-                                             self.temp) * self.domain_weight
-                elif self.config.loss_function == 'distribution_distillation':
-                    scores, labs = convert_probs_to_logits(domain_scores, domain_labels[domain])
-                    loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight
-                else:
-                    labs = domain_labels[domain].reshape(-1)
-                    domain_scores = domain_scores[labs != -1]
-                    labs = labs[labs != -1].float()
-                    loss += self.domain_loss(domain_scores, labs) * self.domain_weight
-    else:
-        request_probs, domain_probs = None, None
-
-    # Informable slot predictions
-    inform_probs = dict()
-    out_dict = dict()
-    mutual_info = dict()
-    stats = dict()
-    for slot, slot_id in self.informable_slot_ids.items():
-        # Get slot belief embedding and value candidates
-        candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
-        belief = belief_embedding[:, :, slot_id, :]
-        slot_size = candidate_embeddings.size(0)
-
-        # Use similaroty matching to produce belief state
-        if self.config.distance_measure in ['cosine', 'euclidean']:
-            belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
-            belief = belief.reshape(-1, self.config.hidden_size)
-
-            # Pooling of set of value candidate description representation
-            if self.config.set_similarity and self.config.set_pooling == 'mean':
-                candidate_mask = (candidate_embeddings != 0.0).float()
-                candidate_embeddings = candidate_embeddings.sum(1) / candidate_mask.sum(1)
-            elif self.config.set_similarity and self.config.set_pooling == 'cnn':
-                candidate_embeddings = candidate_embeddings.transpose(1, 2)
-                candidate_embeddings = self.conv_pooler(candidate_embeddings).mean(-1)
-            elif self.config.set_similarity and self.config.set_pooling == 'dan':
-                candidate_mask = (candidate_embeddings != 0.0).float()
-                candidate_embeddings = candidate_embeddings.sum(1) / torch.sqrt(torch.tensor(candidate_mask.sum(1)))
-                candidate_embeddings = self.avg_net(candidate_embeddings)
-
-            candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
-                                                                                          dialogue_size, 1, 1))
-            candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
-
-        # Score value candidates
-        if self.config.distance_measure == 'cosine':
-            scores = self.distance(belief, candidate_embeddings)
-            # *27 here rescales the cosine similarity for better learning
-            scores = scores.reshape(batch_size * dialogue_size, -1) * 27.0
-        elif self.config.distance_measure == 'euclidean':
-            scores = -1.0 * self.distance(belief, candidate_embeddings)
-            scores = scores.reshape(batch_size * dialogue_size, -1)
-
-        # Calculate belief state
-        if self.config.loss_function in ['crossentropy', 'inhibitedce',
-                                         'labelsmoothing', 'distillation', 'distribution_distillation']:
-            probs_ = torch.softmax(scores.reshape(batch_size, dialogue_size, -1), -1)
-        elif self.config.loss_function in ['bayesianmatching']:
-            probs_ = dirichlet(scores.reshape(batch_size, dialogue_size, -1))
-
-        # Compute knowledge uncertainty in the beleif states
-        if calculate_inform_mutual_info and self.config.loss_function == 'distribution_distillation':
-            mutual_info[slot] = logits_to_mutual_info(scores).reshape(batch_size, dialogue_size)
-
-        # Set padded turn probabilities to zero
-        mask = attention_mask[self.slot_ids[slot],:, 0].reshape(batch_size, dialogue_size)
-        batches, dialogues = torch.where(mask == 0.0)
-        probs_[batches, dialogues, :] = 0.0
-        inform_probs[slot] = probs_
-
-        # Calculate belief state loss
-        if inform_labels is not None and slot in inform_labels:
-            if self.config.loss_function == 'bayesianmatching':
-                prior = torch.ones(scores.size(-1)).float().to(scores.device)
-                prior = prior * self.config.prior_constant
-                prior = prior.unsqueeze(0).repeat((scores.size(0), 1))
-
-                loss += self.loss(scores, inform_labels[slot].reshape(-1), prior=prior)
-            elif self.config.loss_function == 'distillation':
-                labels = inform_labels[slot]
-                labels = labels.reshape(-1, labels.size(-1))
-                loss += self.loss(scores, labels, self.temp)
-            elif self.config.loss_function == 'distribution_distillation':
-                labels = inform_labels[slot]
-                labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
-                loss_, model_stats, ensemble_stats = self.loss(scores, labels, 1.0, 1.0)
-                loss += loss_
-
-                # Calculate stats regarding model precisions
-                precision = model_stats['precision']
-                ensemble_precision = ensemble_stats['precision']
-                stats[slot] = {'model_precision_min': precision.min(),
-                               'model_precision_max': precision.max(),
-                               'model_precision_mean': precision.mean(),
-                               'ensemble_precision_min': ensemble_precision.min(),
-                               'ensemble_precision_max': ensemble_precision.max(),
-                               'ensemble_precision_mean': ensemble_precision.mean()}
-            else:
-                loss += self.loss(scores, inform_labels[slot].reshape(-1))
-
-    # Return model outputs
-    out = inform_probs, request_probs, domain_probs, goodbye_probs, context
-    if inform_labels is not None or request_labels is not None or domain_labels is not None or goodbye_labels is not None:
-        out = (loss,) + out + (stats,)
-    if calculate_inform_mutual_info:
-        out = out + (mutual_info,)
-    return out
-
-
-# Convert binary scores and labels to 2 class classification problem for distribution distillation
-def convert_probs_to_logits(scores, labels):
-    # Convert single target probability p to distribution [1-p, p]
-    labels = labels.reshape(-1, labels.size(-1), 1)
-    labels = torch.cat([1 - labels, labels], -1)
-
-    # Convert input scores into predictive distribution [1-z, z]
-    scores = torch.sigmoid(scores).unsqueeze(1)
-    scores = torch.cat((1 - scores, scores), 1)
-    scores = -1.0 * torch.log((1 / (scores + 1e-8)) - 1)  # Inverse sigmoid
-
-    return scores, labels
diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py
index 36920c5c..f72d17fa 100644
--- a/convlab/dst/setsumbt/modeling/roberta_nbt.py
+++ b/convlab/dst/setsumbt/modeling/roberta_nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,16 +16,19 @@
 """RoBERTa SetSUMBT"""
 
 import torch
-import transformers
-from torch.autograd import Variable
 from transformers import RobertaModel, RobertaPreTrainedModel
 
-from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward
+from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
 
 
 class RobertaSetSUMBT(RobertaPreTrainedModel):
+    """Roberta based SetSUMBT model"""
 
     def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
         super(RobertaSetSUMBT, self).__init__(config)
         self.config = config
 
@@ -35,60 +38,37 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
             for p in self.roberta.parameters():
                 p.requires_grad = False
 
-        _initialise(self, config)
+        self.setsumbt = SetSUMBTHead(config)
+        self.add_slot_candidates = self.setsumbt.add_slot_candidates
+        self.add_value_candidates = self.setsumbt.add_value_candidates
     
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                get_turn_pooled_representation: bool = False,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
+            calculate_state_mutual_info: Return mutual information in the dialogue state
 
-    # Add new slot candidates to the model
-    def add_slot_candidates(self, slot_candidates):
-        """slot_candidates is a list of tuples for each slot.
-        - The tuples contains the slot embedding, informable value embeddings and a request indicator.
-        - If the informable value embeddings is None the slot is not informable
-        - If the request indicator is false the slot is not requestable"""
-        if self.slot_embeddings.size(0) != 0:
-            embeddings = self.slot_embeddings.detach()
-        else:
-            embeddings = torch.zeros(0)
-
-        for slot in slot_candidates:
-            if slot in self.slot_ids:
-                index = self.slot_ids[slot]
-                embeddings[index, :] = slot_candidates[slot][0]
-            else:
-                index = embeddings.size(0)
-                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
-                embeddings = torch.cat((embeddings, emb), 0)
-                self.slot_ids[slot] = index
-                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
-            # Add slot to relevant requestable and informable slot lists
-            if slot_candidates[slot][2]:
-                self.requestable_slot_ids[slot] = index
-            if slot_candidates[slot][1] is not None:
-                self.informable_slot_ids[slot] = index
-            
-            domain = slot.split('-', 1)[0]
-            if domain not in self.domain_ids:
-                self.domain_ids[domain] = []
-            self.domain_ids[domain].append(index)
-            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
-        
-        self.slot_embeddings = Variable(embeddings, requires_grad=False)
-
-
-    # Add new value candidates to the model
-    def add_value_candidates(self, slot, value_candidates, replace=False):
-        embeddings = getattr(self, slot + '_value_embeddings')
-
-        if embeddings.size(0) == 0 or replace:
-            embeddings = value_candidates
-        else:
-            embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
-        
-        setattr(self, slot + '_value_embeddings', embeddings)
-        
-    
-    def forward(self, input_ids, attention_mask, token_type_ids=None, hidden_state=None, inform_labels=None,
-                request_labels=None, domain_labels=None, goodbye_labels=None,
-                get_turn_pooled_representation=False, calculate_inform_mutual_info=False):
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
         if token_type_ids is not None:
             token_type_ids = None
 
@@ -106,9 +86,10 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
         turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
         
         if get_turn_pooled_representation:
-            return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
-                                turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
-                                calculate_inform_mutual_info) + (roberta_output.pooler_output,)
-        return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
-                            turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
-                            calculate_inform_mutual_info)
+            return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
+                                 batch_size, dialogue_size, hidden_state, state_labels,
+                                 request_labels, active_domain_labels, general_act_labels,
+                                 calculate_state_mutual_info) + (roberta_output.pooler_output,)
+        return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
+                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
+                             general_act_labels, calculate_state_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py
new file mode 100644
index 00000000..0249649f
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/setsumbt.py
@@ -0,0 +1,564 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SetSUMBT Prediction Head"""
+
+import torch
+from torch.autograd import Variable
+from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
+                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
+                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
+from torch.nn.init import (xavier_normal_, constant_)
+
+from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss,
+                                       KLDistillationLoss, BinaryKLDistillationLoss,
+                                       LabelSmoothingLoss, BinaryLabelSmoothingLoss,
+                                       RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss)
+
+
+class SlotUtteranceMatching(Module):
+    """Slot Utterance matching attention based information extractor"""
+
+    def __init__(self, hidden_size: int = 768, attention_heads: int = 12):
+        """
+        Args:
+            hidden_size (int): Dimension of token embeddings
+            attention_heads (int): Number of attention heads to use in attention module
+        """
+        super(SlotUtteranceMatching, self).__init__()
+
+        self.attention = MultiheadAttention(hidden_size, attention_heads)
+
+    def forward(self,
+                turn_embeddings: torch.Tensor,
+                attention_mask: torch.Tensor,
+                slot_embeddings: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size]
+            attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size]
+            slot_embeddings: Embeddings for each token in the slot descriptions
+
+        Returns:
+            hidden: Information extracted from turn related to slot descriptions
+        """
+        turn_embeddings = turn_embeddings.transpose(0, 1)
+
+        key_padding_mask = (attention_mask[:, :, 0] == 0.0)
+        key_padding_mask[key_padding_mask[:, 0], :] = False
+
+        hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
+                                   key_padding_mask=key_padding_mask)
+
+        attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
+        hidden = hidden * attention_mask
+
+        return hidden
+
+
+class RecurrentNeuralBeliefTracker(Module):
+    """Recurrent latent neural belief tracking module"""
+
+    def __init__(self,
+                 nbt_type: str = 'gru',
+                 rnn_zero_init: bool = False,
+                 input_size: int = 768,
+                 hidden_size: int = 300,
+                 hidden_layers: int = 1,
+                 dropout_rate: float = 0.3):
+        """
+        Args:
+            nbt_type: Type of recurrent neural network (gru/lstm)
+            rnn_zero_init: Use zero initialised state for the RNN
+            input_size: Embedding size of the inputs
+            hidden_size: Hidden size of the RNN
+            hidden_layers: Number of RNN Layers
+            dropout_rate: Dropout rate
+        """
+        super(RecurrentNeuralBeliefTracker, self).__init__()
+
+        if rnn_zero_init:
+            self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate))
+        else:
+            self.belief_init = None
+
+        self.nbt_type = nbt_type
+        self.hidden_layers = hidden_layers
+        self.hidden_size = hidden_size
+        if nbt_type == 'gru':
+            self.nbt = GRU(input_size=input_size,
+                           hidden_size=hidden_size,
+                           num_layers=hidden_layers,
+                           dropout=0.0 if hidden_layers == 1 else dropout_rate,
+                           batch_first=True)
+        elif nbt_type == 'lstm':
+            self.nbt = LSTM(input_size=input_size,
+                            hidden_size=hidden_size,
+                            num_layers=hidden_layers,
+                            dropout=0.0 if hidden_layers == 1 else dropout_rate,
+                            batch_first=True)
+        else:
+            raise NameError('Not Implemented')
+
+        # Initialise Parameters
+        xavier_normal_(self.nbt.weight_ih_l0)
+        xavier_normal_(self.nbt.weight_hh_l0)
+        constant_(self.nbt.bias_ih_l0, 0.0)
+        constant_(self.nbt.bias_hh_l0, 0.0)
+
+        # Intermediate feature mapping and layer normalisation
+        self.intermediate = Linear(hidden_size, input_size)
+        self.layer_norm = LayerNorm(input_size)
+        self.dropout = Dropout(dropout_rate)
+
+    def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor:
+        """
+        Args:
+            inputs: Latent turn level information
+            hidden_state: Latent internal belief state
+
+        Returns:
+            belief_embedding: Belief state embeddings
+            context: Latent internal belief state
+        """
+        self.nbt.flatten_parameters()
+        if hidden_state is None:
+            if self.belief_init is None:
+                context = torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device)
+            else:
+                context = self.belief_init(inputs[:, 0, :]).unsqueeze(0).repeat((self.hidden_layers, 1, 1))
+            if self.nbt_type == "lstm":
+                context = (context, torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device))
+        else:
+            context = hidden_state.to(inputs.device)
+
+        # [batch_size, dialogue_size, nbt_hidden_size]
+        belief_embedding, context = self.nbt(inputs, context)
+
+        # Normalisation and regularisation
+        belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
+        belief_embedding = self.dropout(belief_embedding)
+
+        return belief_embedding, context
+
+
+class SetPooler(Module):
+    """Token set pooler"""
+
+    def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768):
+        """
+        Args:
+            pooling_strategy: Type of set pooler (cnn/dan/mean)
+            hidden_size: Token embedding size
+        """
+        super(SetPooler, self).__init__()
+
+        self.pooling_strategy = pooling_strategy
+        if pooling_strategy == 'cnn':
+            self.cnn_filter_size = 3
+            self.pooler = Conv1d(hidden_size, hidden_size, self.cnn_filter_size)
+        elif pooling_strategy == 'dan':
+            self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size))
+
+    def forward(self, inputs, attention_mask):
+        """
+        Args:
+            inputs: Token set embeddings
+            attention_mask: Padding mask for the set of tokens
+
+        Returns:
+
+        """
+        if self.pooling_strategy == "mean":
+            hidden = inputs.sum(1) / attention_mask.sum(1)
+        elif self.pooling_strategy == "cnn":
+            hidden = self.pooler(inputs.transpose(1, 2)).mean(-1)
+        elif self.pooling_strategy == 'dan':
+            hidden = inputs.sum(1) / torch.sqrt(torch.tensor(attention_mask.sum(1)))
+            hidden = self.pooler(hidden)
+
+        return hidden
+
+
+class SetSUMBTHead(Module):
+    """SetSUMBT Prediction Head for Language Models"""
+
+    def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
+        super(SetSUMBTHead, self).__init__()
+        self.config = config
+        # Slot Utterance matching attention
+        self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads)
+
+        # Latent context tracker
+        self.nbt = RecurrentNeuralBeliefTracker(config.nbt_type, config.rnn_zero_init, config.hidden_size,
+                                                config.nbt_hidden_size, config.nbt_layers, config.dropout_rate)
+
+        # Set pooler for set similarity model
+        if self.config.set_similarity:
+            self.set_pooler = SetPooler(config.set_pooling, config.hidden_size)
+
+        # Model ontology placeholders
+        self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
+        self.slot_ids = dict()
+        self.requestable_slot_ids = dict()
+        self.informable_slot_ids = dict()
+        self.domain_ids = dict()
+
+        # Matching network similarity measure
+        if config.distance_measure == 'cosine':
+            self.distance = CosineSimilarity(dim=-1, eps=1e-8)
+        elif config.distance_measure == 'euclidean':
+            self.distance = PairwiseDistance(p=2.0, eps=1e-6, keepdim=False)
+        else:
+            raise NameError('NotImplemented')
+
+        # User goal prediction loss function
+        if config.loss_function == 'crossentropy':
+            self.loss = CrossEntropyLoss(ignore_index=-1)
+        elif config.loss_function == 'bayesianmatching':
+            self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+        elif config.loss_function == 'labelsmoothing':
+            self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
+        elif config.loss_function == 'distillation':
+            self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+            self.temp = 1.0
+        elif config.loss_function == 'distribution_distillation':
+            self.loss = RKLDirichletMediatorLoss(ignore_index=-1)
+        else:
+            raise NameError('NotImplemented')
+
+        # Intent and domain prediction heads
+        if config.predict_actions:
+            self.request_gate = Linear(config.hidden_size, 1)
+            self.general_act_gate = Linear(config.hidden_size, 3)
+            self.active_domain_gate = Linear(config.hidden_size, 1)
+
+            # Intent and domain loss function
+            self.request_weight = float(self.config.user_request_loss_weight)
+            self.general_act_weight = float(self.config.user_general_act_loss_weight)
+            self.active_domain_weight = float(self.config.active_domain_loss_weight)
+            if config.loss_function == 'crossentropy':
+                self.request_loss = BCEWithLogitsLoss()
+                self.general_act_loss = CrossEntropyLoss(ignore_index=-1)
+                self.active_domain_loss = BCEWithLogitsLoss()
+            elif config.loss_function == 'labelsmoothing':
+                self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
+                self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
+                self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
+            elif config.loss_function == 'bayesianmatching':
+                self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+                self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+                self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+            elif config.loss_function == 'distillation':
+                self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+                self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+                self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+            elif config.loss_function == 'distribution_distillation':
+                self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
+                self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1)
+                self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
+
+    def add_slot_candidates(self, slot_candidates: tuple):
+        """
+        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
+        and a request indicator, if the informable value embeddings is None the slot is not informable and if
+        the request indicator is false the slot is not requestable.
+
+        Args:
+            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
+        """
+        if self.slot_embeddings.size(0) != 0:
+            embeddings = self.slot_embeddings.detach()
+        else:
+            embeddings = torch.zeros(0)
+
+        for slot in slot_candidates:
+            if slot in self.slot_ids:
+                index = self.slot_ids[slot]
+                embeddings[index, :] = slot_candidates[slot][0]
+            else:
+                index = embeddings.size(0)
+                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
+                embeddings = torch.cat((embeddings, emb), 0)
+                self.slot_ids[slot] = index
+                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
+            # Add slot to relevant requestable and informable slot lists
+            if slot_candidates[slot][2]:
+                self.requestable_slot_ids[slot] = index
+            if slot_candidates[slot][1] is not None:
+                self.informable_slot_ids[slot] = index
+
+            domain = slot.split('-', 1)[0]
+            if domain not in self.domain_ids:
+                self.domain_ids[domain] = []
+            self.domain_ids[domain].append(index)
+            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
+
+        self.slot_embeddings = Variable(embeddings, requires_grad=False)
+
+    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
+        """
+        Add value candidates for a slot
+
+        Args:
+            slot: Slot name
+            value_candidates: Value candidate embeddings
+            replace: If true existing value candidates are replaced
+        """
+        embeddings = getattr(self, slot + '_value_embeddings')
+
+        if embeddings.size(0) == 0 or replace:
+            embeddings = value_candidates
+        else:
+            embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
+
+        setattr(self, slot + '_value_embeddings', embeddings)
+
+    def forward(self,
+                turn_embeddings: torch.Tensor,
+                turn_pooled_representation: torch.Tensor,
+                attention_mask: torch.Tensor,
+                batch_size: int,
+                dialogue_size: int,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            turn_embeddings: Token embeddings in the current turn
+            turn_pooled_representation: Pooled representation of the current dialogue turn
+            attention_mask: Padding mask for the current dialogue turn
+            batch_size: Number of dialogues in the batch
+            dialogue_size: Number of turns in each dialogue
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            calculate_state_mutual_info: Return mutual information in the dialogue state
+
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
+        hidden_size = turn_embeddings.size(-1)
+        # Initialise loss
+        loss = 0.0
+
+        # General Action predictions
+        general_act_probs = None
+        if self.config.predict_actions:
+            # General action prediction
+            general_act_logits = self.general_act_gate(turn_pooled_representation.reshape(batch_size * dialogue_size,
+                                                                                          hidden_size))
+
+            # Compute loss for general action predictions (weighted loss)
+            if general_act_labels is not None:
+                if self.config.loss_function == 'distillation':
+                    general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-1))
+                    loss += self.general_act_loss(general_act_logits, general_act_labels,
+                                                  self.temp) * self.general_act_weight
+                elif self.config.loss_function == 'distribution_distillation':
+                    general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-2),
+                                                                    general_act_labels.size(-1))
+                    loss += self.general_act_loss(general_act_logits, general_act_labels)[0] * self.general_act_weight
+                else:
+                    general_act_labels = general_act_labels.reshape(-1)
+                    loss += self.general_act_loss(general_act_logits, general_act_labels) * self.general_act_weight
+
+            # Compute general action probabilities
+            general_act_probs = torch.softmax(general_act_logits, -1).reshape(batch_size, dialogue_size, -1)
+
+        # Slot utterance matching
+        num_slots = self.slot_embeddings.size(0)
+        slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
+        slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1))
+        slot_embeddings = slot_embeddings.to(turn_embeddings.device)
+
+        if self.config.set_similarity:
+            # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
+            slot_mask = (slot_embeddings != 0.0).float()
+
+        hidden = self.slot_utterance_matching(turn_embeddings, attention_mask, slot_embeddings)
+
+        if self.config.set_similarity:
+            hidden = hidden * slot_mask
+        # Hidden layer shape [num_dials, num_slots, num_turns, 768]
+        hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
+
+        # Latent context tracking
+        # [batch_size * num_slots, dialogue_size, 768]
+        hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
+        belief_embedding, hidden_state = self.nbt(hidden, hidden_state)
+
+        belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0),
+                                                    dialogue_size, -1).transpose(1, 2)
+        if self.config.set_similarity:
+            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
+                                                        self.config.hidden_size)
+        # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
+
+        # Pooling of the set of latent context representation
+        if self.config.set_similarity:
+            slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
+            belief_embedding = belief_embedding * slot_mask
+
+            belief_embedding = self.set_pooler(belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size),
+                                               slot_mask.reshape(-1, slot_mask.size(-2), hidden_size))
+            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
+
+        # Perform classification
+        # Get padded batch, dialogue idx pairs
+        batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0)
+        
+        if self.config.predict_actions:
+            # User request prediction
+            request_probs = dict()
+            for slot, slot_id in self.requestable_slot_ids.items():
+                request_logits = self.request_gate(belief_embedding[:, :, slot_id, :])
+
+                # Store output probabilities
+                request_logits = request_logits.reshape(batch_size, dialogue_size)
+                # Set request scores to 0.0 for padded turns
+                request_logits[batches, dialogues] = 0.0
+                request_probs[slot] = torch.sigmoid(request_logits)
+
+                if request_labels is not None:
+                    # Compute request gate loss
+                    request_logits = request_logits.reshape(-1)
+                    if self.config.loss_function == 'distillation':
+                        loss += self.request_loss(request_logits, request_labels[slot].reshape(-1),
+                                                  self.temp) * self.request_weight
+                    elif self.config.loss_function == 'distribution_distillation':
+                        loss += self.request_loss(request_logits, request_labels[slot])[0] * self.request_weight
+                    else:
+                        labs = request_labels[slot].reshape(-1)
+                        request_logits = request_logits[labs != -1]
+                        labs = labs[labs != -1].float()
+                        loss += self.request_loss(request_logits, labs) * self.request_weight
+
+            # Active domain prediction
+            active_domain_probs = dict()
+            for domain, slot_ids in self.domain_ids.items():
+                belief = belief_embedding[:, :, slot_ids, :]
+                if len(slot_ids) > 1:
+                    # SqrtN reduction across all slots within a domain
+                    belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
+                active_domain_logits = self.active_domain_gate(belief)
+
+                # Store output probabilities
+                active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size)
+                active_domain_logits[batches, dialogues] = 0.0
+                active_domain_probs[domain] = torch.sigmoid(active_domain_logits)
+
+                if active_domain_labels is not None and domain in active_domain_labels:
+                    # Compute domain prediction loss
+                    active_domain_logits = active_domain_logits.reshape(-1)
+                    if self.config.loss_function == 'distillation':
+                        loss += self.active_domain_loss(active_domain_logits, active_domain_labels[domain].reshape(-1),
+                                                        self.temp) * self.active_domain_weight
+                    elif self.config.loss_function == 'distribution_distillation':
+                        loss += self.active_domain_loss(active_domain_logits,
+                                                        active_domain_labels[domain])[0] * self.active_domain_weight
+                    else:
+                        labs = active_domain_labels[domain].reshape(-1)
+                        active_domain_logits = active_domain_logits[labs != -1]
+                        labs = labs[labs != -1].float()
+                        loss += self.active_domain_loss(active_domain_logits, labs) * self.active_domain_weight
+        else:
+            request_probs, active_domain_probs = None, None
+
+        # Dialogue state predictions
+        belief_state_probs = dict()
+        belief_state_mutual_info = dict()
+        belief_state_stats = dict()
+        for slot, slot_id in self.informable_slot_ids.items():
+            # Get slot belief embedding and value candidates
+            candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
+            belief = belief_embedding[:, :, slot_id, :]
+            slot_size = candidate_embeddings.size(0)
+
+            belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
+            belief = belief.reshape(-1, self.config.hidden_size)
+
+            if self.config.set_similarity:
+                candidate_embeddings = self.set_pooler(candidate_embeddings, (candidate_embeddings != 0.0).float())
+            candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
+                                                                                          dialogue_size, 1, 1))
+            candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
+
+            # Score value candidates
+            if self.config.distance_measure == 'cosine':
+                logits = self.distance(belief, candidate_embeddings)
+                # *27 here rescales the cosine similarity for better learning
+                logits = logits.reshape(batch_size * dialogue_size, -1) * 27.0
+            elif self.config.distance_measure == 'euclidean':
+                logits = -1.0 * self.distance(belief, candidate_embeddings)
+                logits = logits.reshape(batch_size * dialogue_size, -1)
+
+            # Calculate belief state
+            probs_ = torch.softmax(logits.reshape(batch_size, dialogue_size, -1), -1)
+
+            # Compute knowledge uncertainty in the beleif states
+            if calculate_state_mutual_info and self.config.loss_function == 'distribution_distillation':
+                belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size)
+
+            # Set padded turn probabilities to zero
+            probs_[batches, dialogues, :] = 0.0
+            belief_state_probs[slot] = probs_
+
+            # Calculate belief state loss
+            if state_labels is not None and slot in state_labels:
+                if self.config.loss_function == 'bayesianmatching':
+                    prior = torch.ones(logits.size(-1)).float().to(logits.device)
+                    prior = prior * self.config.prior_constant
+                    prior = prior.unsqueeze(0).repeat((logits.size(0), 1))
+
+                    loss += self.loss(logits, state_labels[slot].reshape(-1), prior=prior)
+                elif self.config.loss_function == 'distillation':
+                    labels = state_labels[slot]
+                    labels = labels.reshape(-1, labels.size(-1))
+                    loss += self.loss(logits, labels, self.temp)
+                elif self.config.loss_function == 'distribution_distillation':
+                    labels = state_labels[slot]
+                    labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
+                    loss_, model_stats, ensemble_stats = self.loss(logits, labels)
+                    loss += loss_
+
+                    # Calculate stats regarding model precisions
+                    precision = model_stats['precision']
+                    ensemble_precision = ensemble_stats['precision']
+                    belief_state_stats[slot] = {'model_precision_min': precision.min(),
+                                                'model_precision_max': precision.max(),
+                                                'model_precision_mean': precision.mean(),
+                                                'ensemble_precision_min': ensemble_precision.min(),
+                                                'ensemble_precision_max': ensemble_precision.max(),
+                                                'ensemble_precision_mean': ensemble_precision.mean()}
+                else:
+                    loss += self.loss(logits, state_labels[slot].reshape(-1))
+
+        # Return model outputs
+        out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state
+        if state_labels is not None or request_labels is not None:
+            out = (loss,) + out + (belief_state_stats,)
+        if calculate_state_mutual_info:
+            out = out + (belief_state_mutual_info,)
+        return out
diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
index fab205be..654e83c5 100644
--- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py
+++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,50 +13,70 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Temperature Scheduler Class"""
-import torch
+"""Linear Temperature Scheduler Class"""
+
 
 # Temp scheduler class for ensemble distillation
-class TemperatureScheduler:
+class LinearTemperatureScheduler:
+    """
+    Temperature scheduler object used for distribution temperature scheduling in distillation
 
-    def __init__(self, total_steps, base_temp=2.5, cycle_len=0.1):
-        self.state = {}
+    Attributes:
+        state (dict): Internal state of scheduler
+    """
+    def __init__(self,
+                 total_steps: int,
+                 base_temp: float = 2.5,
+                 cycle_len: float = 0.1):
+        """
+        Args:
+            total_steps (int): Total number of training steps
+            base_temp (float): Starting temperature
+            cycle_len (float): Fraction of total steps used for scheduling cycle
+        """
+        self.state = dict()
         self.state['total_steps'] = total_steps
         self.state['current_step'] = 0
         self.state['base_temp'] = base_temp
         self.state['current_temp'] = base_temp
         self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)]
+        self.state['rate'] = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
     
     def step(self):
+        """
+        Update temperature based on the schedule
+        """
         self.state['current_step'] += 1
         assert self.state['current_step'] <= self.state['total_steps']
         if self.state['current_step'] > self.state['cycles'][0]:
             if self.state['current_step'] < self.state['cycles'][1]:
-                rate = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
-                self.state['current_temp'] -= rate
+                self.state['current_temp'] -= self.state['rate']
             else:
                 self.state['current_temp'] = 1.0
     
     def temp(self):
+        """
+        Get current temperature
+
+        Returns:
+            temp (float): Current temperature for distribution scaling
+        """
         return float(self.state['current_temp'])
     
     def state_dict(self):
-        return self.state
-    
-    def load_state_dict(self, sd):
-        self.state = sd
+        """
+        Return scheduler state
 
-
-# if __name__ == "__main__":
-#     temp_scheduler = TemperatureScheduler(100)
-#     print(temp_scheduler.state_dict())
-
-#     temp = []
-#     for i in range(100):
-#         temp.append(temp_scheduler.temp())
-#         temp_scheduler.step()
+        Returns:
+            state (dict): Dictionary format state of the scheduler
+        """
+        return self.state
     
-#     temp_scheduler.load_state_dict(temp_scheduler.state_dict())
-#     print(temp_scheduler.state_dict())
+    def load_state_dict(self, state_dict: dict):
+        """
+        Load scheduler state from dictionary
 
-#     print(temp)
+        Args:
+            state_dict (dict): Dictionary format state of the scheduler
+        """
+        self.state = state_dict
diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py
index 259c6e1d..77f41dc3 100644
--- a/convlab/dst/setsumbt/modeling/training.py
+++ b/convlab/dst/setsumbt/modeling/training.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,17 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Training utils"""
+"""Training and evaluation utils"""
 
 import random
 import os
 import logging
+from copy import deepcopy
 
 import torch
 from torch.nn import DataParallel
 from torch.distributions import Categorical
 import numpy as np
-from transformers import AdamW, get_linear_schedule_with_warmup
+from transformers import get_linear_schedule_with_warmup
+from torch.optim import AdamW
 from tqdm import tqdm, trange
 try:
     from apex import amp
@@ -31,7 +33,7 @@ except:
     print('Apex not used')
 
 from convlab.dst.setsumbt.utils import clear_checkpoints
-from convlab.dst.setsumbt.modeling.temperature_scheduler import TemperatureScheduler
+from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler
 
 
 # Load logger and tensorboard summary writer
@@ -59,18 +61,131 @@ def set_ontology_embeddings(model, slots, load_slots=True):
     if load_slots:
         slots = {slot: embs for slot, embs in slots.items()}
         model.add_slot_candidates(slots)
-    for slot in model.informable_slot_ids:
+    try:
+        informable_slot_ids = model.setsumbt.informable_slot_ids
+    except:
+        informable_slot_ids = model.informable_slot_ids
+    for slot in informable_slot_ids:
         model.add_value_candidates(slot, values[slot], replace=True)
 
 
-def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev, embeddings=None, tokenizer=None):
-    """Train model!"""
+def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None):
+    """
+    Log training statistics.
+
+    Args:
+        global_step: Number of global training steps completed
+        loss: Training loss
+        jg_acc: Joint goal accuracy
+        sl_acc: Slot accuracy
+        req_f1: Request prediction F1 score
+        dom_f1: Active domain prediction F1 score
+        gen_f1: General action prediction F1 score
+        stats: Uncertainty measure statistics of model
+    """
+    if type(global_step) == int:
+        info = f"{global_step} steps complete, "
+        info += f"Loss since last update: {loss}. Validation set stats: "
+    elif global_step == 'training_complete':
+        info = f"Training Complete, "
+        info += f"Validation set stats: "
+    elif global_step == 'dev':
+        info = f"Validation set stats: Loss: {loss}, "
+    elif global_step == 'test':
+        info = f"Test set stats: Loss: {loss}, "
+    info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, "
+    if req_f1 is not None:
+        info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, "
+        info += f"General Action F1 Score: {gen_f1}"
+    logger.info(info)
+
+    if type(global_step) == int:
+        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
+        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
+        if req_f1 is not None:
+            tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
+            tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step)
+            tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step)
+        tb_writer.add_scalar('Loss/Dev', loss, global_step)
+
+        if stats:
+            for slot, stats_slot in stats.items():
+                for key, item in stats_slot.items():
+                    tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
+
+
+def get_input_dict(batch: dict,
+                   predict_actions: bool,
+                   model_informable_slot_ids: list,
+                   model_requestable_slot_ids: list = None,
+                   model_domain_ids: list = None,
+                   device = 'cpu') -> dict:
+    """
+    Produce model input arguments
+
+    Args:
+        batch: Batch of data from the dataloader
+        predict_actions: Model should predict user actions if set true
+        model_informable_slot_ids: List of model dialogue state slots
+        model_requestable_slot_ids: List of model requestable slots
+        model_domain_ids: List of model domains
+        device: Current torch device in use
+
+    Returns:
+        input_dict: Dictrionary containing model inputs for the batch
+    """
+    input_dict = dict()
+
+    input_dict['input_ids'] = batch['input_ids'].to(device)
+    input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+    input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+
+    if any('belief_state' in key for key in batch):
+        input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device)
+                                      for slot in model_informable_slot_ids
+                                      if ('belief_state-' + slot) in batch}
+        if predict_actions:
+            input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device)
+                                            for slot in model_requestable_slot_ids
+                                            if ('request_probs-' + slot) in batch}
+            input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device)
+                                                  for domain in model_domain_ids
+                                                  if ('active_domain_probs-' + domain) in batch}
+            input_dict['general_act_labels'] = batch['general_act_probs'].to(device)
+    else:
+        input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device)
+                                      for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch}
+        if predict_actions:
+            input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device)
+                                            for slot in model_requestable_slot_ids
+                                            if ('request_labels-' + slot) in batch}
+            input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device)
+                                                  for domain in model_domain_ids
+                                                  if ('active_domain_labels-' + domain) in batch}
+            input_dict['general_act_labels'] = batch['general_act_labels'].to(device)
+
+    return input_dict
+
+
+def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict):
+    """
+    Train the SetSUMBT model.
+
+    Args:
+        args: Runtime arguments
+        model: SetSUMBT Model instance to train
+        device: Torch device to use during training
+        train_dataloader: Dataloader containing the training data
+        dev_dataloader: Dataloader containing the validation set data
+        slots: Model ontology used for training
+        slots_dev: Model ontology used for evaluating on the validation set
+    """
 
     # Calculate the total number of training steps to be performed
     if args.max_training_steps > 0:
         t_total = args.max_training_steps
-        args.num_train_epochs = args.max_training_steps // (
-            (len(train_dataloader) // args.gradient_accumulation_steps) + 1)
+        args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1
+        args.num_train_epochs = args.max_training_steps // args.num_train_epochs
     else:
         t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs
 
@@ -88,12 +203,12 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
         {
             "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0,
-            "lr":args.learning_rate
+            "lr": args.learning_rate
         },
     ]
 
     # Initialise the optimizer
-    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
 
     # Initialise linear lr scheduler
     num_warmup_steps = int(t_total * args.warmup_proportion)
@@ -109,8 +224,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
 
     # Set up fp16 and multi gpu usage
     if args.fp16:
-        model, optimizer = amp.initialize(
-            model, optimizer, opt_level=args.fp16_opt_level)
+        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
     if args.n_gpu > 1:
         model = DataParallel(model)
 
@@ -118,7 +232,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
     best_model = {'joint goal accuracy': 0.0,
                   'request f1 score': 0.0,
                   'active domain f1 score': 0.0,
-                  'goodbye act f1 score': 0.0,
+                  'general act f1 score': 0.0,
                   'train loss': np.inf}
     if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')):
         logger.info("Optimizer loaded from previous run.")
@@ -136,27 +250,27 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
             model.eval()
             set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-            jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
+            jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True)
 
             # Set model back to training mode
             model.train()
             model.zero_grad()
             set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
         else:
-            jg_acc, req_f1, dom_f1, bye_f1 = 0.0, 0.0, 0.0, 0.0
+            jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0
 
         best_model['joint goal accuracy'] = jg_acc
         best_model['request f1 score'] = req_f1
         best_model['active domain f1 score'] = dom_f1
-        best_model['goodbye act f1 score'] = bye_f1
+        best_model['general act f1 score'] = gen_f1
 
     # Log training set up
-    logger.info("Device: %s, Number of GPUs: %s, FP16 training: %s" % (device, args.n_gpu, args.fp16))
+    logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}")
     logger.info("***** Running training *****")
-    logger.info("  Num Batches = %d" % len(train_dataloader))
-    logger.info("  Num Epochs = %d" % args.num_train_epochs)
-    logger.info("  Gradient Accumulation steps = %d" % args.gradient_accumulation_steps)
-    logger.info("  Total optimization steps = %d" % t_total)
+    logger.info(f"  Num Batches = {len(train_dataloader)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {t_total}")
 
     # Initialise training parameters
     global_step = 0
@@ -173,11 +287,11 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
             steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
 
             logger.info("  Continuing training from checkpoint, will skip to saved global_step")
-            logger.info("  Continuing training from epoch %d" % epochs_trained)
-            logger.info("  Continuing training from global step %d" % global_step)
-            logger.info("  Will skip the first %d steps in the first epoch" % steps_trained_in_current_epoch)
+            logger.info(f"  Continuing training from epoch {epochs_trained}")
+            logger.info(f"  Continuing training from global step {global_step}")
+            logger.info(f"  Will skip the first {steps_trained_in_current_epoch} steps in the first epoch")
         except ValueError:
-            logger.info("  Starting fine-tuning.")
+            logger.info(f"  Starting fine-tuning.")
 
     # Prepare model for training
     tr_loss, logging_loss = 0.0, 0.0
@@ -196,43 +310,15 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                 continue
 
             # Extract all label dictionaries from the batch
-            if 'goodbye_belief' in batch:
-                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('belief-' + slot) in batch}
-                request_labels = {slot: batch['request_belief-' + slot].to(device)
-                                  for slot in model.requestable_slot_ids
-                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
-                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye_belief'].to(
-                    device) if args.predict_actions else None
-            else:
-                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('labels-' + slot) in batch}
-                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
-                                 if ('active-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye'].to(
-                    device) if args.predict_actions else None
-
-            # Extract all model inputs from batch
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
+                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
 
             # Set up temperature scaling for the model
             if temp_scheduler is not None:
-                model.temp = temp_scheduler.temp()
+                model.setsumbt.temp = temp_scheduler.temp()
 
             # Forward pass to obtain loss
-            loss, _, _, _, _, _, stats = model(input_ids=input_ids,
-                                               token_type_ids=token_type_ids,
-                                               attention_mask=attention_mask,
-                                               inform_labels=labels,
-                                               request_labels=request_labels,
-                                               domain_labels=domain_labels,
-                                               goodbye_labels=goodbye_labels)
+            loss, _, _, _, _, _, stats = model(**input_dict)
 
             if args.n_gpu > 1:
                 loss = loss.mean()
@@ -258,7 +344,6 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                 tb_writer.add_scalar('LearningRate', lr, global_step)
 
                 if stats:
-                    # print(stats.keys())
                     for slot, stats_slot in stats.items():
                         for key, item in stats_slot.items():
                             tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step)
@@ -273,7 +358,6 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
 
                 tr_loss += loss.float().item()
                 epoch_iterator.set_postfix(loss=loss.float().item())
-                loss = 0.0
                 global_step += 1
 
             # Save model checkpoint
@@ -286,52 +370,34 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                     model.eval()
                     set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-                    jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
+                    jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
+                                                                                   is_train=True)
                     # Log model eval information
-                    if req_f1 is not None:
-                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f'
-                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
-                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
-                        tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
-                        tb_writer.add_scalar('DomainF1Score/Dev', dom_f1, global_step)
-                        tb_writer.add_scalar('GoodbyeF1Score/Dev', bye_f1, global_step)
-                    else:
-                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f'
-                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc))
-                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
-                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
-                    tb_writer.add_scalar('Loss/Dev', loss, global_step)
-                    if stats:
-                        for slot, stats_slot in stats.items():
-                            for key, item in stats_slot.items():
-                                tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
+                    log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats)
 
                     # Set model back to training mode
                     model.train()
                     model.zero_grad()
                     set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
                 else:
-                    jg_acc, req_f1 = 0.0, None
-                    logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps))
+                    log_info(global_step, logging_loss / args.save_steps)
 
                 logging_loss = tr_loss
 
                 # Compute the score of the best model
                 try:
-                    best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \
-                        (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \
-                        (best_model['goodbye act f1 score'] *
-                         model.config.user_general_act_loss_weight)
+                    best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
+                    best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
+                    best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
                 except AttributeError:
                     best_score = 0.0
                 best_score += best_model['joint goal accuracy']
 
                 # Compute the score of the current model
                 try:
-                    current_score = (req_f1 * model.config.user_request_loss_weight) + \
-                        (dom_f1 * model.config.active_domain_loss_weight) + \
-                        (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0
+                    current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
+                    current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
+                    current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
                 except AttributeError:
                     current_score = 0.0
                 current_score += jg_acc
@@ -353,10 +419,10 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                     if req_f1:
                         best_model['request f1 score'] = req_f1
                         best_model['active domain f1 score'] = dom_f1
-                        best_model['goodbye act f1 score'] = bye_f1
+                        best_model['general act f1 score'] = gen_f1
                     best_model['train loss'] = tr_loss / global_step
 
-                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
+                    output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                     if not os.path.exists(output_dir):
                         os.makedirs(output_dir, exist_ok=True)
 
@@ -386,14 +452,15 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                 epoch_iterator.close()
                 break
 
-        logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step))
+        steps_trained_in_current_epoch = 0
+        logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}')
 
         if args.max_training_steps > 0 and global_step > args.max_training_steps:
             train_iterator.close()
             break
         if args.patience > 0 and steps_since_last_update >= args.patience:
             train_iterator.close()
-            logger.info('Model has not improved for at least %i steps. Training stopped!' % args.patience)
+            logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!')
             break
 
     # Evaluate final model
@@ -401,30 +468,25 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
         model.eval()
         set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
-        if req_f1 is not None:
-            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f'
-                        % (tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-        else:
-            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f'
-                        % (tr_loss / global_step, jg_acc, sl_acc))
+        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
+                                                                       is_train=True)
+
+        log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
     else:
-        jg_acc = 0.0
         logger.info('Training complete!')
 
     # Store final model
     try:
-        best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \
-            (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \
-            (best_model['goodbye act f1 score'] *
-             model.config.user_general_act_loss_weight)
+        best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
+        best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
+        best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
     except AttributeError:
         best_score = 0.0
     best_score += best_model['joint goal accuracy']
     try:
-        current_score = (req_f1 * model.config.user_request_loss_weight) + \
-                        (dom_f1 * model.config.active_domain_loss_weight) + \
-                        (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0
+        current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
+        current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
+        current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
     except AttributeError:
         current_score = 0.0
     current_score += jg_acc
@@ -456,225 +518,89 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
             torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
         clear_checkpoints(args.output_dir)
     else:
-        logger.info(
-            'Final model not saved, since it is not the best performing model.')
+        logger.info('Final model not saved, as it is not the best performing model.')
 
 
-# Function for validation
-def train_eval(args, model, device, dev_dataloader):
-    """Evaluate Model during training!"""
-    accuracy_jg = []
-    accuracy_sl = []
-    accuracy_req = []
-    truepos_req, falsepos_req, falseneg_req = [], [], []
-    truepos_dom, falsepos_dom, falseneg_dom = [], [], []
-    truepos_bye, falsepos_bye, falseneg_bye = [], [], []
-    accuracy_dom = []
-    accuracy_bye = []
-    turns = []
-    for batch in dev_dataloader:
-        # Perform with no gradients stored
-        with torch.no_grad():
-            if 'goodbye_belief' in batch:
-                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('belief-' + slot) in batch}
-                request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
-                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye_belief'].to(
-                    device) if args.predict_actions else None
-            else:
-                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('labels-' + slot) in batch}
-                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
-                                 if ('active-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye'].to(
-                    device) if args.predict_actions else None
-
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(
-                device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(
-                device) if 'attention_mask' in batch else None
-
-            loss, p, p_req, p_dom, p_bye, _, stats = model(input_ids=input_ids,
-                                                           token_type_ids=token_type_ids,
-                                                           attention_mask=attention_mask,
-                                                           inform_labels=labels,
-                                                           request_labels=request_labels,
-                                                           domain_labels=domain_labels,
-                                                           goodbye_labels=goodbye_labels)
+def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False):
+    """
+    Evaluate model
 
-        jg_acc = 0.0
-        req_acc = 0.0
-        req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
-        dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
-        dom_acc = 0.0
-        for slot in model.informable_slot_ids:
-            labels = batch['labels-' + slot].to(device)
-            p_ = p[slot]
-
-            acc = (p_.argmax(-1) == labels).reshape(-1).float()
-            jg_acc += acc
-
-        if model.config.predict_actions:
-            for slot in model.requestable_slot_ids:
-                p_req_ = p_req[slot]
-                request_labels = batch['request-' + slot].to(device)
+    Args:
+        args: Runtime arguments
+        model: SetSUMBT model instance
+        device: Torch device in use
+        dataloader: Dataloader of data to evaluate on
+        return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format
+        is_train: If true model is training and no logging is performed
 
-                acc = (p_req_.round().int() == request_labels).reshape(-1).float()
-                tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
-                fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float()
-                req_acc += acc
-                req_tp += tp
-                req_fp += fp
-                req_fn += fn
-
-            for domain in model.domain_ids:
-                p_dom_ = p_dom[domain]
-                domain_labels = batch['active-' + domain].to(device)
-
-                acc = (p_dom_.round().int() == domain_labels).reshape(-1).float()
-                tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float()
-                fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float()
-                dom_acc += acc
-                dom_tp += tp
-                dom_fp += fp
-                dom_fn += fn
-
-            goodbye_labels = batch['goodbye'].to(device)
-            bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum()
-            bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
-            bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum()
-            bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
-        else:
-            req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0)
-            req_tp, req_fp, req_fn = None, None, None
-            dom_tp, dom_fp, dom_fn = None, None, None
-            bye_tp, bye_fp, bye_fn = torch.tensor(
-                0.0), torch.tensor(0.0), torch.tensor(0.0)
-
-        sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float()
-        jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float()
-        req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0)
-        req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
-        req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
-        req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
-        dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
-        dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
-        dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
-        dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0)
-        n_turns = (labels >= 0).reshape(-1).sum().float().item()
-
-        accuracy_jg.append(jg_acc.item())
-        accuracy_sl.append(sl_acc.item())
-        accuracy_req.append(req_acc.item())
-        truepos_req.append(req_tp.item())
-        falsepos_req.append(req_fp.item())
-        falseneg_req.append(req_fn.item())
-        accuracy_dom.append(dom_acc.item())
-        truepos_dom.append(dom_tp.item())
-        falsepos_dom.append(dom_fp.item())
-        falseneg_dom.append(dom_fn.item())
-        accuracy_bye.append(bye_acc.item())
-        truepos_bye.append(bye_tp.item())
-        falsepos_bye.append(bye_fp.item())
-        falseneg_bye.append(bye_fn.item())
-        turns.append(n_turns)
-
-    # Global accuracy reduction across batches
-    turns = sum(turns)
-    jg_acc = sum(accuracy_jg) / turns
-    sl_acc = sum(accuracy_sl) / turns
-    if model.config.predict_actions:
-        req_acc = sum(accuracy_req) / turns
-        req_tp = sum(truepos_req)
-        req_fp = sum(falsepos_req)
-        req_fn = sum(falseneg_req)
-        req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn))
-        dom_acc = sum(accuracy_dom) / turns
-        dom_tp = sum(truepos_dom)
-        dom_fp = sum(falsepos_dom)
-        dom_fn = sum(falseneg_dom)
-        dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn))
-        bye_tp = sum(truepos_bye)
-        bye_fp = sum(falsepos_bye)
-        bye_fn = sum(falseneg_bye)
-        bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn))
-        bye_acc = sum(accuracy_bye) / turns
-    else:
-        req_acc, dom_acc, bye_acc = None, None, None
-        req_f1, dom_f1, bye_f1 = None, None, None
-
-    return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats
-
-
-def evaluate(args, model, device, dataloader):
-    """Evaluate Model!"""
-    # Evaluate!
-    logger.info("***** Running evaluation *****")
-    logger.info("  Num Batches = %d", len(dataloader))
+    Returns:
+        out: Evaluated model statistics
+    """
+    return_eval_output = False if is_train else return_eval_output
+    if not is_train:
+        logger.info("***** Running evaluation *****")
+        logger.info("  Num Batches = %d", len(dataloader))
 
     tr_loss = 0.0
     model.eval()
+    if return_eval_output:
+        ontology = dataloader.dataset.ontology
 
-    # logits = {slot: [] for slot in model.informable_slot_ids}
     accuracy_jg = []
     accuracy_sl = []
-    accuracy_req = []
     truepos_req, falsepos_req, falseneg_req = [], [], []
     truepos_dom, falsepos_dom, falseneg_dom = [], [], []
-    truepos_bye, falsepos_bye, falseneg_bye = [], [], []
-    accuracy_dom = []
-    accuracy_bye = []
+    truepos_gen, falsepos_gen, falseneg_gen = [], [], []
     turns = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration")
+    if return_eval_output:
+        evaluation_output = []
+    epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader
     for batch in epoch_iterator:
         with torch.no_grad():
-            if 'goodbye_belief' in batch:
-                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('belief-' + slot) in batch}
-                request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
-                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye_belief'].to(
-                    device) if args.predict_actions else None
-            else:
-                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('labels-' + slot) in batch}
-                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
-                                 if ('active-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye'].to(
-                    device) if args.predict_actions else None
-
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-            loss, p, p_req, p_dom, p_bye, _, _ = model(input_ids=input_ids,
-                                                       token_type_ids=token_type_ids,
-                                                       attention_mask=attention_mask,
-                                                       inform_labels=labels,
-                                                       request_labels=request_labels,
-                                                       domain_labels=domain_labels,
-                                                       goodbye_labels=goodbye_labels)
+            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
+                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
+
+            loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict)
 
         jg_acc = 0.0
+        num_inform_slots = 0.0
         req_acc = 0.0
         req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
         dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
         dom_acc = 0.0
-        for slot in model.informable_slot_ids:
+
+        if return_eval_output:
+            eval_output_batch = []
+            for dial_id, dial in enumerate(input_dict['input_ids']):
+                for turn_id, turn in enumerate(dial):
+                    if turn.sum() != 0:
+                        eval_output_batch.append({'dial_idx': dial_id,
+                                                  'utt_idx': turn_id,
+                                                  'state': dict(),
+                                                  'predictions': {'state': dict()}
+                                                  })
+
+        for slot in model.setsumbt.informable_slot_ids:
             p_ = p[slot]
-            labels = batch['labels-' + slot].to(device)
+            state_labels = batch['state_labels-' + slot].to(device)
+
+            if return_eval_output:
+                prediction = p_.argmax(-1)
+
+                for sample in eval_output_batch:
+                    dom, slt = slot.split('-', 1)
+                    lab = state_labels[sample['dial_idx']][sample['utt_idx']].item()
+                    if lab != -1:
+                        lab = ontology[dom][slt]['possible_values'][lab]
+                        pred = prediction[sample['dial_idx']][sample['utt_idx']].item()
+                        pred = ontology[dom][slt]['possible_values'][pred]
+
+                        if dom not in sample['state']:
+                            sample['state'][dom] = dict()
+                            sample['predictions']['state'][dom] = dict()
+
+                        sample['state'][dom][slt] = lab if lab != 'none' else ''
+                        sample['predictions']['state'][dom][slt] = pred if pred != 'none' else ''
 
             if args.temp_scaling > 0.0:
                 p_ = torch.log(p_ + 1e-10) / args.temp_scaling
@@ -683,28 +609,19 @@ def evaluate(args, model, device, dataloader):
                 p_ = torch.log(p_ + 1e-10) / 1.0
                 p_ = torch.softmax(p_, -1)
 
-            # logits[slot].append(p_)
-
-            if args.accuracy_samples > 0:
-                dist = Categorical(probs=p_.reshape(-1, p_.size(-1)))
-                lab_sample = dist.sample((args.accuracy_samples,))
-                lab_sample = lab_sample.transpose(0, 1)
-                acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)]
-                acc = torch.tensor(acc).float()
-            elif args.accuracy_topn > 0:
-                labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
-                labs = labs[:, :args.accuracy_topn]
-                acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)]
-                acc = torch.tensor(acc).float()
-            else:
-                acc = (p_.argmax(-1) == labels).reshape(-1).float()
+            acc = (p_.argmax(-1) == state_labels).reshape(-1).float()
 
             jg_acc += acc
+            num_inform_slots += (state_labels != -1).float().reshape(-1)
+
+        if return_eval_output:
+            evaluation_output += deepcopy(eval_output_batch)
+            eval_output_batch = []
 
         if model.config.predict_actions:
-            for slot in model.requestable_slot_ids:
+            for slot in model.setsumbt.requestable_slot_ids:
                 p_req_ = p_req[slot]
-                request_labels = batch['request-' + slot].to(device)
+                request_labels = batch['request_labels-' + slot].to(device)
 
                 acc = (p_req_.round().int() == request_labels).reshape(-1).float()
                 tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
@@ -715,85 +632,93 @@ def evaluate(args, model, device, dataloader):
                 req_fp += fp
                 req_fn += fn
 
-            for domain in model.domain_ids:
+            domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch]
+            for domain in domains:
                 p_dom_ = p_dom[domain]
-                domain_labels = batch['active-' + domain].to(device)
+                active_domain_labels = batch['active_domain_labels-' + domain].to(device)
 
-                acc = (p_dom_.round().int() == domain_labels).reshape(-1).float()
-                tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float()
-                fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float()
+                acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float()
+                tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float()
+                fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float()
+                fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float()
                 dom_acc += acc
                 dom_tp += tp
                 dom_fp += fp
                 dom_fn += fn
 
-            goodbye_labels = batch['goodbye'].to(device)
-            bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum()
-            bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
-            bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum()
-            bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
+            general_act_labels = batch['general_act_labels'].to(device)
+            gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum()
+            gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum()
+            gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum()
         else:
-            req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0)
             req_tp, req_fp, req_fn = None, None, None
             dom_tp, dom_fp, dom_fn = None, None, None
-            bye_tp, bye_fp, bye_fn = torch.tensor(
-                0.0), torch.tensor(0.0), torch.tensor(0.0)
-
-        sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float()
-        jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float()
-        req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0)
-        req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
-        req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
-        req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
-        dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
-        dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
-        dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
-        dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0)
-        n_turns = (labels >= 0).reshape(-1).sum().float().item()
+            gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
+
+        jg_acc = jg_acc[num_inform_slots > 0]
+        num_inform_slots = num_inform_slots[num_inform_slots > 0]
+        sl_acc = sum(jg_acc / num_inform_slots).float()
+        jg_acc = sum((jg_acc == num_inform_slots).int()).float()
+        if req_tp is not None and model.setsumbt.requestable_slot_ids:
+            req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float()
+            req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float()
+            req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float()
+        else:
+            req_tp, req_fp, req_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
+        dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
+        dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
+        dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
+        n_turns = num_inform_slots.size(0)
 
         accuracy_jg.append(jg_acc.item())
         accuracy_sl.append(sl_acc.item())
-        accuracy_req.append(req_acc.item())
         truepos_req.append(req_tp.item())
         falsepos_req.append(req_fp.item())
         falseneg_req.append(req_fn.item())
-        accuracy_dom.append(dom_acc.item())
         truepos_dom.append(dom_tp.item())
         falsepos_dom.append(dom_fp.item())
         falseneg_dom.append(dom_fn.item())
-        accuracy_bye.append(bye_acc.item())
-        truepos_bye.append(bye_tp.item())
-        falsepos_bye.append(bye_fp.item())
-        falseneg_bye.append(bye_fn.item())
+        truepos_gen.append(gen_tp.item())
+        falsepos_gen.append(gen_fp.item())
+        falseneg_gen.append(gen_fn.item())
         turns.append(n_turns)
         tr_loss += loss.item()
 
-    # for slot in logits:
-    #     logits[slot] = torch.cat(logits[slot], 0)
-
     # Global accuracy reduction across batches
     turns = sum(turns)
     jg_acc = sum(accuracy_jg) / turns
     sl_acc = sum(accuracy_sl) / turns
     if model.config.predict_actions:
-        req_acc = sum(accuracy_req) / turns
         req_tp = sum(truepos_req)
         req_fp = sum(falsepos_req)
         req_fn = sum(falseneg_req)
-        req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn))
-        dom_acc = sum(accuracy_dom) / turns
+        req_f1 = req_tp + 0.5 * (req_fp + req_fn)
+        req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0
         dom_tp = sum(truepos_dom)
         dom_fp = sum(falsepos_dom)
         dom_fn = sum(falseneg_dom)
-        dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn))
-        bye_tp = sum(truepos_bye)
-        bye_fp = sum(falsepos_bye)
-        bye_fn = sum(falseneg_bye)
-        bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn))
-        bye_acc = sum(accuracy_bye) / turns
+        dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn)
+        dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0
+        gen_tp = sum(truepos_gen)
+        gen_fp = sum(falsepos_gen)
+        gen_fn = sum(falseneg_gen)
+        gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn)
+        gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0
     else:
-        req_acc, dom_acc, bye_acc = None, None, None
-        req_f1, dom_f1, bye_f1 = None, None, None
-
-    return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, tr_loss / len(dataloader)
+        req_f1, dom_f1, gen_f1 = None, None, None
+
+    if return_eval_output:
+        dial_idx = 0
+        for sample in evaluation_output:
+            if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0:
+                dial_idx = 0
+            elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0:
+                dial_idx += 1
+            elif sample['utt_idx'] == 0:
+                dial_idx += 1
+            sample['dial_idx'] = dial_idx
+
+        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output
+    if is_train:
+        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats
+    return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader)
diff --git a/convlab/dst/setsumbt/multiwoz/Tracker.py b/convlab/dst/setsumbt/multiwoz/Tracker.py
deleted file mode 100644
index fed1a1a6..00000000
--- a/convlab/dst/setsumbt/multiwoz/Tracker.py
+++ /dev/null
@@ -1,455 +0,0 @@
-import os
-import json
-import copy
-import logging
-
-import torch
-import transformers
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer)
-from convlab.dst.setsumbt.modeling import (RobertaSetSUMBT,
-                                            BertSetSUMBT)
-
-from convlab.dst.dst import DST
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab.dst.rule.multiwoz import normalize_value
-from convlab.util.custom_util import model_downloader
-
-USE_CUDA = torch.cuda.is_available()
-
-# Map from SetSUMBT slot names to Convlab slot names
-SLOT_MAP = {'arrive by': 'arriveBy',
-            'leave at': 'leaveAt',
-            'price range': 'pricerange',
-            'trainid': 'trainID',
-            'reference': 'Ref',
-            'taxi types': 'car type'}
-
-
-class SetSUMBTTracker(DST):
-
-    def __init__(self, model_path="", model_type="roberta",
-                 get_turn_pooled_representation=False,
-                 get_confidence_scores=False,
-                 threshold='auto',
-                 return_entropy=False,
-                 return_mutual_info=False,
-                 store_full_belief_state=False):
-        super(SetSUMBTTracker, self).__init__()
-
-        self.model_type = model_type
-        self.model_path = model_path
-        self.get_turn_pooled_representation = get_turn_pooled_representation
-        self.get_confidence_scores = get_confidence_scores
-        self.threshold = threshold
-        self.return_entropy = return_entropy
-        self.return_mutual_info = return_mutual_info
-        self.store_full_belief_state = store_full_belief_state
-        if self.store_full_belief_state:
-            self.full_belief_state = {}
-        self.info_dict = {}
-
-        # Download model if needed
-        if not os.path.exists(self.model_path):
-            # Get path /.../convlab/dst/setsumbt/multiwoz/models
-            download_path = os.path.dirname(os.path.abspath(__file__))
-            download_path = os.path.join(download_path, 'models')
-            if not os.path.exists(download_path):
-                os.mkdir(download_path)
-            model_downloader(download_path, self.model_path)
-            # Downloadable model path format http://.../setsumbt_model_name.zip
-            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
-            self.model_path = os.path.join(download_path, self.model_path)
-
-        # Select model type based on the encoder
-        if model_type == "roberta":
-            self.config = RobertaConfig.from_pretrained(self.model_path)
-            self.tokenizer = RobertaTokenizer
-            self.model = RobertaSetSUMBT
-        elif model_type == "bert":
-            self.config = BertConfig.from_pretrained(self.model_path)
-            self.tokenizer = BertTokenizer
-            self.model = BertSetSUMBT
-        else:
-            logging.debug("Name Error: Not Implemented")
-
-        self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
-
-        # Value dict for value normalisation
-        path = os.path.dirname(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        path = os.path.join(path, 'data/multiwoz/value_dict.json')
-        self.value_dict = json.load(open(path))
-
-        self.load_weights()
-
-    def load_weights(self):
-        # Load tokenizer and model checkpoints
-        logging.info('Loading SetSUMBT pretrained model.')
-        self.tokenizer = self.tokenizer.from_pretrained(
-            self.config.tokenizer_name)
-        logging.info(
-            f'Model tokenizer loaded from {self.config.tokenizer_name}.')
-        self.model = self.model.from_pretrained(
-            self.model_path, config=self.config)
-        logging.info(f'Model loaded from {self.model_path}.')
-
-        # Transfer model to compute device and setup eval environment
-        self.model = self.model.to(self.device)
-        self.model.eval()
-        logging.info(f'Model transferred to device: {self.device}')
-
-        logging.info('Loading model ontology')
-        f = open(os.path.join(self.model_path, 'ontology.json'), 'r')
-        self.ontology = json.load(f)
-        f.close()
-
-        db = torch.load(os.path.join(self.model_path, 'ontology.db'))
-        # Get slot and value embeddings
-        slots = {slot: db[slot] for slot in db}
-        values = {slot: db[slot][1] for slot in db}
-        del db
-
-        # Load model ontology
-        self.model.add_slot_candidates(slots)
-        for slot in values:
-            self.model.add_value_candidates(slot, values[slot], replace=True)
-
-        if self.get_confidence_scores:
-            logging.info('Model will output action and state confidence scores.')
-        if self.get_confidence_scores:
-            self.get_thresholds(self.threshold)
-            logging.info('Uncertain Querying set up and thresholds set up at:')
-            logging.info(self.thresholds)
-        if self.return_entropy:
-            logging.info('Model will output state distribution entropy.')
-        if self.return_mutual_info:
-            logging.info('Model will output state distribution mutual information.')
-        logging.info('Ontology loaded successfully.')
-
-        self.det_dic = {}
-        for domain, dic in REF_USR_DA.items():
-            for key, value in dic.items():
-                assert '-' not in key
-                self.det_dic[key.lower()] = key + '-' + domain
-                self.det_dic[value.lower()] = key + '-' + domain
-
-    def get_thresholds(self, threshold='auto'):
-        self.thresholds = {}
-        for slot, value_candidates in self.ontology.items():
-            domain, slot = slot.split('-', 1)
-            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
-            slot = slot.strip().split()[1] if 'book ' in slot else slot
-            slot = SLOT_MAP.get(slot, slot)
-
-            # Auto thresholds are set based on the number of value candidates per slot
-            if domain not in self.thresholds:
-                self.thresholds[domain] = {}
-            if threshold == 'auto':
-                thres = 1.0 / (float(len(value_candidates)) - 2.1)
-                self.thresholds[domain][slot] = max(0.05, thres)
-            else:
-                self.thresholds[domain][slot] = max(0.05, threshold)
-
-        return self.thresholds
-
-    def init_session(self):
-        self.state = default_state()
-        self.active_domains = {}
-        self.hidden_states = None
-        self.info_dict = {}
-
-    def update(self, user_act=''):
-        prev_state = self.state
-
-        # Convert dialogs into transformer input features (token_ids, masks, etc)
-        features = self.get_features(user_act)
-        # Model forward pass
-        pred_states, active_domains, user_acts, turn_pooled_representation, belief_state, entropy_, mutual_info_ = self.predict(
-            features)
-
-        if entropy_ is not None:
-            entropy = {}
-            for slot, e in entropy_.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in entropy:
-                    entropy[domain] = {}
-                if 'book' in slot:
-                    assert slot.startswith('book ')
-                    slot = slot.strip().split()[1]
-                slot = SLOT_MAP.get(slot, slot)
-                entropy[domain][slot] = e
-            del entropy_
-        else:
-            entropy = None
-
-        if mutual_info_ is not None:
-            mutual_info = {}
-            for slot, mi in mutual_info_.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in mutual_info:
-                    mutual_info[domain] = {}
-                if 'book' in slot:
-                    assert slot.startswith('book ')
-                    slot = slot.strip().split()[1]
-                slot = SLOT_MAP.get(slot, slot)
-                mutual_info[domain][slot] = mi[0, 0]
-        else:
-            mutual_info = None
-
-        if belief_state is not None:
-            bs_probs = {}
-            belief_state, request_dist, domain_dist, greeting_dist = belief_state
-            for slot, p in belief_state.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in bs_probs:
-                    bs_probs[domain] = {}
-                if 'book' in slot:
-                    assert slot.startswith('book ')
-                    slot = slot.strip().split()[1]
-                slot = SLOT_MAP.get(slot, slot)
-                if slot not in bs_probs[domain]:
-                    bs_probs[domain][slot] = {}
-                bs_probs[domain][slot]['inform'] = p
-
-            for slot, p in request_dist.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in bs_probs:
-                    bs_probs[domain] = {}
-                slot = SLOT_MAP.get(slot, slot)
-                if slot not in bs_probs[domain]:
-                    bs_probs[domain][slot] = {}
-                bs_probs[domain][slot]['request'] = p
-
-            for domain, p in domain_dist.items():
-                if domain not in bs_probs:
-                    bs_probs[domain] = {}
-                bs_probs[domain]['none'] = {'inform': p}
-
-            if 'general' not in bs_probs:
-                bs_probs['general'] = {}
-            bs_probs['general']['none'] = greeting_dist
-
-        new_domains = [d for d, active in active_domains.items() if active]
-        new_domains = [
-            d for d in new_domains if not self.active_domains.get(d, False)]
-        self.active_domains = active_domains
-
-        for domain in new_domains:
-            user_acts.append(['Inform', domain.capitalize(), 'none', 'none'])
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        # user_acts = []
-        for state, value in pred_states.items():
-            domain, slot = state.split('-', 1)
-            value = '' if value == 'none' else value
-            value = 'dontcare' if value == 'do not care' else value
-            value = 'guesthouse' if value == 'guest house' else value
-            if slot not in ['name', 'book']:
-                if domain not in new_belief_state:
-                    if domain == 'bus':
-                        continue
-                    else:
-                        logging.debug(
-                            'Error: domain <{}> not in belief state'.format(domain))
-            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
-            assert 'semi' in new_belief_state[domain]
-            assert 'book' in new_belief_state[domain]
-            if 'book' in slot:
-                assert slot.startswith('book ')
-                slot = slot.strip().split()[1]
-            slot = SLOT_MAP.get(slot, slot)
-
-            # Uncertainty clipping of state
-            if belief_state is not None:
-                if bs_probs[domain][slot].get('inform', 1.0) < self.thresholds[domain][slot]:
-                    value = ''
-
-            domain_dic = new_belief_state[domain]
-            value = normalize_value(self.value_dict, domain, slot, value)
-            if slot in domain_dic['semi']:
-                new_belief_state[domain]['semi'][slot] = value
-                if prev_state['belief_state'][domain]['semi'][slot] != value:
-                    user_acts.append(['Inform', domain.capitalize(
-                    ), REF_USR_DA[domain.capitalize()].get(slot, slot), value])
-            elif slot in domain_dic['book']:
-                new_belief_state[domain]['book'][slot] = value
-                if prev_state['belief_state'][domain]['book'][slot] != value:
-                    user_acts.append(['Inform', domain.capitalize(
-                    ), REF_USR_DA[domain.capitalize()].get(slot, slot), value])
-            elif slot.lower() in domain_dic['book']:
-                new_belief_state[domain]['book'][slot.lower()] = value
-                if prev_state['belief_state'][domain]['book'][slot.lower()] != value:
-                    user_acts.append(['Inform', domain.capitalize(
-                    ), REF_USR_DA[domain.capitalize()].get(slot.lower(), slot.lower()), value])
-            else:
-                logging.debug(
-                    'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(
-                        slot, value, domain, state)
-                )
-
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state
-        new_state['active_domains'] = self.active_domains
-        if belief_state is not None:
-            new_state['belief_state_probs'] = bs_probs
-        if entropy is not None:
-            new_state['entropy'] = entropy
-        if mutual_info is not None:
-            new_state['mutual_information'] = mutual_info
-
-        new_state['user_action'] = user_acts
-
-        user_requests = [[a, d, s, v]
-                         for a, d, s, v in user_acts if a == 'Request']
-        for act, domain, slot, value in user_requests:
-            k = REF_SYS_DA[domain].get(slot, slot)
-            domain = domain.lower()
-            if domain not in new_state['request_state']:
-                new_state['request_state'][domain] = {}
-            if k not in new_state['request_state'][domain]:
-                new_state['request_state'][domain][k] = 0
-
-        if turn_pooled_representation is not None:
-            new_state['turn_pooled_representation'] = turn_pooled_representation
-
-        self.state = new_state
-        self.info_dict = copy.deepcopy(dict(new_state))
-
-        return self.state
-
-    # Model prediction function
-
-    def predict(self, features):
-        # Forward Pass
-        mutual_info = None
-        with torch.no_grad():
-            turn_pooled_representation = None
-            if self.get_turn_pooled_representation:
-                belief_state, request, domain, goodbye, self.hidden_states, turn_pooled_representation = self.model(input_ids=features['input_ids'],
-                                                                                                                    token_type_ids=features[
-                                                                                                                        'token_type_ids'],
-                                                                                                                    attention_mask=features[
-                                                                                                                        'attention_mask'],
-                                                                                                                    hidden_state=self.hidden_states,
-                                                                                                                    get_turn_pooled_representation=True)
-            elif self.return_mutual_info:
-                belief_state, request, domain, goodbye, self.hidden_states, mutual_info = self.model(input_ids=features['input_ids'],
-                                                                                                     token_type_ids=features[
-                                                                                                         'token_type_ids'],
-                                                                                                     attention_mask=features[
-                                                                                                         'attention_mask'],
-                                                                                                     hidden_state=self.hidden_states,
-                                                                                                     get_turn_pooled_representation=False,
-                                                                                                     calculate_inform_mutual_info=True)
-            else:
-                belief_state, request, domain, goodbye, self.hidden_states = self.model(input_ids=features['input_ids'],
-                                                                                        token_type_ids=features['token_type_ids'],
-                                                                                        attention_mask=features['attention_mask'],
-                                                                                        hidden_state=self.hidden_states,
-                                                                                        get_turn_pooled_representation=False)
-
-        # Convert belief state into dialog state
-        predictions = {slot: state[0, 0, :].argmax().item()
-                       for slot, state in belief_state.items()}
-        predictions = {slot: self.ontology[slot][idx]
-                       for slot, idx in predictions.items()}
-        predictions = {s: v for s, v in predictions.items() if v != 'none'}
-
-        if self.store_full_belief_state:
-            self.full_belief_state = belief_state
-
-        # Obtain model output probabilities
-        if self.get_confidence_scores:
-            entropy = None
-            if self.return_entropy:
-                entropy = {slot: state[0, 0, :]
-                           for slot, state in belief_state.items()}
-                entropy = {slot: self.relative_entropy(
-                    p).item() for slot, p in entropy.items()}
-
-            # Confidence score is the max probability across all not "none" values candidates.
-            belief_state = {slot: state[0, 0, 1:].max().item()
-                            for slot, state in belief_state.items()}
-            request_dist = {SLOT_MAP.get(
-                slot, slot): p[0, 0].item() for slot, p in request.items()}
-            domain_dist = {domain: p[0, 0].item()
-                           for domain, p in domain.items()}
-            greeting_dist = {'bye': goodbye[0, 0, 1].item(
-            ), 'thank': goodbye[0, 0, 2].item()}
-            belief_state = (belief_state, request_dist,
-                            domain_dist, greeting_dist)
-        else:
-            belief_state = None
-            entropy = None
-
-        # Construct request action prediction
-        request = [slot for slot, p in request.items() if p[0, 0].item() > 0.5]
-        request = [slot.split('-', 1) for slot in request]
-        request = [[domain, SLOT_MAP.get(slot, slot)]
-                   for domain, slot in request]
-        request = [['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(
-            slot, slot), '?'] for domain, slot in request]
-
-        # Construct active domain set
-        domain = {domain: p[0, 0].item() > 0.5 for domain, p in domain.items()}
-
-        # Construct general domain action
-        goodbye = goodbye[0, 0, :].argmax(-1).item()
-        goodbye = [[], ['bye'], ['thank']][goodbye]
-        goodbye = [[act, 'general', 'none', 'none'] for act in goodbye]
-
-        user_acts = request + goodbye
-
-        return predictions, domain, user_acts, turn_pooled_representation, belief_state, entropy, mutual_info
-
-    def relative_entropy(self, probs):
-        entropy = probs * torch.log(probs + 1e-8)
-        entropy = -entropy.sum()
-        # Maximum entropy of a K dimentional distribution is ln(K)
-        entropy /= torch.log(torch.tensor(probs.size(-1)).float())
-
-        return entropy
-
-    # Convert dialog turns into model features
-    def get_features(self, user_act):
-        # Extract system utterance from dialog history
-        context = self.state['history']
-        if context:
-            if context[-1][0] != 'sys':
-                system_act = ''
-            else:
-                system_act = context[-1][-1]
-        else:
-            system_act = ''
-
-        # Tokenize dialog
-        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, max_length=self.config.max_turn_len,
-                                              padding='max_length', truncation='longest_first')
-
-        input_ids = torch.tensor(features['input_ids']).reshape(
-            1, 1, -1).to(self.device) if 'input_ids' in features else None
-        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
-            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
-        attention_mask = torch.tensor(features['attention_mask']).reshape(
-            1, 1, -1).to(self.device) if 'attention_mask' in features else None
-        features = {'input_ids': input_ids,
-                    'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
-
-        return features
-
-
-# if __name__ == "__main__":
-#     tracker = SetSUMBTTracker(model_type='roberta', model_path='/gpfs/project/niekerk/results/nbt/convlab_setsumbt_acts')
-#                         # nlu_path='/gpfs/project/niekerk/data/bert_multiwoz_all_context.zip')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     # tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     # tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     # state = tracker.update('If you have something Asian that would be great.')
-#     # tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     # tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     # state = tracker.update('Great. Where are they located?')
-#     # tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/dst/setsumbt/multiwoz/__init__.py b/convlab/dst/setsumbt/multiwoz/__init__.py
deleted file mode 100644
index a1f1fb89..00000000
--- a/convlab/dst/setsumbt/multiwoz/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from convlab.dst.setsumbt.multiwoz.dataset import multiwoz21, ontology
-from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
deleted file mode 100644
index 34df41d0..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
+++ /dev/null
@@ -1,83 +0,0 @@
-it's	it is
-don't	do not
-doesn't	does not
-didn't	did not
-you'd	you would
-you're	you are
-you'll	you will
-i'm	i am
-they're	they are
-that's	that is
-what's	what is
-couldn't	could not
-i've	i have
-we've	we have
-can't	cannot
-i'd	i would
-i'd	i would
-aren't	are not
-isn't	is not
-wasn't	was not
-weren't	were not
-won't	will not
-there's	there is
-there're	there are
-. .	.
-restaurants	restaurant -s
-hotels	hotel -s
-laptops	laptop -s
-cheaper	cheap -er
-dinners	dinner -s
-lunches	lunch -s
-breakfasts	breakfast -s
-expensively	expensive -ly
-moderately	moderate -ly
-cheaply	cheap -ly
-prices	price -s
-places	place -s
-venues	venue -s
-ranges	range -s
-meals	meal -s
-locations	location -s
-areas	area -s
-policies	policy -s
-children	child -s
-kids	kid -s
-kidfriendly	kid friendly
-cards	card -s
-upmarket	expensive
-inpricey	cheap
-inches	inch -s
-uses	use -s
-dimensions	dimension -s
-driverange	drive range
-includes	include -s
-computers	computer -s
-machines	machine -s
-families	family -s
-ratings	rating -s
-constraints	constraint -s
-pricerange	price range
-batteryrating	battery rating
-requirements	requirement -s
-drives	drive -s
-specifications	specification -s
-weightrange	weight range
-harddrive	hard drive
-batterylife	battery life
-businesses	business -s
-hours	hour -s
-one	1
-two	2
-three	3
-four	4
-five	5
-six	6
-seven	7
-eight	8
-nine	9
-ten	10
-eleven	11
-twelve	12
-anywhere	any where
-good bye	goodbye
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
deleted file mode 100644
index 2c8e98f3..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
+++ /dev/null
@@ -1,502 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""MultiWOZ 2.1/2.3 Dialogue Dataset"""
-
-import os
-import json
-import requests
-import zipfile
-import io
-from shutil import copy2 as copy
-
-import torch
-from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
-from tqdm import tqdm
-
-from convlab.dst.setsumbt.multiwoz.dataset.utils import (clean_text, ACTIVE_DOMAINS, get_domains, set_util_domains,
-                                                        fix_delexicalisation, extract_dialogue, PRICERANGE,
-                                                        BOOLEAN, DAYS, QUANTITIES, TIME, VALUE_MAP, map_values)
-
-
-# Set up global data_directory
-def set_datadir(dir):
-    global DATA_DIR
-    DATA_DIR = dir
-
-
-def set_active_domains(domains):
-    global ACTIVE_DOMAINS
-    ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
-    set_util_domains(ACTIVE_DOMAINS)
-
-
-# MultiWOZ2.1 download link
-URL = 'https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip'
-def set_url(url):
-    global URL
-    URL = url
-
-
-# Create Dialogue examples from the dataset
-def create_examples(max_utt_len, get_requestable_slots=False, force_processing=False):
-
-    # Load or download Raw Data
-    if not os.path.exists(DATA_DIR):
-        os.mkdir(DATA_DIR)
-    if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')):
-        # Download data archive and extract
-        archive = _download()
-        data = _extract(archive)
-
-        writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w')
-        json.dump(data, writer, indent = 2)
-        del archive, writer
-    else:
-        reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r')
-        data = json.load(reader)
-
-    if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')):
-        # Preprocess all dialogues
-        data_processed = _process(data['data'], data['system_acts'])
-        # Format data and split train, test and devlopment sets
-        train, dev, test = _split_data(data_processed, data['testListFile'],
-                                                            data['valListFile'], max_utt_len)
-
-        # Write data
-        writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w')
-        json.dump(train, writer, indent = 2)
-        writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w')
-        json.dump(test, writer, indent = 2)
-        writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w')
-        json.dump(dev, writer, indent = 2)
-        writer.flush()
-        writer.close()
-        del writer
-
-        # Extract slots and slot value candidates from the dataset
-        for set_type in ['train', 'dev', 'test']:
-            _get_ontology(set_type, get_requestable_slots)
-        
-        script_path = os.path.abspath(__file__).replace('/multiwoz21.py', '')
-        file_name = 'mwoz21_ont_request.json' if get_requestable_slots else 'mwoz21_ont.json'
-        copy(os.path.join(script_path, file_name), os.path.join(DATA_DIR, 'ontology_test.json'))
-        copy(os.path.join(script_path, 'mwoz21_slot_descriptions.json'), os.path.join(DATA_DIR, 'slot_descriptions.json'))
-
-
-# Extract slots and slot value candidates from the dataset
-def _get_ontology(set_type, get_requestable_slots=False):
-
-    datasets = ['train']
-    if set_type in ['test', 'dev']:
-        datasets.append('dev')
-        datasets.append('test')
-
-    # Load examples
-    data = []
-    for dataset in datasets:
-        reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r')
-        data += json.load(reader)
-
-    ontology = dict()
-    for dial in data:
-        for turn in dial['dialogue']:
-            for state in turn['dialogue_state']:
-                slot, value = state
-                value = map_values(value)
-                if slot not in ontology:
-                    ontology[slot] = [value]
-                else:
-                    ontology[slot].append(value)
-
-    requestable_slots = []
-    if get_requestable_slots:
-        for dial in data:
-            for turn in dial['dialogue']:
-                for act, dom, slot, val in turn['user_acts']:
-                    if act == 'request':
-                        requestable_slots.append(f'{dom}-{slot}')
-        requestable_slots = list(set(requestable_slots))
-
-    for slot in ontology:
-        if 'price' in slot:
-            ontology[slot] = PRICERANGE
-        if 'parking' in slot or 'internet' in slot:
-            ontology[slot] = BOOLEAN
-        if 'day' in slot:
-            ontology[slot] = DAYS
-        if 'people' in slot or 'duration' in slot or 'stay' in slot:
-            ontology[slot] = QUANTITIES
-        if 'time' in slot or 'leave' in slot or 'arrive' in slot:
-            ontology[slot] = TIME
-        if 'stars' in slot:
-            ontology[slot] += [str(i) for i in range(5)]
-
-    # Sort slot values and add none and dontcare values
-    for slot in ontology:
-        ontology[slot] = list(set(ontology[slot]))
-        ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']])
-    for slot in requestable_slots:
-        if slot in ontology:
-            ontology[slot].append('request')
-        else:
-            ontology[slot] = ['request']
-
-    writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w')
-    json.dump(ontology, writer, indent=2)
-    writer.close()
-
-
-# Convert dialogue examples to model input features and labels
-def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64):
-
-    features = dict()
-
-    # Load examples
-    reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r')
-    data = json.load(reader)
-
-    # Get encoder input for system, user utterance pairs
-    input_feats = []
-    for dial in data:
-        dial_feats = []
-        for turn in dial['dialogue']:
-            if len(turn['system_transcript']) == 0:
-                usr = turn['transcript']
-                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens = True,
-                                                        max_length = max_seq_len, padding='max_length',
-                                                        truncation = 'longest_first'))
-            else:
-                usr = turn['transcript']
-                sys = turn['system_transcript']
-                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens = True,
-                                                        max_length = max_seq_len, padding='max_length',
-                                                        truncation = 'longest_first'))
-            if len(dial_feats) >= max_turns:
-                break
-        input_feats.append(dial_feats)
-    del dial_feats
-
-    # Perform turn level padding
-    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
-    if 'token_type_ids' in input_feats[0][0]:
-        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
-    else:
-        token_type_ids = None
-    if 'attention_mask' in input_feats[0][0]:
-        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
-    else:
-        attention_mask = None
-    del input_feats
-
-    # Create torch data tensors
-    features['input_ids'] = torch.tensor(input_ids)
-    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
-    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
-    del input_ids, token_type_ids, attention_mask
-
-    # Load ontology
-    reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
-    ontology = json.load(reader)
-    reader.close()
-
-    informable_slots = [slot for slot, values in ontology.items() if values != ['request']]
-    requestable_slots = [slot for slot, values in ontology.items() if 'request' in values]
-    for slot in requestable_slots:
-        ontology[slot].remove('request')
-    
-    domains = list(set(informable_slots + requestable_slots))
-    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
-
-    # Create slot labels
-    for slot in informable_slots:
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial['dialogue']:
-                slots_active = [s for s, v in turn['dialogue_state']]
-                if slot in slots_active:
-                    value = [v for s, v in turn['dialogue_state'] if s == slot][0]
-                else:
-                    value = 'none'
-                if value in ontology[slot]:
-                    value = ontology[slot].index(value)
-                else:
-                    value = map_values(value)
-                    if value in ontology[slot]:
-                        value = ontology[slot].index(value)
-                    else:
-                        value = -1 # If value is not in ontology then we do not penalise the model
-                labs.append(value)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['labels-' + slot] = labels
-
-    for slot in requestable_slots:
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial['dialogue']:
-                slots_active = [[d, s] for i, d, s, v in turn['user_acts']]
-                if slot.split('-', 1) in slots_active:
-                    act_ = [i for i, d, s, v in turn['user_acts'] if f"{d}-{s}" == slot][0]
-                    if act_ == 'request':
-                        labs.append(1)
-                    else:
-                        labs.append(0)
-                else:
-                    labs.append(0)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-        
-        labels = torch.tensor(labels)
-        features['request-' + slot] = labels
-    
-    # Greeting act labels (0-no greeting, 1-goodbye, 2-thank you)
-    labels = []
-    for dial in data:
-        labs = []
-        for turn in dial['dialogue']:
-            greeting_active = [i for i, d, s, v in turn['user_acts'] if i in ['bye', 'thank']]
-            if greeting_active:
-                if 'bye' in greeting_active:
-                    labs.append(1)
-                else :
-                    labs.append(2)
-            else:
-                labs.append(0)
-            if len(labs) >= max_turns:
-                break
-        labs = labs + [-1] * (max_turns - len(labs))
-        labels.append(labs)
-    
-    labels = torch.tensor(labels)
-    features['goodbye'] = labels
-
-    for domain in domains:
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial['dialogue']:
-                if domain == turn['domain']:
-                    labs.append(1)
-                else:
-                    labs.append(0)
-                if len(labs) >= max_turns:
-                        break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-        
-        labels = torch.tensor(labels)
-        features['active-' + domain] = labels
-
-    del labels
-
-    return features
-
-
-# MultiWOZ2.1 Dataset object
-class MultiWoz21(Dataset):
-
-    def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64):
-        self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len)
-
-    def __getitem__(self, index):
-        return {label: self.features[label][index] for label in self.features
-                if self.features[label] is not None}
-
-    def __len__(self):
-        return self.features['input_ids'].size(0)
-
-    def resample(self, size=None):
-        n_dialogues = self.__len__()
-        if not size:
-            size = n_dialogues
-
-        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
-        self.features = {label: self.features[label][dialogues] for label in self.features
-                        if self.features[label] is not None}
-        
-        return self
-
-    def to(self, device):
-        self.device = device
-        self.features = {label: self.features[label].to(device) for label in self.features
-                         if self.features[label] is not None}
-
-
-# MultiWOZ2.1 Dataset object
-class EnsembleMultiWoz21(Dataset):
-    def __init__(self, data):
-        self.features = data
-
-    def __getitem__(self, index):
-        return {label: self.features[label][index] for label in self.features
-                if self.features[label] is not None}
-
-    def __len__(self):
-        return self.features['input_ids'].size(0)
-
-    def resample(self, size=None):
-        n_dialogues = self.__len__()
-        if not size:
-            size = n_dialogues
-
-        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
-        self.features = {label: self.features[label][dialogues] for label in self.features
-                        if self.features[label] is not None}
-
-    def to(self, device):
-        self.device = device
-        self.features = {label: self.features[label].to(device) for label in self.features
-                         if self.features[label] is not None}
-
-
-# Module to create torch dataloaders
-def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None):
-    data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len)
-    data.to('cpu')
-
-    if resampled_size:
-        data.resample(resampled_size)
-
-    if set_type in ['test', 'dev']:
-        sampler = SequentialSampler(data)
-    else:
-        sampler = RandomSampler(data)
-    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
-
-    return loader
-
-
-def _download(chunk_size=1048576):
-    """Download data archive.
-
-    Parameters:
-        chunk_size (int): Download chunk size. (default=1048576)
-    Returns:
-        archive: ZipFile archive object.
-    """
-    # Download the archive byte string
-    req = requests.get(URL, stream=True)
-    archive = b''
-    for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'):
-        if chunk:
-            archive += chunk
-
-    # Convert the bytestring into a zipfile object
-    archive = io.BytesIO(archive)
-    archive = zipfile.ZipFile(archive)
-
-    return archive
-
-
-def _extract(archive):
-    """Extract the json dictionaries from the archive.
-
-    Parameters:
-        archive: ZipFile archive object.
-    Returns:
-        data: Data dictionary.
-    """
-    files = [file for file in archive.filelist if ('.json' in file.filename or '.txt' in file.filename)
-            and 'MACOSX' not in file.filename]
-    objects = []
-    for file in tqdm(files, desc='File'):
-        data = archive.open(file).read()
-        # Get data objects from the files
-        try:
-            data = json.loads(data)
-        except json.decoder.JSONDecodeError:
-            data = data.decode().split('\n')
-        objects.append(data)
-
-    files = [file.filename.split('/')[-1].split('.')[0] for file in files]
-
-    data = {file: data for file, data in zip(files, objects)}
-    return data
-
-
-# Process files
-def _process(dialogue_data, acts_data):
-    print('Processing Dialogues')
-    out = {}
-    for dial_name in tqdm(dialogue_data):
-        dialogue = dialogue_data[dial_name]
-
-        prev_dom = ''
-        for turn_id, turn in enumerate(dialogue['log']):
-            dialogue['log'][turn_id]['text'] = clean_text(turn['text'])
-            if len(turn['metadata']) != 0:
-                crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom)
-                prev_dom = crnt_dom
-                dialogue['log'][turn_id - 1]['domain'] = crnt_dom
-
-            dialogue['log'][turn_id] = fix_delexicalisation(turn)
-
-        out[dial_name] = dialogue
-
-    return out
-
-
-# Split data (train, dev, test)
-def _split_data(dial_data, test, dev, max_utt_len):
-    train_dials, test_dials, dev_dials = [], [], []
-    print('Formatting and Splitting Data')
-    for name in tqdm(dial_data):
-        dialogue = dial_data[name]
-        domains = []
-
-        dial = extract_dialogue(dialogue, max_utt_len)
-        if dial:
-            dialogue = dict()
-            dialogue['dialogue_idx'] = name
-            dialogue['domains'] = []
-            dialogue['dialogue'] = []
-
-            for turn_id, turn in enumerate(dial):
-                turn_dialog = dict()
-                turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else ''
-                turn_dialog['turn_idx'] = turn_id
-                turn_dialog['dialogue_state'] = turn['ds']
-                turn_dialog['transcript'] = turn['usr']
-                # turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else []
-                turn_dialog['user_acts'] = turn['usr_a']
-                turn_dialog['domain'] = turn['domain']
-                dialogue['domains'].append(turn['domain'])
-                dialogue['dialogue'].append(turn_dialog)
-
-            dialogue['domains'] = [d for d in list(set(dialogue['domains'])) if d != '']
-            if True in [dom not in ACTIVE_DOMAINS for dom in dialogue['domains']]:
-                dialogue['domains'] = []
-            dialogue['domains'] = [dom for dom in dialogue['domains'] if dom in ACTIVE_DOMAINS]
-
-            if dialogue['domains']:
-                if name in test:
-                    test_dials.append(dialogue)
-                elif name in dev:
-                    dev_dials.append(dialogue)
-                else:
-                    train_dials.append(dialogue)
-
-    print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials)))
-
-    return train_dials, dev_dials, test_dials
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
deleted file mode 100644
index b703793d..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
+++ /dev/null
@@ -1,2990 +0,0 @@
-{
-  "hotel-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate"
-  ],
-  "hotel-type": [
-    "none",
-    "do not care",
-    "bed and breakfast",
-    "guest house",
-    "hotel"
-  ],
-  "hotel-parking": [
-    "none",
-    "do not care",
-    "no",
-    "yes"
-  ],
-  "hotel-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "hotel-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-book stay": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "train-destination": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "london liverpool street",
-    "centre",
-    "bishop stortford",
-    "liverpool",
-    "leicester",
-    "broxbourne",
-    "gourmet burger kitchen",
-    "copper kettle",
-    "bournemouth",
-    "stevenage",
-    "liverpool street",
-    "norwich",
-    "huntingdon marriott hotel",
-    "city centre north",
-    "taj tandoori",
-    "the copper kettle",
-    "peterborough",
-    "ely",
-    "lecester",
-    "london",
-    "willi",
-    "stansted airport",
-    "huntington marriott",
-    "cambridge",
-    "gonv",
-    "glastonbury",
-    "hol",
-    "north",
-    "birmingham new street",
-    "norway",
-    "petersborough",
-    "london kings cross",
-    "curry prince",
-    "bishops storford"
-  ],
-  "train-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "train-departure": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "brookshite",
-    "london liverpool street",
-    "cam",
-    "liverpool",
-    "bro",
-    "leicester",
-    "broxbourne",
-    "norwhich",
-    "saint johns",
-    "stevenage",
-    "stansted",
-    "london liverpool",
-    "cambrid",
-    "city hall",
-    "rosas bed and breakfast",
-    "alpha-milton",
-    "wandlebury country park",
-    "norwich",
-    "liecester",
-    "stratford",
-    "peterborough",
-    "duxford",
-    "ely",
-    "london",
-    "stansted airport",
-    "lon",
-    "cambridge",
-    "panahar",
-    "cineworld",
-    "leicaster",
-    "birmingham",
-    "cafe uno",
-    "camboats",
-    "huntingdon",
-    "birmingham new street",
-    "arbu",
-    "alpha milton",
-    "east london",
-    "london kings cross",
-    "hamilton lodge",
-    "aylesbray lodge guest",
-    "el shaddai"
-  ],
-  "train-day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "train-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-stars": [
-    "none",
-    "do not care",
-    "0",
-    "1",
-    "2",
-    "3",
-    "4",
-    "5"
-  ],
-  "hotel-internet": [
-    "none",
-    "do not care",
-    "no",
-    "yes"
-  ],
-  "hotel-name": [
-    "a and b guest house",
-    "city roomz",
-    "carolina bed and breakfast",
-    "limehouse",
-    "anatolia",
-    "hamilton lodge",
-    "the lensfield hotel",
-    "rosa's bed and breakfast",
-    "gall",
-    "aylesbray lodge",
-    "kirkwood",
-    "cambridge belfry",
-    "warkworth house",
-    "gonville",
-    "belfy hotel",
-    "nus",
-    "alexander",
-    "super 5",
-    "aylesbray lodge guest house",
-    "the gonvile hotel",
-    "allenbell",
-    "nothamilton lodge",
-    "ashley hotel",
-    "autumn house",
-    "hobsons house",
-    "hotel",
-    "ashely hotel",
-    "caridge belfrey",
-    "el shaddia guest house",
-    "avalon",
-    "cote",
-    "city centre north bed and breakfast",
-    "the cambridge belfry",
-    "home from home",
-    "wandlebury coutn",
-    "wankworth house",
-    "city stop rest",
-    "the worth house",
-    "cityroomz",
-    "huntingdon marriottt hotel",
-    "none",
-    "lensfield",
-    "rosas bed and breakfast",
-    "leverton house",
-    "gonville hotel",
-    "holiday inn cambridge",
-    "do not care",
-    "archway house",
-    "lan hon",
-    "levert",
-    "acorn guest house",
-    "cambridge",
-    "the ashley hotel",
-    "el shaddai",
-    "sleeperz",
-    "alpha milton guest house",
-    "doubletree by hilton cambridge",
-    "tandoori palace",
-    "express by",
-    "express by holiday inn cambridge",
-    "north bed and breakfast",
-    "holiday inn",
-    "arbury lodge guest house",
-    "alexander bed and breakfast",
-    "huntingdon marriott hotel",
-    "royal spice",
-    "sou",
-    "finches bed and breakfast",
-    "the alpha milton",
-    "bridge guest house",
-    "the acorn guest house",
-    "kirkwood house",
-    "eraina",
-    "la margherit",
-    "lensfield hotel",
-    "marriott hotel",
-    "nusha",
-    "city centre bed and breakfast",
-    "the allenbell",
-    "university arms hotel",
-    "clare",
-    "cherr",
-    "wartworth",
-    "acorn place",
-    "lovell lodge",
-    "whale"
-  ],
-  "train-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "restaurant-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate"
-  ],
-  "restaurant-food": [
-    "british food",
-    "steakhouse",
-    "turkish",
-    "sushi",
-    "north american",
-    "scottish",
-    "french",
-    "austrian",
-    "korean",
-    "eastern european",
-    "swedish",
-    "gastro pub",
-    "modern eclectic",
-    "afternoon tea",
-    "welsh",
-    "christmas",
-    "tuscan",
-    "gastropub",
-    "sri lankan",
-    "molecular gastronomy",
-    "traditional american",
-    "italian",
-    "pizza",
-    "thai",
-    "south african",
-    "creative",
-    "english",
-    "asian",
-    "lebanese",
-    "hungarian",
-    "halal",
-    "portugese",
-    "modern english",
-    "african",
-    "light bites",
-    "malaysian",
-    "venetian",
-    "traditional",
-    "chinese",
-    "vegetarian",
-    "persian",
-    "thai and chinese",
-    "scandinavian",
-    "catalan",
-    "polynesian",
-    "crossover",
-    "canapes",
-    "cantonese",
-    "north african",
-    "seafood",
-    "brazilian",
-    "south indian",
-    "australasian",
-    "belgian",
-    "barbeque",
-    "the americas",
-    "indonesian",
-    "singaporean",
-    "irish",
-    "middle eastern",
-    "dojo noodle bar",
-    "caribbean",
-    "vietnamese",
-    "modern european",
-    "russian",
-    "none",
-    "german",
-    "world",
-    "japanese",
-    "moroccan",
-    "modern global",
-    "do not care",
-    "indian",
-    "british",
-    "american",
-    "danish",
-    "panasian",
-    "swiss",
-    "basque",
-    "north indian",
-    "modern american",
-    "australian",
-    "european",
-    "corsica",
-    "greek",
-    "northern european",
-    "mediterranean",
-    "portuguese",
-    "romanian",
-    "jamaican",
-    "polish",
-    "international",
-    "unusual",
-    "latin american",
-    "asian oriental",
-    "mexican",
-    "bistro",
-    "cuban",
-    "fusion",
-    "new zealand",
-    "spanish",
-    "eritrean",
-    "afghan",
-    "kosher"
-  ],
-  "attraction-name": [
-    "downing college",
-    "fitzwilliam",
-    "clare college",
-    "ruskin gallery",
-    "sidney sussex college",
-    "great saint mary's church",
-    "cherry hinton water play park",
-    "wandlebury country park",
-    "cafe uno",
-    "place",
-    "broughton",
-    "cineworld cinema",
-    "jesus college",
-    "vue cinema",
-    "history of science museum",
-    "mumford theatre",
-    "whale of time",
-    "fitzbillies",
-    "christs church",
-    "churchill college",
-    "museum of classical archaeology",
-    "gonville and caius college",
-    "pizza",
-    "kirkwood",
-    "saint catharines college",
-    "kings college",
-    "parkside",
-    "by",
-    "st catharines college",
-    "saint john's college",
-    "cherry hinton water park",
-    "st christs college",
-    "christ's college",
-    "bangkok city",
-    "scudamores punti co",
-    "free",
-    "great saint marys church",
-    "milton country park",
-    "the fez club",
-    "soultree",
-    "autu",
-    "whipple museum of the history of science",
-    "aylesbray lodge guest house",
-    "broughton house gallery",
-    "peoples portraits exhibition",
-    "primavera",
-    "kettles yard",
-    "all saint's church",
-    "cinema cinema",
-    "regency gallery",
-    "corpus christi",
-    "corn cambridge exchange",
-    "da vinci pizzeria",
-    "school",
-    "hobsons house",
-    "cambride and country folk museum",
-    "north",
-    "da v",
-    "cambridge corn exchange",
-    "soul tree nightclub",
-    "cambridge arts theater",
-    "saint catharine's college",
-    "byard art",
-    "cambridge punter",
-    "cambridge university botanic gardens",
-    "castle galleries",
-    "museum of archaelogy and anthropogy",
-    "no specific location",
-    "cherry hinton hall",
-    "gallery at 12 a high street",
-    "parkside pools",
-    "queen's college",
-    "little saint mary's church",
-    "gallery",
-    "home from home",
-    "tenpin",
-    "the wandlebury",
-    "county folk museum",
-    "swimming pool",
-    "christs college",
-    "cafe jello museum",
-    "scott polar",
-    "christ college",
-    "cambridge museum of technology",
-    "abbey pool and astroturf pitch",
-    "king hedges learner pool",
-    "the cambridge arts theatre",
-    "the castle galleries",
-    "cambridge and country folk museum",
-    "kohinoor",
-    "scudamores punting co",
-    "sidney sussex",
-    "the man on the moon",
-    "little saint marys church",
-    "queens",
-    "the place",
-    "old school",
-    "churchill",
-    "churchills college",
-    "hughes hall",
-    "churchhill college",
-    "riverboat georgina",
-    "none",
-    "belf",
-    "cambridge temporary art",
-    "abc theatre",
-    "cambridge contemporary art museum",
-    "man on the moon",
-    "the junction",
-    "cherry hinton water play",
-    "adc theatre",
-    "gonville hotel",
-    "magdalene college",
-    "peoples portraits exhibition at girton college",
-    "boat",
-    "centre",
-    "sheep's green and lammas land park fen causeway",
-    "do not care",
-    "the mumford theatre",
-    "archway house",
-    "queens' college",
-    "williams art and antiques",
-    "funky fun house",
-    "cherry hinton village centre",
-    "camboats",
-    "cambridge",
-    "old schools",
-    "kettle's yard",
-    "whale of a time",
-    "the churchill college",
-    "cafe jello gallery",
-    "aut",
-    "salsa",
-    "city",
-    "clare hall",
-    "boating",
-    "pembroke college",
-    "kings hedges learner pool",
-    "caffe uno",
-    "lammas land park",
-    "museum",
-    "the fitzwilliam museum",
-    "the cherry hinton village centre",
-    "the cambridge corn exchange",
-    "fitzwilliam museum",
-    "museum of archaelogy and anthropology",
-    "fez club",
-    "the cambridge punter",
-    "saint johns college",
-    "emmanuel college",
-    "cambridge belf",
-    "scudamore",
-    "lynne strover gallery",
-    "king's college",
-    "whippple museum",
-    "trinity college",
-    "college in the north",
-    "sheep's green",
-    "kambar",
-    "museum of archaelogy",
-    "adc",
-    "garde",
-    "club salsa",
-    "people's portraits exhibition at girton college",
-    "botanic gardens",
-    "carol",
-    "college",
-    "gallery at twelve a high street",
-    "abbey pool and astroturf",
-    "cambridge book and print gallery",
-    "jesus green outdoor pool",
-    "scott polar museum",
-    "saint barnabas press gallery",
-    "cambridge artworks",
-    "older churches",
-    "cambridge contemporary art",
-    "cherry hinton hall and grounds",
-    "univ",
-    "jesus green",
-    "ballare",
-    "abbey pool",
-    "cambridge botanic gardens",
-    "nusha",
-    "worth house",
-    "thanh",
-    "university arms hotel",
-    "cambridge arts theatre",
-    "cafe jello",
-    "cambridge and county folk museum",
-    "the cambridge artworks",
-    "all saints church",
-    "holy trinity church",
-    "contemporary art museum",
-    "architectural churches",
-    "queens college",
-    "trinity street college"
-  ],
-  "restaurant-name": [
-    "none",
-    "do not care",
-    "hotel du vin and bistro",
-    "ask",
-    "gourmet formal kitchen",
-    "the meze bar",
-    "lan hong house",
-    "cow pizza",
-    "one seven",
-    "prezzo",
-    "maharajah tandoori restaurant",
-    "alex",
-    "shanghai",
-    "golden wok",
-    "restaurant",
-    "fitzbillies",
-    "nil",
-    "copper kettle",
-    "meghna",
-    "hk fusion",
-    "bangkok city",
-    "hobsons house",
-    "tang chinese",
-    "anatolia",
-    "ugly duckling",
-    "anatolia and efes restaurant",
-    "sitar tandoori",
-    "city stop",
-    "ashley",
-    "pizza express fen ditton",
-    "molecular gastronomy",
-    "autumn house",
-    "el shaddia guesthouse",
-    "the grafton hotel",
-    "limehouse",
-    "gardenia",
-    "not metioned",
-    "hakka",
-    "michaelhouse cafe",
-    "pipasha",
-    "meze bar",
-    "archway",
-    "molecular gastonomy",
-    "yipee noodle bar",
-    "the peking",
-    "curry prince",
-    "midsummer house restaurant",
-    "pizza hut cherry hinton",
-    "the lucky star",
-    "stazione restaurant and coffee bar",
-    "shanghi family restaurant",
-    "good luck",
-    "j restaurant",
-    "bedouin",
-    "cott",
-    "little seoul",
-    "south",
-    "thanh binh",
-    "el",
-    "efes restaurant",
-    "kohinoor",
-    "clowns",
-    "india",
-    "the slug and lettuce",
-    "shiraz",
-    "barbakan",
-    "zizzi cambridge",
-    "restaurant one seven",
-    "slug and lettuce",
-    "travellers rest",
-    "binh",
-    "worth house",
-    "broughton house gallery",
-    "chiquito",
-    "the river bar steakhouse and grill",
-    "ros",
-    "golden house",
-    "india west",
-    "cam",
-    "panahar",
-    "restaurant 22",
-    "adden",
-    "indian",
-    "hu",
-    "jinling noodle bar",
-    "darrys cookhouse and wine shop",
-    "hobson house",
-    "cambridge be",
-    "el shaddai",
-    "ac",
-    "nandos",
-    "cambridge lodge",
-    "the cow pizza kitchen and bar",
-    "charlie",
-    "rajmahal",
-    "kymmoy",
-    "cambri",
-    "backstreet bistro",
-    "galleria",
-    "restaurant 2 two",
-    "chiquito restaurant bar",
-    "royal standard",
-    "lucky star",
-    "curry king",
-    "grafton hotel restaurant",
-    "mahal of cambridge",
-    "the bedouin",
-    "nus",
-    "the kohinoor",
-    "pizza hut fenditton",
-    "camboats",
-    "the gardenia",
-    "de luca cucina and bar",
-    "nusha",
-    "european",
-    "taj tandoori",
-    "tandoori palace",
-    "golden curry",
-    "efes",
-    "loch fyne",
-    "the maharajah tandoor",
-    "lovel",
-    "restaurant 17",
-    "clowns cafe",
-    "cambridge punter",
-    "bloomsbury restaurant",
-    "la mimosa",
-    "the cambridge chop house",
-    "funky",
-    "cotto",
-    "oak bistro",
-    "restaurant two two",
-    "pipasha restaurant",
-    "river bar steakhouse and grill",
-    "royal spice",
-    "the copper kettle",
-    "graffiti",
-    "nandos city centre",
-    "saffron brasserie",
-    "cambridge chop house",
-    "sitar",
-    "kitchen and bar",
-    "the good luck chinese food takeaway",
-    "clu",
-    "la tasca",
-    "cafe uno",
-    "cote",
-    "the varsity restaurant",
-    "bri",
-    "eraina",
-    "bridge",
-    "fin",
-    "cambridge lodge restaurant",
-    "grafton",
-    "hotpot",
-    "sala thong",
-    "margherita",
-    "wise buddha",
-    "the missing sock",
-    "seasame restaurant and bar",
-    "the dojo noodle bar",
-    "restaurant alimentum",
-    "gastropub",
-    "saigon city",
-    "la margherita",
-    "pizza hut",
-    "curry garden",
-    "ashley hotel",
-    "eraina and michaelhouse cafe",
-    "the golden curry",
-    "curry queen",
-    "cow pizza kitchen and bar",
-    "the peking restaurant:",
-    "hamilton lodge",
-    "alimentum",
-    "yippee noodle bar",
-    "2 two and cote",
-    "shanghai family restaurant",
-    "grafton hotel",
-    "yes",
-    "ali baba",
-    "dif",
-    "fitzbillies restaurant",
-    "peking restaurant",
-    "lev",
-    "nirala",
-    "the alex",
-    "tandoori",
-    "city stop restaurant",
-    "rice house",
-    "cityr",
-    "yu garden",
-    "meze bar restaurant",
-    "the",
-    "don pasquale pizzeria",
-    "rice boat",
-    "the hotpot",
-    "old school",
-    "the oak bistro",
-    "sesame restaurant and bar",
-    "pizza express",
-    "the gandhi",
-    "pizza hut fen ditton",
-    "charlie chan",
-    "da vinci pizzeria",
-    "dojo noodle bar",
-    "gourmet burger kitchen",
-    "the golden house",
-    "india house",
-    "hobso",
-    "missing sock",
-    "pizza hut city centre",
-    "parkside pools",
-    "riverside brasserie",
-    "caffe uno",
-    "primavera",
-    "the nirala",
-    "wagamama",
-    "au",
-    "ian hong house",
-    "frankie and bennys",
-    "4 kings parade city centre",
-    "shiraz restaurant",
-    "scudamores punt",
-    "mahal",
-    "saint johns chop house",
-    "de luca cucina and bar riverside brasserie",
-    "cocum",
-    "la raza"
-  ],
-  "attraction-type": [
-    "none",
-    "do not care",
-    "architecture",
-    "boat",
-    "boating",
-    "camboats",
-    "church",
-    "churchills college",
-    "cinema",
-    "college",
-    "concert",
-    "concerthall",
-    "entertainment",
-    "gallery",
-    "gastropub",
-    "hiking",
-    "hotel",
-    "multiple sports",
-    "museum",
-    "museum kettles yard",
-    "night club",
-    "outdoor",
-    "park",
-    "pool",
-    "special",
-    "sports",
-    "swimming pool",
-    "theater",
-    "theatre",
-    "concert hall",
-    "local site",
-    "nightclub",
-    "hotspot"
-  ],
-  "taxi-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "taxi-destination": [
-    "none",
-    "do not care",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "attraction",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge county fair next to the city tourist museum",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge road church of christ",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village centre",
-    "cherry hinton water park",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "efes restaurant",
-    "el shaddia guesthouse",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "finders corner newmarket road",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "gastropub",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hk fusion",
-    "hobsons house",
-    "holy trinity church",
-    "home from home",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "ian hong",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "leicester train station",
-    "lensfield hotel",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "midsummer house restaurant",
-    "milton country park",
-    "mumford theatre",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "nil",
-    "nirala",
-    "norwich train station",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pipasha restaurant",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "shanghai family restaurant",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "sleeperz hotel",
-    "soul tree nightclub",
-    "st johns chop house",
-    "stansted airport train station",
-    "station road",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tall monument",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the anatolia",
-    "the cambridge corn exchange",
-    "the cambridge shop",
-    "the fez club",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the regent street city center",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "wankworth hotel",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "leverton house",
-    "the cambridge chop house",
-    "saint john's college",
-    "churchill college",
-    "the nirala",
-    "the cow pizza kitchen and bar",
-    "christ's college",
-    "el shaddai",
-    "saint catharine's college",
-    "camb",
-    "the golden curry",
-    "little saint mary's church",
-    "country folk museum",
-    "meze bar restaurant",
-    "the cambridge belfry",
-    "the fitzwilliam museum",
-    "the lensfield hotel",
-    "pizza express fen ditton",
-    "the cambridge punter",
-    "king's college",
-    "the cherry hinton village centre",
-    "shiraz restaurant",
-    "sheep's green and lammas land park fen causeway",
-    "caffe uno",
-    "the ghandi",
-    "the copper kettle",
-    "man on the moon concert hall",
-    "alpha-milton guest house",
-    "queen's college",
-    "restaurant one seven",
-    "restaurant two two",
-    "city centre north b and b",
-    "rosa's bed and breakfast",
-    "the good luck chinese food takeaway",
-    "not museum of archaeology and anthropologymentioned",
-    "tandori in cambridge",
-    "kettle's yard",
-    "megna",
-    "grou",
-    "gallery at twelve a high street",
-    "maharajah tandoori restaurant",
-    "pizza hut fen ditton",
-    "gandhi",
-    "tranh binh",
-    "kambur",
-    "people's portraits exhibition at girton college",
-    "hotel",
-    "restaurant",
-    "the galleria",
-    "queens' college",
-    "great saint mary's church",
-    "theathre",
-    "cambridge artworks",
-    "acorn house",
-    "shiraz",
-    "riverboat georginawd",
-    "mic",
-    "the gallery at twelve",
-    "the soul tree",
-    "finches"
-  ],
-  "taxi-departure": [
-    "none",
-    "do not care",
-    "172 chestertown road",
-    "4455 woodbridge road",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "alyesbray lodge hotel",
-    "ambridge",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "caffee uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge towninfo centre",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "centre of town at my hotel",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village center",
-    "cherry hinton village centre",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "citiroomz",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clair hall",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "curry queen",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "downing street",
-    "el shaddia guesthouse",
-    "ely",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "girton college",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hobsons house",
-    "holy trinity church",
-    "home",
-    "home from home",
-    "hotel",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "junction theatre",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kings lynn train station",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "lensfield hotel",
-    "leverton house",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "milton country park",
-    "mumford theatre",
-    "museum",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "new england",
-    "nirala",
-    "norwich train station",
-    "nstaot mentioned",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "sheeps green and lammas land park",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "soul tree nightclub",
-    "st johns college",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the cambridge corn exchange",
-    "the fez club",
-    "the gallery at 12",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "christ's college",
-    "city centre north b and b",
-    "the lensfield hotel",
-    "alpha-milton guest house",
-    "el shaddai",
-    "churchill college",
-    "the cambridge belfry",
-    "king's college",
-    "great saint mary's church",
-    "restaurant two two",
-    "queens' college",
-    "little saint mary's church",
-    "chinese city centre",
-    "kettle's yard",
-    "pizza hut",
-    "the golden curry",
-    "rosa's bed and breakfast",
-    "the cambridge punter",
-    "the byard art museum",
-    "saint catharine's college",
-    "meze bar restaurant",
-    "the good luck chinese food takeaway",
-    "restaurant one seven",
-    "pizza hut fen ditton",
-    "the nirala",
-    "the fitzwilliam museum",
-    "st. john's college",
-    "gallery at twelve a high street",
-    "sheep's green and lammas land park fen causeway",
-    "the cherry hinton village centre",
-    "pizza express fen ditton",
-    "corpus cristi",
-    "cas",
-    "acorn house",
-    "lens",
-    "the cambridge chop house",
-    "the copper kettle",
-    "the avalon",
-    "saint john's college",
-    "aylesbray lodge",
-    "the alexander bed and breakfast",
-    "cambridge belfy",
-    "people's portraits exhibition at girton college",
-    "gonville",
-    "caffe uno",
-    "the cow pizza kitchen and bar",
-    "lovell ldoge",
-    "cinema",
-    "shiraz restaurant",
-    "park",
-    "the allenbell"
-  ],
-  "restaurant-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "restaurant-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "restaurant-book time": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "taxi-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "restaurant-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west"
-  ],
-  "hotel-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west"
-  ],
-  "attraction-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west"
-  ]
-}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
deleted file mode 100644
index b0dd00fd..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
+++ /dev/null
@@ -1,3128 +0,0 @@
-{
-  "hotel-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate",
-    "request"
-  ],
-  "hotel-type": [
-    "none",
-    "do not care",
-    "bed and breakfast",
-    "guest house",
-    "hotel",
-    "request"
-  ],
-  "hotel-parking": [
-    "none",
-    "do not care",
-    "no",
-    "yes",
-    "request"
-  ],
-  "hotel-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "hotel-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-book stay": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "train-destination": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "london liverpool street",
-    "centre",
-    "bishop stortford",
-    "liverpool",
-    "leicester",
-    "broxbourne",
-    "gourmet burger kitchen",
-    "copper kettle",
-    "bournemouth",
-    "stevenage",
-    "liverpool street",
-    "norwich",
-    "huntingdon marriott hotel",
-    "city centre north",
-    "taj tandoori",
-    "the copper kettle",
-    "peterborough",
-    "ely",
-    "lecester",
-    "london",
-    "willi",
-    "stansted airport",
-    "huntington marriott",
-    "cambridge",
-    "gonv",
-    "glastonbury",
-    "hol",
-    "north",
-    "birmingham new street",
-    "norway",
-    "petersborough",
-    "london kings cross",
-    "curry prince",
-    "bishops storford"
-  ],
-  "train-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "train-departure": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "brookshite",
-    "london liverpool street",
-    "cam",
-    "liverpool",
-    "bro",
-    "leicester",
-    "broxbourne",
-    "norwhich",
-    "saint johns",
-    "stevenage",
-    "stansted",
-    "london liverpool",
-    "cambrid",
-    "city hall",
-    "rosas bed and breakfast",
-    "alpha-milton",
-    "wandlebury country park",
-    "norwich",
-    "liecester",
-    "stratford",
-    "peterborough",
-    "duxford",
-    "ely",
-    "london",
-    "stansted airport",
-    "lon",
-    "cambridge",
-    "panahar",
-    "cineworld",
-    "leicaster",
-    "birmingham",
-    "cafe uno",
-    "camboats",
-    "huntingdon",
-    "birmingham new street",
-    "arbu",
-    "alpha milton",
-    "east london",
-    "london kings cross",
-    "hamilton lodge",
-    "aylesbray lodge guest",
-    "el shaddai"
-  ],
-  "train-day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "train-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-stars": [
-    "none",
-    "do not care",
-    "0",
-    "1",
-    "2",
-    "3",
-    "4",
-    "5",
-    "request"
-  ],
-  "hotel-internet": [
-    "none",
-    "do not care",
-    "no",
-    "yes",
-    "request"
-  ],
-  "hotel-name": [
-    "none",
-    "do not care",
-    "a and b guest house",
-    "city roomz",
-    "carolina bed and breakfast",
-    "limehouse",
-    "anatolia",
-    "hamilton lodge",
-    "the lensfield hotel",
-    "rosa's bed and breakfast",
-    "gall",
-    "aylesbray lodge",
-    "kirkwood",
-    "cambridge belfry",
-    "warkworth house",
-    "gonville",
-    "belfy hotel",
-    "nus",
-    "alexander",
-    "super 5",
-    "aylesbray lodge guest house",
-    "the gonvile hotel",
-    "allenbell",
-    "nothamilton lodge",
-    "ashley hotel",
-    "autumn house",
-    "hobsons house",
-    "hotel",
-    "ashely hotel",
-    "caridge belfrey",
-    "el shaddia guest house",
-    "avalon",
-    "cote",
-    "city centre north bed and breakfast",
-    "the cambridge belfry",
-    "home from home",
-    "wandlebury coutn",
-    "wankworth house",
-    "city stop rest",
-    "the worth house",
-    "cityroomz",
-    "huntingdon marriottt hotel",
-    "lensfield",
-    "rosas bed and breakfast",
-    "leverton house",
-    "gonville hotel",
-    "holiday inn cambridge",
-    "archway house",
-    "lan hon",
-    "levert",
-    "acorn guest house",
-    "cambridge",
-    "the ashley hotel",
-    "el shaddai",
-    "sleeperz",
-    "alpha milton guest house",
-    "doubletree by hilton cambridge",
-    "tandoori palace",
-    "express by",
-    "express by holiday inn cambridge",
-    "north bed and breakfast",
-    "holiday inn",
-    "arbury lodge guest house",
-    "alexander bed and breakfast",
-    "huntingdon marriott hotel",
-    "royal spice",
-    "sou",
-    "finches bed and breakfast",
-    "the alpha milton",
-    "bridge guest house",
-    "the acorn guest house",
-    "kirkwood house",
-    "eraina",
-    "la margherit",
-    "lensfield hotel",
-    "marriott hotel",
-    "nusha",
-    "city centre bed and breakfast",
-    "the allenbell",
-    "university arms hotel",
-    "clare",
-    "cherr",
-    "wartworth",
-    "acorn place",
-    "lovell lodge",
-    "whale"
-  ],
-  "train-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "restaurant-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate",
-    "request"
-  ],
-  "restaurant-food": [
-    "none",
-    "do not care",
-    "british food",
-    "steakhouse",
-    "turkish",
-    "sushi",
-    "north american",
-    "scottish",
-    "french",
-    "austrian",
-    "korean",
-    "eastern european",
-    "swedish",
-    "gastro pub",
-    "modern eclectic",
-    "afternoon tea",
-    "welsh",
-    "christmas",
-    "tuscan",
-    "gastropub",
-    "sri lankan",
-    "molecular gastronomy",
-    "traditional american",
-    "italian",
-    "pizza",
-    "thai",
-    "south african",
-    "creative",
-    "english",
-    "asian",
-    "lebanese",
-    "hungarian",
-    "halal",
-    "portugese",
-    "modern english",
-    "african",
-    "light bites",
-    "malaysian",
-    "venetian",
-    "traditional",
-    "chinese",
-    "vegetarian",
-    "persian",
-    "thai and chinese",
-    "scandinavian",
-    "catalan",
-    "polynesian",
-    "crossover",
-    "canapes",
-    "cantonese",
-    "north african",
-    "seafood",
-    "brazilian",
-    "south indian",
-    "australasian",
-    "belgian",
-    "barbeque",
-    "the americas",
-    "indonesian",
-    "singaporean",
-    "irish",
-    "middle eastern",
-    "dojo noodle bar",
-    "caribbean",
-    "vietnamese",
-    "modern european",
-    "russian",
-    "german",
-    "world",
-    "japanese",
-    "moroccan",
-    "modern global",
-    "indian",
-    "british",
-    "american",
-    "danish",
-    "panasian",
-    "swiss",
-    "basque",
-    "north indian",
-    "modern american",
-    "australian",
-    "european",
-    "corsica",
-    "greek",
-    "northern european",
-    "mediterranean",
-    "portuguese",
-    "romanian",
-    "jamaican",
-    "polish",
-    "international",
-    "unusual",
-    "latin american",
-    "asian oriental",
-    "mexican",
-    "bistro",
-    "cuban",
-    "fusion",
-    "new zealand",
-    "spanish",
-    "eritrean",
-    "afghan",
-    "kosher",
-    "request"
-  ],
-  "attraction-name": [
-    "none",
-    "do not care",
-    "downing college",
-    "fitzwilliam",
-    "clare college",
-    "ruskin gallery",
-    "sidney sussex college",
-    "great saint mary's church",
-    "cherry hinton water play park",
-    "wandlebury country park",
-    "cafe uno",
-    "place",
-    "broughton",
-    "cineworld cinema",
-    "jesus college",
-    "vue cinema",
-    "history of science museum",
-    "mumford theatre",
-    "whale of time",
-    "fitzbillies",
-    "christs church",
-    "churchill college",
-    "museum of classical archaeology",
-    "gonville and caius college",
-    "pizza",
-    "kirkwood",
-    "saint catharines college",
-    "kings college",
-    "parkside",
-    "by",
-    "st catharines college",
-    "saint john's college",
-    "cherry hinton water park",
-    "st christs college",
-    "christ's college",
-    "bangkok city",
-    "scudamores punti co",
-    "free",
-    "great saint marys church",
-    "milton country park",
-    "the fez club",
-    "soultree",
-    "autu",
-    "whipple museum of the history of science",
-    "aylesbray lodge guest house",
-    "broughton house gallery",
-    "peoples portraits exhibition",
-    "primavera",
-    "kettles yard",
-    "all saint's church",
-    "cinema cinema",
-    "regency gallery",
-    "corpus christi",
-    "corn cambridge exchange",
-    "da vinci pizzeria",
-    "school",
-    "hobsons house",
-    "cambride and country folk museum",
-    "north",
-    "da v",
-    "cambridge corn exchange",
-    "soul tree nightclub",
-    "cambridge arts theater",
-    "saint catharine's college",
-    "byard art",
-    "cambridge punter",
-    "cambridge university botanic gardens",
-    "castle galleries",
-    "museum of archaelogy and anthropogy",
-    "no specific location",
-    "cherry hinton hall",
-    "gallery at 12 a high street",
-    "parkside pools",
-    "queen's college",
-    "little saint mary's church",
-    "gallery",
-    "home from home",
-    "tenpin",
-    "the wandlebury",
-    "county folk museum",
-    "swimming pool",
-    "christs college",
-    "cafe jello museum",
-    "scott polar",
-    "christ college",
-    "cambridge museum of technology",
-    "abbey pool and astroturf pitch",
-    "king hedges learner pool",
-    "the cambridge arts theatre",
-    "the castle galleries",
-    "cambridge and country folk museum",
-    "kohinoor",
-    "scudamores punting co",
-    "sidney sussex",
-    "the man on the moon",
-    "little saint marys church",
-    "queens",
-    "the place",
-    "old school",
-    "churchill",
-    "churchills college",
-    "hughes hall",
-    "churchhill college",
-    "riverboat georgina",
-    "belf",
-    "cambridge temporary art",
-    "abc theatre",
-    "cambridge contemporary art museum",
-    "man on the moon",
-    "the junction",
-    "cherry hinton water play",
-    "adc theatre",
-    "gonville hotel",
-    "magdalene college",
-    "peoples portraits exhibition at girton college",
-    "boat",
-    "centre",
-    "sheep's green and lammas land park fen causeway",
-    "the mumford theatre",
-    "archway house",
-    "queens' college",
-    "williams art and antiques",
-    "funky fun house",
-    "cherry hinton village centre",
-    "camboats",
-    "cambridge",
-    "old schools",
-    "kettle's yard",
-    "whale of a time",
-    "the churchill college",
-    "cafe jello gallery",
-    "aut",
-    "salsa",
-    "city",
-    "clare hall",
-    "boating",
-    "pembroke college",
-    "kings hedges learner pool",
-    "caffe uno",
-    "lammas land park",
-    "museum",
-    "the fitzwilliam museum",
-    "the cherry hinton village centre",
-    "the cambridge corn exchange",
-    "fitzwilliam museum",
-    "museum of archaelogy and anthropology",
-    "fez club",
-    "the cambridge punter",
-    "saint johns college",
-    "emmanuel college",
-    "cambridge belf",
-    "scudamore",
-    "lynne strover gallery",
-    "king's college",
-    "whippple museum",
-    "trinity college",
-    "college in the north",
-    "sheep's green",
-    "kambar",
-    "museum of archaelogy",
-    "adc",
-    "garde",
-    "club salsa",
-    "people's portraits exhibition at girton college",
-    "botanic gardens",
-    "carol",
-    "college",
-    "gallery at twelve a high street",
-    "abbey pool and astroturf",
-    "cambridge book and print gallery",
-    "jesus green outdoor pool",
-    "scott polar museum",
-    "saint barnabas press gallery",
-    "cambridge artworks",
-    "older churches",
-    "cambridge contemporary art",
-    "cherry hinton hall and grounds",
-    "univ",
-    "jesus green",
-    "ballare",
-    "abbey pool",
-    "cambridge botanic gardens",
-    "nusha",
-    "worth house",
-    "thanh",
-    "university arms hotel",
-    "cambridge arts theatre",
-    "cafe jello",
-    "cambridge and county folk museum",
-    "the cambridge artworks",
-    "all saints church",
-    "holy trinity church",
-    "contemporary art museum",
-    "architectural churches",
-    "queens college",
-    "trinity street college"
-  ],
-  "restaurant-name": [
-    "none",
-    "do not care",
-    "hotel du vin and bistro",
-    "ask",
-    "gourmet formal kitchen",
-    "the meze bar",
-    "lan hong house",
-    "cow pizza",
-    "one seven",
-    "prezzo",
-    "maharajah tandoori restaurant",
-    "alex",
-    "shanghai",
-    "golden wok",
-    "restaurant",
-    "fitzbillies",
-    "nil",
-    "copper kettle",
-    "meghna",
-    "hk fusion",
-    "bangkok city",
-    "hobsons house",
-    "tang chinese",
-    "anatolia",
-    "ugly duckling",
-    "anatolia and efes restaurant",
-    "sitar tandoori",
-    "city stop",
-    "ashley",
-    "pizza express fen ditton",
-    "molecular gastronomy",
-    "autumn house",
-    "el shaddia guesthouse",
-    "the grafton hotel",
-    "limehouse",
-    "gardenia",
-    "not metioned",
-    "hakka",
-    "michaelhouse cafe",
-    "pipasha",
-    "meze bar",
-    "archway",
-    "molecular gastonomy",
-    "yipee noodle bar",
-    "the peking",
-    "curry prince",
-    "midsummer house restaurant",
-    "pizza hut cherry hinton",
-    "the lucky star",
-    "stazione restaurant and coffee bar",
-    "shanghi family restaurant",
-    "good luck",
-    "j restaurant",
-    "bedouin",
-    "cott",
-    "little seoul",
-    "south",
-    "thanh binh",
-    "el",
-    "efes restaurant",
-    "kohinoor",
-    "clowns",
-    "india",
-    "the slug and lettuce",
-    "shiraz",
-    "barbakan",
-    "zizzi cambridge",
-    "restaurant one seven",
-    "slug and lettuce",
-    "travellers rest",
-    "binh",
-    "worth house",
-    "broughton house gallery",
-    "chiquito",
-    "the river bar steakhouse and grill",
-    "ros",
-    "golden house",
-    "india west",
-    "cam",
-    "panahar",
-    "restaurant 22",
-    "adden",
-    "indian",
-    "hu",
-    "jinling noodle bar",
-    "darrys cookhouse and wine shop",
-    "hobson house",
-    "cambridge be",
-    "el shaddai",
-    "ac",
-    "nandos",
-    "cambridge lodge",
-    "the cow pizza kitchen and bar",
-    "charlie",
-    "rajmahal",
-    "kymmoy",
-    "cambri",
-    "backstreet bistro",
-    "galleria",
-    "restaurant 2 two",
-    "chiquito restaurant bar",
-    "royal standard",
-    "lucky star",
-    "curry king",
-    "grafton hotel restaurant",
-    "mahal of cambridge",
-    "the bedouin",
-    "nus",
-    "the kohinoor",
-    "pizza hut fenditton",
-    "camboats",
-    "the gardenia",
-    "de luca cucina and bar",
-    "nusha",
-    "european",
-    "taj tandoori",
-    "tandoori palace",
-    "golden curry",
-    "efes",
-    "loch fyne",
-    "the maharajah tandoor",
-    "lovel",
-    "restaurant 17",
-    "clowns cafe",
-    "cambridge punter",
-    "bloomsbury restaurant",
-    "la mimosa",
-    "the cambridge chop house",
-    "funky",
-    "cotto",
-    "oak bistro",
-    "restaurant two two",
-    "pipasha restaurant",
-    "river bar steakhouse and grill",
-    "royal spice",
-    "the copper kettle",
-    "graffiti",
-    "nandos city centre",
-    "saffron brasserie",
-    "cambridge chop house",
-    "sitar",
-    "kitchen and bar",
-    "the good luck chinese food takeaway",
-    "clu",
-    "la tasca",
-    "cafe uno",
-    "cote",
-    "the varsity restaurant",
-    "bri",
-    "eraina",
-    "bridge",
-    "fin",
-    "cambridge lodge restaurant",
-    "grafton",
-    "hotpot",
-    "sala thong",
-    "margherita",
-    "wise buddha",
-    "the missing sock",
-    "seasame restaurant and bar",
-    "the dojo noodle bar",
-    "restaurant alimentum",
-    "gastropub",
-    "saigon city",
-    "la margherita",
-    "pizza hut",
-    "curry garden",
-    "ashley hotel",
-    "eraina and michaelhouse cafe",
-    "the golden curry",
-    "curry queen",
-    "cow pizza kitchen and bar",
-    "the peking restaurant:",
-    "hamilton lodge",
-    "alimentum",
-    "yippee noodle bar",
-    "2 two and cote",
-    "shanghai family restaurant",
-    "grafton hotel",
-    "yes",
-    "ali baba",
-    "dif",
-    "fitzbillies restaurant",
-    "peking restaurant",
-    "lev",
-    "nirala",
-    "the alex",
-    "tandoori",
-    "city stop restaurant",
-    "rice house",
-    "cityr",
-    "yu garden",
-    "meze bar restaurant",
-    "the",
-    "don pasquale pizzeria",
-    "rice boat",
-    "the hotpot",
-    "old school",
-    "the oak bistro",
-    "sesame restaurant and bar",
-    "pizza express",
-    "the gandhi",
-    "pizza hut fen ditton",
-    "charlie chan",
-    "da vinci pizzeria",
-    "dojo noodle bar",
-    "gourmet burger kitchen",
-    "the golden house",
-    "india house",
-    "hobso",
-    "missing sock",
-    "pizza hut city centre",
-    "parkside pools",
-    "riverside brasserie",
-    "caffe uno",
-    "primavera",
-    "the nirala",
-    "wagamama",
-    "au",
-    "ian hong house",
-    "frankie and bennys",
-    "4 kings parade city centre",
-    "shiraz restaurant",
-    "scudamores punt",
-    "mahal",
-    "saint johns chop house",
-    "de luca cucina and bar riverside brasserie",
-    "cocum",
-    "la raza"
-  ],
-  "attraction-type": [
-    "none",
-    "do not care",
-    "architecture",
-    "boat",
-    "boating",
-    "camboats",
-    "church",
-    "churchills college",
-    "cinema",
-    "college",
-    "concert",
-    "concerthall",
-    "entertainment",
-    "gallery",
-    "gastropub",
-    "hiking",
-    "hotel",
-    "multiple sports",
-    "museum",
-    "museum kettles yard",
-    "night club",
-    "outdoor",
-    "park",
-    "pool",
-    "special",
-    "sports",
-    "swimming pool",
-    "theater",
-    "theatre",
-    "concert hall",
-    "local site",
-    "nightclub",
-    "hotspot",
-    "request"
-  ],
-  "taxi-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "taxi-destination": [
-    "none",
-    "do not care",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "attraction",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge county fair next to the city tourist museum",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge road church of christ",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village centre",
-    "cherry hinton water park",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "efes restaurant",
-    "el shaddia guesthouse",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "finders corner newmarket road",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "gastropub",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hk fusion",
-    "hobsons house",
-    "holy trinity church",
-    "home from home",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "ian hong",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "leicester train station",
-    "lensfield hotel",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "midsummer house restaurant",
-    "milton country park",
-    "mumford theatre",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "nil",
-    "nirala",
-    "norwich train station",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pipasha restaurant",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "shanghai family restaurant",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "sleeperz hotel",
-    "soul tree nightclub",
-    "st johns chop house",
-    "stansted airport train station",
-    "station road",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tall monument",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the anatolia",
-    "the cambridge corn exchange",
-    "the cambridge shop",
-    "the fez club",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the regent street city center",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "wankworth hotel",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "leverton house",
-    "the cambridge chop house",
-    "saint john's college",
-    "churchill college",
-    "the nirala",
-    "the cow pizza kitchen and bar",
-    "christ's college",
-    "el shaddai",
-    "saint catharine's college",
-    "camb",
-    "the golden curry",
-    "little saint mary's church",
-    "country folk museum",
-    "meze bar restaurant",
-    "the cambridge belfry",
-    "the fitzwilliam museum",
-    "the lensfield hotel",
-    "pizza express fen ditton",
-    "the cambridge punter",
-    "king's college",
-    "the cherry hinton village centre",
-    "shiraz restaurant",
-    "sheep's green and lammas land park fen causeway",
-    "caffe uno",
-    "the ghandi",
-    "the copper kettle",
-    "man on the moon concert hall",
-    "alpha-milton guest house",
-    "queen's college",
-    "restaurant one seven",
-    "restaurant two two",
-    "city centre north b and b",
-    "rosa's bed and breakfast",
-    "the good luck chinese food takeaway",
-    "not museum of archaeology and anthropologymentioned",
-    "tandori in cambridge",
-    "kettle's yard",
-    "megna",
-    "grou",
-    "gallery at twelve a high street",
-    "maharajah tandoori restaurant",
-    "pizza hut fen ditton",
-    "gandhi",
-    "tranh binh",
-    "kambur",
-    "people's portraits exhibition at girton college",
-    "hotel",
-    "restaurant",
-    "the galleria",
-    "queens' college",
-    "great saint mary's church",
-    "theathre",
-    "cambridge artworks",
-    "acorn house",
-    "shiraz",
-    "riverboat georginawd",
-    "mic",
-    "the gallery at twelve",
-    "the soul tree",
-    "finches"
-  ],
-  "taxi-departure": [
-    "none",
-    "do not care",
-    "172 chestertown road",
-    "4455 woodbridge road",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "alyesbray lodge hotel",
-    "ambridge",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "caffee uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge towninfo centre",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "centre of town at my hotel",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village center",
-    "cherry hinton village centre",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "citiroomz",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clair hall",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "curry queen",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "downing street",
-    "el shaddia guesthouse",
-    "ely",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "girton college",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hobsons house",
-    "holy trinity church",
-    "home",
-    "home from home",
-    "hotel",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "junction theatre",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kings lynn train station",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "lensfield hotel",
-    "leverton house",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "milton country park",
-    "mumford theatre",
-    "museum",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "new england",
-    "nirala",
-    "norwich train station",
-    "nstaot mentioned",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "sheeps green and lammas land park",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "soul tree nightclub",
-    "st johns college",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the cambridge corn exchange",
-    "the fez club",
-    "the gallery at 12",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "christ's college",
-    "city centre north b and b",
-    "the lensfield hotel",
-    "alpha-milton guest house",
-    "el shaddai",
-    "churchill college",
-    "the cambridge belfry",
-    "king's college",
-    "great saint mary's church",
-    "restaurant two two",
-    "queens' college",
-    "little saint mary's church",
-    "chinese city centre",
-    "kettle's yard",
-    "pizza hut",
-    "the golden curry",
-    "rosa's bed and breakfast",
-    "the cambridge punter",
-    "the byard art museum",
-    "saint catharine's college",
-    "meze bar restaurant",
-    "the good luck chinese food takeaway",
-    "restaurant one seven",
-    "pizza hut fen ditton",
-    "the nirala",
-    "the fitzwilliam museum",
-    "st. john's college",
-    "gallery at twelve a high street",
-    "sheep's green and lammas land park fen causeway",
-    "the cherry hinton village centre",
-    "pizza express fen ditton",
-    "corpus cristi",
-    "cas",
-    "acorn house",
-    "lens",
-    "the cambridge chop house",
-    "the copper kettle",
-    "the avalon",
-    "saint john's college",
-    "aylesbray lodge",
-    "the alexander bed and breakfast",
-    "cambridge belfy",
-    "people's portraits exhibition at girton college",
-    "gonville",
-    "caffe uno",
-    "the cow pizza kitchen and bar",
-    "lovell ldoge",
-    "cinema",
-    "shiraz restaurant",
-    "park",
-    "the allenbell"
-  ],
-  "restaurant-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "restaurant-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "restaurant-book time": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "taxi-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "restaurant-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west",
-    "request"
-  ],
-  "hotel-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west",
-    "request"
-  ],
-  "attraction-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west",
-    "request"
-  ],
-  "hospital-department": [
-    "none",
-    "do not care",
-    "acute medical assessment unit",
-    "acute medicine for the elderly",
-    "antenatal",
-    "cambridge eye unit",
-    "cardiology",
-    "cardiology and coronary care unit",
-    "childrens oncology and haematology",
-    "childrens surgical and medicine",
-    "clinical decisions unit",
-    "clinical research facility",
-    "coronary care unit",
-    "diabetes and endocrinology",
-    "emergency department",
-    "gastroenterology",
-    "gynaecology",
-    "haematology",
-    "haematology and haematological oncology",
-    "haematology day unit",
-    "hepatobillary and gastrointestinal surgery regional referral centre",
-    "hepatology",
-    "infectious diseases",
-    "infusion services",
-    "inpatient occupational therapy",
-    "intermediate dependancy area",
-    "john farman intensive care unit",
-    "medical decisions unit",
-    "medicine for the elderly",
-    "neonatal unit",
-    "neurology",
-    "neurology neurosurgery",
-    "neurosciences",
-    "neurosciences critical care unit",
-    "oncology",
-    "oral and maxillofacial surgery and ent",
-    "paediatric clinic",
-    "paediatric day unit",
-    "paediatric intensive care unit",
-    "plastic and vascular surgery plastics",
-    "psychiatry",
-    "respiratory medicine",
-    "surgery",
-    "teenage cancer trust unit",
-    "transitional care",
-    "transplant high dependency unit",
-    "trauma and orthopaedics",
-    "trauma high dependency unit",
-    "urology"
-  ],
-  "police-postcode": [
-    "request"
-  ],
-  "restaurant-postcode": [
-    "request"
-  ],
-  "train-duration": [
-    "request"
-  ],
-  "train-trainid": [
-    "request"
-  ],
-  "hospital-address": [
-    "request"
-  ],
-  "restaurant-phone": [
-    "request"
-  ],
-  "hotel-phone": [
-    "request"
-  ],
-  "restaurant-address": [
-    "request"
-  ],
-  "hotel-postcode": [
-    "request"
-  ],
-  "attraction-phone": [
-    "request"
-  ],
-  "attraction-entrance fee": [
-    "request"
-  ],
-  "hotel-reference": [
-    "request"
-  ],
-  "taxi-taxi types": [
-    "request"
-  ],
-  "attraction-address": [
-    "request"
-  ],
-  "hospital-phone": [
-    "request"
-  ],
-  "attraction-postcode": [
-    "request"
-  ],
-  "police-address": [
-    "request"
-  ],
-  "taxi-taxi phone": [
-    "request"
-  ],
-  "train-price": [
-    "request"
-  ],
-  "hospital-postcode": [
-    "request"
-  ],
-  "police-phone": [
-    "request"
-  ],
-  "hotel-address": [
-    "request"
-  ],
-  "restaurant-reference": [
-    "request"
-  ],
-  "train-reference": [
-    "request"
-  ]
-}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
deleted file mode 100644
index 87e31536..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
+++ /dev/null
@@ -1,57 +0,0 @@
-{
-  "hotel-price range": "preferred cost or price of the hotel",
-  "hotel-type": "what is the type of the hotel",
-  "hotel-parking": "does the hotel have parking",
-  "hotel-book stay": "number of nights for the hotel reservation",
-  "hotel-book day": "starting day of the hotel booking",
-  "hotel-book people": "number of people for the hotel booking",
-  "hotel-area": "area or place of the hotel",
-  "hotel-stars": "star rating of the hotel",
-  "hotel-internet": "does the hotel have internet or wifi",
-  "hotel-name": "name of the hotel",
-  "hotel-phone": "phone number of the hotel",
-  "hotel-postcode": "postcode of the hotel",
-  "hotel-reference": "booking reference of the hotel booking",
-  "hotel-address": "street address of the hotel",
-  "train-destination": "train station you want to travel to",
-  "train-day": "day of the train booking",
-  "train-departure": "train station you want to leave from",
-  "train-arrive by": "arrival time of the train",
-  "train-book people": "number of people for the train booking",
-  "train-leave at": "departure time for the train",
-  "train-duration": "duration of the train journey",
-  "train-trainid": "train identifier or number",
-  "train-price": "how much does the train trip cost",
-  "train-reference": "booking reference of the train booking",
-  "attraction-type": "type of attraction or point of interest",
-  "attraction-area": "area or place of the attraction",
-  "attraction-name": "name of the attraction",
-  "attraction-phone": "phone number of the attraction",
-  "attraction-entrance fee": "entrace fee at the attraction",
-  "attraction-address": "street address of the attraction",
-  "attraction-postcode": "postcode of the attraction",
-  "restaurant-book people": "number of people for the restaurant booking",
-  "restaurant-book day": "weekday for the restaurant booking",
-  "restaurant-book time": "time of the restaurant booking",
-  "restaurant-food": "type of food served at the restaurant",
-  "restaurant-price range": "preferred cost or price of the restaurant",
-  "restaurant-name": "name of the restaurant",
-  "restaurant-area": "area or place of the restaurant",
-  "restaurant-postcode": "postcode of the restaurant",
-  "restaurant-phone": "phone number of the restaurant",
-  "restaurant-address": "street address of the restaurant",
-  "restaurant-reference": "booking reference of the hotel booking",
-  "taxi-leave at": "what time you want the taxi to leave by",
-  "taxi-destination": "where you want the taxi to drop you off",
-  "taxi-departure": "where you want the taxi to pick you up",
-  "taxi-arrive by": "what time you to arrive at your destination",
-  "taxi-taxi types": "vehicle type of the taxi",
-  "taxi-taxi phone": "phone number of the taxi",
-  "hospital-department": "name of hospital department",
-  "hospital-address": "street address of the hospital",
-  "hospital-phone": "phone number of the hospital",
-  "hospital-postcode": "postcode of the hospital",
-  "police-postcode": "postcode of the police station",
-  "police-address": "street address of the police station",
-  "police-phone": "phone number of the police station"
-}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py
deleted file mode 100644
index c6b9c336..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py
+++ /dev/null
@@ -1,168 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Create Ontology Embeddings"""
-
-import json
-import os
-import random
-
-import torch
-import numpy as np
-
-
-# Slot mapping table for description extractions
-# SLOT_NAME_MAPPINGS = {
-#     'arrive at': 'arriveAt',
-#     'arrive by': 'arriveBy',
-#     'leave at': 'leaveAt',
-#     'leave by': 'leaveBy',
-#     'arriveby': 'arriveBy',
-#     'arriveat': 'arriveAt',
-#     'leaveat': 'leaveAt',
-#     'leaveby': 'leaveBy',
-#     'price range': 'pricerange'
-# }
-
-# Set up global data directory
-def set_datadir(dir):
-    global DATA_DIR
-    DATA_DIR = dir
-
-
-# Set seeds
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-# Get embeddings for slots and candidates
-def get_slot_candidate_embeddings(set_type, args, tokenizer, embedding_model, save_to_file=True):
-    # Get set alots and candidates
-    reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
-    ontology = json.load(reader)
-    reader.close()
-
-    reader = open(os.path.join(DATA_DIR, 'slot_descriptions.json'), 'r')
-    slot_descriptions = json.load(reader)
-    reader.close()
-
-    embedding_model.eval()
-
-    slots = dict()
-    for slot in ontology:
-        if args.use_descriptions:
-            # d, s = slot.split('-', 1)
-            # s = SLOT_NAME_MAPPINGS[s] if s in SLOT_NAME_MAPPINGS else s
-            # s = d + '-' + s
-            # if slot in slot_descriptions:
-            desc = slot_descriptions[slot]
-            # elif slot.lower() in slot_descriptions:
-            #     desc = slot_descriptions[s.lower()]
-            # else:
-            #     desc = slot.replace('-', ' ')
-        else:
-            desc = slot
-
-        # Tokenize slot and get embeddings
-        feats = tokenizer.encode_plus(desc, add_special_tokens = True,
-                                            max_length = args.max_slot_len, padding='max_length',
-                                            truncation = 'longest_first')
-
-        with torch.no_grad():
-            input_ids = torch.tensor([feats['input_ids']]).to(embedding_model.device) # [1, max_slot_len]
-            if 'token_type_ids' in feats:
-                token_type_ids = torch.tensor([feats['token_type_ids']]).to(embedding_model.device) # [1, max_slot_len]
-                if 'attention_mask' in feats:
-                    attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) # [1, max_slot_len]
-                    feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
-                                            attention_mask=attention_mask)
-                    attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                    feats = feats * attention_mask # [1, max_slot_len, hidden_dim]
-                else:
-                    feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids)
-            else:
-                if 'attention_mask' in feats:
-                    attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device)
-                    feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
-                    attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                    feats = feats * attention_mask # [1, max_slot_len, hidden_dim]
-                else:
-                    feats, pooled_feats = embedding_model(input_ids=input_ids) # [1, max_slot_len, hidden_dim]
-        
-        if args.set_similarity:
-            slot_emb = feats[0, :, :].detach().cpu() # [seq_len, hidden_dim]
-        else:
-            if args.candidate_pooling == 'cls' and pooled_feats is not None:
-                slot_emb = pooled_feats[0, :].detach().cpu() # [hidden_dim]
-            elif args.candidate_pooling == 'mean':
-                feats = feats.sum(1)
-                feats = torch.nn.functional.layer_norm(feats, feats.size())
-                slot_emb = feats[0, :].detach().cpu() # [hidden_dim]
-
-        # Tokenize value candidates and get embeddings
-        values = ontology[slot]
-        is_requestable = False
-        if 'request' in values:
-            is_requestable = True
-            values.remove('request')
-        if values:
-            feats = [tokenizer.encode_plus(val, add_special_tokens = True,
-                                                max_length = args.max_candidate_len, padding='max_length',
-                                                truncation = 'longest_first')
-                    for val in values]
-            with torch.no_grad():
-                input_ids = torch.tensor([f['input_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
-                if 'token_type_ids' in feats[0]:
-                    token_type_ids = torch.tensor([f['token_type_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
-                    if 'attention_mask' in feats[0]:
-                        attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
-                        feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
-                                                attention_mask=attention_mask)
-                        attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                        feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
-                    else:
-                        feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) # [num_candidates, max_candidate_len, hidden_dim]
-                else:
-                    if 'attention_mask' in feats[0]:
-                        attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device)
-                        feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
-                        attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                        feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
-                    else:
-                        feats, pooled_feats = embedding_model(input_ids=input_ids) # [num_candidates, max_candidate_len, hidden_dim]
-            
-            if args.set_similarity:
-                feats = feats.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim]
-            else:
-                if args.candidate_pooling == 'cls' and pooled_feats is not None:
-                    feats = pooled_feats.detach().cpu()
-                elif args.candidate_pooling == "mean":
-                    feats = feats.sum(1)
-                    feats = torch.nn.functional.layer_norm(feats, feats.size())
-                    feats = feats.detach().cpu()
-        else:
-            feats = None
-        slots[slot] = (slot_emb, feats, is_requestable)
-
-    # Dump tensors for use in training
-    if save_to_file:
-        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
-        torch.save(slots, writer)
-    
-    return slots
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/utils.py b/convlab/dst/setsumbt/multiwoz/dataset/utils.py
deleted file mode 100644
index 485dee64..00000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/utils.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Code adapted from the TRADE preprocessing code (https://github.com/jasonwu0731/trade-dst)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""MultiWOZ2.1/3 data processing utilities"""
-
-import re
-import os
-
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
-from convlab.dst.rule.multiwoz import normalize_value
-
-# ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train']
-ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train', 'hospital', 'police']
-def set_util_domains(domains):
-    global ACTIVE_DOMAINS
-    ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
-
-MAPPING_PATH = os.path.abspath(__file__).replace('utils.py', 'mapping.pair')
-# Read replacement pairs from the mapping.pair file
-REPLACEMENTS = []
-for line in open(MAPPING_PATH).readlines():
-    tok_from, tok_to = line.replace('\n', '').split('\t')
-    REPLACEMENTS.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
-
-# Extract belief state from mturk annotations
-def build_dialoguestate(metadata, get_domains=False):
-    domains_list = [dom for dom in ACTIVE_DOMAINS if dom in metadata]
-    dialogue_state, domains = [], []
-    for domain in domains_list:
-        active = False
-        # Extract booking information
-        booking = []
-        for slot in sorted(metadata[domain]['book'].keys()):
-            if slot != 'booked':
-                if metadata[domain]['book'][slot] == 'not mentioned':
-                    continue
-                if metadata[domain]['book'][slot] != '':
-                    val = ['%s-book %s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['book'][slot])]
-                    dialogue_state.append(val)
-                    active = True
-
-        for slot in metadata[domain]['semi']:
-            if metadata[domain]['semi'][slot] == 'not mentioned':
-                continue
-            elif metadata[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", 'don not care',
-                                                    'do not care', 'does not care']:
-                dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), 'do not care'])
-                active = True
-            elif metadata[domain]['semi'][slot]:
-                dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['semi'][slot])])
-                active = True
-
-        if active:
-            domains.append(domain)
-
-    if get_domains:
-        return domains
-    return clean_dialoguestate(dialogue_state)
-
-
-PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive']
-BOOLEAN = ['do not care', 'yes', 'no']
-DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday',
-        'friday', 'saterday', 'sunday']
-QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
-TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
-TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l]
-
-VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
-            'cityroomz': 'city roomz', '  ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
-            'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
-            'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
-            'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
-            'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
-
-def map_values(value):
-    for old, new in VALUE_MAP.items():
-        value = value.replace(old, new)
-    return value
-
-def clean_dialoguestate(states, is_acts=False):
-    # path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
-    # path = os.path.join(path, 'data/multiwoz/value_dict.json')
-    # value_dict = json.load(open(path))
-    clean_state = []
-    for slot, value in states:
-        if 'pricerange' in slot:
-            d, s = slot.split('-', 1)
-            s = 'price range'
-            slot = f'{d}-{s}'
-            if value in PRICERANGE:
-                clean_state.append([slot, value])
-            elif True in [v in value for v in PRICERANGE]:
-                value = [v for v in PRICERANGE if v in value][0]
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'parking' in slot or 'internet' in slot:
-            if value in BOOLEAN:
-                clean_state.append([slot, value])
-            if value == 'free':
-                value = 'yes'
-                clean_state.append([slot, value])
-            elif True in [v in value for v in BOOLEAN]:
-                value = [v for v in BOOLEAN if v in value][0]
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'day' in slot:
-            if value in DAYS:
-                clean_state.append([slot, value])
-            elif True in [v in value for v in DAYS]:
-                value = [v for v in DAYS if v in value][0]
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'people' in slot or 'duration' in slot or 'stay' in slot:
-            if value in QUANTITIES:
-                clean_state.append([slot, value])
-            elif True in [v in value for v in QUANTITIES]:
-                value = [v for v in QUANTITIES if v in value][0]
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                try:
-                    value = int(value)
-                    if value >= 10:
-                        value = '10 or more'
-                        clean_state.append([slot, value])
-                    else:
-                        continue
-                except:
-                    continue
-        elif 'time' in slot or 'leaveat' in slot or 'arriveby' in slot:
-            if 'leaveat' in slot:
-                d, s = slot.split('-', 1)
-                s = 'leave at'
-                slot = f'{d}-{s}'
-            if 'arriveby' in slot:
-                d, s = slot.split('-', 1)
-                s = 'arrive by'
-                slot = f'{d}-{s}'
-            if value in TIME:
-                if value == 'do not care':
-                    clean_state.append([slot, value])
-                else:
-                    h, m = value.split(':')
-                    if int(m) % 5 == 0:
-                        clean_state.append([slot, value])
-                    else:
-                        m = round(int(m) / 5) * 5
-                        h = int(h)
-                        if m == 60:
-                            m = 0
-                            h += 1
-                        if h >= 24:
-                            h -= 24
-                        value = '%02i:%02i' % (h, m)
-                        clean_state.append([slot, value])
-            elif True in [v in value for v in TIME]:
-                value = [v for v in TIME if v in value][0]
-                h, m = value.split(':')
-                if int(m) % 5 == 0:
-                    clean_state.append([slot, value])
-                else:
-                    m = round(int(m) / 5) * 5
-                    h = int(h)
-                    if m == 60:
-                        m = 0
-                        h += 1
-                    if h >= 24:
-                        h -= 24
-                    value = '%02i:%02i' % (h, m)
-                    clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'stars' in slot:
-            if len(value) == 1 or value == 'do not care':
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            elif len(value) > 1:
-                try:
-                    value = int(value[0])
-                    value = str(value)
-                    clean_state.append([slot, value])
-                except:
-                    continue
-        elif 'area' in slot:
-            if '|' in value:
-                value = value.split('|', 1)[0]
-            clean_state.append([slot, value])
-        else:
-            if '|' in value:
-                value = value.split('|', 1)[0]
-                value = map_values(value)
-                # d, s = slot.split('-', 1)
-                # value = normalize_value(value_dict, d, s, value)
-            clean_state.append([slot, value])
-    
-    return clean_state
-
-
-# Module to process a dialogue and check its validity
-def process_dialogue(dialogue, max_utt_len=128):
-    if len(dialogue['log']) % 2 != 0:
-        return None
-
-    # Extract user and system utterances
-    usr_utts, sys_utts = [], []
-    avg_len = sum(len(utt['text'].split(' ')) for utt in dialogue['log'])
-    avg_len = avg_len / len(dialogue['log'])
-    if avg_len > max_utt_len:
-        return None
-
-    # If the first term is a system turn then ignore dialogue
-    if dialogue['log'][0]['metadata']:
-        return None
-
-    usr, sys = None, None
-    for turn in dialogue['log']:
-        if not is_ascii(turn['text']):
-            return None
-
-        if not usr or not sys:
-            if len(turn['metadata']) == 0:
-                usr = turn
-            else:
-                sys = turn
-        
-        if usr and sys:
-            states = build_dialoguestate(sys['metadata'], get_domains = False)
-            sys['dialogue_states'] = states
-
-            usr_utts.append(usr)
-            sys_utts.append(sys)
-            usr, sys = None, None
-
-    dial_clean = dict()
-    dial_clean['usr_log'] = usr_utts
-    dial_clean['sys_log'] = sys_utts
-    return dial_clean
-
-
-# Get new domains
-def get_act_domains(prev, crnt):
-    diff = {}
-    if not prev or not crnt:
-        return diff
-
-    for ((prev_dom, prev_val), (crnt_dom, crnt_val)) in zip(prev.items(), crnt.items()):
-        assert prev_dom == crnt_dom
-        if prev_val != crnt_val:
-            diff[crnt_dom] = crnt_val
-    return diff
-
-
-# Get current domains
-def get_domains(dial_log, turn_id, prev_domain):
-    if turn_id == 1:
-        active = build_dialoguestate(dial_log[turn_id]['metadata'], get_domains=True)
-        acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
-        acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
-        active += acts
-        crnt = active[0] if active else ''
-    else:
-        active = get_act_domains(dial_log[turn_id - 2]['metadata'], dial_log[turn_id]['metadata'])
-        active = list(active.keys())
-        acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
-        acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
-        active += acts
-        crnt = [prev_domain] if not active else active
-        crnt = crnt[0]
-
-    return crnt
-
-
-# Function to extract dialogue info from data
-def extract_dialogue(dialogue, max_utt_len=50):
-    dialogue = process_dialogue(dialogue, max_utt_len)
-    if not dialogue:
-        return None
-
-    usr_utts = [turn['text'] for turn in dialogue['usr_log']]
-    sys_utts = [turn['text'] for turn in dialogue['sys_log']]
-    # sys_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['sys_log']]
-    usr_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['usr_log']]
-    dialogue_states = [turn['dialogue_states'] for turn in dialogue['sys_log']]
-    domains = [turn['domain'] for turn in dialogue['usr_log']]
-
-    # dial = [{'usr': u,'sys': s, 'usr_a': ua, 'sys_a': a, 'domain': d, 'ds': v}
-    #         for u, s, ua, a, d, v in zip(usr_utts, sys_utts, usr_acts, sys_acts, domains, dialogue_states)]
-    dial = [{'usr': u,'sys': s, 'usr_a': ua, 'domain': d, 'ds': v}
-            for u, s, ua, d, v in zip(usr_utts, sys_utts, usr_acts, domains, dialogue_states)]    
-    return dial
-
-
-def format_acts(acts):
-    new_acts = []
-    for key, item in acts.items():
-        domain, intent = key.split('-', 1)
-        if domain.lower() in ACTIVE_DOMAINS + ['general']:
-            state = []
-            for slot, value in item:
-                slot = str(REF_SYS_DA[domain].get(slot, slot)).lower() if domain in REF_SYS_DA else slot
-                value = clean_text(value)
-                slot = slot.replace('_', ' ').replace('ref', 'reference')
-                state.append([f'{domain.lower()}-{slot}', value])
-            state = clean_dialoguestate(state, is_acts=True)
-            if domain == 'general':
-                if intent in ['thank', 'bye']:
-                    state = [['general-none', 'none']]
-                else:
-                    state = []
-            for slot, value in state:
-                if slot not in ['train-people']:
-                    slot = slot.split('-', 1)[-1]
-                    new_acts.append([intent.lower(), domain.lower(), slot, value])
-    
-    return new_acts
-                
-
-# Fix act labels
-def fix_delexicalisation(turn):
-    if 'dialog_act' in turn:
-        for dom, act in turn['dialog_act'].items():
-            if 'Attraction' in dom:
-                if 'restaurant_' in turn['text']:
-                    turn['text'] = turn['text'].replace("restaurant", "attraction")
-                if 'hotel_' in turn['text']:
-                    turn['text'] = turn['text'].replace("hotel", "attraction")
-            if 'Hotel' in dom:
-                if 'attraction_' in turn['text']:
-                    turn['text'] = turn['text'].replace("attraction", "hotel")
-                if 'restaurant_' in turn['text']:
-                    turn['text'] = turn['text'].replace("restaurant", "hotel")
-            if 'Restaurant' in dom:
-                if 'attraction_' in turn['text']:
-                    turn['text'] = turn['text'].replace("attraction", "restaurant")
-                if 'hotel_' in turn['text']:
-                    turn['text'] = turn['text'].replace("hotel", "restaurant")
-
-    return turn
-
-
-# Check if a character is an ascii character
-def is_ascii(s):
-    return all(ord(c) < 128 for c in s)
-
-
-# Insert white space
-def separate_token(token, text):
-    sidx = 0
-    while True:
-        # Find next instance of token
-        sidx = text.find(token, sidx)
-        if sidx == -1:
-            break
-        # If the token is already seperated continue to next
-        if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
-                re.match('[0-9]', text[sidx + 1]):
-            sidx += 1
-            continue
-        # Create white space separation around token
-        if text[sidx - 1] != ' ':
-            text = text[:sidx] + ' ' + text[sidx:]
-            sidx += 1
-        if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
-            text = text[:sidx + 1] + ' ' + text[sidx + 1:]
-        sidx += 1
-    return text
-
-
-def clean_text(text):
-    # Replace white spaces in front and end
-    text = re.sub(r'^\s*|\s*$', '', text.strip().lower())
-
-    # Replace b&v or 'b and b' with 'bed and breakfast'
-    text = re.sub(r"b&b", "bed and breakfast", text)
-    text = re.sub(r"b and b", "bed and breakfast", text)
-
-    # Fix apostrophies
-    text = re.sub(u"(\u2018|\u2019)", "'", text)
-
-    # Correct punctuation
-    text = text.replace(';', ',')
-    text = re.sub('$\/', '', text)
-    text = text.replace('/', ' and ')
-
-    # Replace special characters
-    text = text.replace('-', ' ')
-    text = re.sub('[\"\<>@\(\)]', '', text)
-
-    # Insert white space around special tokens:
-    for token in ['?', '.', ',', '!']:
-        text = separate_token(token, text)
-
-    # insert white space for 's
-    text = separate_token('\'s', text)
-
-    # replace it's, does't, you'd ... etc
-    text = re.sub('^\'', '', text)
-    text = re.sub('\'$', '', text)
-    text = re.sub('\'\s', ' ', text)
-    text = re.sub('\s\'', ' ', text)
-
-    # Perform pair replacements listed in the mapping.pair file
-    for fromx, tox in REPLACEMENTS:
-        text = ' ' + text + ' '
-        text = text.replace(fromx, tox)[1:-1]
-
-    # Remove multiple spaces
-    text = re.sub(' +', ' ', text)
-
-    # Concatenate numbers eg '1 3' -> '13'
-    tokens = text.split()
-    i = 1
-    while i < len(tokens):
-        if re.match(u'^\d+$', tokens[i]) and \
-                re.match(u'\d+$', tokens[i - 1]):
-            tokens[i - 1] += tokens[i]
-            del tokens[i]
-        else:
-            i += 1
-    text = ' '.join(tokens)
-
-    return text
diff --git a/convlab/dst/setsumbt/predict_user_actions.py b/convlab/dst/setsumbt/predict_user_actions.py
new file mode 100644
index 00000000..2c304a56
--- /dev/null
+++ b/convlab/dst/setsumbt/predict_user_actions.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Predict dataset user action using SetSUMBT Model"""
+
+from copy import deepcopy
+import os
+import json
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+from convlab.util.custom_util import flatten_acts as flatten
+from convlab.util import load_dataset, load_policy_data
+from convlab.dst.setsumbt import SetSUMBTTracker
+
+
+def flatten_acts(acts: dict) -> list:
+    """
+    Flatten dictionary actions.
+
+    Args:
+        acts: Dictionary acts
+
+    Returns:
+        flat_acts: Flattened actions
+    """
+    acts = flatten(acts)
+    flat_acts = []
+    for intent, domain, slot, value in acts:
+        flat_acts.append([intent,
+                          domain,
+                          slot if slot != 'none' else '',
+                          value.lower() if value != 'none' else ''])
+
+    return flat_acts
+
+
+def get_user_actions(context: list, system_acts: list) -> list:
+    """
+    Extract user actions from the data.
+
+    Args:
+        context: Previous dialogue turns.
+        system_acts: List of flattened system actions.
+
+    Returns:
+        user_acts: List of flattened user actions.
+    """
+    user_acts = context[-1]['dialogue_acts']
+    user_acts = flatten_acts(user_acts)
+    if len(context) == 3:
+        prev_state = context[-3]['state']
+        cur_state = context[-1]['state']
+        for domain, substate in cur_state.items():
+            for slot, value in substate.items():
+                if prev_state[domain][slot] != value:
+                    act = ['inform', domain, slot, value]
+                    if act not in user_acts and act not in system_acts:
+                        user_acts.append(act)
+
+    return user_acts
+
+
+def extract_dataset(dataset: str = 'multiwoz21') -> list:
+    """
+    Extract acts and utterances from the dataset.
+
+    Args:
+        dataset: Dataset name
+
+    Returns:
+        data: Extracted data
+    """
+    data = load_dataset(dataset_name=dataset)
+    raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test']
+
+    dialogue = list()
+    data = list()
+    for turn in raw_data:
+        state = dict()
+        state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else ''
+        state['utterance'] = turn['context'][-1]['utterance']
+        state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {}
+        state['system_actions'] = flatten_acts(state['system_actions'])
+        state['user_actions'] = get_user_actions(turn['context'], state['system_actions'])
+        dialogue.append(state)
+        if turn['terminated']:
+            data.append(dialogue)
+            dialogue = list()
+
+    return data
+
+
+def unflatten_acts(acts: list) -> dict:
+    """
+    Convert acts from flat list format to dict format.
+
+    Args:
+        acts: List of flat actions.
+
+    Returns:
+        unflat_acts: Dictionary of acts.
+    """
+    binary_acts = []
+    cat_acts = []
+    for intent, domain, slot, value in acts:
+        include = True if (domain == 'general') or (slot != 'none') else False
+        if include and (value == '' or value == 'none' or intent == 'request'):
+            binary_acts.append({'intent': intent,
+                                'domain': domain,
+                                'slot': slot if slot != 'none' else ''})
+        elif include:
+            cat_acts.append({'intent': intent,
+                             'domain': domain,
+                             'slot': slot if slot != 'none' else '',
+                             'value': value})
+
+    unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()}
+
+    return unflat_acts
+
+
+def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list:
+    """
+    Predict the user actions using the SetSUMBT Tracker.
+
+    Args:
+        data: List of dialogues.
+        tracker: SetSUMBT Tracker
+
+    Returns:
+        predict_result: List of turns containing predictions and true user actions.
+    """
+    tracker.init_session()
+    predict_result = []
+    for dial_idx, dialogue in enumerate(data):
+        for turn_idx, state in enumerate(dialogue):
+            sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx}
+
+            tracker.state['history'].append(['sys', state['system_utterance']])
+            predicted_state = deepcopy(tracker.update(state['utterance']))
+            tracker.state['history'].append(['usr', state['utterance']])
+            tracker.state['system_action'] = state['system_actions']
+
+            sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])}
+            sample['dialogue_acts'] = unflatten_acts(state['user_actions'])
+
+            predict_result.append(sample)
+
+        tracker.init_session()
+
+    return predict_result
+
+
+if __name__ =="__main__":
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
+    parser.add_argument('--model_path', type=str, help='Path to model dir')
+    args = parser.parse_args()
+
+    dataset = extract_dataset(args.dataset_name)
+    tracker = SetSUMBTTracker(args.model_path)
+    predict_results = predict_user_acts(dataset, tracker)
+
+    with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer:
+        json.dump(predict_results, writer, indent=2)
+        writer.close()
diff --git a/convlab/dst/setsumbt/process_mwoz_data.py b/convlab/dst/setsumbt/process_mwoz_data.py
deleted file mode 100755
index 701a5236..00000000
--- a/convlab/dst/setsumbt/process_mwoz_data.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-import json
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-import torch
-from tqdm import tqdm
-
-from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
-from convlab.util.multiwoz.lexicalize import deflat_da, flat_da
-
-
-def load_data(path):
-    with open(path, 'r') as reader:
-        data = json.load(reader)
-        reader.close()
-    
-    return data
-
-
-def load_tracker(model_checkpoint):
-    model = SetSUMBTTracker(model_path=model_checkpoint)
-    model.init_session()
-
-    return model
-
-
-def process_dialogue(dial, model, get_full_belief_state):
-    model.store_full_belief_state = get_full_belief_state
-    model.init_session()
-
-    model.state['history'].append(['sys', ''])
-    processed_dial = []
-    belief_state = {}
-    for turn in dial:
-        if not turn['metadata']:
-            state = model.update(turn['text'])
-            model.state['history'].append(['usr', turn['text']])
-            
-            acts = model.state['user_action']
-            acts = [[val.replace('-', ' ') for val in act] for act in acts]
-            acts = flat_da(acts)
-            acts = deflat_da(acts)
-            turn['dialog_act'] = acts
-        else:
-            model.state['history'].append(['sys', turn['text']])
-            turn['metadata'] = model.state['belief_state']
-        
-        if get_full_belief_state:
-            for slot, probs in model.full_belief_state.items():
-                if slot not in belief_state:
-                    belief_state[slot] = [probs[0]]
-                else:
-                    belief_state[slot].append(probs[0])
-        
-        processed_dial.append(turn)
-    
-    if get_full_belief_state:
-        belief_state = {slot: torch.cat(probs, 0).cpu() for slot, probs in belief_state.items()}
-
-    return processed_dial, belief_state
-
-
-def process_dialogues(data, model, get_full_belief_state=False):
-    processed_data = {}
-    belief_states = {}
-    for dial_id, dial in tqdm(data.items()):
-        dial['log'], bs = process_dialogue(dial['log'], model, get_full_belief_state)
-        processed_data[dial_id] = dial
-        if get_full_belief_state:
-            belief_states[dial_id] = bs
-
-    return processed_data, belief_states
-
-
-def get_arguments():
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-
-    parser.add_argument('--model_path')
-    parser.add_argument('--data_path')
-    parser.add_argument('--get_full_belief_state', action='store_true')
-
-    return parser.parse_args()
-
-
-if __name__ == "__main__":
-    args = get_arguments()
-
-    print('Loading data and model...')
-    data = load_data(os.path.join(args.data_path, 'data.json'))
-    model = load_tracker(args.model_path)
-
-    print('Processing data...\n')
-    data, belief_states = process_dialogues(data, model, get_full_belief_state=args.get_full_belief_state)
-    
-    print('Saving results...\n')
-    torch.save(belief_states, os.path.join(args.data_path, 'setsumbt_belief_states.bin'))
-    with open(os.path.join(args.data_path, 'setsumbt_data.json'), 'w') as writer:
-        json.dump(data, writer, indent=2)
-        writer.close()
diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py
index b9c9a75b..e45bf129 100644
--- a/convlab/dst/setsumbt/run.py
+++ b/convlab/dst/setsumbt/run.py
@@ -33,8 +33,8 @@ def main():
     if args.run_nbt:
         from convlab.dst.setsumbt.do.nbt import main
         main(args, config)
-    if args.run_calibration:
-        from convlab.dst.setsumbt.do.calibration import main
+    if args.run_evaluation:
+        from convlab.dst.setsumbt.do.evaluate import main
         main(args, config)
 
 
diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py
new file mode 100644
index 00000000..6b620247
--- /dev/null
+++ b/convlab/dst/setsumbt/tracker.py
@@ -0,0 +1,446 @@
+import os
+import json
+import copy
+import logging
+
+import torch
+import transformers
+from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer
+
+from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT
+from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings
+from convlab.dst.dst import DST
+from convlab.util.custom_util import model_downloader
+
+USE_CUDA = torch.cuda.is_available()
+transformers.logging.set_verbosity_error()
+
+
+class SetSUMBTTracker(DST):
+    """SetSUMBT Tracker object for Convlab dialogue system"""
+
+    def __init__(self,
+                 model_path: str = "",
+                 model_type: str = "roberta",
+                 return_turn_pooled_representation: bool = False,
+                 return_confidence_scores: bool = False,
+                 confidence_threshold='auto',
+                 return_belief_state_entropy: bool = False,
+                 return_belief_state_mutual_info: bool = False,
+                 store_full_belief_state: bool = False):
+        """
+        Args:
+            model_path: Model path or download URL
+            model_type: Transformer type (roberta/bert)
+            return_turn_pooled_representation: If true a turn level pooled representation is returned
+            return_confidence_scores: If true act confidence scores are included in the state
+            confidence_threshold: Confidence threshold value for constraints or option auto
+            return_belief_state_entropy: If true belief state distribution entropies are included in the state
+            return_belief_state_mutual_info: If true belief state distribution mutual infos are included in the state
+            store_full_belief_state: If true full belief state is stored within tracker object
+        """
+        super(SetSUMBTTracker, self).__init__()
+
+        self.model_type = model_type
+        self.model_path = model_path
+        self.return_turn_pooled_representation = return_turn_pooled_representation
+        self.return_confidence_scores = return_confidence_scores
+        self.confidence_threshold = confidence_threshold
+        self.return_belief_state_entropy = return_belief_state_entropy
+        self.return_belief_state_mutual_info = return_belief_state_mutual_info
+        self.store_full_belief_state = store_full_belief_state
+        if self.store_full_belief_state:
+            self.full_belief_state = {}
+        self.info_dict = {}
+
+        # Download model if needed
+        if not os.path.exists(self.model_path):
+            # Get path /.../convlab/dst/setsumbt/multiwoz/models
+            download_path = os.path.dirname(os.path.abspath(__file__))
+            download_path = os.path.join(download_path, 'models')
+            if not os.path.exists(download_path):
+                os.mkdir(download_path)
+            model_downloader(download_path, self.model_path)
+            # Downloadable model path format http://.../setsumbt_model_name.zip
+            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
+            self.model_path = os.path.join(download_path, self.model_path)
+
+        # Select model type based on the encoder
+        if model_type == "roberta":
+            self.config = RobertaConfig.from_pretrained(self.model_path)
+            self.tokenizer = RobertaTokenizer
+            self.model = RobertaSetSUMBT
+        elif model_type == "bert":
+            self.config = BertConfig.from_pretrained(self.model_path)
+            self.tokenizer = BertTokenizer
+            self.model = BertSetSUMBT
+        else:
+            logging.debug("Name Error: Not Implemented")
+
+        self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
+
+        self.load_weights()
+
+    def load_weights(self):
+        """Load model weights and model ontology"""
+        logging.info('Loading SetSUMBT pretrained model.')
+        self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name)
+        logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.')
+        self.model = self.model.from_pretrained(self.model_path, config=self.config)
+        logging.info(f'Model loaded from {self.model_path}.')
+
+        # Transfer model to compute device and setup eval environment
+        self.model = self.model.to(self.device)
+        self.model.eval()
+        logging.info(f'Model transferred to device: {self.device}')
+
+        logging.info('Loading model ontology')
+        f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r')
+        self.ontology = json.load(f)
+        f.close()
+
+        db = torch.load(os.path.join(self.model_path, 'database', 'test.db'))
+        set_ontology_embeddings(self.model, db)
+
+        if self.return_confidence_scores:
+            logging.info('Model returns user action and belief state confidence scores.')
+            self.get_thresholds(self.confidence_threshold)
+            logging.info('Uncertain Querying set up and thresholds set up at:')
+            logging.info(self.confidence_thresholds)
+        if self.return_belief_state_entropy:
+            logging.info('Model returns belief state distribution entropy scores (Total uncertainty).')
+        if self.return_belief_state_mutual_info:
+            logging.info('Model returns belief state distribution mutual information scores (Knowledge uncertainty).')
+        logging.info('Ontology loaded successfully.')
+
+    def get_thresholds(self, threshold='auto') -> dict:
+        """
+        Setup dictionary of domain specific confidence thresholds
+
+        Args:
+            threshold: Threshold value or option auto
+
+        Returns:
+            confidence_thresholds: Domain specific confidence thresholds
+        """
+        self.confidence_thresholds = dict()
+        for domain, substate in self.ontology.items():
+            for slot, slot_info in substate.items():
+                # Auto thresholds are set based on the number of value candidates per slot
+                if domain not in self.confidence_thresholds:
+                    self.confidence_thresholds[domain] = dict()
+                if threshold == 'auto':
+                    thres = 1.0 / (float(len(slot_info['possible_values'])) - 2.1)
+                    self.confidence_thresholds[domain][slot] = max(0.05, thres)
+                else:
+                    self.confidence_thresholds[domain][slot] = max(0.05, threshold)
+
+        return self.confidence_thresholds
+
+    def init_session(self):
+        self.state = dict()
+        self.state['belief_state'] = dict()
+        self.state['booked'] = dict()
+        for domain, substate in self.ontology.items():
+            self.state['belief_state'][domain] = dict()
+            for slot, slot_info in substate.items():
+                if slot_info['possible_values'] and slot_info['possible_values'] != ['?']:
+                    self.state['belief_state'][domain][slot] = ''
+            self.state['booked'][domain] = list()
+        self.state['history'] = []
+        self.state['system_action'] = []
+        self.state['user_action'] = []
+        self.state['terminated'] = False
+        self.active_domains = {}
+        self.hidden_states = None
+        self.info_dict = {}
+
+    def update(self, user_act: str = '') -> dict:
+        """
+        Update user actions and dialogue and belief states.
+
+        Args:
+            user_act:
+
+        Returns:
+
+        """
+        prev_state = self.state
+        _output = self.predict(self.get_features(user_act))
+
+        # Format state entropy
+        if _output[5] is not None:
+            state_entropy = dict()
+            for slot, e in _output[5].items():
+                domain, slot = slot.split('-', 1)
+                if domain not in state_entropy:
+                    state_entropy[domain] = dict()
+                state_entropy[domain][slot] = e
+        else:
+            state_entropy = None
+
+        # Format state mutual information
+        if _output[6] is not None:
+            state_mutual_info = dict()
+            for slot, mi in _output[6].items():
+                domain, slot = slot.split('-', 1)
+                if domain not in state_mutual_info:
+                    state_mutual_info[domain] = dict()
+                state_mutual_info[domain][slot] = mi[0, 0]
+        else:
+            state_mutual_info = None
+
+        # Format all confidence scores
+        belief_state_confidence = None
+        if _output[4] is not None:
+            belief_state_confidence = dict()
+            belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4]
+            for slot, p in belief_state_conf.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in belief_state_confidence:
+                    belief_state_confidence[domain] = dict()
+                if slot not in belief_state_confidence[domain]:
+                    belief_state_confidence[domain][slot] = dict()
+                belief_state_confidence[domain][slot]['inform'] = p
+
+            for slot, p in request_probs.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in belief_state_confidence:
+                    belief_state_confidence[domain] = dict()
+                if slot not in belief_state_confidence[domain]:
+                    belief_state_confidence[domain][slot] = dict()
+                belief_state_confidence[domain][slot]['request'] = p
+
+            for domain, p in active_domain_probs.items():
+                if domain not in belief_state_confidence:
+                    belief_state_confidence[domain] = dict()
+                belief_state_confidence[domain]['none'] = {'inform': p}
+
+            if 'general' not in belief_state_confidence:
+                belief_state_confidence['general'] = dict()
+            belief_state_confidence['general']['none'] = general_act_probs
+
+        # Get new domain activation actions
+        new_domains = [d for d, active in _output[1].items() if active]
+        new_domains = [d for d in new_domains if not self.active_domains.get(d, False)]
+        self.active_domains = _output[1]
+
+        user_acts = _output[2]
+        for domain in new_domains:
+            user_acts.append(['inform', domain, 'none', 'none'])
+
+        new_belief_state = copy.deepcopy(prev_state['belief_state'])
+        for domain, substate in _output[0].items():
+            for slot, value in substate.items():
+                value = '' if value == 'none' else value
+                value = 'dontcare' if value == 'do not care' else value
+                value = 'guesthouse' if value == 'guest house' else value
+
+                if domain not in new_belief_state:
+                    if domain == 'bus':
+                        continue
+                    else:
+                        logging.debug('Error: domain <{}> not in belief state'.format(domain))
+
+                # Uncertainty clipping of state
+                if belief_state_confidence is not None:
+                    threshold = self.confidence_thresholds[domain][slot]
+                    if belief_state_confidence[domain][slot].get('inform', 1.0) < threshold:
+                        value = ''
+
+                new_belief_state[domain][slot] = value
+                if prev_state['belief_state'][domain][slot] != value:
+                    user_acts.append(['inform', domain, slot, value])
+                else:
+                    bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>'
+                    logging.debug(bug)
+
+        new_state = copy.deepcopy(dict(prev_state))
+        new_state['belief_state'] = new_belief_state
+        new_state['active_domains'] = self.active_domains
+        if belief_state_confidence is not None:
+            new_state['belief_state_probs'] = belief_state_confidence
+        if state_entropy is not None:
+            new_state['entropy'] = state_entropy
+        if state_mutual_info is not None:
+            new_state['mutual_information'] = state_mutual_info
+
+        user_acts = [act for act in user_acts if act not in new_state['system_action']]
+        new_state['user_action'] = user_acts
+
+        if _output[3] is not None:
+            new_state['turn_pooled_representation'] = _output[3]
+
+        self.state = new_state
+        self.info_dict = copy.deepcopy(dict(new_state))
+
+        return self.state
+
+    def predict(self, features: dict) -> tuple:
+        """
+        Model forward pass and prediction post processing.
+
+        Args:
+            features: Dictionary of model input features
+
+        Returns:
+            out: Model predictions and uncertainty features
+        """
+        state_mutual_info = None
+        with torch.no_grad():
+            turn_pooled_representation = None
+            if self.return_turn_pooled_representation:
+                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
+                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
+                                      get_turn_pooled_representation=True)
+                belief_state = _outputs[0]
+                request_probs = _outputs[1]
+                active_domain_probs = _outputs[2]
+                general_act_probs = _outputs[3]
+                self.hidden_states = _outputs[4]
+                turn_pooled_representation = _outputs[5]
+            elif self.return_belief_state_mutual_info:
+                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
+                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
+                                      get_turn_pooled_representation=True, calculate_state_mutual_info=True)
+                belief_state = _outputs[0]
+                request_probs = _outputs[1]
+                active_domain_probs = _outputs[2]
+                general_act_probs = _outputs[3]
+                self.hidden_states = _outputs[4]
+                state_mutual_info = _outputs[5]
+            else:
+                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
+                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
+                                      get_turn_pooled_representation=False)
+                belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs
+
+        # Convert belief state into dialog state
+        dialogue_state = dict()
+        for slot, probs in belief_state.items():
+            dom, slot = slot.split('-', 1)
+            if dom not in dialogue_state:
+                dialogue_state[dom] = dict()
+            val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()]
+            if val != 'none':
+                dialogue_state[dom][slot] = val
+
+        if self.store_full_belief_state:
+            self.full_belief_state = belief_state
+
+        # Obtain model output probabilities
+        if self.return_confidence_scores:
+            state_entropy = None
+            if self.return_belief_state_entropy:
+                state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()}
+                state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()}
+
+            # Confidence score is the max probability across all not "none" values candidates.
+            belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()}
+            _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()}
+            _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()}
+            _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()}
+            confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs)
+        else:
+            confidence_scores = None
+            state_entropy = None
+
+        # Construct request action prediction
+        request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
+        request_acts = [slot.split('-', 1) for slot in request_acts]
+        request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
+
+        # Construct active domain set
+        active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
+
+        # Construct general domain action
+        general_acts = general_act_probs[0, 0, :].argmax(-1).item()
+        general_acts = [[], ['bye'], ['thank']][general_acts]
+        general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
+
+        user_acts = request_acts + general_acts
+
+        out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores)
+        out += (state_entropy, state_mutual_info)
+        return out
+
+    def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor:
+        """
+        Compute relative entrop for a probability distribution
+
+        Args:
+            probs: Probability distributions
+
+        Returns:
+            entropy: Relative entropy
+        """
+        entropy = probs * torch.log(probs + 1e-8)
+        entropy = -entropy.sum()
+        # Maximum entropy of a K dimentional distribution is ln(K)
+        entropy /= torch.log(torch.tensor(probs.size(-1)).float())
+
+        return entropy
+
+    def get_features(self, user_act: str) -> dict:
+        """
+        Tokenize utterances and construct model input features
+
+        Args:
+            user_act: User action string
+
+        Returns:
+            features: Model input features
+        """
+        # Extract system utterance from dialog history
+        context = self.state['history']
+        if context:
+            if context[-1][0] != 'sys':
+                system_act = ''
+            else:
+                system_act = context[-1][-1]
+        else:
+            system_act = ''
+
+        # Tokenize dialog
+        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True,
+                                              max_length=self.config.max_turn_len, padding='max_length',
+                                              truncation='longest_first')
+
+        input_ids = torch.tensor(features['input_ids']).reshape(
+            1, 1, -1).to(self.device) if 'input_ids' in features else None
+        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
+            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
+        attention_mask = torch.tensor(features['attention_mask']).reshape(
+            1, 1, -1).to(self.device) if 'attention_mask' in features else None
+        features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
+
+        return features
+
+
+# if __name__ == "__main__":
+#     from convlab.policy.vector.vector_uncertainty import VectorUncertainty
+#     # from convlab.policy.vector.vector_binary import VectorBinary
+#     tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
+#                               return_confidence_scores=True, confidence_threshold='auto',
+#                               return_belief_state_entropy=True)
+#     vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds,
+#                                use_masking=True)
+#     # vector = VectorBinary()
+#     tracker.init_session()
+#
+#     state = tracker.update('hey. I need a cheap restaurant.')
+#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
+#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
+#     state = tracker.update('If you have something Asian that would be great.')
+#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
+#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
+#     tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'],
+#                                       ['inform', 'restaurant', 'name', 'the golden wok']]
+#     state = tracker.update('Great. Where are they located?')
+#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
+#     state = tracker.state
+#     state['terminated'] = False
+#     state['booked'] = {}
+#
+#     print(state)
+#     print(vector.state_vectorize(state))
diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils.py
index 75a6a1fe..51839552 100644
--- a/convlab/dst/setsumbt/utils.py
+++ b/convlab/dst/setsumbt/utils.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,57 +15,43 @@
 # limitations under the License.
 """SetSUMBT utils"""
 
-import re
 import os
 import shutil
 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-from glob import glob
 from datetime import datetime
 
-from google.cloud import storage
+from git import Repo
 
 
-def get_args(MODELS):
+def get_args(base_models: dict):
     # Get arguments
     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
 
     # Optional
-    parser.add_argument('--tensorboard_path',
-                        help='Path to tensorboard', default='')
+    parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='')
     parser.add_argument('--logging_path', help='Path for log file', default='')
-    parser.add_argument(
-        '--seed', help='Seed value for reproducability', default=0, type=int)
+    parser.add_argument('--seed', help='Seed value for reproducibility', default=0, type=int)
 
     # DATASET (Optional)
-    parser.add_argument(
-        '--dataset', help='Dataset Name: multiwoz21/simr', default='multiwoz21')
-    parser.add_argument('--shrink_active_domains', help='Shrink active domains to only well represented test set domains',
-                        action='store_true')
-    parser.add_argument(
-        '--data_dir', help='Data storage directory', default=None)
-    parser.add_argument(
-        '--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
-    parser.add_argument(
-        '--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
-    parser.add_argument(
-        '--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
-    parser.add_argument('--max_candidate_len',
-                        help='Maximum number of tokens per value candidate', default=12, type=int)
-    parser.add_argument('--force_processing', action='store_true',
-                        help='Force preprocessing of data.')
-    parser.add_argument('--data_sampling_size',
-                        help='Resampled dataset size', default=-1, type=int)
-    parser.add_argument('--use_descriptions', help='Use slot descriptions rather than slot names for embeddings',
+    parser.add_argument('--dataset', help='Dataset Name (See Convlab 3 unified format for possible datasets',
+                        default='multiwoz21')
+    parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0,
+                        type=float)
+    parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
+    parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
+    parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
+    parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12,
+                        type=int)
+    parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.')
+    parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int)
+    parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings',
                         action='store_true')
 
     # MODEL
     # Environment
-    parser.add_argument(
-        '--output_dir', help='Output storage directory', default=None)
-    parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta',
-                        default='roberta')
-    parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.',
-                        default=None)
+    parser.add_argument('--output_dir', help='Output storage directory', default=None)
+    parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta')
+    parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None)
     parser.add_argument('--candidate_embedding_model_name', default=None,
                         help='Name of the pretrained candidate embedding model.')
 
@@ -74,92 +60,73 @@ def get_args(MODELS):
                         action='store_true')
     parser.add_argument('--slot_attention_heads', help='Number of attention heads for slot conditioning',
                         default=12, type=int)
-    parser.add_argument('--dropout_rate', help='Dropout Rate',
-                        default=0.3, type=float)
-    parser.add_argument(
-        '--nbt_type', help='Belief Tracker type: gru/lstm', default='gru')
+    parser.add_argument('--dropout_rate', help='Dropout Rate', default=0.3, type=float)
+    parser.add_argument('--nbt_type', help='Belief Tracker type: gru/lstm', default='gru')
     parser.add_argument('--nbt_hidden_size', help='Hidden embedding size for the Neural Belief Tracker',
                         default=300, type=int)
-    parser.add_argument(
-        '--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int)
-    parser.add_argument(
-        '--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true')
+    parser.add_argument('--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int)
+    parser.add_argument('--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true')
     parser.add_argument('--distance_measure', default='cosine',
                         help='Similarity measure for candidate scoring: cosine/euclidean')
-    parser.add_argument(
-        '--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
-    parser.add_argument('--set_similarity', action='store_true',
-                        help='Set True to not use set similarity (Model tracks latent belief state as sequence and performs semantic similarity of sets)')
-    parser.add_argument('--set_pooling', help='Set pooling method for set similarity model using single embedding distances',
+    parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
+    parser.add_argument('--no_set_similarity', action='store_true', help='Set True to not use set similarity')
+    parser.add_argument('--set_pooling',
+                        help='Set pooling method for set similarity model using single embedding distances',
                         default='cnn')
-    parser.add_argument('--candidate_pooling', help='Pooling approach for non set based candidate representations: cls/mean',
+    parser.add_argument('--candidate_pooling',
+                        help='Pooling approach for non set based candidate representations: cls/mean',
                         default='mean')
-    parser.add_argument('--predict_actions', help='Model predicts user actions and active domain',
+    parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain',
                         action='store_true')
 
     # Loss
-    parser.add_argument('--loss_function', help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/distillation/distribution_distillation',
+    parser.add_argument('--loss_function',
+                        help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/...',
                         default='labelsmoothing')
     parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss',
                         type=float)
     parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss',
                         type=float)
-    parser.add_argument('--ensemble_smoothing',
-                        help='Ensemble distribution smoothing constant', type=float)
-    parser.add_argument('--annealing_base_temp', help='Ensemble Distribution destillation temp annealing base temp',
+    parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float)
+    parser.add_argument('--annealing_base_temp', help='Ensemble Distribution distillation temp annealing base temp',
                         type=float)
-    parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution destillation temp annealing cycle length',
+    parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution distillation temp annealing cycle length',
                         type=float)
-    parser.add_argument('--inhibiting_factor',
-                        help='Inhibiting factor for Inhibited Softmax CE', type=float)
-    parser.add_argument('--label_smoothing',
-                        help='Label smoothing coefficient.', type=float)
-    parser.add_argument(
-        '--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument(
-        '--user_request_loss_weight', help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument(
-        '--user_general_act_loss_weight', help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument(
-        '--active_domain_loss_weight', help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float)
+    parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0',
+                        type=float)
+    parser.add_argument('--user_request_loss_weight',
+                        help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--user_general_act_loss_weight',
+                        help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--active_domain_loss_weight',
+                        help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float)
 
     # TRAINING
-    parser.add_argument('--train_batch_size',
-                        help='Training Set Batch Size', default=4, type=int)
-    parser.add_argument('--max_training_steps', help='Maximum number of training update steps',
-                        default=-1, type=int)
+    parser.add_argument('--train_batch_size', help='Training Set Batch Size', default=8, type=int)
+    parser.add_argument('--max_training_steps', help='Maximum number of training update steps', default=-1, type=int)
     parser.add_argument('--gradient_accumulation_steps', default=1, type=int,
                         help='Number of batches accumulated for one update step')
-    parser.add_argument('--num_train_epochs',
-                        help='Number of training epochs', default=50, type=int)
+    parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int)
     parser.add_argument('--patience', help='Number of training steps without improving model before stopping.',
-                        default=25, type=int)
-    parser.add_argument(
-        '--weight_decay', help='Weight decay rate', default=0.01, type=float)
-    parser.add_argument('--learning_rate',
-                        help='Initial Learning Rate', default=5e-5, type=float)
-    parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler',
-                        default=0.2, type=float)
-    parser.add_argument(
-        '--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float)
-    parser.add_argument(
-        '--save_steps', help='Number of update steps between saving model', default=-1, type=int)
-    parser.add_argument(
-        '--keep_models', help='How many model checkpoints should be kept during training', default=1, type=int)
+                        default=20, type=int)
+    parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float)
+    parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float)
+    parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float)
+    parser.add_argument('--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float)
+    parser.add_argument('--save_steps', help='Number of update steps between saving model', default=-1, type=int)
+    parser.add_argument('--keep_models', help='How many model checkpoints should be kept during training',
+                        default=1, type=int)
 
     # CALIBRATION
-    parser.add_argument(
-        '--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float)
+    parser.add_argument('--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float)
 
     # EVALUATION
-    parser.add_argument('--dev_batch_size',
-                        help='Dev Set Batch Size', default=16, type=int)
-    parser.add_argument('--test_batch_size',
-                        help='Test Set Batch Size', default=16, type=int)
+    parser.add_argument('--dev_batch_size', help='Dev Set Batch Size', default=16, type=int)
+    parser.add_argument('--test_batch_size', help='Test Set Batch Size', default=16, type=int)
 
     # COMPUTING
-    parser.add_argument(
-        '--n_gpu', help='Number of GPUs to use', default=1, type=int)
+    parser.add_argument('--n_gpu', help='Number of GPUs to use', default=1, type=int)
     parser.add_argument('--fp16', action='store_true',
                         help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
     parser.add_argument('--fp16_opt_level', type=str, default='O1',
@@ -167,32 +134,29 @@ def get_args(MODELS):
                              "See details at https://nvidia.github.io/apex/amp.html")
 
     # ACTIONS
-    parser.add_argument('--run_nbt', help='Run NBT script',
-                        action='store_true')
-    parser.add_argument('--run_calibration',
-                        help='Run calibration', action='store_true')
+    parser.add_argument('--run_nbt', help='Run NBT script', action='store_true')
+    parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true')
 
     # RUN_NBT ACTIONS
-    parser.add_argument(
-        '--do_train', help='Perform training', action='store_true')
-    parser.add_argument(
-        '--do_eval', help='Perform model evaluation during training', action='store_true')
-    parser.add_argument(
-        '--do_test', help='Evaulate model on test data', action='store_true')
+    parser.add_argument('--do_train', help='Perform training', action='store_true')
+    parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true')
+    parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
     args = parser.parse_args()
 
-    # Setup default directories
-    if not args.data_dir:
-        args.data_dir = os.path.dirname(os.path.abspath(__file__))
-        args.data_dir = os.path.join(args.data_dir, 'data')
-        os.makedirs(args.data_dir, exist_ok=True)
+    # Simplify args
+    args.set_similarity = not args.no_set_similarity
+    args.use_descriptions = not args.no_descriptions
+    args.predict_actions = not args.no_action_prediction
 
+    # Setup default directories
     if not args.output_dir:
         args.output_dir = os.path.dirname(os.path.abspath(__file__))
         args.output_dir = os.path.join(args.output_dir, 'models')
 
-        name = 'SetSUMBT'
-        name += '-Acts' if args.predict_actions else ''
+        name = 'SetSUMBT' if args.set_similarity else 'SUMBT'
+        name += '+ActPrediction' if args.predict_actions else ''
+        name += '-' + args.dataset
+        name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else ''
         name += '-' + args.model_type
         name += '-' + args.nbt_type
         name += '-' + args.distance_measure
@@ -208,9 +172,6 @@ def get_args(MODELS):
             args.kl_scaling_factor = 0.001
         if not args.prior_constant:
             args.prior_constant = 1.0
-    if args.loss_function == 'inhibitedce':
-        if not args.inhibiting_factor:
-            args.inhibiting_factor = 1.0
     if args.loss_function == 'labelsmoothing':
         if not args.label_smoothing:
             args.label_smoothing = 0.05
@@ -233,10 +194,8 @@ def get_args(MODELS):
         if not args.active_domain_loss_weight:
             args.active_domain_loss_weight = 0.2
 
-    args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(
-        args.output_dir, 'tb_logs')
-    args.logging_path = args.logging_path if args.logging_path else os.path.join(
-        args.output_dir, 'run.log')
+    args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(args.output_dir, 'tb_logs')
+    args.logging_path = args.logging_path if args.logging_path else os.path.join(args.output_dir, 'run.log')
 
     # Default model_name's
     if not args.model_name_or_path:
@@ -250,30 +209,37 @@ def get_args(MODELS):
     if not args.candidate_embedding_model_name:
         args.candidate_embedding_model_name = args.model_name_or_path
 
-    if args.model_type in MODELS:
-        configClass = MODELS[args.model_type][-2]
+    if args.model_type in base_models:
+        config_class = base_models[args.model_type][-2]
     else:
         raise NameError('NotImplemented')
-    config = build_config(configClass, args)
+    config = build_config(config_class, args)
     return args, config
 
 
-def build_config(configClass, args):
-    if args.model_type == 'fasttext':
-        config = configClass.from_pretrained('bert-base-uncased')
-        config.model_type == 'fasttext'
-        config.fasttext_path = args.model_name_or_path
-        config.vocab_size = None
-    elif not os.path.exists(args.model_name_or_path):
-        config = configClass.from_pretrained(args.model_name_or_path)
+def get_git_info():
+    repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True)
+    branch_name = repo.active_branch.name
+    commit_hex = repo.head.object.hexsha
+
+    info = f"{branch_name}/{commit_hex}"
+    return info
+
+
+def build_config(config_class, args):
+    config = config_class.from_pretrained(args.model_name_or_path)
+    config.code_version = get_git_info()
+    if not os.path.exists(args.model_name_or_path):
         config.tokenizer_name = args.model_name_or_path
-    elif 'tod-bert' in args.model_name_or_path.lower():
-        config = configClass.from_pretrained(args.model_name_or_path)
+    try:
+        config.tokenizer_name = config.tokenizer_name
+    except AttributeError:
         config.tokenizer_name = args.model_name_or_path
-    else:
-        config = configClass.from_pretrained(args.model_name_or_path)
-    if args.candidate_embedding_model_name:
-        config.candidate_embedding_model_name = args.candidate_embedding_model_name
+    try:
+        config.candidate_embedding_model_name = config.candidate_embedding_model_name
+    except:
+        if args.candidate_embedding_model_name:
+            config.candidate_embedding_model_name = args.candidate_embedding_model_name
     config.max_dialogue_len = args.max_dialogue_len
     config.max_turn_len = args.max_turn_len
     config.max_slot_len = args.max_slot_len
diff --git a/convlab/policy/mle/loader.py b/convlab/policy/mle/loader.py
index ebc01a01..bb898ab4 100755
--- a/convlab/policy/mle/loader.py
+++ b/convlab/policy/mle/loader.py
@@ -2,6 +2,9 @@ import os
 import pickle
 import torch
 import torch.utils.data as data
+from copy import deepcopy
+
+from tqdm import tqdm
 
 from convlab.policy.vector.vector_binary import VectorBinary
 from convlab.util import load_policy_data, load_dataset
@@ -12,18 +15,20 @@ from convlab.policy.vector.dataset import ActDataset
 
 class PolicyDataVectorizer:
     
-    def __init__(self, dataset_name='multiwoz21', vector=None):
+    def __init__(self, dataset_name='multiwoz21', vector=None, dst=None):
         self.dataset_name = dataset_name
         if vector is None:
             self.vector = VectorBinary(dataset_name)
         else:
             self.vector = vector
+        self.dst = dst
         self.process_data()
 
     def process_data(self):
-
-        processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
-                                     f'processed_data/{self.dataset_name}_{type(self.vector).__name__}')
+        name = f"{self.dataset_name}_"
+        name += f"{type(self.dst).__name__}_" if self.dst is not None else ""
+        name += f"{type(self.vector).__name__}"
+        processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), name)
         if os.path.exists(processed_dir):
             print('Load processed data file')
             self._load_data(processed_dir)
@@ -42,15 +47,27 @@ class PolicyDataVectorizer:
             self.data[split] = []
             raw_data = data_split[split]
 
-            for data_point in raw_data:
-                state = default_state()
+            if self.dst is not None:
+                self.dst.init_session()
+
+            for data_point in tqdm(raw_data):
+                if self.dst is None:
+                    state = default_state()
+
+                    state['belief_state'] = data_point['context'][-1]['state']
+                    state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
+                else:
+                    last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else ''
+                    self.dst.state['history'].append(['sys', last_system_utt])
 
-                state['belief_state'] = data_point['context'][-1]['state']
-                state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
-                last_system_act = data_point['context'][-2]['dialogue_acts'] \
-                    if len(data_point['context']) > 1 else {}
+                    usr_utt = data_point['context'][-1]['utterance']
+                    state = deepcopy(self.dst.update(usr_utt))
+                    self.dst.state['history'].append(['usr', usr_utt])
+                last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {}
                 state['system_action'] = flatten_acts(last_system_act)
                 state['terminated'] = data_point['terminated']
+                if self.dst is not None and state['terminated']:
+                    self.dst.init_session()
                 state['booked'] = data_point['booked']
                 dialogue_act = flatten_acts(data_point['dialogue_acts'])
 
diff --git a/convlab/policy/mle/train.py b/convlab/policy/mle/train.py
index 2b82a476..c2477760 100755
--- a/convlab/policy/mle/train.py
+++ b/convlab/policy/mle/train.py
@@ -137,15 +137,6 @@ class MLE_Trainer(MLE_Trainer_Abstract):
     def __init__(self, manager, vector, cfg):
         self._init_data(manager, cfg)
 
-        try:
-            self.use_entropy = manager.use_entropy
-            self.use_mutual_info = manager.use_mutual_info
-            self.use_confidence_scores = manager.use_confidence_scores
-        except:
-            self.use_entropy = False
-            self.use_mutual_info = False
-            self.use_confidence_scores = False
-
         # override the loss defined in the MLE_Trainer_Abstract to support pos_weight
         pos_weight = cfg['pos_weight'] * torch.ones(vector.da_dim).to(device=DEVICE)
         self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
@@ -161,6 +152,10 @@ def arg_parser():
     parser.add_argument("--seed", type=int, default=0)
     parser.add_argument("--eval_freq", type=int, default=1)
     parser.add_argument("--dataset_name", type=str, default="multiwoz21")
+    parser.add_argument("--use_masking", action='store_true')
+
+    parser.add_argument("--dst", type=str, default=None)
+    parser.add_argument("--dst_args", type=str, default=None)
 
     args = parser.parse_args()
     return args
@@ -181,8 +176,28 @@ if __name__ == '__main__':
     set_seed(args.seed)
     logging.info(f"Seed used: {args.seed}")
 
-    vector = VectorBinary(dataset_name=args.dataset_name, use_masking=False)
-    manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector)
+    if args.dst is None:
+        vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
+        dst = None
+    elif args.dst == "setsumbt":
+        dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ')
+                    if '=' in arg] if args.dst_args is not None else []
+        dst_args = {key: eval(value) for key, value in dst_args}
+        from convlab.dst.setsumbt import SetSUMBTTracker
+        dst = SetSUMBTTracker(**dst_args)
+        if dst.return_confidence_scores:
+            from convlab.policy.vector.vector_uncertainty import VectorUncertainty
+            vector = VectorUncertainty(dataset_name=args.dataset_name, use_masking=args.use_masking,
+                                       manually_add_entity_names=False,
+                                       use_confidence_scores=dst.return_confidence_scores,
+                                       confidence_thresholds=dst.confidence_thresholds,
+                                       use_state_total_uncertainty=dst.return_belief_state_entropy,
+                                       use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info)
+        else:
+            vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
+    else:
+        raise NameError(f"Tracker: {args.tracker} not implemented.")
+    manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst)
     agent = MLE_Trainer(manager, vector, cfg)
 
     logging.info('Start training')
diff --git a/convlab/policy/ppo/setsumbt_end_baseline_config.json b/convlab/policy/ppo/setsumbt_config.json
similarity index 53%
rename from convlab/policy/ppo/setsumbt_end_baseline_config.json
rename to convlab/policy/ppo/setsumbt_config.json
index ea84dd76..5a13ee82 100644
--- a/convlab/policy/ppo/setsumbt_end_baseline_config.json
+++ b/convlab/policy/ppo/setsumbt_config.json
@@ -1,22 +1,22 @@
 {
 	"model": {
-		"load_path": "supervised",
+		"load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-11-13-12-56-34/save/supervised",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
 		"batchsz": 1000,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
-		"process_num": 4,
+		"process_num": 2,
 		"num_eval_dialogues": 500,
-		"sys_semantic_to_usr": false
+		"sys_semantic_to_usr": true
 	},
 	"vectorizer_sys": {
 		"uncertainty_vector_mul": {
-			"class_path": "convlab.policy.vector.vector_multiwoz_uncertainty.MultiWozVector",
+			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
 			"ini_params": {
 				"use_masking": false,
-				"manually_add_entity_names": false,
+				"manually_add_entity_names": true,
 				"seed": 0
 			}
 		}
@@ -24,12 +24,9 @@
 	"nlu_sys": {},
 	"dst_sys": {
 		"setsumbt-mul": {
-			"class_path": "convlab.dst.setsumbt.multiwoz.Tracker.SetSUMBTTracker",
+			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
 			"ini_params": {
-				"model_path": "https://zenodo.org/record/5497808/files/setsumbt_end.zip",
-				"get_confidence_scores": true,
-				"return_mutual_info": false,
-				"return_entropy": true
+				"model_path": "/gpfs/project/niekerk/models/setsumbt_models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0-30-08-22-15-00"
 			}
 		}
 	},
@@ -41,16 +38,7 @@
 			}
 		}
 	},
-	"nlu_usr": {
-		"BERTNLU": {
-			"class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU",
-			"ini_params": {
-				"mode": "sys",
-				"config_file": "multiwoz_sys_context.json",
-				"model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip"
-			}
-		}
-	},
+	"nlu_usr": {},
 	"dst_usr": {},
 	"policy_usr": {
 		"RulePolicy": {
@@ -65,7 +53,7 @@
 			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
 			"ini_params": {
 				"is_user": true,
-				"label_noise": 0.0,
+				"label_noise": 0.05,
 				"text_noise": 0.0
 			}
 		}
diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/setsumbt_unc_config.json
new file mode 100644
index 00000000..6b7d115a
--- /dev/null
+++ b/convlab/policy/ppo/setsumbt_unc_config.json
@@ -0,0 +1,65 @@
+{
+	"model": {
+		"load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-11-10-10-37-30/save/supervised",
+		"pretrained_load_path": "",
+		"use_pretrained_initialisation": false,
+		"batchsz": 1000,
+		"seed": 0,
+		"epoch": 50,
+		"eval_frequency": 5,
+		"process_num": 2,
+		"num_eval_dialogues": 500,
+		"sys_semantic_to_usr": true
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_uncertainty.VectorUncertainty",
+			"ini_params": {
+				"use_masking": false,
+				"manually_add_entity_names": true,
+				"seed": 0,
+				"use_confidence_scores": true,
+				"use_state_knowledge_uncertainty": true
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"setsumbt-mul": {
+			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
+			"ini_params": {
+				"model_path": "/gpfs/project/niekerk/models/setsumbt_models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0-30-08-22-15-00",
+				"return_confidence_scores": true,
+				"return_belief_state_mutual_info": true
+			}
+		}
+	},
+	"sys_nlg": {
+		"TemplateNLG": {
+			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
+			"ini_params": {
+				"is_user": false
+			}
+		}
+	},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"RulePolicy": {
+			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
+			"ini_params": {
+				"character": "usr"
+			}
+		}
+	},
+	"usr_nlg": {
+		"TemplateNLG": {
+			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
+			"ini_params": {
+				"is_user": true,
+				"label_noise": 0.05,
+				"text_noise": 0.0
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py
index 50b06aab..45681169 100755
--- a/convlab/policy/ppo/train.py
+++ b/convlab/policy/ppo/train.py
@@ -199,7 +199,7 @@ if __name__ == '__main__':
     logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
         init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
 
-    args = [('model', 'seed', seed)] if seed is not None else list()
+    args = [('model', 'seed', seed)] if seed else list()
 
     environment_config = load_config_file(path)
     save_config(vars(parser.parse_args()), environment_config, config_save_path)
@@ -228,14 +228,6 @@ if __name__ == '__main__':
 
     env, sess = env_config(conf, policy_sys)
 
-    # Setup uncertainty thresholding
-    if env.sys_dst:
-        try:
-            if env.sys_dst.use_confidence_scores:
-                policy_sys.vector.setup_uncertain_query(env.sys_dst.thresholds)
-        except:
-            logging.info('Uncertainty threshold not set.')
-
     policy_sys.current_time = current_time
     policy_sys.log_dir = config_save_path.replace('configs', 'logs')
     policy_sys.save_dir = save_path
@@ -261,7 +253,7 @@ if __name__ == '__main__':
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
+            logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60)
 
             eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
diff --git a/convlab/policy/vector/dataset.py b/convlab/policy/vector/dataset.py
index b1481854..0aa1b7ad 100755
--- a/convlab/policy/vector/dataset.py
+++ b/convlab/policy/vector/dataset.py
@@ -18,26 +18,6 @@ class ActDataset(data.Dataset):
         return self.num_total
 
 
-class ActDatasetKG(data.Dataset):
-    def __init__(self, action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch):
-        self.action_batch = action_batch
-        self.a_masks = a_masks
-        self.current_domain_mask_batch = current_domain_mask_batch
-        self.non_current_domain_mask_batch = non_current_domain_mask_batch
-        self.num_total = len(action_batch)
-
-    def __getitem__(self, index):
-        action = self.action_batch[index]
-        action_mask = self.a_masks[index]
-        current_domain_mask = self.current_domain_mask_batch[index]
-        non_current_domain_mask = self.non_current_domain_mask_batch[index]
-
-        return action, action_mask, current_domain_mask, non_current_domain_mask, index
-
-    def __len__(self):
-        return self.num_total
-
-
 class ActStateDataset(data.Dataset):
     def __init__(self, s_s, a_s, next_s):
         self.s_s = s_s
diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py
index 89f22203..8b7d8ff0 100644
--- a/convlab/policy/vector/vector_base.py
+++ b/convlab/policy/vector/vector_base.py
@@ -2,11 +2,10 @@
 import os
 import sys
 import numpy as np
-import logging
 
 from copy import deepcopy
 from convlab.policy.vec import Vector
-from convlab.util.custom_util import flatten_acts, timeout
+from convlab.util.custom_util import flatten_acts
 from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
 from convlab.util import load_ontology, load_database, load_dataset
 
@@ -23,20 +22,18 @@ class VectorBase(Vector):
 
         super().__init__()
 
-        logging.info(f"Vectorizer: Data set used is {dataset_name}")
         self.set_seed(seed)
         self.ontology = load_ontology(dataset_name)
         try:
             # execute to make sure that the database exists or is downloaded otherwise
-            if dataset_name == "multiwoz21":
-                load_database(dataset_name)
+            load_database(dataset_name)
             # the following two lines are needed for pickling correctly during multi-processing
             exec(f'from data.unified_datasets.{dataset_name}.database import Database')
             self.db = eval('Database()')
             self.db_domains = self.db.domains
         except Exception as e:
             self.db = None
-            self.db_domains = []
+            self.db_domains = None
             print(f"VectorBase: {e}")
 
         self.dataset_name = dataset_name
@@ -275,10 +272,6 @@ class VectorBase(Vector):
         2. If there is an entity available, can not say NoOffer or NoBook
         '''
         mask_list = np.zeros(self.da_dim)
-
-        if number_entities_dict is None:
-            return mask_list
-
         for i in range(self.da_dim):
             action = self.vec2act[i]
             domain, intent, slot, value = action.split('-')
diff --git a/convlab/policy/vector/vector_binary.py b/convlab/policy/vector/vector_binary.py
index 3671178b..e780dc64 100755
--- a/convlab/policy/vector/vector_binary.py
+++ b/convlab/policy/vector/vector_binary.py
@@ -8,7 +8,7 @@ from .vector_base import VectorBase
 class VectorBinary(VectorBase):
 
     def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True,
-                 seed=0):
+                 seed=0, **kwargs):
 
         super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
 
diff --git a/convlab/policy/vector/vector_multiwoz_uncertainty.py b/convlab/policy/vector/vector_multiwoz_uncertainty.py
deleted file mode 100644
index 6a0850f4..00000000
--- a/convlab/policy/vector/vector_multiwoz_uncertainty.py
+++ /dev/null
@@ -1,238 +0,0 @@
-# -*- coding: utf-8 -*-
-import sys
-import os
-import numpy as np
-import logging
-from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
-from .vector_binary import VectorBinary as VectorBase
-
-DEFAULT_INTENT_FILEPATH = os.path.join(
-    os.path.dirname(os.path.dirname(os.path.dirname(
-        os.path.dirname(os.path.abspath(__file__))))),
-    'data/multiwoz/trackable_intent.json'
-)
-
-
-SLOT_MAP = {'taxi_types': 'car type'}
-
-
-class MultiWozVector(VectorBase):
-
-    def __init__(self, voc_file=None, voc_opp_file=None, character='sys',
-                 intent_file=DEFAULT_INTENT_FILEPATH,
-                 use_confidence_scores=False,
-                 use_entropy=False,
-                 use_mutual_info=False,
-                 use_masking=False,
-                 manually_add_entity_names=False,
-                 seed=0,
-                 shrink=False):
-
-        self.use_confidence_scores = use_confidence_scores
-        self.use_entropy = use_entropy
-        self.use_mutual_info = use_mutual_info
-        self.thresholds = None
-
-        super().__init__(voc_file, voc_opp_file, character, intent_file, use_masking, manually_add_entity_names, seed)
-
-    def get_state_dim(self):
-        self.belief_state_dim = 0
-        for domain in self.belief_domains:
-            for slot in default_state()['belief_state'][domain.lower()]['semi']:
-                # Dim 1 - indicator/confidence score
-                # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc)
-                slot_dim = 1 if not self.use_entropy else 2
-                slot_dim += 1 if self.use_mutual_info else 0
-                self.belief_state_dim += slot_dim
-
-        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
-            len(self.db_domains) + 6 * len(self.db_domains) + 1
-
-    def dbquery_domain(self, domain):
-        """
-        query entities of specified domain
-        Args:
-            domain string:
-                domain to query
-        Returns:
-            entities list:
-                list of entities of the specified domain
-        """
-        # Get all user constraints
-        constraint = self.state[domain.lower()]['semi']
-        constraint = {k: i for k, i in constraint.items() if i and i not in ['dontcare', "do n't care", "do not care"]}
-
-        # Remove constraints for which the uncertainty is high
-        if self.confidence_scores is not None and self.use_confidence_scores and self.thresholds != None:
-            # Collect threshold values for each domain-slot pair
-            thres = self.thresholds.get(domain.lower(), {})
-            thres = {k: thres.get(k, 0.05) for k in constraint}
-            # Get confidence scores for each constraint
-            probs = self.confidence_scores.get(domain.lower(), {})
-            probs = {k: probs.get(k, {}).get('inform', 1.0)
-                     for k in constraint}
-
-            # Filter out constraints for which confidence is lower than threshold
-            constraint = {k: i for k, i in constraint.items()
-                          if probs[k] >= thres[k]}
-
-        return self.db.query(domain.lower(), constraint.items())
-
-    # Add thresholds for db_queries
-    def setup_uncertain_query(self, thresholds):
-        self.use_confidence_scores = True
-        self.thresholds = thresholds
-        logging.info('DB Search uncertainty activated.')
-
-    def vectorize_user_act_confidence_scores(self, state, opp_action):
-        """Return confidence scores for the user actions"""
-        opp_act_vec = np.zeros(self.da_opp_dim)
-        for da in self.opp2vec:
-            domain, intent, slot, value = da.split('-')
-            if domain.lower() in state['belief_state_probs']:
-                # Map slot name to match user actions
-                slot = REF_SYS_DA[domain].get(
-                    slot, slot) if domain in REF_SYS_DA else slot
-                slot = slot if slot else 'none'
-                slot = SLOT_MAP.get(slot, slot)
-                domain = domain.lower()
-
-                if slot in state['belief_state_probs'][domain]:
-                    prob = state['belief_state_probs'][domain][slot]
-                elif slot.lower() in state['belief_state_probs'][domain]:
-                    prob = state['belief_state_probs'][domain][slot.lower()]
-                else:
-                    prob = {}
-
-                intent = intent.lower()
-                if intent in prob:
-                    prob = float(prob[intent])
-                elif da in opp_action:
-                    prob = 1.0
-                else:
-                    prob = 0.0
-            elif da in opp_action:
-                prob = 1.0
-            else:
-                prob = 0.0
-            opp_act_vec[self.opp2vec[da]] = prob
-
-        return opp_act_vec
-
-    def state_vectorize(self, state):
-        """vectorize a state
-
-        Args:
-            state (dict):
-                Dialog state
-            action (tuple):
-                Dialog act
-        Returns:
-            state_vec (np.array):
-                Dialog state vector
-        """
-        self.state = state['belief_state']
-        self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
-        domain_active_dict = {}
-        for domain in self.belief_domains:
-            domain_active_dict[domain] = False
-
-        # when character is sys, to help query database when da is booking-book
-        # update current domain according to user action
-        if self.character == 'sys':
-            action = state['user_action']
-            for intent, domain, slot, value in action:
-                domain_active_dict[domain] = True
-
-        action = state['user_action'] if self.character == 'sys' else state['system_action']
-        opp_action = delexicalize_da(action, self.requestable)
-        opp_action = flat_da(opp_action)
-        if 'belief_state_probs' in state and self.use_confidence_scores:
-            opp_act_vec = self.vectorize_user_act_confidence_scores(
-                state, opp_action)
-        else:
-            opp_act_vec = np.zeros(self.da_opp_dim)
-            for da in opp_action:
-                if da in self.opp2vec:
-                    prob = 1.0
-                    opp_act_vec[self.opp2vec[da]] = prob
-
-        action = state['system_action'] if self.character == 'sys' else state['user_action']
-        action = delexicalize_da(action, self.requestable)
-        action = flat_da(action)
-        last_act_vec = np.zeros(self.da_dim)
-        for da in action:
-            if da in self.act2vec:
-                last_act_vec[self.act2vec[da]] = 1.
-
-        belief_state = np.zeros(self.belief_state_dim)
-        i = 0
-        for domain in self.belief_domains:
-            if self.use_confidence_scores and 'belief_state_probs' in state:
-                for slot in state['belief_state'][domain.lower()]['semi']:
-                    if slot in state['belief_state_probs'][domain.lower()]:
-                        prob = state['belief_state_probs'][domain.lower()
-                                                           ][slot]
-                        prob = prob['inform'] if 'inform' in prob else None
-                    if prob:
-                        belief_state[i] = float(prob)
-                    i += 1
-            else:
-                for slot, value in state['belief_state'][domain.lower()]['semi'].items():
-                    if value and value != 'not mentioned':
-                        belief_state[i] = 1.
-                    i += 1
-            if 'active_domains' in state:
-                domain_active = state['active_domains'][domain.lower()]
-                domain_active_dict[domain] = domain_active
-            else:
-                if [slot for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value]:
-                    domain_active_dict[domain] = True
-
-        # Add knowledge and/or total uncertainty to the belief state
-        if self.use_entropy and 'entropy' in state:
-            for domain in self.belief_domains:
-                for slot in state['belief_state'][domain.lower()]['semi']:
-                    if slot in state['entropy'][domain.lower()]:
-                        belief_state[i] = float(
-                            state['entropy'][domain.lower()][slot])
-                    i += 1
-
-        if self.use_mutual_info and 'mutual_information' in state:
-            for domain in self.belief_domains:
-                for slot in state['belief_state'][domain.lower()]['semi']:
-                    if slot in state['mutual_information'][domain.lower()]:
-                        belief_state[i] = float(
-                            state['mutual_information'][domain.lower()][slot])
-                    i += 1
-
-        book = np.zeros(len(self.db_domains))
-        for i, domain in enumerate(self.db_domains):
-            if state['belief_state'][domain.lower()]['book']['booked']:
-                book[i] = 1.
-
-        degree, number_entities_dict = self.pointer()
-
-        final = 1. if state['terminated'] else 0.
-
-        state_vec = np.r_[opp_act_vec, last_act_vec,
-                          belief_state, book, degree, final]
-        assert len(state_vec) == self.state_dim
-
-        if self.use_mask is not None:
-            # None covers the case for policies that don't use masking at all, so do not expect an output "state_vec, mask"
-            if self.use_mask:
-                domain_mask = self.compute_domain_mask(domain_active_dict)
-                entity_mask = self.compute_entity_mask(number_entities_dict)
-                general_mask = self.compute_general_mask()
-                mask = domain_mask + entity_mask + general_mask
-                for i in range(self.da_dim):
-                    mask[i] = -int(bool(mask[i])) * sys.maxsize
-            else:
-                mask = np.zeros(self.da_dim)
-
-            return state_vec, mask
-        else:
-            return state_vec
diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py
index 2e073669..c2f6258f 100644
--- a/convlab/policy/vector/vector_nodes.py
+++ b/convlab/policy/vector/vector_nodes.py
@@ -1,8 +1,6 @@
 # -*- coding: utf-8 -*-
 import sys
 import numpy as np
-import logging
-
 from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
 from .vector_base import VectorBase
 
@@ -10,11 +8,9 @@ from .vector_base import VectorBase
 class VectorNodes(VectorBase):
 
     def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True,
-                 seed=0, filter_state=True):
+                 seed=0):
 
         super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
-        self.filter_state = filter_state
-        logging.info(f"We filter state by active domains: {self.filter_state}")
 
     def get_state_dim(self):
         self.belief_state_dim = 0
@@ -60,16 +56,9 @@ class VectorNodes(VectorBase):
         self.get_user_act_feature(state)
         self.get_sys_act_feature(state)
         domain_active_dict = self.get_user_goal_feature(state, domain_active_dict)
+        number_entities_dict = self.get_db_features()
         self.get_general_features(state, domain_active_dict)
 
-        if self.db is not None:
-            number_entities_dict = self.get_db_features()
-        else:
-            number_entities_dict = None
-
-        if self.filter_state:
-            self.kg_info = self.filter_inactive_domains(domain_active_dict)
-
         if self.use_mask:
             mask = self.get_mask(domain_active_dict, number_entities_dict)
             for i in range(self.da_dim):
@@ -100,15 +89,13 @@ class VectorNodes(VectorBase):
 
         feature_type = 'user goal'
         for domain in self.belief_domains:
-            # the if case is needed because SGD only saves the dialogue state info for active domains
-            if domain in state['belief_state']:
-                for slot, value in state['belief_state'][domain].items():
-                    description = f"user goal-{domain}-{slot}".lower()
-                    value = 1.0 if (value and value != "not mentioned") else 0.0
-                    self.add_graph_node(domain, feature_type, description, value)
-
-                if [slot for slot, value in state['belief_state'][domain].items() if value]:
-                    domain_active_dict[domain] = True
+            for slot, value in state['belief_state'][domain].items():
+                description = f"user goal-{domain}-{slot}".lower()
+                value = 1.0 if (value and value != "not mentioned") else 0.0
+                self.add_graph_node(domain, feature_type, description, value)
+
+            if [slot for slot, value in state['belief_state'][domain].items() if value]:
+                domain_active_dict[domain] = True
         return domain_active_dict
 
     def get_sys_act_feature(self, state):
@@ -141,12 +128,11 @@ class VectorNodes(VectorBase):
     def get_general_features(self, state, domain_active_dict):
 
         feature_type = 'general'
-        if 'booked' in state:
-            for i, domain in enumerate(self.db_domains):
-                if domain in state['booked']:
-                    description = f"general-{domain}-booked".lower()
-                    value = 1.0 if state['booked'][domain] else 0.0
-                    self.add_graph_node(domain, feature_type, description, value)
+        for i, domain in enumerate(self.db_domains):
+            if domain in state['booked']:
+                description = f"general-{domain}-booked".lower()
+                value = 1.0 if state['booked'][domain] else 0.0
+                self.add_graph_node(domain, feature_type, description, value)
 
         for domain in self.domains:
             if domain == 'general':
@@ -154,17 +140,3 @@ class VectorNodes(VectorBase):
             value = 1.0 if domain_active_dict[domain] else 0
             description = f"general-{domain}".lower()
             self.add_graph_node(domain, feature_type, description, value)
-
-    def filter_inactive_domains(self, domain_active_dict):
-
-        kg_filtered = []
-        for node in self.kg_info:
-            domain = node['domain']
-            if domain in domain_active_dict:
-                if domain_active_dict[domain]:
-                    kg_filtered.append(node)
-            else:
-                kg_filtered.append(node)
-
-        return kg_filtered
-
diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py
new file mode 100644
index 00000000..7da05449
--- /dev/null
+++ b/convlab/policy/vector/vector_uncertainty.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+import sys
+import numpy as np
+import logging
+
+from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
+from convlab.policy.vector.vector_binary import VectorBinary
+
+
+class VectorUncertainty(VectorBinary):
+    """Vectorise state and state uncertainty predictions"""
+
+    def __init__(self,
+                 dataset_name: str = 'multiwoz21',
+                 character: str = 'sys',
+                 use_masking: bool = False,
+                 manually_add_entity_names: bool = True,
+                 seed: str = 0,
+                 use_confidence_scores: bool = True,
+                 confidence_thresholds: dict = None,
+                 use_state_total_uncertainty: bool = False,
+                 use_state_knowledge_uncertainty: bool = False):
+        """
+        Args:
+            dataset_name: Name of environment dataset
+            character: Character of the agent (sys/usr)
+            use_masking: If true certain actions are masked during devectorisation
+            manually_add_entity_names: If true inform entity name actions are manually added
+            seed: Seed
+            use_confidence_scores: If true confidence scores are used in state vectorisation
+            confidence_thresholds: If true confidence thresholds are used in database querying
+            use_state_total_uncertainty: If true state entropy is added to the state vector
+            use_state_knowledge_uncertainty: If true state mutual information is added to the state vector
+        """
+
+        self.use_confidence_scores = use_confidence_scores
+        self.use_state_total_uncertainty = use_state_total_uncertainty
+        self.use_state_knowledge_uncertainty = use_state_knowledge_uncertainty
+        if confidence_thresholds is not None:
+            self.setup_uncertain_query(confidence_thresholds)
+
+        super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
+
+    def get_state_dim(self):
+        self.belief_state_dim = 0
+
+        for domain in self.ontology['state']:
+            for slot in self.ontology['state'][domain]:
+                # Dim 1 - indicator/confidence score
+                # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc)
+                slot_dim = 1 if not self.use_state_total_uncertainty else 2
+                slot_dim += 1 if self.use_state_knowledge_uncertainty else 0
+                self.belief_state_dim += slot_dim
+
+        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
+            len(self.db_domains) + 6 * len(self.db_domains) + 1
+
+    # Add thresholds for db_queries
+    def setup_uncertain_query(self, confidence_thresholds):
+        self.use_confidence_scores = True
+        self.confidence_thresholds = confidence_thresholds
+        logging.info('DB Search uncertainty activated.')
+
+    def dbquery_domain(self, domain):
+        """
+        query entities of specified domain
+        Args:
+            domain string:
+                domain to query
+        Returns:
+            entities list:
+                list of entities of the specified domain
+        """
+        # Get all user constraints
+        constraints = {slot: value for slot, value in self.state[domain].items()
+                       if slot and value not in ['dontcare',
+                                                 "do n't care", "do not care"]} if domain in self.state else dict()
+
+        # Remove constraints for which the uncertainty is high
+        if self.confidence_scores is not None and self.use_confidence_scores and self.confidence_thresholds is not None:
+            # Collect threshold values for each domain-slot pair
+            threshold = self.confidence_thresholds.get(domain, dict())
+            threshold = {slot: threshold.get(slot, 0.05) for slot in constraints}
+            # Get confidence scores for each constraint
+            probs = self.confidence_scores.get(domain, dict())
+            probs = {slot: probs.get(slot, {}).get('inform', 1.0) for slot in constraints}
+
+            # Filter out constraints for which confidence is lower than threshold
+            constraints = {slot: value for slot, value in constraints.items() if probs[slot] >= threshold[slot]}
+
+        return self.db.query(domain, constraints.items(), topk=10)
+
+    def vectorize_user_act(self, state):
+        """Return confidence scores for the user actions"""
+        self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
+        action = state['user_action'] if self.character == 'sys' else state['system_action']
+        opp_action = delexicalize_da(action, self.requestable)
+        opp_action = flat_da(opp_action)
+        opp_act_vec = np.zeros(self.da_opp_dim)
+        for da in opp_action:
+            if da in self.opp2vec:
+                if 'belief_state_probs' in state and self.use_confidence_scores:
+                    domain, intent, slot, value = da.split('-')
+                    if domain in state['belief_state_probs']:
+                        slot = slot if slot else 'none'
+                        if slot in state['belief_state_probs'][domain]:
+                            prob = state['belief_state_probs'][domain][slot]
+                        elif slot.lower() in state['belief_state_probs'][domain]:
+                            prob = state['belief_state_probs'][domain][slot.lower()]
+                        else:
+                            prob = dict()
+
+                        if intent in prob:
+                            prob = float(prob[intent])
+                        else:
+                            prob = 1.0
+                    else:
+                        prob = 1.0
+                else:
+                    prob = 1.0
+                opp_act_vec[self.opp2vec[da]] = prob
+
+        return opp_act_vec
+
+    def vectorize_belief_state(self, state, domain_active_dict):
+        belief_state = np.zeros(self.belief_state_dim)
+        i = 0
+        for domain in self.belief_domains:
+            if self.use_confidence_scores and 'belief_state_probs' in state:
+                for slot in state['belief_state'][domain]:
+                    prob = None
+                    if slot in state['belief_state_probs'][domain]:
+                        prob = state['belief_state_probs'][domain][slot]
+                        prob = prob['inform'] if 'inform' in prob else None
+                    if prob:
+                        belief_state[i] = float(prob)
+                    i += 1
+            else:
+                for slot, value in state['belief_state'][domain].items():
+                    if value and value != 'not mentioned':
+                        belief_state[i] = 1.
+                    i += 1
+
+            if 'active_domains' in state:
+                domain_active = state['active_domains'][domain]
+                domain_active_dict[domain] = domain_active
+            else:
+                if [slot for slot, value in state['belief_state'][domain].items() if value]:
+                    domain_active_dict[domain] = True
+
+        # Add knowledge and/or total uncertainty to the belief state
+        if self.use_state_total_uncertainty and 'entropy' in state:
+            for domain in self.belief_domains:
+                for slot in state['belief_state'][domain]:
+                    if slot in state['entropy'][domain]:
+                        belief_state[i] = float(state['entropy'][domain][slot])
+                    i += 1
+
+        if self.use_state_knowledge_uncertainty and 'mutual_information' in state:
+            for domain in self.belief_domains:
+                for slot in state['belief_state'][domain]:
+                    if slot in state['mutual_information'][domain]:
+                        belief_state[i] = float(state['mutual_information'][domain][slot])
+                    i += 1
+
+        return belief_state, domain_active_dict
diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py
index aad6c4cd..c79c6f0d 100644
--- a/convlab/util/custom_util.py
+++ b/convlab/util/custom_util.py
@@ -21,7 +21,6 @@ from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
 from convlab.util import load_dataset
 
 import shutil
-import signal
 
 
 slot_mapping = {"pricerange": "price range", "post": "postcode", "arriveBy": "arrive by", "leaveAt": "leave at",
@@ -35,22 +34,6 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 device = DEVICE
 
 
-class timeout:
-    def __init__(self, seconds=10, error_message='Timeout'):
-        self.seconds = seconds
-        self.error_message = error_message
-
-    def handle_timeout(self, signum, frame):
-        raise TimeoutError(self.error_message)
-
-    def __enter__(self):
-        signal.signal(signal.SIGALRM, self.handle_timeout)
-        signal.alarm(self.seconds)
-
-    def __exit__(self, type, value, traceback):
-        signal.alarm(0)
-
-
 class NumpyEncoder(json.JSONEncoder):
     """ Special json encoder for numpy types """
 
@@ -171,20 +154,20 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
     if conf['model']['process_num'] == 1:
         complete_rate, success_rate, success_rate_strict, avg_return, turns, \
             avg_actions, task_success, book_acts, inform_acts, request_acts, \
-                select_acts, offer_acts, recommend_acts = evaluate(sess,
+                select_acts, offer_acts = evaluate(sess,
                                                 num_dialogues=conf['model']['num_eval_dialogues'],
                                                 sys_semantic_to_usr=conf['model'][
                                                     'sys_semantic_to_usr'],
                                                 save_flag=save_eval, save_path=log_save_path, goals=goals)
 
-        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
+        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts
     else:
         complete_rate, success_rate, success_rate_strict, avg_return, turns, \
             avg_actions, task_success, book_acts, inform_acts, request_acts, \
-            select_acts, offer_acts, recommend_acts = \
+            select_acts, offer_acts = \
             evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])),
                                  conf['model']['process_num'], goals)
-        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
+        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts
 
         task_success_gathered = {}
         for task_dict in task_success:
@@ -195,40 +178,22 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
         task_success = task_success_gathered
 
     policy_sys.is_train = True
-
-    mean_complete, err_complete = np.average(complete_rate), np.std(complete_rate) / np.sqrt(len(complete_rate))
-    mean_success, err_success = np.average(success_rate), np.std(success_rate) / np.sqrt(len(success_rate))
-    mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std(success_rate_strict) / np.sqrt(len(success_rate_strict))
-    mean_return, err_return = np.average(avg_return), np.std(avg_return) / np.sqrt(len(avg_return))
-    mean_turns, err_turns = np.average(turns), np.std(turns) / np.sqrt(len(turns))
-    mean_actions, err_actions = np.average(avg_actions), np.std(avg_actions) / np.sqrt(len(avg_actions))
-
-    logging.info(f"Complete: {mean_complete}+-{round(err_complete, 2)}, "
-                 f"Success: {mean_success}+-{round(err_success, 2)}, "
-                 f"Success strict: {mean_success_strict}+-{round(err_success_strict, 2)}, "
-                 f"Average Return: {mean_return}+-{round(err_return, 2)}, "
-                 f"Turns: {mean_turns}+-{round(err_turns, 2)}, "
-                 f"Average Actions: {mean_actions}+-{round(err_actions, 2)}, "
+    logging.info(f"Complete: {complete_rate}, Success: {success_rate}, Success strict: {success_rate_strict}, "
+                 f"Average Return: {avg_return}, Turns: {turns}, Average Actions: {avg_actions}, "
                  f"Book Actions: {book_acts/total_acts}, Inform Actions: {inform_acts/total_acts}, "
                  f"Request Actions: {request_acts/total_acts}, Select Actions: {select_acts/total_acts}, "
-                 f"Offer Actions: {offer_acts/total_acts}, Recommend Actions: {recommend_acts/total_acts}")
+                 f"Offer Actions: {offer_acts/total_acts}")
 
     for key in task_success:
         logging.info(
             f"{key}: Num: {len(task_success[key])} Success: {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}")
 
-    return {"complete_rate": mean_complete,
-            "success_rate": mean_success,
-            "success_rate_strict": mean_success_strict,
-            "avg_return": mean_return,
-            "turns": mean_turns,
-            "avg_actions": mean_actions,
-            "book_acts": book_acts/total_acts,
-            "inform_acts": inform_acts/total_acts,
-            "request_acts": request_acts/total_acts,
-            "select_acts": select_acts/total_acts,
-            "offer_acts": offer_acts/total_acts,
-            "recommend_acts": recommend_acts/total_acts}
+    return {"complete_rate": complete_rate,
+            "success_rate": success_rate,
+            "success_rate_strict": success_rate_strict,
+            "avg_return": avg_return,
+            "turns": turns,
+            "avg_actions": avg_actions}
 
 
 def env_config(conf, policy_sys, check_book_constraints=True):
@@ -240,6 +205,14 @@ def env_config(conf, policy_sys, check_book_constraints=True):
     policy_usr = conf['policy_usr_activated']
     usr_nlg = conf['usr_nlg_activated']
 
+    # Setup uncertainty thresholding
+    if dst_sys:
+        try:
+            if dst_sys.return_confidence_scores:
+                policy_sys.vector.setup_uncertain_query(dst_sys.confidence_thresholds)
+        except:
+            logging.info('Uncertainty threshold not set.')
+
     simulator = PipelineAgent(nlu_usr, dst_usr, policy_usr, usr_nlg, 'user')
     system_pipeline = PipelineAgent(nlu_sys, dst_sys, policy_sys, sys_nlg,
                                     'sys', return_semantic_acts=conf['model']['sys_semantic_to_usr'])
@@ -321,7 +294,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
     task_success = {'All_user_sim': [], 'All_evaluator': [], "All_evaluator_strict": [],
                     'total_return': [], 'turns': [], 'avg_actions': [],
                     'total_booking_acts': [], 'total_inform_acts': [], 'total_request_acts': [],
-                    'total_select_acts': [], 'total_offer_acts': [], 'total_recommend_acts': []}
+                    'total_select_acts': [], 'total_offer_acts': []}
     dial_count = 0
     for seed in range(1000, 1000 + num_dialogues):
         set_seed(seed)
@@ -337,7 +310,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
         request = 0
         select = 0
         offer = 0
-        recommend = 0
         # this 40 represents the max turn of dialogue
         for i in range(40):
             sys_response, user_response, session_over, reward = sess.next_turn(
@@ -360,8 +332,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
                     select += 1
                 if intent.lower() == 'offerbook':
                     offer += 1
-                if intent.lower() == 'recommend':
-                    recommend += 1
             avg_actions += len(acts)
             turn_counter += 1
             turns += 1
@@ -398,8 +368,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
         task_success['total_request_acts'].append(request)
         task_success['total_select_acts'].append(select)
         task_success['total_offer_acts'].append(offer)
-        task_success['total_offer_acts'].append(offer)
-        task_success['total_recommend_acts'].append(recommend)
 
         # print(agent_sys.agent_saves)
         eval_save['Conversation {}'.format(str(dial_count))] = [
@@ -415,11 +383,12 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
         save_file.close()
     # save dialogue_info and clear mem
 
-    return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \
-           task_success['total_return'], task_success['turns'], task_success['avg_actions'], task_success, \
+    return np.average(task_success['All_user_sim']), np.average(task_success['All_evaluator']), \
+        np.average(task_success['All_evaluator_strict']), np.average(task_success['total_return']), \
+        np.average(task_success['turns']), np.average(task_success['avg_actions']), task_success, \
         np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \
         np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \
-        np.average(task_success['total_offer_acts']), np.average(task_success['total_recommend_acts'])
+        np.average(task_success['total_offer_acts'])
 
 
 def model_downloader(download_dir, model_path):
@@ -570,18 +539,21 @@ def get_config(filepath, args) -> dict:
     vec_name = [model for model in conf['vectorizer_sys']]
     vec_name = vec_name[0] if vec_name else None
     if dst_name and 'setsumbt' in dst_name.lower():
-        if 'get_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = conf['dst_sys'][dst_name]['ini_params']['get_confidence_scores']
+        if 'return_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']:
+            param = conf['dst_sys'][dst_name]['ini_params']['return_confidence_scores']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = param
         else:
             conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = False
-        if 'return_mutual_info' in conf['dst_sys'][dst_name]['ini_params']:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = conf['dst_sys'][dst_name]['ini_params']['return_mutual_info']
+        if 'return_belief_state_mutual_info' in conf['dst_sys'][dst_name]['ini_params']:
+            param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_mutual_info']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = param
         else:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = False
-        if 'return_entropy' in conf['dst_sys'][dst_name]['ini_params']:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = conf['dst_sys'][dst_name]['ini_params']['return_entropy']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = False
+        if 'return_belief_state_entropy' in conf['dst_sys'][dst_name]['ini_params']:
+            param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_entropy']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = param
         else:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = False
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = False
 
     from convlab.nlu import NLU
     from convlab.dst import DST
@@ -610,8 +582,7 @@ def get_config(filepath, args) -> dict:
                 cls_path = infos.get('class_path', '')
                 cls = map_class(cls_path)
                 conf[unit + '_class'] = cls
-                conf[unit + '_activated'] = conf[unit +
-                                                 '_class'](**conf[unit][model]['ini_params'])
+                conf[unit + '_activated'] = conf[unit + '_class'](**conf[unit][model]['ini_params'])
                 print("Loaded " + model + " for " + unit)
     return conf
 
-- 
GitLab