diff --git a/.gitignore b/.gitignore
index 136bbce75fa69edcf426d6a0ef9aaf0e85e70662..c879a524e95127b0e91aeb95734c4bbfade592fc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -102,8 +102,6 @@ convlab/dst/trade/multiwoz_config/
 convlab/deploy/bert_multiwoz_all.zip
 convlab/deploy/templates/dialog_eg.html
 test.py
-convlab/dst/setsumbt/models/*
-
 *budget*.pdf
 *convlab/policy/vector/action_dicts
 
diff --git a/convlab/__init__.py b/convlab/__init__.py
index 4b14a707ee43021fc97936c5f64d95c2cb7aa659..43cc4bc75aab6c4576a4097a549079e521d20e20 100755
--- a/convlab/__init__.py
+++ b/convlab/__init__.py
@@ -1,3 +1,4 @@
+
 import os
 from convlab.nlu import NLU
 from convlab.dst import DST
diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py
index 6216eaaac9fb81615f903f048d6d85766ce663c5..508d06b56080f72146d65a98286a0fd099f9b4c2 100755
--- a/convlab/dialog_agent/env.py
+++ b/convlab/dialog_agent/env.py
@@ -50,10 +50,10 @@ class Environment():
             observation) if self.sys_nlu else observation
         self.sys_dst.state['user_action'] = dialog_act
         state = self.sys_dst.update(dialog_act)
-        state = deepcopy(state)
+        self.sys_dst.state['history'].append(["sys", model_response])
+        self.sys_dst.state['history'].append(["usr", observation])
 
-        state['history'].append(["sys", model_response])
-        state['history'].append(["usr", observation])
+        state = deepcopy(state)
 
         terminated = self.usr.is_terminated()
 
diff --git a/convlab/dst/evaluate_unified_datasets.py b/convlab/dst/evaluate_unified_datasets.py
index 4b4accf1bb78b1580c8235bebde1160ed771bd59..d4e0720dc02bf90efcfebb1780630211f0722f7f 100644
--- a/convlab/dst/evaluate_unified_datasets.py
+++ b/convlab/dst/evaluate_unified_datasets.py
@@ -1,73 +1,42 @@
 import json
 from pprint import pprint
 
-import numpy as np
-
 
 def evaluate(predict_result):
     predict_result = json.load(open(predict_result))
 
-    metrics = {'TP': 0, 'FP': 0, 'FN': 0}
-    jga = []
-    aga = []
-    fga = []
-    l2_err = []
-    lamb = [0.25, 0.5, 0.75, 1.0]
+    metrics = {'TP':0, 'FP':0, 'FN':0}
+    acc = []
 
     for sample in predict_result:
         pred_state = sample['predictions']['state']
         gold_state = sample['state']
-        utt_idx = sample['utt_idx']
-
-        predicts = {(domain, slot, ''.join(value.split()).lower()) for domain in pred_state
-                    for slot, value in pred_state[domain].items() if value}
-        labels = {(domain, slot, ''.join(value.split()).lower()) for domain in gold_state
-                  for slot, value in gold_state[domain].items() if value}
-        predicts, labels = sorted(list(predicts)), sorted(list(labels))
-
-        # Flexible goal accuracy (see https://arxiv.org/pdf/2204.03375.pdf)
-        weighted_err = [1] * len(lamb)
-        if utt_idx == 0:
-            err_idx = -999999
-            predicts_prev = []
-            labels_prev = []
-
-            if predicts != labels:
-                err_idx = utt_idx
-                weighted_err = [0] * len(lamb)
-        else:
-            if predicts != labels:
-                predicts_changes = [ele for ele in predicts if ele not in predicts_prev]
-                labels_changes = [ele for ele in labels if ele not in labels_prev]
-
-                new_predict_err = [ele for ele in predicts_changes if ele not in labels]
-                new_predict_miss = [ele for ele in labels_changes if ele not in predicts]
-
-                if new_predict_err or new_predict_miss:
-                    weighted_err = [0] * len(lamb)
-                    err_idx = utt_idx
+        flag = True
+        for domain in gold_state:
+            for slot, values in gold_state[domain].items():
+                if domain not in pred_state or slot not in pred_state[domain]:
+                    predict_values = ''
                 else:
-                    err_age = utt_idx - err_idx
-                    weighted_err = [1 - np.exp(-l * err_age) for l in lamb]
-            predicts_prev = predicts
-            labels_prev = labels
-        fga.append(weighted_err)
-
-        _l2 = 2.0 * len([ele for ele in labels if ele not in predicts])
-        _l2 += 2.0 * len([ele for ele in predicts if ele not in labels])
-        l2_err.append(_l2)
+                    predict_values = ''.join(pred_state[domain][slot].split()).lower()
+                if len(values) > 0:
+                    if len(predict_values) > 0:
+                        values = [''.join(value.split()).lower() for value in values.split('|')]
+                        predict_values = [''.join(value.split()).lower() for value in predict_values.split('|')]
+                        if any([value in values for value in predict_values]):
+                            metrics['TP'] += 1
+                        else:
+                            metrics['FP'] += 1
+                            metrics['FN'] += 1
+                            flag = False
+                    else:
+                        metrics['FN'] += 1
+                        flag = False
+                else:
+                    if len(predict_values) > 0:
+                        metrics['FP'] += 1
+                        flag = False
 
-        flag = True
-        for ele in predicts:
-            if ele in labels:
-                metrics['TP'] += 1
-            else:
-                metrics['FP'] += 1
-        for ele in labels:
-            if ele not in predicts:
-                metrics['FN'] += 1
-        flag &= (predicts == labels)
-        jga.append(flag)
+        acc.append(flag)
     
     TP = metrics.pop('TP')
     FP = metrics.pop('FP')
@@ -78,10 +47,7 @@ def evaluate(predict_result):
     metrics[f'slot_f1'] = f1
     metrics[f'slot_precision'] = precision
     metrics[f'slot_recall'] = recall
-    metrics['joint_goal_accuracy'] = sum(jga) / len(jga)
-    for i, l in enumerate(lamb):
-        metrics[f'flexible_goal_accuracy_{l}'] = sum(weighted_err[i] for weighted_err in fga)/len(fga)
-    metrics['l2_norm_error'] = sum(l2_err) / len(l2_err)
+    metrics['accuracy'] = sum(acc)/len(acc)
 
     return metrics
 
diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py
index 9492faa9c9a20d1c476819bb995900ca71d56607..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/convlab/dst/setsumbt/__init__.py
+++ b/convlab/dst/setsumbt/__init__.py
@@ -1 +0,0 @@
-from convlab.dst.setsumbt.tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py
index a41f280d3349164a2a67333d0ab176a37cbe50ea..379057e6411082d466b10a81027bb57e4131bb9b 100644
--- a/convlab/dst/setsumbt/calibration_plots.py
+++ b/convlab/dst/setsumbt/calibration_plots.py
@@ -35,7 +35,7 @@ def main():
     path = args.data_dir
 
     models = os.listdir(path)
-    models = [os.path.join(path, model, 'test.predictions') for model in models]
+    models = [os.path.join(path, model, 'test.belief') for model in models]
 
     fig = plt.figure(figsize=(14,8))
     font=20
@@ -56,16 +56,16 @@ def main():
 
 
 def get_calibration(path, device, n_bins=10, temperature=1.00):
-    probs = torch.load(path, map_location=device)
-    y_true = probs['state_labels']
-    probs = probs['belief_states']
+    logits = torch.load(path, map_location=device)
+    y_true = logits['labels']
+    logits = logits['belief_states']
 
-    y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs}
+    y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
     goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
     goal_acc = sum([goal_acc[slot] for slot in goal_acc])
     goal_acc = (goal_acc == len(y_true)).int()
 
-    scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs]
+    scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
     scores = torch.cat(scores, 0).min(0)[0]
 
     step = 1.0 / float(n_bins)
diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py
deleted file mode 100644
index 17b1f93b3b39f95827cf6c09e8826383cd00b805..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size
-from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings
diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py
deleted file mode 100644
index 81e207805c47dcde8b7194ff5f7fac8ff10b1c2f..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/ontology.py
+++ /dev/null
@@ -1,133 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Create Ontology Embeddings"""
-
-import json
-import os
-import random
-from copy import deepcopy
-
-import torch
-import numpy as np
-
-
-def set_seed(args):
-    """
-    Set random seeds
-
-    Args:
-        args (Arguments class): Arguments class containing seed and number of gpus to use
-    """
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor:
-    """
-    Embed candidates
-
-    Args:
-        candidates (list): List of candidate descriptions
-        args (argument class): Runtime arguments
-        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
-        embedding_model (transformer Model): Transformer model for embedding candidate descriptions
-
-    Returns:
-        feats (torch.tensor): Embeddings of the candidate descriptions
-    """
-    # Tokenize candidate descriptions
-    feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len,
-                                   padding='max_length', truncation='longest_first')
-             for val in candidates]
-
-    # Encode tokenized descriptions
-    with torch.no_grad():
-        feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]}
-        embedded_feats = embedding_model(**feats)  # [num_candidates, max_candidate_len, hidden_dim]
-
-    # Reduce/pool descriptions embeddings if required
-    if args.set_similarity:
-        feats = embedded_feats.last_hidden_state.detach().cpu()  # [num_candidates, max_candidate_len, hidden_dim]
-    elif args.candidate_pooling == 'cls':
-        feats = embedded_feats.pooler_output.detach().cpu()  # [num_candidates, hidden_dim]
-    elif args.candidate_pooling == "mean":
-        feats = embedded_feats.last_hidden_state.detach().cpu()
-        feats = feats.sum(1)
-        feats = torch.nn.functional.layer_norm(feats, feats.size())
-        feats = feats.detach().cpu()  # [num_candidates, hidden_dim]
-
-    return feats
-
-
-def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True):
-    """
-    Get embeddings for slots and candidates
-
-    Args:
-        ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets
-        set_type (str): Subset of the dataset being used (train/validation/test)
-        args (argument class): Runtime arguments
-        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
-        embedding_model (transformer Model): Transormer model for embedding candidate descriptions
-        save_to_file (bool): Indication of whether to save information to file
-
-    Returns:
-        slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
-    """
-    # Set model to eval mode
-    embedding_model.eval()
-
-    slots = dict()
-    for domain, subset in ontology.items():
-        for slot, slot_info in subset.items():
-            # Get description or use "domain-slot"
-            if args.use_descriptions:
-                desc = slot_info['description']
-            else:
-                desc = f"{domain}-{slot}"
-
-            # Encode domain-slot pair description
-            slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0]
-
-            # Obtain possible value set and discard requestable value
-            values = deepcopy(slot_info['possible_values'])
-            is_requestable = False
-            if '?' in values:
-                is_requestable = True
-                values.remove('?')
-
-            # Encode value candidates
-            if values:
-                feats = encode_candidates(values, args, tokenizer, embedding_model)
-            else:
-                feats = None
-
-            # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot
-            slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
-
-    # Dump tensors and ontology for use in training and evaluation
-    if save_to_file:
-        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
-        torch.save(slots, writer)
-
-        writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w')
-        json.dump(ontology, writer, indent=2)
-        writer.close()
-    
-    return slots
diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py
deleted file mode 100644
index 83d8f776c135ac596fed322cd39c1c31c456a481..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/unified_format.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Convlab3 Unified Format Dialogue Datasets"""
-
-from copy import deepcopy
-
-import torch
-import transformers
-from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
-from transformers.tokenization_utils import PreTrainedTokenizer
-from tqdm import tqdm
-
-from convlab.util import load_dataset
-from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
-                                                get_values_from_data, ontology_add_requestable_slots,
-                                                get_requestable_slots, load_dst_data, extract_dialogues,
-                                                combine_value_sets)
-
-transformers.logging.set_verbosity_error()
-
-
-def convert_examples_to_features(data: list,
-                                 ontology: dict,
-                                 tokenizer: PreTrainedTokenizer,
-                                 max_turns: int = 12,
-                                 max_seq_len: int = 64) -> dict:
-    """
-    Convert dialogue examples to model input features and labels
-
-    Args:
-        data (list): List of all extracted dialogues
-        ontology (dict): Ontology dictionary containing slots, slot descriptions and
-        possible value sets including requests
-        tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used
-        max_turns (int): Maximum numbers of turns in a dialogue
-        max_seq_len (int): Maximum number of tokens in a dialogue turn
-
-    Returns:
-        features (dict): All inputs and labels required to train the model
-    """
-    features = dict()
-    ontology = deepcopy(ontology)
-
-    # Get encoder input for system, user utterance pairs
-    input_feats = []
-    for dial in tqdm(data):
-        dial_feats = []
-        for turn in dial:
-            if len(turn['system_utterance']) == 0:
-                usr = turn['user_utterance']
-                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True,
-                                                        max_length=max_seq_len, padding='max_length',
-                                                        truncation='longest_first'))
-            else:
-                usr = turn['user_utterance']
-                sys = turn['system_utterance']
-                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True,
-                                                        max_length=max_seq_len, padding='max_length',
-                                                        truncation='longest_first'))
-            # Truncate
-            if len(dial_feats) >= max_turns:
-                break
-        input_feats.append(dial_feats)
-    del dial_feats
-
-    # Perform turn level padding
-    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
-                 for dial in input_feats]
-    if 'token_type_ids' in input_feats[0][0]:
-        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
-                          for dial in input_feats]
-    else:
-        token_type_ids = None
-    if 'attention_mask' in input_feats[0][0]:
-        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
-                          for dial in input_feats]
-    else:
-        attention_mask = None
-    del input_feats
-
-    # Create torch data tensors
-    features['input_ids'] = torch.tensor(input_ids)
-    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
-    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
-    del input_ids, token_type_ids, attention_mask
-
-    # Extract all informable and requestable slots from the ontology
-    informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
-                        if ontology[domain][slot]['possible_values']
-                        and ontology[domain][slot]['possible_values'] != ['?']]
-    requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
-                         if '?' in ontology[domain][slot]['possible_values']]
-    for slot in requestable_slots:
-        domain, slot = slot.split('-', 1)
-        ontology[domain][slot]['possible_values'].remove('?')
-
-    # Extract a list of domains from the ontology slots
-    domains = list(set(informable_slots + requestable_slots))
-    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
-
-    # Create slot labels
-    for domslot in tqdm(informable_slots):
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial:
-                value = [v for d, substate in turn['state'].items() for s, v in substate.items()
-                         if f'{d}-{s}' == domslot]
-                value = value[0] if value else 'none'
-                domain, slot = domslot.split('-', 1)
-                if value in ontology[domain][slot]['possible_values']:
-                    value = ontology[domain][slot]['possible_values'].index(value)
-                else:
-                    value = -1  # If value is not in ontology then we do not penalise the model
-                labs.append(value)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['state_labels-' + domslot] = labels
-
-    # Create requestable slot labels
-    for domslot in tqdm(requestable_slots):
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial:
-                domain, slot = domslot.split('-', 1)
-                acts = [act['intent'] for act in turn['dialogue_acts']
-                        if act['domain'] == domain and act['slot'] == slot]
-                if acts:
-                    act_ = acts[0]
-                    if act_ == 'request':
-                        labs.append(1)
-                    else:
-                        labs.append(0)
-                else:
-                    labs.append(0)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['request_labels-' + domslot] = labels
-
-    # General act labels (1-goodbye, 2-thank you)
-    labels = []
-    for dial in tqdm(data):
-        labs = []
-        for turn in dial:
-            acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
-            if acts:
-                if 'bye' in acts:
-                    labs.append(1)
-                else:
-                    labs.append(2)
-            else:
-                labs.append(0)
-            if len(labs) >= max_turns:
-                break
-        labs = labs + [-1] * (max_turns - len(labs))
-        labels.append(labs)
-
-    labels = torch.tensor(labels)
-    features['general_act_labels'] = labels
-
-    # Create active domain labels
-    for domain in tqdm(domains):
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial:
-                if domain in turn['active_domains']:
-                    labs.append(1)
-                else:
-                    labs.append(0)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['active_domain_labels-' + domain] = labels
-
-    del labels
-
-    return features
-
-
-class UnifiedFormatDataset(Dataset):
-    """
-    Class for preprocessing, and storing data easily from the Convlab3 unified format.
-
-    Attributes:
-        dataset_dict (dict): Dictionary containing all the data in dataset
-        ontology (dict): Set of all domain-slot-value triplets in the ontology of the model
-        features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model
-    """
-    def __init__(self,
-                 dataset_name: str,
-                 set_type: str,
-                 tokenizer: PreTrainedTokenizer,
-                 max_turns: int = 12,
-                 max_seq_len: int = 64,
-                 train_ratio: float = 1.0,
-                 seed: int = 0,
-                 data: dict = None,
-                 ontology: dict = None):
-        """
-        Args:
-            dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +)
-            set_type (str): Subset of the dataset to load (train, validation or test)
-            tokenizer (transformers tokenizer): Tokenizer for the encoder model used
-            max_turns (int): Maximum numbers of turns in a dialogue
-            max_seq_len (int): Maximum number of tokens in a dialogue turn
-            train_ratio (float): Fraction of training data to use during training
-            seed (int): Seed governing random order of ids for subsampling
-            data (dict): Dataset features for loading from dict
-            ontology (dict): Ontology dict for loading from dict
-        """
-        if data is not None:
-            self.ontology = ontology
-            self.features = data
-        else:
-            if '+' in dataset_name:
-                dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')]
-            else:
-                dataset_args = [{"dataset_name": dataset_name}]
-            if train_ratio != 1.0:
-                for dataset_args_ in dataset_args:
-                    dataset_args_['dial_ids_order'] = seed
-                    dataset_args_['split2ratio'] = {'train': train_ratio, 'validation': train_ratio}
-            self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
-            self.ontology = get_ontology_slots(dataset_name)
-            values = [get_values_from_data(dataset) for dataset in self.dataset_dicts]
-            self.ontology = ontology_add_values(self.ontology, combine_value_sets(values))
-            self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts))
-
-            data = [load_dst_data(dataset_dict, data_split=set_type, speaker='all',
-                                  dialogue_acts=True, split_to_turn=False)
-                    for dataset_dict in self.dataset_dicts]
-            data_list = [data_[set_type] for data_ in data]
-
-            data = []
-            for data_ in data_list:
-                data += extract_dialogues(data_)
-            self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len)
-
-    def __getitem__(self, index: int) -> dict:
-        """
-        Obtain dialogues with specific ids from dataset
-
-        Args:
-            index (int/list/tensor): Index/indices of dialogues to get
-
-        Returns:
-            features (dict): All inputs and labels required to train the model
-        """
-        return {label: self.features[label][index] for label in self.features
-                if self.features[label] is not None}
-
-    def __len__(self):
-        """
-        Get number of dialogues in the dataset
-
-        Returns:
-            len (int): Number of dialogues in the dataset object
-        """
-        return self.features['input_ids'].size(0)
-
-    def resample(self, size: int = None) -> Dataset:
-        """
-        Resample subset of the dataset
-
-        Args:
-            size (int): Number of dialogues to sample
-
-        Returns:
-            self (Dataset): Dataset object
-        """
-        # If no subset size is specified we resample a set with the same size as the full dataset
-        n_dialogues = self.__len__()
-        if not size:
-            size = n_dialogues
-
-        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
-        self.features = self.__getitem__(dialogues)
-        
-        return self
-
-    def to(self, device):
-        """
-        Map all data to a device
-
-        Args:
-            device (torch device): Device to map data to
-        """
-        self.device = device
-        self.features = {label: self.features[label].to(device) for label in self.features
-                         if self.features[label] is not None}
-
-    @classmethod
-    def from_datadict(cls, data: dict, ontology: dict):
-        return cls(None, None, None, data=data, ontology=ontology)
-
-
-def get_dataloader(dataset_name: str,
-                   set_type: str,
-                   batch_size: int,
-                   tokenizer: PreTrainedTokenizer,
-                   max_turns: int = 12,
-                   max_seq_len: int = 64,
-                   device='cpu',
-                   resampled_size: int = None,
-                   train_ratio: float = 1.0,
-                   seed: int = 0) -> DataLoader:
-    '''
-    Module to create torch dataloaders
-
-    Args:
-        dataset_name (str): Name of the dataset to load
-        set_type (str): Subset of the dataset to load (train, validation or test)
-        batch_size (int): Batch size for the dataloader
-        tokenizer (transformers tokenizer): Tokenizer for the encoder model used
-        max_turns (int): Maximum numbers of turns in a dialogue
-        max_seq_len (int): Maximum number of tokens in a dialogue turn
-        device (torch device): Device to map data to
-        resampled_size (int): Number of dialogues to sample
-        train_ratio (float): Ratio of training data to use for training
-        seed (int): Seed governing random order of ids for subsampling
-
-    Returns:
-        loader (torch dataloader): Dataloader to train and evaluate the setsumbt model
-    '''
-    data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio,
-                                seed=seed)
-    data.to(device)
-
-    if resampled_size:
-        data = data.resample(resampled_size)
-
-    if set_type in ['test', 'validation']:
-        sampler = SequentialSampler(data)
-    else:
-        sampler = RandomSampler(data)
-    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
-
-    return loader
-
-
-def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader:
-    """
-    Change the batch size of a preloaded loader
-
-    Args:
-        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
-        batch_size (int): Batch size for the dataloader
-
-    Returns:
-        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
-    """
-
-    if 'SequentialSampler' in str(loader.sampler):
-        sampler = SequentialSampler(loader.dataset)
-    else:
-        sampler = RandomSampler(loader.dataset)
-    loader = DataLoader(loader.dataset, sampler=sampler, batch_size=batch_size)
-
-    return loader
diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/dataset/utils.py
deleted file mode 100644
index dded04aac31102d9e81f5de288b9399b0c190d0c..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/utils.py
+++ /dev/null
@@ -1,404 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Convlab3 Unified dataset data processing utilities"""
-
-from convlab.util import load_ontology, load_dst_data, load_nlu_data
-from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
-
-
-def get_ontology_slots(dataset_name: str) -> dict:
-    """
-    Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
-
-    Args:
-        dataset_name (str): Dataset name
-
-    Returns:
-        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
-    """
-    dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name]
-    ontology_slots = dict()
-    for dataset_name in dataset_names:
-        ontology = load_ontology(dataset_name)
-        domains = [domain for domain in ontology['domains'] if domain not in ['booking', 'general']]
-        for domain in domains:
-            domain_name = DOMAINS_MAP.get(domain, domain.lower())
-            if domain_name not in ontology_slots:
-                ontology_slots[domain_name] = dict()
-            for slot, slot_info in ontology['domains'][domain]['slots'].items():
-                if slot not in ontology_slots[domain_name]:
-                    ontology_slots[domain_name][slot] = {'description': slot_info['description'],
-                                                         'possible_values': list()}
-                if slot_info['is_categorical']:
-                    ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values']
-
-                ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values']))
-
-    return ontology_slots
-
-
-def get_values_from_data(dataset: dict) -> dict:
-    """
-    Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
-
-    Args:
-        dataset (dict): Dataset dictionary obtained using the load_dataset function
-
-    Returns:
-        value_sets (dict): Dictionary containing possible values obtained from dataset
-    """
-    data = load_dst_data(dataset, data_split='all', speaker='user')
-    value_sets = {}
-    for set_type, dataset in data.items():
-        for turn in dataset:
-            for domain, substate in turn['state'].items():
-                domain_name = DOMAINS_MAP.get(domain, domain.lower())
-                if domain not in value_sets:
-                    value_sets[domain_name] = {}
-                for slot, value in substate.items():
-                    if slot not in value_sets[domain_name]:
-                        value_sets[domain_name][slot] = []
-                    if value and value not in value_sets[domain_name][slot]:
-                        value_sets[domain_name][slot].append(value)
-
-    return clean_values(value_sets)
-
-
-def combine_value_sets(value_sets: list) -> dict:
-    """
-    Function to combine value sets extracted from different datasets
-
-    Args:
-        value_sets (list): List of value sets extracted using the get_values_from_data function
-
-    Returns:
-        value_set (dict): Dictionary containing possible values obtained from datasets
-    """
-    value_set = value_sets[0]
-    for _value_set in value_sets[1:]:
-        for domain, domain_info in _value_set.items():
-            for slot, possible_values in domain_info.items():
-                if domain not in value_set:
-                    value_set[domain] = dict()
-                if slot not in value_set[domain]:
-                    value_set[domain][slot] = list()
-                value_set[domain][slot] += _value_set[domain][slot]
-                value_set[domain][slot] = list(set(value_set[domain][slot]))
-
-    return value_set
-
-
-def clean_values(value_sets: dict, value_map: dict = VALUE_MAP) -> dict:
-    """
-    Function to clean up the possible value sets extracted from the states in the dataset
-
-    Args:
-        value_sets (dict): Dictionary containing possible values obtained from dataset
-        value_map (dict): Label map to avoid duplication and typos in values
-
-    Returns:
-        clean_vals (dict): Cleaned Dictionary containing possible values obtained from dataset
-    """
-    clean_vals = {}
-    for domain, subset in value_sets.items():
-        clean_vals[domain] = {}
-        for slot, values in subset.items():
-            # Remove pipe separated values
-            values = list(set([val.split('|', 1)[0] for val in values]))
-
-            # Map values using value_map
-            for old, new in value_map.items():
-                values = list(set([val.replace(old, new) for val in values]))
-
-            # Remove empty and dontcare from possible value sets
-            values = [val for val in values if val not in ['', 'dontcare']]
-
-            # MultiWOZ specific value sets for quantity, time and boolean slots
-            if 'people' in slot or 'duration' in slot or 'stay' in slot:
-                values = QUANTITIES
-            elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
-                values = TIME
-            elif 'parking' in slot or 'internet' in slot:
-                values = ['yes', 'no']
-
-            clean_vals[domain][slot] = values
-
-    return clean_vals
-
-
-def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict:
-    """
-    Add value sets obtained from the dataset to the ontology
-    Args:
-        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
-        value_sets (dict): Cleaned Dictionary containing possible values obtained from dataset
-
-    Returns:
-        ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and possible value sets
-    """
-    ontology = {}
-    for domain in sorted(ontology_slots):
-        ontology[domain] = {}
-        for slot in sorted(ontology_slots[domain]):
-            if not ontology_slots[domain][slot]['possible_values']:
-                if domain in value_sets:
-                    if slot in value_sets[domain]:
-                        ontology_slots[domain][slot]['possible_values'] = value_sets[domain][slot]
-            if ontology_slots[domain][slot]['possible_values']:
-                values = sorted(ontology_slots[domain][slot]['possible_values'])
-                ontology_slots[domain][slot]['possible_values'] = ['none', 'do not care'] + values
-
-            ontology[domain][slot] = ontology_slots[domain][slot]
-
-    return ontology
-
-
-def get_requestable_slots(datasets: list) -> dict:
-    """
-    Function to get set of requestable slots from the dataset action labels.
-    Args:
-        dataset (dict): Dataset dictionary obtained using the load_dataset function
-
-    Returns:
-        slots (dict): Dictionary containing requestable domain-slot pairs
-    """
-    datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets]
-
-    slots = {}
-    for data in datasets:
-        for set_type, subset in data.items():
-            for turn in subset:
-                requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request']
-                requests += [act for act in turn['dialogue_acts']['non-categorical'] if act['intent'] == 'request']
-                requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request']
-                requests = [(act['domain'], act['slot']) for act in requests]
-                for domain, slot in requests:
-                    domain_name = DOMAINS_MAP.get(domain, domain.lower())
-                    if domain_name not in slots:
-                        slots[domain_name] = []
-                    slots[domain_name].append(slot)
-
-    slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()}
-
-    return slots
-
-
-def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict) -> dict:
-    """
-    Add requestable slots obtained from the dataset to the ontology
-    Args:
-        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
-        requestable_slots (dict): Dictionary containing requestable domain-slot pairs
-
-    Returns:
-        ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and
-        possible value sets including requests
-    """
-    for domain in ontology_slots:
-        for slot in ontology_slots[domain]:
-            if domain in requestable_slots:
-                if slot in requestable_slots[domain]:
-                    ontology_slots[domain][slot]['possible_values'].append('?')
-
-    return ontology_slots
-
-
-def extract_turns(dialogue: list) -> list:
-    """
-    Extract the required information from the data provided by unified loader
-    Args:
-        dialogue (list): List of turns within a dialogue
-
-    Returns:
-        turns (list): List of turns within a dialogue
-    """
-    turns = []
-    turn_info = {}
-    for turn in dialogue:
-        if turn['speaker'] == 'system':
-            turn_info['system_utterance'] = turn['utterance']
-
-        # System utterance in the first turn is always empty as conversation is initiated by the user
-        if turn['utt_idx'] == 1:
-            turn_info['system_utterance'] = ''
-
-        if turn['speaker'] == 'user':
-            turn_info['user_utterance'] = turn['utterance']
-
-            # Inform acts not required by model
-            turn_info['dialogue_acts'] = [act for act in turn['dialogue_acts']['categorical']
-                                          if act['intent'] not in ['inform']]
-            turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['non-categorical']
-                                           if act['intent'] not in ['inform']]
-            turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['binary']
-                                           if act['intent'] not in ['inform']]
-
-            turn_info['state'] = turn['state']
-
-        if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
-            turns.append(turn_info)
-            turn_info = {}
-
-    return turns
-
-
-def clean_states(turns: list) -> list:
-    """
-    Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology)
-    Args:
-        turns (list): List of turns within a dialogue
-
-    Returns:
-        clean_turns (list): List of turns within a dialogue
-    """
-    clean_turns = []
-    for turn in turns:
-        clean_state = {}
-        clean_acts = []
-        for act in turn['dialogue_acts']:
-            domain = act['domain']
-            act['domain'] = DOMAINS_MAP.get(domain, domain.lower())
-            clean_acts.append(act)
-        for domain, subset in turn['state'].items():
-            domain_name = DOMAINS_MAP.get(domain, domain.lower())
-            clean_state[domain_name] = {}
-            for slot, value in subset.items():
-                # Remove pipe separated values
-                value = value.split('|', 1)[0]
-
-                # Map values using value_map
-                for old, new in VALUE_MAP.items():
-                    value = value.replace(old, new)
-
-                # Map dontcare to "do not care" and empty to 'none'
-                value = value.replace('dontcare', 'do not care')
-                value = value if value else 'none'
-
-                # Map quantity values to the integer quantity value
-                if 'people' in slot or 'duration' in slot or 'stay' in slot:
-                    try:
-                        if value not in ['do not care', 'none']:
-                            value = int(value)
-                            value = str(value) if value < 10 else QUANTITIES[-1]
-                    except:
-                        value = value
-                # Map time values to the most appropriate value in the standard time set
-                elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
-                    try:
-                        if value not in ['do not care', 'none']:
-                            # Strip after/before from time value
-                            value = value.replace('after ', '').replace('before ', '')
-                            # Extract hours and minutes from different possible formats
-                            if ':' not in value and len(value) == 4:
-                                h, m = value[:2], value[2:]
-                            elif len(value) == 1:
-                                h = int(value)
-                                m = 0
-                            elif 'pm' in value:
-                                h = int(value.replace('pm', '')) + 12
-                                m = 0
-                            elif 'am' in value:
-                                h = int(value.replace('pm', ''))
-                                m = 0
-                            elif ':' in value:
-                                h, m = value.split(':')
-                            elif ';' in value:
-                                h, m = value.split(';')
-                            # Map to closest 5 minutes
-                            if int(m) % 5 != 0:
-                                m = round(int(m) / 5) * 5
-                                h = int(h)
-                                if m == 60:
-                                    m = 0
-                                    h += 1
-                                if h >= 24:
-                                    h -= 24
-                            # Set in standard 24 hour format
-                            h, m = int(h), int(m)
-                            value = '%02i:%02i' % (h, m)
-                    except:
-                        value = value
-                # Map boolean slots to yes/no value
-                elif 'parking' in slot or 'internet' in slot:
-                    if value not in ['do not care', 'none']:
-                        if value == 'free':
-                            value = 'yes'
-                        elif True in [v in value.lower() for v in ['yes', 'no']]:
-                            value = [v for v in ['yes', 'no'] if v in value][0]
-
-                clean_state[domain_name][slot] = value
-        turn['state'] = clean_state
-        turn['dialogue_acts'] = clean_acts
-        clean_turns.append(turn)
-
-    return clean_turns
-
-
-def get_active_domains(turns: list) -> list:
-    """
-    Get active domains at each turn in a dialogue
-    Args:
-        turns (list): List of turns within a dialogue
-
-    Returns:
-        turns (list): List of turns within a dialogue
-    """
-    for turn_id in range(len(turns)):
-        # At first turn all domains with not none values in the state are active
-        if turn_id == 0:
-            domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none']
-            domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']]
-            domains = [DOMAINS_MAP.get(domain, domain.lower()) for domain in domains]
-            turns[turn_id]['active_domains'] = list(set(domains))
-        else:
-            # Use changes in domains to identify active domains
-            domains = []
-            for domain, substate in turns[turn_id]['state'].items():
-                domain_name = DOMAINS_MAP.get(domain, domain.lower())
-                for slot, value in substate.items():
-                    if value != turns[turn_id - 1]['state'][domain][slot]:
-                        val = value
-                    else:
-                        val = 'none'
-                    if value == 'none':
-                        val = 'none'
-                    if val != 'none':
-                        domains.append(domain_name)
-            # Add all domains activated by a user action
-            domains += [act['domain'] for act in turns[turn_id]['dialogue_acts']
-                        if act['domain'] in turns[turn_id]['state']]
-            turns[turn_id]['active_domains'] = list(set(domains))
-
-    return turns
-
-
-def extract_dialogues(data: list) -> list:
-    """
-    Extract all dialogues from dataset
-    Args:
-        data (list): List of all dialogues in a subset of the data
-
-    Returns:
-        dialogues (list): List of all extracted dialogues
-    """
-    dialogues = []
-    for dial in data:
-        turns = extract_turns(dial['turns'])
-        turns = clean_states(turns)
-        turns = get_active_domains(turns)
-        dialogues.append(turns)
-
-    return dialogues
diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/dataset/value_maps.py
deleted file mode 100644
index 619600a7b0a57096918058ff117aa2ca5aac864a..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/dataset/value_maps.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Convlab3 Unified dataset value maps"""
-
-
-# MultiWOZ specific label map to avoid duplication and typos in values
-VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
-             'cityroomz': 'city roomz', '  ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
-             'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
-             'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
-             'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
-             'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
-
-
-# Domain map for SGD and TM Data
-DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus',
-               'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events',
-               'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights',
-               'Flights_4': 'flights', 'Homes_1': 'homes', 'Homes_2': 'homes', 'Hotels_1': 'hotel',
-               'Hotels_2': 'hotel', 'Hotels_3': 'hotel', 'Hotels_4': 'hotel', 'Media_1': 'media',
-               'Media_2': 'media', 'Media_3': 'media', 'Messaging_1': 'messaging', 'Movies_1': 'movies',
-               'Movies_2': 'movies', 'Movies_3': 'movies', 'Music_1': 'music', 'Music_2': 'music', 'Music_3': 'music',
-               'Payment_1': 'payment', 'RentalCars_1': 'rentalcars', 'RentalCars_2': 'rentalcars',
-               'RentalCars_3': 'rentalcars', 'Restaurants_1': 'restaurant', 'Restaurants_2': 'restaurant',
-               'RideSharing_1': 'ridesharing', 'RideSharing_2': 'ridesharing', 'Services_1': 'services',
-               'Services_2': 'services', 'Services_3': 'services', 'Services_4': 'services', 'Trains_1': 'train',
-               'Travel_1': 'travel', 'Weather_1': 'weather', 'movie_ticket': 'movies',
-               'restaurant_reservation': 'restaurant', 'coffee_ordering': 'coffee', 'pizza_ordering': 'takeout',
-               'auto_repair': 'car_repairs', 'flights': 'flights', 'food-ordering': 'takeout', 'hotels': 'hotel',
-               'movies': 'movies', 'music': 'music', 'restaurant-search': 'restaurant', 'sports': 'sports',
-               'movie': 'movies'}
-
-
-# Generic value sets for quantity and time slots
-QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
-TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
-TIME = ['%02i:%02i' % t for l in TIME for t in l]
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py
index e54ccaa224d80aa7a009f7cf538900e289d5dd4f..e0d87bb964041cae19d58351cd6b31f6d836f125 100644
--- a/convlab/dst/setsumbt/distillation_setup.py
+++ b/convlab/dst/setsumbt/distillation_setup.py
@@ -1,57 +1,42 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Get ensemble predictions and build distillation dataloaders"""
-
 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
 import os
-import json
 
 import torch
-from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
+import transformers
+from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
+from transformers import RobertaConfig, BertConfig
 from tqdm import tqdm
 
-from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset
+import convlab
+from convlab.dst.setsumbt.multiwoz.dataset.multiwoz21 import EnsembleMultiWoz21
 from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT
-from convlab.dst.setsumbt.modeling import training
 
-DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+DEVICE = 'cuda'
 
 
-def main():
+def args_parser():
     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
     parser.add_argument('--model_path', type=str)
     parser.add_argument('--model_type', type=str)
     parser.add_argument('--set_type', type=str)
     parser.add_argument('--batch_size', type=int)
+    parser.add_argument('--ensemble_size', type=int)
     parser.add_argument('--reduction', type=str, default='mean')
     parser.add_argument('--get_ensemble_distributions', action='store_true')
     parser.add_argument('--build_dataloaders', action='store_true')
-    args = parser.parse_args()
+    
+    return parser.parse_args()
+
+
+def main():
+    args = args_parser()
 
     if args.get_ensemble_distributions:
         get_ensemble_distributions(args)
     elif args.build_dataloaders:
         path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
         data = torch.load(path)
-
-        reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
-        ontology = json.load(reader)
-        reader.close()
-
-        loader = get_loader(data, ontology, args.set_type, args.batch_size)
+        loader = get_loader(data, args.set_type, args.batch_size)
 
         path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
         torch.save(loader, path)
@@ -59,22 +44,10 @@ def main():
         raise NameError("NotImplemented")
 
 
-def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader:
-    """
-    Build dataloader from ensemble prediction data
-
-    Args:
-        data: Dictionary of ensemble predictions
-        ontology: Data ontology
-        set_type: Data subset (train/validation/test)
-        batch_size: Number of dialogues per batch
-
-    Returns:
-        loader: Data loader object
-    """
+def get_loader(data, set_type='train', batch_size=3):
     data = flatten_data(data)
     data = do_label_padding(data)
-    data = UnifiedFormatDataset.from_datadict(data, ontology)
+    data = EnsembleMultiWoz21(data)
     if set_type == 'train':
         sampler = RandomSampler(data)
     else:
@@ -84,16 +57,7 @@ def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size:
     return loader
 
 
-def do_label_padding(data: dict) -> dict:
-    """
-    Add padding to the ensemble predictions (used as labels in distillation)
-
-    Args:
-        data: Dictionary of ensemble predictions
-
-    Returns:
-        data: Padded ensemble predictions
-    """
+def do_label_padding(data):
     if 'attention_mask' in data:
         dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
     else:
@@ -106,17 +70,13 @@ def do_label_padding(data: dict) -> dict:
     return data
 
 
-def flatten_data(data: dict) -> dict:
-    """
-    Map data to flattened feature format used in training
-    Args:
-        data: Ensemble prediction data
-
-    Returns:
-        data: Flattened ensemble prediction data
-    """
+map_dict = {'belief_state': 'belief', 'greeting_act_belief': 'goodbye_belief',
+            'state_labels': 'labels', 'request_labels': 'request',
+            'domain_labels': 'active', 'greeting_labels': 'goodbye'}
+def flatten_data(data):
     data_new = dict()
     for label, feats in data.items():
+        label = map_dict.get(label, label)
         if type(feats) == dict:
             for label_, feats_ in feats.items():
                 data_new[label + '-' + label_] = feats_
@@ -127,11 +87,13 @@ def flatten_data(data: dict) -> dict:
 
 
 def get_ensemble_distributions(args):
-    """
-    Load data and get ensemble predictions
-    Args:
-        args: Runtime arguments
-    """
+    if args.model_type == 'roberta':
+        config = RobertaConfig
+    elif args.model_type == 'bert':
+        config = BertConfig
+    config = config.from_pretrained(args.model_path)
+    config.ensemble_size = args.ensemble_size
+
     device = DEVICE
 
     model = EnsembleSetSUMBT.from_pretrained(args.model_path)
@@ -145,7 +107,16 @@ def get_ensemble_distributions(args):
     dataloader = torch.load(dataloader)
     database = torch.load(database)
 
-    training.set_ontology_embeddings(model, database)
+    # Get slot and value embeddings
+    slots = {slot: val for slot, val in database.items()}
+    values = {slot: val[1] for slot, val in database.items()}
+    del database
+
+    # Load model ontology
+    model.add_slot_candidates(slots)
+    for slot in model.informable_slot_ids:
+        model.add_value_candidates(slot, values[slot], replace=True)
+    del slots, values
 
     print('Environment set up.')
 
@@ -154,24 +125,18 @@ def get_ensemble_distributions(args):
     attention_mask = []
     state_labels = {slot: [] for slot in model.informable_slot_ids}
     request_labels = {slot: [] for slot in model.requestable_slot_ids}
-    active_domain_labels = {domain: [] for domain in model.domain_ids}
-    general_act_labels = []
-
-    is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None
-
+    domain_labels = {domain: [] for domain in model.domain_ids}
+    greeting_labels = []
     belief_state = {slot: [] for slot in model.informable_slot_ids}
-    request_probs = {slot: [] for slot in model.requestable_slot_ids}
-    active_domain_probs = {domain: [] for domain in model.domain_ids}
-    general_act_probs = []
+    request_belief = {slot: [] for slot in model.requestable_slot_ids}
+    domain_belief = {domain: [] for domain in model.domain_ids}
+    greeting_act_belief = []
     model.eval()
     for batch in tqdm(dataloader, desc='Batch:'):
         ids = batch['input_ids']
         tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None
         mask = batch['attention_mask'] if 'attention_mask' in batch else None
 
-        if 'is_noisy' in batch:
-            is_noisy.append(batch['is_noisy'])
-
         input_ids.append(ids)
         token_type_ids.append(tt_ids)
         attention_mask.append(mask)
@@ -181,59 +146,57 @@ def get_ensemble_distributions(args):
         mask = mask.to(device) if mask is not None else None
 
         for slot in state_labels:
-            state_labels[slot].append(batch['state_labels-' + slot])
-        if model.config.predict_actions:
+            state_labels[slot].append(batch['labels-' + slot])
+        if model.config.predict_intents:
             for slot in request_labels:
-                request_labels[slot].append(batch['request_labels-' + slot])
+                request_labels[slot].append(batch['request-' + slot])
             for domain in domain_labels:
-                domain_labels[domain].append(batch['active_domain_labels-' + domain])
-            greeting_labels.append(batch['general_act_labels'])
+                domain_labels[domain].append(batch['active-' + domain])
+            greeting_labels.append(batch['goodbye'])
 
         with torch.no_grad():
-            p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction)
+            p, p_req, p_dom, p_bye, _ = model(ids, mask, tt_ids,
+                                            reduction=args.reduction)
 
         for slot in belief_state:
             belief_state[slot].append(p[slot].cpu())
-        if model.config.predict_actions:
-            for slot in request_probs:
-                request_probs[slot].append(p_req[slot].cpu())
-            for domain in active_domain_probs:
-                active_domain_probs[domain].append(p_dom[domain].cpu())
-            general_act_probs.append(p_gen.cpu())
+        if model.config.predict_intents:
+            for slot in request_belief:
+                request_belief[slot].append(p_req[slot].cpu())
+            for domain in domain_belief:
+                domain_belief[domain].append(p_dom[domain].cpu())
+            greeting_act_belief.append(p_bye.cpu())
     
     input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None
     token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None
     attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None
-    is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None
 
     state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()}
-    if model.config.predict_actions:
+    if model.config.predict_intents:
         request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()}
-        active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()}
-        general_act_labels = torch.cat(general_act_labels, 0)
+        domain_labels = {domain: torch.cat(l, 0) for domain, l in domain_labels.items()}
+        greeting_labels = torch.cat(greeting_labels, 0)
     
     belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()}
-    if model.config.predict_actions:
-        request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()}
-        active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()}
-        general_act_probs = torch.cat(general_act_probs, 0)
+    if model.config.predict_intents:
+        request_belief = {slot: torch.cat(p, 0) for slot, p in request_belief.items()}
+        domain_belief = {domain: torch.cat(p, 0) for domain, p in domain_belief.items()}
+        greeting_act_belief = torch.cat(greeting_act_belief, 0)
 
     data = {'input_ids': input_ids}
     if token_type_ids is not None:
         data['token_type_ids'] = token_type_ids
     if attention_mask is not None:
         data['attention_mask'] = attention_mask
-    if is_noisy is not None:
-        data['is_noisy'] = is_noisy
     data['state_labels'] = state_labels
     data['belief_state'] = belief_state
-    if model.config.predict_actions:
+    if model.config.predict_intents:
         data['request_labels'] = request_labels
-        data['active_domain_labels'] = active_domain_labels
-        data['general_act_labels'] = general_act_labels
-        data['request_probs'] = request_probs
-        data['active_domain_probs'] = active_domain_probs
-        data['general_act_probs'] = general_act_probs
+        data['domain_labels'] = domain_labels
+        data['greeting_labels'] = greeting_labels
+        data['request_belief'] = request_belief
+        data['domain_belief'] = domain_belief
+        data['greeting_act_belief'] = greeting_act_belief
 
     file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
     torch.save(data, file)
diff --git a/convlab/dst/setsumbt/do/calibration.py b/convlab/dst/setsumbt/do/calibration.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ee058eca882ce7e10937f9640d143b88e57f5e
--- /dev/null
+++ b/convlab/dst/setsumbt/do/calibration.py
@@ -0,0 +1,481 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Run SetSUMBT Calibration"""
+
+import logging
+import random
+import os
+from shutil import copy2 as copy
+
+import torch
+from transformers import (BertModel, BertConfig, BertTokenizer,
+                          RobertaModel, RobertaConfig, RobertaTokenizer,
+                          AdamW, get_linear_schedule_with_warmup)
+from tqdm import tqdm, trange
+from tensorboardX import SummaryWriter
+from torch.distributions import Categorical
+
+from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
+from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
+from convlab.dst.setsumbt.multiwoz import multiwoz21
+from convlab.dst.setsumbt.multiwoz import ontology as embeddings
+from convlab.dst.setsumbt.utils import get_args, upload_local_directory_to_gcs, update_args
+from convlab.dst.setsumbt.modeling import calibration_utils
+from convlab.dst.setsumbt.modeling import ensemble_utils
+from convlab.dst.setsumbt.loss.ece import ece, jg_ece, l2_acc
+
+
+# Datasets
+DATASETS = {
+    'multiwoz21': multiwoz21
+}
+
+MODELS = {
+    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
+    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
+}
+
+
+def main(args=None, config=None):
+    # Get arguments
+    if args is None:
+        args, config = get_args(MODELS)
+
+    # Select Dataset object
+    if args.dataset in DATASETS:
+        Dataset = DATASETS[args.dataset]
+    else:
+        raise NameError('NotImplemented')
+
+    if args.model_type in MODELS:
+        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
+    else:
+        raise NameError('NotImplemented')
+
+    # Set up output directory
+    OUTPUT_DIR = args.output_dir
+    if not os.path.exists(OUTPUT_DIR):
+        os.mkdir(OUTPUT_DIR)
+    args.output_dir = OUTPUT_DIR
+    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+
+    paths = os.listdir(args.output_dir) if os.path.exists(
+        args.output_dir) else []
+    if 'pytorch_model.bin' in paths and 'config.json' in paths:
+        args.model_name_or_path = args.output_dir
+        config = ConfigClass.from_pretrained(args.model_name_or_path)
+    else:
+        paths = os.listdir(args.output_dir) if os.path.exists(
+            args.output_dir) else []
+        paths = [os.path.join(args.output_dir, p)
+                 for p in paths if 'checkpoint-' in p]
+        if paths:
+            paths = paths[0]
+            args.model_name_or_path = paths
+            config = ConfigClass.from_pretrained(args.model_name_or_path)
+
+    if args.ensemble_size > 0:
+        paths = os.listdir(args.output_dir) if os.path.exists(
+            args.output_dir) else []
+        paths = [os.path.join(args.output_dir, p)
+                 for p in paths if 'ensemble_' in p]
+        if paths:
+            args.model_name_or_path = args.output_dir
+            config = ConfigClass.from_pretrained(args.model_name_or_path)
+
+    args = update_args(args, config)
+
+    # Set up data directory
+    DATA_DIR = args.data_dir
+    Dataset.set_datadir(DATA_DIR)
+    embeddings.set_datadir(DATA_DIR)
+
+    if args.shrink_active_domains and args.dataset == 'multiwoz21':
+        Dataset.set_active_domains(
+            ['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
+
+    # Download and preprocess
+    Dataset.create_examples(
+        args.max_turn_len, args.predict_intents, args.force_processing)
+
+    # Create logger
+    global logger
+    logger = logging.getLogger(__name__)
+    logger.setLevel(logging.INFO)
+
+    formatter = logging.Formatter(
+        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+    if 'stream' not in args.logging_path:
+        fh = logging.FileHandler(args.logging_path)
+        fh.setLevel(logging.INFO)
+        fh.setFormatter(formatter)
+        logger.addHandler(fh)
+    else:
+        ch = logging.StreamHandler()
+        ch.setLevel(level=logging.INFO)
+        ch.setFormatter(formatter)
+        logger.addHandler(ch)
+
+    # Get device
+    if torch.cuda.is_available() and args.n_gpu > 0:
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+        args.n_gpu = 0
+
+    if args.n_gpu == 0:
+        args.fp16 = False
+
+    # Set up model training/evaluation
+    calibration.set_logger(logger, None)
+    calibration.set_seed(args)
+
+    if args.ensemble_size > 0:
+        ensemble.set_logger(logger, tb_writer)
+        ensemble_utils.set_seed(args)
+
+    # Perform tasks
+
+    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
+        pred = torch.load(os.path.join(
+            OUTPUT_DIR, 'predictions', 'test.predictions'))
+        labels = pred['labels']
+        belief_states = pred['belief_states']
+        if 'request_labels' in pred:
+            request_labels = pred['request_labels']
+            request_belief = pred['request_belief']
+            domain_labels = pred['domain_labels']
+            domain_belief = pred['domain_belief']
+            greeting_labels = pred['greeting_labels']
+            greeting_belief = pred['greeting_belief']
+        else:
+            request_belief = None
+        del pred
+    elif args.ensemble_size > 0:
+        # Get training batch loaders and ontology embeddings
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
+            test_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'test.db'))
+        else:
+            # Create Tokenizer and embedding model for Data Loaders and ontology
+            encoder = CandidateEncoderModel.from_pretrained(
+                config.candidate_embedding_model_name)
+            tokenizer = Tokenizer(config.candidate_embedding_model_name)
+            embeddings.get_slot_candidate_embeddings(
+                'test', args, tokenizer, encoder)
+            test_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'test.db'))
+
+        exists = False
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size == args.test_batch_size:
+                exists = True
+        if not exists:
+            tokenizer = Tokenizer(config.candidate_embedding_model_name)
+            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                     config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        config, models = ensemble.get_models(
+            args.model_name_or_path, device, ConfigClass, SetSumbtModel)
+
+        belief_states, labels = ensemble_utils.get_predictions(
+            args, models, device, test_dataloader, test_slots)
+        torch.save({'belief_states': belief_states, 'labels': labels},
+                   os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
+    else:
+        # Get training batch loaders and ontology embeddings
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
+            test_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'test.db'))
+        else:
+            # Create Tokenizer and embedding model for Data Loaders and ontology
+            encoder = CandidateEncoderModel.from_pretrained(
+                config.candidate_embedding_model_name)
+            tokenizer = Tokenizer(config.candidate_embedding_model_name)
+            embeddings.get_slot_candidate_embeddings(
+                'test', args, tokenizer, encoder)
+            test_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'test.db'))
+
+        exists = False
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size == args.test_batch_size:
+                exists = True
+        if not exists:
+            tokenizer = Tokenizer(config.candidate_embedding_model_name)
+            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                     config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        # Initialise Model
+        model = SetSumbtModel.from_pretrained(
+            args.model_name_or_path, config=config)
+        model = model.to(device)
+
+        # Get slot and value embeddings
+        slots = {slot: test_slots[slot] for slot in test_slots}
+        values = {slot: test_slots[slot][1] for slot in test_slots}
+
+        # Load model ontology
+        model.add_slot_candidates(slots)
+        for slot in model.informable_slot_ids:
+            model.add_value_candidates(slot, values[slot], replace=True)
+
+        belief_states = calibration.get_predictions(
+            args, model, device, test_dataloader)
+        belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states
+        out = {'belief_states': belief_states, 'labels': labels,
+               'request_belief': request_belief, 'request_labels': request_labels,
+               'domain_belief': domain_belief, 'domain_labels': domain_labels,
+               'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels}
+        torch.save(out, os.path.join(
+            OUTPUT_DIR, 'predictions', 'test.predictions'))
+
+    # err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10)
+    #         for slot in belief_states]
+    # err = max(err)
+    # logger.info('ECE: %f' % err)
+
+    # Calculate calibration metrics
+
+    jg = jg_ece(belief_states, labels, 10)
+    logger.info('Joint Goal ECE: %f' % jg)
+
+    binary_states = {}
+    for slot, p in belief_states.items():
+        shp = p.shape
+        p = p.reshape(-1, p.size(-1))
+        p_ = torch.ones(p.shape).to(p.device) * 1e-8
+        p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
+        binary_states[slot] = p_.reshape(shp)
+    jg = jg_ece(binary_states, labels, 10)
+    logger.info('Joint Goal Binary ECE: %f' % jg)
+
+    bs = {slot: torch.cat((p[:, :, 0].unsqueeze(-1), p[:, :, 1:].max(-1)
+                          [0].unsqueeze(-1)), -1) for slot, p in belief_states.items()}
+    ls = {}
+    for slot, l in labels.items():
+        y = torch.zeros((l.size(0), l.size(1))).to(l.device)
+        dials, turns = torch.where(l > 0)
+        y[dials, turns] = 1.0
+        dials, turns = torch.where(l < 0)
+        y[dials, turns] = -1.0
+        ls[slot] = y
+
+    jg = jg_ece(bs, ls, 10)
+    logger.info('Slot presence ECE: %f' % jg)
+
+    binary_states = {}
+    for slot, p in bs.items():
+        shp = p.shape
+        p = p.reshape(-1, p.size(-1))
+        p_ = torch.ones(p.shape).to(p.device) * 1e-8
+        p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
+        binary_states[slot] = p_.reshape(shp)
+    jg = jg_ece(binary_states, ls, 10)
+    logger.info('Slot presence Binary ECE: %f' % jg)
+
+    jg_acc = 0.0
+    padding = torch.cat([item.unsqueeze(-1)
+                        for _, item in labels.items()], -1).sum(-1) * -1.0
+    padding = (padding == len(labels))
+    padding = padding.reshape(-1)
+    for slot in belief_states:
+        topn = args.accuracy_topn
+        p_ = belief_states[slot]
+        gold = labels[slot]
+
+        if p_.size(-1) <= topn:
+            topn = p_.size(-1) - 1
+        if topn <= 0:
+            topn = 1
+
+        if topn > 1:
+            labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
+            labs = labs[:, :topn]
+        else:
+            labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
+        acc = [lab in s for lab, s, pad in zip(
+            gold.reshape(-1), labs, padding) if not pad]
+        acc = torch.tensor(acc).float()
+
+        jg_acc += acc
+
+    n_turns = jg_acc.size(0)
+    sl_acc = sum(jg_acc / len(belief_states)).float()
+    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
+
+    sl_acc /= n_turns
+    jg_acc /= n_turns
+
+    logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc))
+
+    l2 = l2_acc(belief_states, labels, remove_belief=False)
+    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
+    l2 = l2_acc(belief_states, labels, remove_belief=True)
+    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
+
+    for slot in belief_states:
+        p = belief_states[slot]
+        p = p.reshape(-1, p.size(-1))
+        p = torch.cat(
+            (p[:, 0].unsqueeze(-1), p[:, 1:].max(-1)[0].unsqueeze(-1)), -1)
+        belief_states[slot] = p
+
+        l = labels[slot].reshape(-1)
+        l[l > 0] = 1
+        labels[slot] = l
+
+    f1 = 0.0
+    for slot in belief_states:
+        prd = belief_states[slot].argmax(-1)
+        tp = ((prd == 1) * (labels[slot] == 1)).sum()
+        fp = ((prd == 1) * (labels[slot] == 0)).sum()
+        fn = ((prd == 0) * (labels[slot] == 1)).sum()
+        if tp > 0:
+            f1 += tp / (tp + 0.5 * (fp + fn))
+    f1 /= len(belief_states)
+    logger.info(f'Trucated Goal F1 Score: {f1}')
+
+    l2 = l2_acc(belief_states, labels, remove_belief=False)
+    logger.info(f'Model L2 Norm Trucated Goal Accuracy: {l2}')
+    l2 = l2_acc(belief_states, labels, remove_belief=True)
+    logger.info(f'Binary Model L2 Norm Trucated Goal Accuracy: {l2}')
+
+    if request_belief is not None:
+        tp, fp, fn = 0.0, 0.0, 0.0
+        for slot in request_belief:
+            p = request_belief[slot]
+            l = request_labels[slot]
+
+            tp += (p.round().int() * (l == 1)).reshape(-1).float()
+            fp += (p.round().int() * (l == 0)).reshape(-1).float()
+            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
+        tp /= len(request_belief)
+        fp /= len(request_belief)
+        fn /= len(request_belief)
+        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
+        logger.info('Request F1 Score: %f' % f1.item())
+
+        for slot in request_belief:
+            p = request_belief[slot]
+            p = p.unsqueeze(-1)
+            p = torch.cat((1 - p, p), -1)
+            request_belief[slot] = p
+        jg = jg_ece(request_belief, request_labels, 10)
+        logger.info('Request Joint Goal ECE: %f' % jg)
+
+        binary_states = {}
+        for slot, p in request_belief.items():
+            shp = p.shape
+            p = p.reshape(-1, p.size(-1))
+            p_ = torch.ones(p.shape).to(p.device) * 1e-8
+            p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
+            binary_states[slot] = p_.reshape(shp)
+        jg = jg_ece(binary_states, request_labels, 10)
+        logger.info('Request Joint Goal Binary ECE: %f' % jg)
+
+        tp, fp, fn = 0.0, 0.0, 0.0
+        for dom in domain_belief:
+            p = domain_belief[dom]
+            l = domain_labels[dom]
+
+            tp += (p.round().int() * (l == 1)).reshape(-1).float()
+            fp += (p.round().int() * (l == 0)).reshape(-1).float()
+            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
+        tp /= len(domain_belief)
+        fp /= len(domain_belief)
+        fn /= len(domain_belief)
+        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
+        logger.info('Domain F1 Score: %f' % f1.item())
+
+        for dom in domain_belief:
+            p = domain_belief[dom]
+            p = p.unsqueeze(-1)
+            p = torch.cat((1 - p, p), -1)
+            domain_belief[dom] = p
+        jg = jg_ece(domain_belief, domain_labels, 10)
+        logger.info('Domain Joint Goal ECE: %f' % jg)
+
+        binary_states = {}
+        for slot, p in domain_belief.items():
+            shp = p.shape
+            p = p.reshape(-1, p.size(-1))
+            p_ = torch.ones(p.shape).to(p.device) * 1e-8
+            p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
+            binary_states[slot] = p_.reshape(shp)
+        jg = jg_ece(binary_states, domain_labels, 10)
+        logger.info('Domain Joint Goal Binary ECE: %f' % jg)
+
+        tp = ((greeting_belief.argmax(-1) > 0) *
+              (greeting_labels > 0)).reshape(-1).float().sum()
+        fp = ((greeting_belief.argmax(-1) > 0) *
+              (greeting_labels == 0)).reshape(-1).float().sum()
+        fn = ((greeting_belief.argmax(-1) == 0) *
+              (greeting_labels > 0)).reshape(-1).float().sum()
+        f1 = tp / (tp + 0.5 * (fp + fn))
+        logger.info('Greeting F1 Score: %f' % f1.item())
+
+        err = ece(greeting_belief.reshape(-1, greeting_belief.size(-1)),
+                  greeting_labels.reshape(-1), 10)
+        logger.info('Greetings ECE: %f' % err)
+
+        greeting_belief = greeting_belief.reshape(-1, greeting_belief.size(-1))
+        binary_states = torch.ones(greeting_belief.shape).to(
+            greeting_belief.device) * 1e-8
+        binary_states[range(greeting_belief.size(0)),
+                      greeting_belief.argmax(-1)] = 1.0 - 1e-8
+        err = ece(binary_states, greeting_labels.reshape(-1), 10)
+        logger.info('Greetings Binary ECE: %f' % err)
+
+        for slot in request_belief:
+            p = request_belief[slot].unsqueeze(-1)
+            request_belief[slot] = torch.cat((1 - p, p), -1)
+
+        l2 = l2_acc(request_belief, request_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
+        l2 = l2_acc(request_belief, request_labels, remove_belief=True)
+        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
+
+        for slot in domain_belief:
+            p = domain_belief[slot].unsqueeze(-1)
+            domain_belief[slot] = torch.cat((1 - p, p), -1)
+
+        l2 = l2_acc(domain_belief, domain_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
+        l2 = l2_acc(domain_belief, domain_labels, remove_belief=True)
+        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
+
+        greeting_labels = {'bye': greeting_labels}
+        greeting_belief = {'bye': greeting_belief}
+
+        l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Greeting Accuracy: {l2}')
+        l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False)
+        logger.info(f'Binary Model L2 Norm Greeting Accuracy: {l2}')
+
+
+if __name__ == "__main__":
+    main()
diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py
deleted file mode 100644
index 2fe351b3d5c2af187da58ffcc46e8184013bbcdb..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/do/evaluate.py
+++ /dev/null
@@ -1,296 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Run SetSUMBT Calibration"""
-
-import logging
-import os
-
-import torch
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer)
-
-from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
-from convlab.dst.setsumbt.dataset import unified_format
-from convlab.dst.setsumbt.dataset import ontology as embeddings
-from convlab.dst.setsumbt.utils import get_args, update_args
-from convlab.dst.setsumbt.modeling import evaluation_utils
-from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc
-from convlab.dst.setsumbt.modeling import training
-
-
-# Available model
-MODELS = {
-    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
-    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
-}
-
-
-def main(args=None, config=None):
-    # Get arguments
-    if args is None:
-        args, config = get_args(MODELS)
-
-    if args.model_type in MODELS:
-        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
-    else:
-        raise NameError('NotImplemented')
-
-    # Set up output directory
-    OUTPUT_DIR = args.output_dir
-    args.output_dir = OUTPUT_DIR
-    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
-        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
-
-    # Set pretrained model path to the trained checkpoint
-    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
-    if 'pytorch_model.bin' in paths and 'config.json' in paths:
-        args.model_name_or_path = args.output_dir
-        config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
-        if paths:
-            paths = paths[0]
-            args.model_name_or_path = paths
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-
-    args = update_args(args, config)
-
-    # Create logger
-    global logger
-    logger = logging.getLogger(__name__)
-    logger.setLevel(logging.INFO)
-
-    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
-
-    fh = logging.FileHandler(args.logging_path)
-    fh.setLevel(logging.INFO)
-    fh.setFormatter(formatter)
-    logger.addHandler(fh)
-
-    # Get device
-    if torch.cuda.is_available() and args.n_gpu > 0:
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-        args.n_gpu = 0
-
-    if args.n_gpu == 0:
-        args.fp16 = False
-
-    # Set up model training/evaluation
-    evaluation_utils.set_seed(args)
-
-    # Perform tasks
-    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
-        pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
-        state_labels = pred['state_labels']
-        belief_states = pred['belief_states']
-        if 'request_labels' in pred:
-            request_labels = pred['request_labels']
-            request_probs = pred['request_probs']
-            active_domain_labels = pred['active_domain_labels']
-            active_domain_probs = pred['active_domain_probs']
-            general_act_labels = pred['general_act_labels']
-            general_act_probs = pred['general_act_probs']
-        else:
-            request_probs = None
-        del pred
-    else:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size != args.test_batch_size:
-                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
-        else:
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
-                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                            config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
-            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
-                                                                  'test', args, tokenizer, encoder)
-
-        # Initialise Model
-        model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
-        model = model.to(device)
-
-        training.set_ontology_embeddings(model, test_slots)
-
-        belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader)
-        state_labels = belief_states[1]
-        request_probs = belief_states[2]
-        request_labels = belief_states[3]
-        active_domain_probs = belief_states[4]
-        active_domain_labels = belief_states[5]
-        general_act_probs = belief_states[6]
-        general_act_labels = belief_states[7]
-        belief_states = belief_states[0]
-        out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs,
-               'request_labels': request_labels, 'active_domain_probs': active_domain_probs,
-               'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs,
-               'general_act_labels': general_act_labels}
-        torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
-
-    # Calculate calibration metrics
-    jg = jg_ece(belief_states, state_labels, 10)
-    logger.info('Joint Goal ECE: %f' % jg)
-
-    jg_acc = 0.0
-    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
-    padding = (padding == len(state_labels))
-    padding = padding.reshape(-1)
-    for slot in belief_states:
-        p_ = belief_states[slot]
-        gold = state_labels[slot]
-
-        pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
-        acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad]
-        acc = torch.tensor(acc).float()
-
-        jg_acc += acc
-
-    n_turns = jg_acc.size(0)
-    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
-
-    jg_acc /= n_turns
-
-    logger.info(f'Joint Goal Accuracy: {jg_acc}')
-
-    l2 = l2_acc(belief_states, state_labels, remove_belief=False)
-    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
-    l2 = l2_acc(belief_states, state_labels, remove_belief=True)
-    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
-
-    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
-    padding = (padding == len(state_labels))
-    padding = padding.reshape(-1)
-
-    tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0
-    for slot in belief_states:
-        p_ = belief_states[slot]
-        gold = state_labels[slot].reshape(-1)
-        p_ = p_.reshape(-1, p_.size(-1))
-
-        p_ = p_[~padding].argmax(-1)
-        gold = gold[~padding]
-
-        tp += (p_ == gold)[gold != 0].int().sum().item()
-        fp += (p_ != 0)[gold == 0].int().sum().item()
-        fp += (p_ != gold)[gold != 0].int().sum().item()
-        fp -= (p_ == 0)[gold != 0].int().sum().item()
-        fn += (p_ == 0)[gold != 0].int().sum().item()
-        tn += (p_ == 0)[gold == 0].int().sum().item()
-        n += p_.size(0)
-
-    acc = (tp + tn) / n
-    prec = tp / (tp + fp)
-    rec = tp / (tp + fn)
-    f1 = 2 * (prec * rec) / (prec + rec)
-
-    logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}")
-
-    if request_probs is not None:
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for slot in request_probs:
-            p = request_probs[slot]
-            l = request_labels[slot]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(request_probs)
-        fp /= len(request_probs)
-        fn /= len(request_probs)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Request F1 Score: %f' % f1.item())
-
-        for slot in request_probs:
-            p = request_probs[slot]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            request_probs[slot] = p
-        jg = jg_ece(request_probs, request_labels, 10)
-        logger.info('Request Joint Goal ECE: %f' % jg)
-
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for dom in active_domain_probs:
-            p = active_domain_probs[dom]
-            l = active_domain_labels[dom]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(active_domain_probs)
-        fp /= len(active_domain_probs)
-        fn /= len(active_domain_probs)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Domain F1 Score: %f' % f1.item())
-
-        for dom in active_domain_probs:
-            p = active_domain_probs[dom]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            active_domain_probs[dom] = p
-        jg = jg_ece(active_domain_probs, active_domain_labels, 10)
-        logger.info('Domain Joint Goal ECE: %f' % jg)
-
-        tp = ((general_act_probs.argmax(-1) > 0) *
-              (general_act_labels > 0)).reshape(-1).float().sum()
-        fp = ((general_act_probs.argmax(-1) > 0) *
-              (general_act_labels == 0)).reshape(-1).float().sum()
-        fn = ((general_act_probs.argmax(-1) == 0) *
-              (general_act_labels > 0)).reshape(-1).float().sum()
-        f1 = tp / (tp + 0.5 * (fp + fn))
-        logger.info('General Act F1 Score: %f' % f1.item())
-
-        err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)),
-                  general_act_labels.reshape(-1), 10)
-        logger.info('General Act ECE: %f' % err)
-
-        for slot in request_probs:
-            p = request_probs[slot].unsqueeze(-1)
-            request_probs[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(request_probs, request_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
-        l2 = l2_acc(request_probs, request_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
-
-        for slot in active_domain_probs:
-            p = active_domain_probs[slot].unsqueeze(-1)
-            active_domain_probs[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
-        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
-
-        general_act_labels = {'general': general_act_labels}
-        general_act_probs = {'general': general_act_probs}
-
-        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm General Act Accuracy: {l2}')
-        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
-        logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}')
-
-
-if __name__ == "__main__":
-    main()
diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py
index e3b4f7fff907fc9bf4880bcc16250f075bee16ca..821dca598c814240f39e359cedec7ef795a341b5 100644
--- a/convlab/dst/setsumbt/do/nbt.py
+++ b/convlab/dst/setsumbt/do/nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,23 +16,33 @@
 """Run SetSUMBT training/eval"""
 
 import logging
+import random
 import os
 from shutil import copy2 as copy
-import json
 
 import torch
+from torch.nn import DataParallel
 from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer)
+                          RobertaModel, RobertaConfig, RobertaTokenizer,
+                          AdamW, get_linear_schedule_with_warmup)
+from tqdm import tqdm, trange
+import numpy as np
 from tensorboardX import SummaryWriter
 
-from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
-from convlab.dst.setsumbt.dataset import unified_format
+from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
+from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
+from convlab.dst.setsumbt.multiwoz import multiwoz21
 from convlab.dst.setsumbt.modeling import training
-from convlab.dst.setsumbt.dataset import ontology as embeddings
+from convlab.dst.setsumbt.multiwoz import ontology as embeddings
 from convlab.dst.setsumbt.utils import get_args, update_args
+from convlab.dst.setsumbt.modeling import ensemble_utils
 
 
-# Available model
+# Datasets
+DATASETS = {
+    'multiwoz21': multiwoz21
+}
+
 MODELS = {
     'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
     'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
@@ -44,6 +54,12 @@ def main(args=None, config=None):
     if args is None:
         args, config = get_args(MODELS)
 
+    # Select Dataset object
+    if args.dataset in DATASETS:
+        Dataset = DATASETS[args.dataset]
+    else:
+        raise NameError('NotImplemented')
+
     if args.model_type in MODELS:
         SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
     else:
@@ -58,19 +74,53 @@ def main(args=None, config=None):
     args.output_dir = OUTPUT_DIR
 
     # Set pretrained model path to the trained checkpoint
-    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
-    if 'pytorch_model.bin' in paths and 'config.json' in paths:
-        args.model_name_or_path = args.output_dir
-        config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
+    if args.do_train:
+        paths = os.listdir(args.output_dir) if os.path.exists(
+            args.output_dir) else []
+        paths = [os.path.join(args.output_dir, p)
+                 for p in paths if 'checkpoint-' in p]
         if paths:
             paths = paths[0]
             args.model_name_or_path = paths
             config = ConfigClass.from_pretrained(args.model_name_or_path)
+        else:
+            paths = os.listdir(args.output_dir) if os.path.exists(
+                args.output_dir) else []
+            if 'pytorch_model.bin' in paths and 'config.json' in paths:
+                args.model_name_or_path = args.output_dir
+                config = ConfigClass.from_pretrained(args.model_name_or_path)
+    else:
+        paths = os.listdir(args.output_dir) if os.path.exists(
+            args.output_dir) else []
+        if 'pytorch_model.bin' in paths and 'config.json' in paths:
+            args.model_name_or_path = args.output_dir
+            config = ConfigClass.from_pretrained(args.model_name_or_path)
+        else:
+            paths = os.listdir(args.output_dir) if os.path.exists(
+                args.output_dir) else []
+            paths = [os.path.join(args.output_dir, p)
+                     for p in paths if 'checkpoint-' in p]
+            if paths:
+                paths = paths[0]
+                args.model_name_or_path = paths
+                config = ConfigClass.from_pretrained(args.model_name_or_path)
 
     args = update_args(args, config)
 
+    # Set up data directory
+    DATA_DIR = args.data_dir
+    Dataset.set_datadir(DATA_DIR)
+    embeddings.set_datadir(DATA_DIR)
+
+    # If use shrinked domains, remove bus and hospital domains from the training data and model ontology
+    if args.shrink_active_domains and args.dataset == 'multiwoz21':
+        Dataset.set_active_domains(
+            ['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
+
+    # Download and preprocess
+    Dataset.create_examples(
+        args.max_turn_len, args.predict_actions, args.force_processing)
+
     # Create TensorboardX writer
     tb_writer = SummaryWriter(logdir=args.tensorboard_path)
 
@@ -79,12 +129,19 @@ def main(args=None, config=None):
     logger = logging.getLogger(__name__)
     logger.setLevel(logging.INFO)
 
-    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
+    formatter = logging.Formatter(
+        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
-    fh = logging.FileHandler(args.logging_path)
-    fh.setLevel(logging.INFO)
-    fh.setFormatter(formatter)
-    logger.addHandler(fh)
+    if 'stream' not in args.logging_path:
+        fh = logging.FileHandler(args.logging_path)
+        fh.setLevel(logging.INFO)
+        fh.setFormatter(formatter)
+        logger.addHandler(fh)
+    else:
+        ch = logging.StreamHandler()
+        ch.setLevel(level=logging.INFO)
+        ch.setFormatter(formatter)
+        logger.addHandler(ch)
 
     # Get device
     if torch.cuda.is_available() and args.n_gpu > 0:
@@ -97,11 +154,14 @@ def main(args=None, config=None):
         args.fp16 = False
 
     # Initialise Model
-    model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
+    model = SetSumbtModel.from_pretrained(
+        args.model_name_or_path, config=config)
     model = model.to(device)
 
     # Create Tokenizer and embedding model for Data Loaders and ontology
-    encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
+    encoder = model.roberta if args.model_type == 'roberta' else None
+    encoder = model.bert if args.model_type == 'bert' else encoder
+
     tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config)
 
     # Set up model training/evaluation
@@ -110,62 +170,39 @@ def main(args=None, config=None):
     embeddings.set_seed(args)
 
     if args.ensemble_size > 1:
+        ensemble_utils.set_logger(logger, tb_writer)
+        ensemble.set_seed(args)
         logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
                                                                            args.data_sampling_size))
-        dataloaders = [unified_format.get_dataloader(args.dataset,
-                                                     'train',
-                                                     args.train_batch_size,
-                                                     tokenizer,
-                                                     args.max_dialogue_len,
-                                                     args.max_turn_len,
-                                                     resampled_size=args.data_sampling_size,
-                                                     train_ratio=args.dataset_train_ratio,
-                                                     seed=args.seed)
-                       for _ in range(args.ensemble_size)]
+        dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset)
         logger.info('Dataloaders built.')
-
         for i, loader in enumerate(dataloaders):
-            path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
+            path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i)
             if not os.path.exists(path):
                 os.mkdir(path)
             path = os.path.join(path, 'train.dataloader')
             torch.save(loader, path)
         logger.info('Dataloaders saved.')
 
-        train_dataloader = unified_format.get_dataloader(args.dataset,
-                                                         'train',
-                                                         args.train_batch_size,
-                                                         tokenizer,
-                                                         args.max_dialogue_len,
-                                                         args.max_turn_len,
-                                                         resampled_size=args.data_sampling_size,
-                                                         train_ratio=args.dataset_train_ratio,
-                                                         seed=args.seed)
-        torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-        dev_dataloader = unified_format.get_dataloader(args.dataset,
-                                                       'validation',
-                                                       args.train_batch_size,
-                                                       tokenizer,
-                                                       args.max_dialogue_len,
-                                                       args.max_turn_len,
-                                                       resampled_size=args.data_sampling_size,
-                                                       train_ratio=args.dataset_train_ratio,
-                                                       seed=args.seed)
-        torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-        test_dataloader = unified_format.get_dataloader(args.dataset,
-                                                        'test',
-                                                        args.train_batch_size,
-                                                        tokenizer,
-                                                        args.max_dialogue_len,
-                                                        args.max_turn_len,
-                                                        resampled_size=args.data_sampling_size,
-                                                        train_ratio=args.dataset_train_ratio,
-                                                        seed=args.seed)
-        torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder)
-        embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder)
-        embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder)
+        train_slots = embeddings.get_slot_candidate_embeddings(
+            'train', args, tokenizer, encoder)
+        dev_slots = embeddings.get_slot_candidate_embeddings(
+            'dev', args, tokenizer, encoder)
+        test_slots = embeddings.get_slot_candidate_embeddings(
+            'test', args, tokenizer, encoder)
+
+        train_dataloader = Dataset.get_dataloader(
+            'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
+        torch.save(dev_dataloader, os.path.join(
+            OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+        dev_dataloader = Dataset.get_dataloader(
+            'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
+        torch.save(dev_dataloader, os.path.join(
+            OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        test_dataloader = Dataset.get_dataloader(
+            'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
+        torch.save(test_dataloader, os.path.join(
+            OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
 
         # Do not perform standard training after ensemble setup is created
         return 0
@@ -173,51 +210,47 @@ def main(args=None, config=None):
     # Perform tasks
     # TRAINING
     if args.do_train:
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
-            train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-            if train_dataloader.batch_size != args.train_batch_size:
-                train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size)
-        else:
-            if args.data_sampling_size <= 0:
-                args.data_sampling_size = None
-            train_dataloader = unified_format.get_dataloader(args.dataset,
-                                                             'train',
-                                                             args.train_batch_size,
-                                                             tokenizer,
-                                                             args.max_dialogue_len,
-                                                             config.max_turn_len,
-                                                             resampled_size=args.data_sampling_size,
-                                                             train_ratio=args.dataset_train_ratio,
-                                                             seed=args.seed)
-            torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-
         # Get training batch loaders and ontology embeddings
         if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
-            train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db'))
+            train_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'train.db'))
         else:
-            train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology,
-                                                                   'train', args, tokenizer, encoder)
+            train_slots = embeddings.get_slot_candidate_embeddings(
+                'train', args, tokenizer, encoder)
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
+            dev_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'dev.db'))
+        else:
+            dev_slots = embeddings.get_slot_candidate_embeddings(
+                'dev', args, tokenizer, encoder)
+
+        exists = False
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
+            train_dataloader = torch.load(os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            if train_dataloader.batch_size == args.train_batch_size:
+                exists = True
+        if not exists:
+            if args.data_sampling_size <= 0:
+                args.data_sampling_size = None
+            train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
+                                                      config.max_turn_len, resampled_size=args.data_sampling_size)
+            torch.save(train_dataloader, os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
 
         # Get development set batch loaders= and ontology embeddings
         if args.do_eval:
+            exists = False
             if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-                dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-                if dev_dataloader.batch_size != args.dev_batch_size:
-                    dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
-            else:
-                dev_dataloader = unified_format.get_dataloader(args.dataset,
-                                                               'validation',
-                                                               args.dev_batch_size,
-                                                               tokenizer,
-                                                               args.max_dialogue_len,
-                                                               config.max_turn_len)
-                torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-
-            if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-                dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
-            else:
-                dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
-                                                                     'dev', args, tokenizer, encoder)
+                dev_dataloader = torch.load(os.path.join(
+                    OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+                if dev_dataloader.batch_size == args.dev_batch_size:
+                    exists = True
+            if not exists:
+                dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
+                                                        config.max_turn_len)
+                torch.save(dev_dataloader, os.path.join(
+                    OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
         else:
             dev_dataloader = None
             dev_slots = None
@@ -226,80 +259,94 @@ def main(args=None, config=None):
         training.set_ontology_embeddings(model, train_slots)
 
         # TRAINING !!!!!!!!!!!!!!!!!!
-        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots)
+        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots,
+                       embeddings=embeddings, tokenizer=tokenizer)
 
         # Copy final best model to the output dir
         checkpoints = os.listdir(OUTPUT_DIR)
         checkpoints = [p for p in checkpoints if 'checkpoint' in p]
         checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
-        best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
-        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
-        copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json'))
+        best_checkpoint = checkpoints[-1]
+        best_checkpoint = os.path.join(
+            OUTPUT_DIR, f'checkpoint-{best_checkpoint}')
+        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'),
+             os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
+        copy(os.path.join(best_checkpoint, 'config.json'),
+             os.path.join(OUTPUT_DIR, 'config.json'))
 
         # Load best model for evaluation
-        model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
+        model = SumbtModel.from_pretrained(OUTPUT_DIR)
         model = model.to(device)
 
     # Evaluation on the development set
     if args.do_eval:
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-            dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-            if dev_dataloader.batch_size != args.dev_batch_size:
-                dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
-        else:
-            dev_dataloader = unified_format.get_dataloader(args.dataset,
-                                                           'validation',
-                                                           args.dev_batch_size,
-                                                           tokenizer,
-                                                           args.max_dialogue_len,
-                                                           config.max_turn_len)
-            torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-
+        # Get development set batch loaders= and ontology embeddings
         if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-            dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
+            dev_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'dev.db'))
         else:
-            dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
-                                                                 'dev', args, tokenizer, encoder)
+            dev_slots = embeddings.get_slot_candidate_embeddings(
+                'dev', args, tokenizer, encoder)
+
+        exists = False
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
+            dev_dataloader = torch.load(os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+            if dev_dataloader.batch_size == args.dev_batch_size:
+                exists = True
+        if not exists:
+            dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
+                                                    config.max_turn_len)
+            torch.save(dev_dataloader, os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
 
         # Load model ontology
         training.set_ontology_embeddings(model, dev_slots)
 
         # EVALUATION
-        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader)
-        training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
+        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
+            args, model, device, dev_dataloader)
+        if req_f1:
+            logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
+                        % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
+        else:
+            logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f'
+                        % (loss, jg_acc, sl_acc))
 
     # Evaluation on the test set
     if args.do_test:
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size != args.test_batch_size:
-                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
-        else:
-            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
-                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                            config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
+        # Get test set batch loaders= and ontology embeddings
         if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
+            test_slots = torch.load(os.path.join(
+                OUTPUT_DIR, 'database', 'test.db'))
         else:
-            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
-                                                                  'test', args, tokenizer, encoder)
+            test_slots = embeddings.get_slot_candidate_embeddings(
+                'test', args, tokenizer, encoder)
+
+        exists = False
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size == args.test_batch_size:
+                exists = True
+        if not exists:
+            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                     config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(
+                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
 
         # Load model ontology
         training.set_ontology_embeddings(model, test_slots)
 
         # TESTING
-        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader,
-                                                                                 return_eval_output=True)
-
-        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
-            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
-        writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w')
-        json.dump(output, writer)
-        writer.close()
-
-        training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
+        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
+            args, model, device, test_dataloader)
+        if req_f1:
+            logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
+                        % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
+        else:
+            logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f'
+                        % (loss, jg_acc, sl_acc))
 
     tb_writer.close()
 
diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py
deleted file mode 100644
index 475f7646126ea03b630efcbbc688f86c5a8ec16e..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss
-from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
-from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
-from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss
diff --git a/convlab/dst/setsumbt/loss/bayesian.py b/convlab/dst/setsumbt/loss/bayesian.py
new file mode 100644
index 0000000000000000000000000000000000000000..e52d8d07733383c7f95b6825b4ab5e5e1c7a0977
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/bayesian.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Bayesian Matching Activation and Loss Functions"""
+
+import torch
+from torch import digamma, lgamma
+from torch.nn import Module
+
+
+# Inverse Linear activation function
+def invlinear(x):
+    z = (1.0 / (1.0 - x)) * (x < 0)
+    z += (1.0 + x) * (x >= 0)
+    return z
+
+# Exponential activation function
+def exponential(x):
+    return torch.exp(x)
+
+
+# Dirichlet activation function for the model
+def dirichlet(a):
+    p = exponential(a)
+    repeat_dim = (1,)*(len(p.shape)-1) + (p.size(-1),)
+    p = p / p.sum(-1).unsqueeze(-1).repeat(repeat_dim)
+    return p
+
+
+# Pytorch BayesianMatchingLoss nn.Module
+class BayesianMatchingLoss(Module):
+
+    def __init__(self, lamb=0.01, ignore_index=-1):
+        super(BayesianMatchingLoss, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, alpha, labels, prior=None):
+        # Assert input sizes
+        assert alpha.dim() == 2                 # Observations, predictive distribution
+        assert labels.dim() == 1                # Label for each observation
+        assert labels.size(0) == alpha.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if labels.max() <= alpha.size(-1):
+            dimension = alpha.size(-1)
+        else:
+            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
+        
+        # Remove observations with no labels
+        if prior is not None:
+            prior = prior[labels != self.ignore_index]
+        alpha = exponential(alpha[labels != self.ignore_index])
+        labels = labels[labels != self.ignore_index]
+        
+        # Initialise and reshape prior parameters
+        if prior is None:
+            prior = torch.ones(dimension)
+        prior = prior.to(alpha.device)
+
+        # KL divergence term
+        lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1)
+        e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1)))
+        e = ((alpha - prior) * e).sum(-1)
+        kl = lb + e
+        kl *= self.lamb
+        del lb, e, prior
+
+        # Expected log likelihood
+        expected_likelihood = digamma(alpha[range(labels.size(0)), labels]) - digamma(alpha.sum(1))
+        del alpha, labels
+
+        # Apply ELBO loss and mean reduction
+        loss = (kl - expected_likelihood).mean()
+        del kl, expected_likelihood
+
+        return loss
+
+
+# Pytorch BayesianMatchingLoss nn.Module
+class BinaryBayesianMatchingLoss(Module):
+
+    def __init__(self, lamb=0.01, ignore_index=-1):
+        super(BinaryBayesianMatchingLoss, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, alpha, labels, prior=None):
+        # Assert input sizes
+        assert alpha.dim() == 1                 # Observations, predictive distribution
+        assert labels.dim() == 1                # Label for each observation
+        assert labels.size(0) == alpha.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if labels.max() <= 2:
+            dimension = 2
+        else:
+            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
+        
+        # Remove observations with no labels
+        if prior is not None:
+            prior = prior[labels != self.ignore_index]
+        alpha = alpha[labels != self.ignore_index]
+        alpha_sum = 1 + (1 / self.lamb)
+        alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1)
+        alpha = torch.cat((alpha_sum - alpha, alpha), 1)
+        labels = labels[labels != self.ignore_index]
+        
+        # Initialise and reshape prior parameters
+        if prior is None:
+            prior = torch.ones(dimension)
+        prior = prior.to(alpha.device)
+
+        # KL divergence term
+        lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1)
+        e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1)))
+        e = ((alpha - prior) * e).sum(-1)
+        kl = lb + e
+        kl *= self.lamb
+        del lb, e, prior
+
+        # Expected log likelihood
+        expected_likelihood = digamma(alpha[range(labels.size(0)), labels.long()]) - digamma(alpha.sum(1))
+        del alpha, labels
+
+        # Apply ELBO loss and mean reduction
+        loss = (kl - expected_likelihood).mean()
+        del kl, expected_likelihood
+
+        return loss
diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/loss/bayesian_matching.py
deleted file mode 100644
index 3e91444d60afeeb6e2ca54192dd2283810fc5135..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/bayesian_matching.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Bayesian Matching Activation and Loss Functions (see https://arxiv.org/pdf/2002.07965.pdf for details)"""
-
-import torch
-from torch import digamma, lgamma
-from torch.nn import Module
-
-
-class BayesianMatchingLoss(Module):
-    """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
-
-    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
-        """
-        Args:
-            lamb (float): Weighting factor for the KL Divergence component
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-        """
-        super(BayesianMatchingLoss, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
-        """
-        Args:
-            inputs (Tensor): Predictive distribution
-            labels (Tensor): Label indices
-            prior (Tensor): Prior distribution over label classes
-
-        Returns:
-            loss (Tensor): Loss value
-        """
-        # Assert input sizes
-        assert inputs.dim() == 2                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == inputs.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.max() <= inputs.size(-1):
-            dimension = inputs.size(-1)
-        else:
-            raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.')
-        
-        # Remove observations to be ignored in loss calculation
-        if prior is not None:
-            prior = prior[labels != self.ignore_index]
-        inputs = torch.exp(inputs[labels != self.ignore_index])
-        labels = labels[labels != self.ignore_index]
-        
-        # Initialise and reshape prior parameters
-        if prior is None:
-            prior = torch.ones(dimension).to(inputs.device)
-        prior = prior.to(inputs.device)
-
-        # KL divergence term (divergence of predictive distribution from prior over label classes - regularisation term)
-        log_gamma_term = lgamma(inputs.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(inputs)).sum(-1)
-        div_term = digamma(inputs) - digamma(inputs.sum(-1)).unsqueeze(-1).repeat((1, inputs.size(-1)))
-        div_term = ((inputs - prior) * div_term).sum(-1)
-        kl_term = log_gamma_term + div_term
-        kl_term *= self.lamb
-        del log_gamma_term, div_term, prior
-
-        # Expected log likelihood
-        expected_likelihood = digamma(inputs[range(labels.size(0)), labels]) - digamma(inputs.sum(-1))
-        del inputs, labels
-
-        # Apply ELBO loss and mean reduction
-        loss = (kl_term - expected_likelihood).mean()
-        del kl_term, expected_likelihood
-
-        return loss
-
-
-class BinaryBayesianMatchingLoss(BayesianMatchingLoss):
-    """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
-
-    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
-        """
-        Args:
-            lamb (float): Weighting factor for the KL Divergence component
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-        """
-        super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index)
-
-    def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
-        """
-        Args:
-            inputs (Tensor): Predictive distribution
-            labels (Tensor): Label indices
-            prior (Tensor): Prior distribution over label classes
-
-        Returns:
-            loss (Tensor): Loss value
-        """
-        
-        # Create 2D input dirichlet distribution
-        input_sum = 1 + (1 / self.lamb)
-        inputs = (torch.sigmoid(inputs) * input_sum).reshape(-1, 1)
-        inputs = torch.cat((input_sum - inputs, inputs), 1)
-
-        return super().forward(inputs, labels, prior=prior)
diff --git a/convlab/dst/setsumbt/loss/distillation.py b/convlab/dst/setsumbt/loss/distillation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf13f10635376467f3b137adaa83367f26603ef
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/distillation.py
@@ -0,0 +1,201 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Bayesian Matching Activation and Loss Functions"""
+
+import torch
+from torch import lgamma, log
+from torch.nn import Module
+from torch.nn.functional import kl_div
+
+from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss
+
+
+# Pytorch BayesianMatchingLoss nn.Module
+class DistillationKL(Module):
+
+    def __init__(self, lamb=1e-4, ignore_index=-1):
+        super(DistillationKL, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, alpha, labels, temp=1.0):
+        # Assert input sizes
+        assert alpha.dim() == 2                 # Observations, predictive distribution
+        assert labels.dim() == 2                # Label for each observation
+        assert labels.size(0) == alpha.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if labels.size(-1) == alpha.size(-1):
+            dimension = alpha.size(-1)
+        else:
+            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
+        
+        alpha = torch.log(torch.softmax(alpha / temp, -1))
+        ids = torch.where(labels[:, 0] != self.ignore_index)[0]
+        alpha = alpha[ids]
+        labels = labels[ids]
+
+        labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))
+
+        kl = kl_div(alpha, labels, reduction='none').sum(-1).mean()
+        return kl    
+
+
+# Pytorch BayesianMatchingLoss nn.Module
+class BinaryDistillationKL(Module):
+
+    def __init__(self, lamb=1e-4, ignore_index=-1):
+        super(BinaryDistillationKL, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, alpha, labels, temp=0.0):
+        # Assert input sizes
+        assert alpha.dim() == 1                 # Observations, predictive distribution
+        assert labels.dim() == 1                # Label for each observation
+        assert labels.size(0) == alpha.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        # if labels.size(-1) == alpha.size(-1):
+        #     dimension = alpha.size(-1)
+        # else:
+        #     raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
+        
+        alpha = torch.sigmoid(alpha / temp).unsqueeze(-1)
+        ids = torch.where(labels != self.ignore_index)[0]
+        alpha = alpha[ids]
+        labels = labels[ids]
+
+        alpha = torch.log(torch.cat((1 - alpha, alpha), 1))
+        
+        labels = labels.unsqueeze(-1)
+        labels = torch.cat((1 - labels, labels), -1)
+        labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))
+
+        kl = kl_div(alpha, labels, reduction='none').sum(-1).mean()
+        return kl  
+
+
+# def smart_sort(x, permutation):
+#     assert x.dim() == permutation.dim()
+#     if x.dim() == 3:
+#         d1, d2, d3 = x.size()
+#         ret = x[torch.arange(d1).unsqueeze(-1).unsqueeze(-1).repeat((1, d2, d3)).flatten(),
+#                 torch.arange(d2).unsqueeze(0).unsqueeze(-1).repeat((d1, 1, d3)).flatten(),
+#                 permutation.flatten()].view(d1, d2, d3)
+#         return ret
+#     elif x.dim() == 2:
+#         d1, d2 = x.size()
+#         ret = x[torch.arange(d1).unsqueeze(-1).repeat((1, d2)).flatten(),
+#                 permutation.flatten()].view(d1, d2)
+#         return ret
+
+
+# # Pytorch BayesianMatchingLoss nn.Module
+# class DistillationNLL(Module):
+
+#     def __init__(self, lamb=1e-4, ignore_index=-1):
+#         super(DistillationNLL, self).__init__()
+
+#         self.lamb = lamb
+#         self.ignore_index = ignore_index
+#         self.loss_add = BayesianMatchingLoss(lamb=0.001)
+    
+#     def forward(self, alpha, labels, temp=1.0):
+#         # Assert input sizes
+#         assert alpha.dim() == 2                 # Observations, predictive distribution
+#         assert labels.dim() == 3                # Label for each observation
+#         assert labels.size(0) == alpha.size(0)  # Equal number of observation
+
+#         # Confirm predictive distribution dimension
+#         if labels.size(-1) == alpha.size(-1):
+#             dimension = alpha.size(-1)
+#         else:
+#             raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
+        
+#         alpha = torch.exp(alpha / temp)
+#         ids = torch.where(labels[:, 0, 0] != self.ignore_index)[0]
+#         alpha = alpha[ids]
+#         labels = labels[ids]
+
+#         best_labels = labels.mean(-2).argmax(-1)
+#         loss2 = self.loss_add(alpha, best_labels)
+
+#         topn = labels.mean(-2).argsort(-1, descending=True)
+#         n = 10
+#         alpha = smart_sort(alpha, topn)[:, :n]
+#         labels = smart_sort(labels, topn.unsqueeze(-2).repeat((1, labels.size(-2), 1)))
+#         labels = labels[:, :, :n]
+#         labels = labels / labels.sum(-1).unsqueeze(-1).repeat((1, 1, labels.size(-1)))
+
+#         labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))))
+
+#         loss = (alpha - 1) * labels.mean(-2)
+#         # loss = (alpha - 1) * labels
+#         loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) 
+#         loss = -1.0 * loss.mean()
+#         # loss = -1.0 * loss.mean() / alpha.size(-1)
+
+#         return loss      
+
+
+# # Pytorch BayesianMatchingLoss nn.Module
+# class BinaryDistillationNLL(Module):
+
+#     def __init__(self, lamb=1e-4, ignore_index=-1):
+#         super(BinaryDistillationNLL, self).__init__()
+
+#         self.lamb = lamb
+#         self.ignore_index = ignore_index
+    
+#     def forward(self, alpha, labels, temp=0.0):
+#         # Assert input sizes
+#         assert alpha.dim() == 1                 # Observations, predictive distribution
+#         assert labels.dim() == 2                # Label for each observation
+#         assert labels.size(0) == alpha.size(0)  # Equal number of observation
+
+#         # Confirm predictive distribution dimension
+#         # if labels.size(-1) == alpha.size(-1):
+#         #     dimension = alpha.size(-1)
+#         # else:
+#         #     raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
+        
+#         # Remove observations with no labels
+#         ids = torch.where(labels[:, 0] != self.ignore_index)[0]
+#         # alpha_sum = 1 + (1 / self.lamb)
+#         alpha_sum = 10.0
+#         alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1)
+#         alpha = alpha[ids]
+#         labels = labels[ids]
+
+#         if temp != 1.0:
+#             alpha = torch.log(alpha + 1e-4)
+#             alpha = torch.exp(alpha / temp)
+
+#         alpha = torch.cat((alpha_sum - alpha, alpha), 1)
+        
+#         labels = labels.unsqueeze(-1)
+#         labels = torch.cat((1 - labels, labels), -1)
+#         # labels[labels[:, 0, 0] == self.ignore_index] = 1
+#         labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))))
+
+#         loss = (alpha - 1) * labels.mean(-2)
+#         loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1)
+#         loss = -1.0 * loss.mean()
+
+#         return loss    
diff --git a/convlab/dst/setsumbt/loss/uncertainty_measures.py b/convlab/dst/setsumbt/loss/ece.py
similarity index 50%
rename from convlab/dst/setsumbt/loss/uncertainty_measures.py
rename to convlab/dst/setsumbt/loss/ece.py
index 87c89dd31c724cc7d599230c6d4a15faee9b680e..034b9aa0bf5882aea49b08a64d7f93164208b5d9 100644
--- a/convlab/dst/setsumbt/loss/uncertainty_measures.py
+++ b/convlab/dst/setsumbt/loss/ece.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,24 +13,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Uncertainty evaluation metrics for dialogue belief tracking"""
+"""Expected calibration error"""
 
 import torch
 
 
-def fill_bins(n_bins: int, probs: torch.Tensor) -> list:
-    """
-    Function to split observations into bins based on predictive probabilities
-
-    Args:
-        n_bins (int): Number of bins
-        probs (Tensor): Predictive probabilities for the observations
-
-    Returns:
-        bins (list): List of observation ids for each bin
-    """
-    assert probs.dim() == 2
-    probs = probs.max(-1)[0]
+def fill_bins(n_bins, logits):
+    assert logits.dim() == 2
+    logits = logits.max(-1)[0]
 
     step = 1.0 / n_bins
     bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
@@ -38,49 +28,29 @@ def fill_bins(n_bins: int, probs: torch.Tensor) -> list:
     for b in range(n_bins):
         lower, upper = bin_ranges[b], bin_ranges[b + 1]
         if b == 0:
-            ids = torch.where((probs >= lower) * (probs <= upper))[0]
+            ids = torch.where((logits >= lower) * (logits <= upper))[0]
         else:
-            ids = torch.where((probs > lower) * (probs <= upper))[0]
+            ids = torch.where((logits > lower) * (logits <= upper))[0]
         bins.append(ids)
     return bins
 
 
-def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor:
-    """
-    Compute the confidence score within each bin
-
-    Args:
-        bins (list): List of observation ids for each bin
-        probs (Tensor): Predictive probabilities for the observations
-
-    Returns:
-        scores (Tensor): Average confidence score within each bin
-    """
-    probs = probs.max(-1)[0]
+def bin_confidence(bins, logits):
+    logits = logits.max(-1)[0]
 
     scores = []
     for b in bins:
         if b is not None:
-            scores.append(probs[b].mean())
+            l = logits[b]
+            scores.append(l.mean())
         else:
             scores.append(-1)
     scores = torch.tensor(scores)
     return scores
 
 
-def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
-    """
-    Compute the accuracy score for observations in each bin
-
-    Args:
-        bins (list): List of observation ids for each bin
-        probs (Tensor): Predictive probabilities for the observations
-        y_true (Tensor): Labels for the observations
-
-    Returns:
-        acc (Tensor): Accuracies for the observations in each bin
-    """
-    y_pred = probs.argmax(-1)
+def bin_accuracy(bins, logits, y_true):
+    y_pred = logits.argmax(-1)
 
     acc = []
     for b in bins:
@@ -98,24 +68,13 @@ def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch
     return acc
 
 
-def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float:
-    """
-    Expected calibration error calculation
+def ece(logits, y_true, n_bins):
+    bins = fill_bins(n_bins, logits)
 
-    Args:
-        probs (Tensor): Predictive probabilities for the observations
-        y_true (Tensor): Labels for the observations
-        n_bins (int): Number of bins
+    scores = bin_confidence(bins, logits)
+    acc = bin_accuracy(bins, logits, y_true)
 
-    Returns:
-        ece (float): Expected calibration error
-    """
-    bins = fill_bins(n_bins, probs)
-
-    scores = bin_confidence(bins, probs)
-    acc = bin_accuracy(bins, probs, y_true)
-
-    n = probs.size(0)
+    n = logits.size(0)
     bk = torch.tensor([b.size(0) for b in bins])
 
     ece = torch.abs(scores - acc) * bk / n
@@ -125,30 +84,34 @@ def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float:
     return ece
 
 
-def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
-    """
-        Joint goal expected calibration error calculation
-
-        Args:
-            belief_state (dict): Belief state probabilities for the dialogue turns
-            y_true (dict): Labels for the state in dialogue turns
-            n_bins (int): Number of bins
-
-        Returns:
-            ece (float): Joint goal expected calibration error
-        """
-    y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()}
+def jg_ece(logits, y_true, n_bins):
+    y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
     goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
     goal_acc = sum([goal_acc[slot] for slot in goal_acc])
     goal_acc = (goal_acc == len(y_true)).int()
 
-    # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state
-    scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()]
+    scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
     scores = torch.cat(scores, 0).min(0)[0]
 
-    bins = fill_bins(n_bins, scores.unsqueeze(-1))
+    step = 1.0 / n_bins
+    bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
+    bins = []
+    for b in range(n_bins):
+        lower, upper = bin_ranges[b], bin_ranges[b + 1]
+        if b == 0:
+            ids = torch.where((scores >= lower) * (scores <= upper))[0]
+        else:
+            ids = torch.where((scores > lower) * (scores <= upper))[0]
+        bins.append(ids)
 
-    conf = bin_confidence(bins, scores.unsqueeze(-1))
+    conf = []
+    for b in bins:
+        if b is not None:
+            l = scores[b]
+            conf.append(l.mean())
+        else:
+            conf.append(-1)
+    conf = torch.tensor(conf)
 
     slot = [s for s in y_true][0]
     acc = []
@@ -164,7 +127,7 @@ def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
             acc.append(-1)
     acc = torch.tensor(acc)
 
-    n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0)
+    n = logits[slot].reshape(-1, logits[slot].size(-1)).size(0)
     bk = torch.tensor([b.size(0) for b in bins])
 
     ece = torch.abs(conf - acc) * bk / n
@@ -174,22 +137,12 @@ def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
     return ece
 
 
-def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float:
-    """
-    Compute L2 Error of belief state prediction
-
-    Args:
-        belief_state (dict): Belief state probabilities for the dialogue turns
-        labels (dict): Labels for the state in dialogue turns
-        remove_belief (bool): Convert belief state to dialogue state
-
-    Returns:
-        err (float): L2 Error of belief state prediction
-    """
+def l2_acc(belief_state, labels, remove_belief=False):
     # Get ids used for removing padding turns.
     padding = labels[list(labels.keys())[0]].reshape(-1)
     padding = torch.where(padding != -1)[0]
 
+    # l2 = []
     state = []
     labs = []
     for slot, bs in belief_state.items():
@@ -210,8 +163,13 @@ def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> flo
         y = torch.zeros(bs.shape).cuda()
         y[range(y.size(0)), lab] = 1.0
 
+        # err = torch.sqrt(((y - bs) ** 2).sum(-1))
+        # l2.append(err.unsqueeze(-1))
+
         state.append(bs)
         labs.append(y)
+    
+    # err = torch.cat(l2, -1).max(-1)[0]
 
     # Concatenate all slots into a single belief state
     state = torch.cat(state, -1)
diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/loss/endd_loss.py
index c3b2497373fda86a9e4e2b847daf2081555a0952..d84c3f720d4970c520d97aa9e65a12647468d3ae 100644
--- a/convlab/dst/setsumbt/loss/endd_loss.py
+++ b/convlab/dst/setsumbt/loss/endd_loss.py
@@ -1,46 +1,30 @@
 import torch
-from torch.nn import Module
-from torch.nn.functional import kl_div
 
 EPS = torch.finfo(torch.float32).eps
 
-
 @torch.no_grad()
-def compute_mkl(ensemble_mean_probs: torch.Tensor, ensemble_logprobs: torch.Tensor) -> torch.Tensor:
-    """
-    Computing MKL in ensemble.
-
-    Args:
-        ensemble_mean_probs (Tensor): Marginal predictive distribution of the ensemble
-        ensemble_logprobs (Tensor): Log predictive distributions of individual ensemble members
-
-    Returns:
-        mkl (Tensor): MKL
-    """
-    mkl = kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_logprobs),reduction='none')
-    return mkl.sum(-1).mean(1)
-
+def compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs):
+    mkl = torch.nn.functional.kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_probs),
+                                     reduction='none').sum(-1).mean(1)
+    return mkl
 
 @torch.no_grad()
-def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict:
-    """
-    Compute a range of ensemble uncertainty measures
-
-    Args:
-        ensemble_probs (Tensor): Predictive distributions of the ensemble members
-
-    Returns:
-        stats (dict): Dictionary of ensemble uncertainty measures
-    """
+def compute_ensemble_stats(ensemble_logits):
+    # ensemble_probs = torch.softmax(ensemble_logits, dim=-1)
+    # ensemble_mean_probs = ensemble_probs.mean(dim=1)
+    # ensemble_logprobs = torch.log_softmax(ensemble_logits, dim=-1)
+    ensemble_probs = ensemble_logits
     ensemble_mean_probs = ensemble_probs.mean(dim=1)
-    num_classes = ensemble_probs.size(-1)
-    ensemble_logprobs = torch.log(ensemble_probs + (1e-4 / num_classes))
+    num_classes = ensemble_logits.size(-1)
+    ensemble_logprobs = torch.log(ensemble_logits + (1e-4 / num_classes))
 
     entropy_of_expected = torch.distributions.Categorical(probs=ensemble_mean_probs).entropy()
     expected_entropy = torch.distributions.Categorical(probs=ensemble_probs).entropy().mean(dim=1)
     mutual_info = entropy_of_expected - expected_entropy
 
-    mkl = compute_mkl(ensemble_mean_probs, ensemble_logprobs)
+    mkl = compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs)
+
+    # num_classes = ensemble_logits.size(-1)
 
     ensemble_precision = (num_classes - 1) / (2 * mkl.unsqueeze(1) + EPS)
 
@@ -55,226 +39,108 @@ def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict:
     }
     return stats
 
-
-def entropy(probs: torch.Tensor, dim: int = -1) -> torch.Tensor:
-    """
-    Compute entropy in a predictive distribution
-
-    Args:
-        probs (Tensor): Predictive distributions
-        dim (int): Dimension representing the predictive probabilities for a single prediction
-
-    Returns:
-        entropy (Tensor): Entropy
-    """
+def entropy(probs, dim: int = -1):
     return -(probs * (probs + EPS).log()).sum(dim=dim)
 
 
-def compute_dirichlet_uncertainties(dirichlet_params: torch.Tensor,
-                                    precisions: torch.Tensor,
-                                    expected_dirichlet: torch.Tensor) -> tuple:
+def compute_dirichlet_uncertainties(dirichlet_params, precisions, expected_dirichlet):
     """
     Function which computes measures of uncertainty for Dirichlet model.
-
-    Args:
-        dirichlet_params (Tensor): Dirichlet concentration parameters.
-        precisions (Tensor): Dirichlet Precisions
-        expected_dirichlet (Tensor): Probabities of expected categorical under Dirichlet.
-
-    Returns:
-        stats (tuple): Token level uncertainties
+    :param dirichlet_params:  Tensor of size [batch_size, n_classes] of Dirichlet concentration parameters.
+    :param precisions: Tensor of size [batch_size, 1] of Dirichlet Precisions
+    :param expected_dirichlet: Tensor of size [batch_size, n_classes] of probablities of expected categorical under Dirichlet.
+    :return: Tensors of token level uncertainties of size [batch_size]
     """
     batch_size, n_classes = dirichlet_params.size()
 
     entropy_of_expected = entropy(expected_dirichlet)
 
-    expected_entropy = -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))
-    expected_entropy = expected_entropy.sum(dim=-1)
+    expected_entropy = (
+            -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))).sum(dim=-1)
 
-    mutual_information = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS)
-    mutual_information += torch.digamma(precisions + 1 + EPS)
-    mutual_information *= -(expected_dirichlet + EPS)
-    mutual_information = mutual_information.sum(dim=-1)
+    mutual_information = -((expected_dirichlet + EPS) * (
+            torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) + torch.digamma(
+        precisions + 1 + EPS))).sum(dim=-1)
+    # assert torch.allclose(mutual_information, entropy_of_expected - expected_entropy, atol=1e-4, rtol=0)
 
     epkl = (n_classes - 1) / precisions.squeeze(-1)
 
-    mkl = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS)
-    mkl += torch.digamma(precisions + EPS)
-    mkl *= expected_dirichlet
-    mkl = mkl.sum(dim=-1)
-
-    stats = (entropy_of_expected.clamp(min=0), expected_entropy.clamp(min=0), mutual_information.clamp(min=0))
-    stats += (epkl.clamp(min=0), mkl.clamp(min=0))
-
-    return stats
-
-
-def get_dirichlet_parameters(logits: torch.Tensor,
-                             parametrization,
-                             add_to_alphas: float = 0,
-                             dtype=torch.double) -> tuple:
-    """
-    Get dirichlet parameters from model logits
+    mkl = (expected_dirichlet * (
+            torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) + torch.digamma(
+        precisions + EPS))).sum(dim=-1)
 
-    Args:
-        logits (Tensor): Model logits
-        parametrization (function): Mapping from logits to concentration parameters
-        add_to_alphas (float): Addition constant for stability
-        dtype (data type): Data type of the parameters
+    return entropy_of_expected.clamp(min=0), \
+           expected_entropy.clamp(min=0), \
+           mutual_information.clamp(min=0), \
+           epkl.clamp(min=0), \
+           mkl.clamp(min=0)
 
-    Return:
-        params (tuple): Concentration and precision parameters of the model Dirichlet
-    """
+def get_dirichlet_parameters(logits, parametrization, add_to_alphas=0, dtype=torch.double):
     max_val = torch.finfo(dtype).max / logits.size(-1) - 1
     alphas = torch.clip(parametrization(logits.to(dtype=dtype)) + add_to_alphas, max=max_val)
     precision = torch.sum(alphas, dim=-1, dtype=dtype)
     return alphas, precision
 
 
-def logits_to_mutual_info(logits: torch.Tensor) -> torch.Tensor:
-    """
-    Map modfel logits to mutual information of model Dirichlet
-
-    Args:
-        logits (Tensor): Model logits
-
-    Returns:
-        mutual_information (Tensor): Mutual information of the model Dirichlet
-    """
+def logits_to_mutual_info(logits):
     alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0)
 
-    normalized_probs = alphas / precision.unsqueeze(1)
+    unsqueezed_precision = precision.unsqueeze(1)
+    normalized_probs = alphas / unsqueezed_precision
 
-    _, _, mutual_information, _, _ = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs)
+    entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas,
+                                                                                                           unsqueezed_precision,
+                                                                                                           normalized_probs)
+    
+    # Max entropy is log(K) for K classes. Hence relative MI is calculated as MI/log(K)
+    # mutual_information /= torch.log(torch.tensor(logits.size(-1)))
     
     return mutual_information
 
 
-class RKLDirichletMediatorLoss(Module):
-    """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)"""
-
-    def __init__(self,
-                 model_offset: float = 1.0,
-                 target_offset: float = 1,
-                 ignore_index: int = -1,
-                 parameterization=torch.exp):
-        """
-        Args:
-            model_offset (float): Offset of model Dirichlet for stability
-            target_offset (float): Offset of target Dirichlet for stability
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-            parameterization (function): Mapping from logits to concentration parameters
-        """
-        super(RKLDirichletMediatorLoss, self).__init__()
-
-        self.model_offset = model_offset
-        self.target_offset = target_offset
-        self.ignore_index = ignore_index
-        self.parameterization = parameterization
-
-    def logits_to_mutual_info(self, logits: torch.Tensor) -> torch.Tensor:
-        """
-        Map modfel logits to mutual information of model Dirichlet
-
-        Args:
-            logits (Tensor): Model logits
-
-        Returns:
-            mutual_information (Tensor): Mutual information of the model Dirichlet
-        """
-        return logits_to_mutual_info(logits)
-
-    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
-        """
-        Args:
-            logits (Tensor): Model logits
-            targets (Tensor): Ensemble predictive distributions
-
-        Returns:
-            loss (Tensor): RKL dirichlet mediator loss value
-        """
-
-        # Remove padding
-        turns = torch.where(targets[:, 0, 0] != self.ignore_index)[0]
-        logits = logits[turns]
-        targets = targets[turns]
-
-        ensemble_stats = compute_ensemble_stats(targets)
-
-        alphas, precision = get_dirichlet_parameters(logits, self.parametrization, self.model_offset)
-
-        normalized_probs = alphas / precision.unsqueeze(1)
-
-        stats = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs)
-        entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = stats
-
-        stats = {
-            'alpha_min': alphas.min(),
-            'alpha_mean': alphas.mean(),
-            'precision': precision,
-            'entropy_of_expected': entropy_of_expected,
-            'mutual_info': mutual_information,
-            'mkl': mkl,
-        }
-
-        num_classes = alphas.size(-1)
-
-        ensemble_precision = ensemble_stats['precision']
-
-        ensemble_precision += self.target_offset * num_classes
-        ensemble_probs = ensemble_stats['mean_probs']
-
-        expected_kl_term = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)
-        expected_kl_term = -1.0 * torch.sum(ensemble_probs * expected_kl_term, dim=-1)
-        assert torch.isfinite(expected_kl_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype)
-
-        differential_negentropy_term_ = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)
-        differential_negentropy_term_ *= alphas - 1.0
-        differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS)
-        differential_negentropy_term -= torch.sum(differential_negentropy_term_, dim=-1)
-        assert torch.isfinite(differential_negentropy_term).all()
-
-        loss = expected_kl_term - differential_negentropy_term / ensemble_precision.squeeze(-1)
-        assert torch.isfinite(loss).all()
-
-        return torch.mean(loss), stats, ensemble_stats
-
-
-class BinaryRKLDirichletMediatorLoss(RKLDirichletMediatorLoss):
-    """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)"""
-
-    def __init__(self,
-                 model_offset: float = 1.0,
-                 target_offset: float = 1,
-                 ignore_index: int = -1,
-                 parameterization=torch.exp):
-        """
-        Args:
-            model_offset (float): Offset of model Dirichlet for stability
-            target_offset (float): Offset of target Dirichlet for stability
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-            parameterization (function): Mapping from logits to concentration parameters
-        """
-        super(BinaryRKLDirichletMediatorLoss, self).__init__(model_offset, target_offset,
-                                                             ignore_index, parameterization)
-
-    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
-        """
-        Args:
-            logits (Tensor): Model logits
-            targets (Tensor): Ensemble predictive distributions
-
-        Returns:
-            loss (Tensor): RKL dirichlet mediator loss value
-        """
-        # Convert single target probability p to distribution [1-p, p]
-        targets = targets.reshape(-1, targets.size(-1), 1)
-        targets = torch.cat([1 - targets, targets], -1)
-        targets[targets[:, 1] == self.ignore_index] = self.ignore_index
-
-        # Convert input logits into predictive distribution [1-z, z]
-        logits = torch.sigmoid(logits).unsqueeze(1)
-        logits = torch.cat((1 - logits, logits), 1)
-        logits = -1.0 * torch.log((1 / (logits + 1e-8)) - 1)  # Inverse sigmoid
-
-        return super().forward(logits, targets)
+def rkl_dirichlet_mediator_loss(logits, ensemble_stats, model_offset, target_offset, parametrization=torch.exp):
+    turns = torch.where(ensemble_stats[:, 0, 0] != -1)[0]
+    logits = logits[turns]
+    ensemble_stats = ensemble_stats[turns]
+    
+    ensemble_stats = compute_ensemble_stats(ensemble_stats)
+
+    alphas, precision = get_dirichlet_parameters(logits, parametrization, model_offset)
+
+    unsqueezed_precision = precision.unsqueeze(1)
+    normalized_probs = alphas / unsqueezed_precision
+
+    entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas,
+                                                                                                           unsqueezed_precision,
+                                                                                                           normalized_probs)
+
+    stats = {
+        'alpha_min': alphas.min(),
+        'alpha_mean': alphas.mean(),
+        'precision': precision,
+        'entropy_of_expected': entropy_of_expected,
+        'mutual_info': mutual_information,
+        'mkl': mkl,
+    }
+
+    num_classes = alphas.size(-1)
+
+    ensemble_precision = ensemble_stats['precision']
+
+    ensemble_precision += target_offset * num_classes
+    ensemble_probs = ensemble_stats['mean_probs']
+
+    expected_KL_term = -1.0 * torch.sum(ensemble_probs * (torch.digamma(alphas + EPS)
+                                                          - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1)
+    assert torch.isfinite(expected_KL_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype)
+
+    differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) \
+                                   - torch.sum(
+        (alphas - 1) * (torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1)
+    assert torch.isfinite(differential_negentropy_term).all()
+
+    cost = expected_KL_term - differential_negentropy_term / ensemble_precision.squeeze(-1)
+
+    assert torch.isfinite(cost).all()
+    return torch.mean(cost), stats, ensemble_stats
+
diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/loss/kl_distillation.py
deleted file mode 100644
index 9aee234ab68054f2b4a83d6feb5e453384d89e94..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/kl_distillation.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""KL Divergence Ensemble Distillation loss"""
-
-import torch
-from torch.nn import Module
-from torch.nn.functional import kl_div
-
-
-class KLDistillationLoss(Module):
-    """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
-
-    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
-        """
-        Args:
-            lamb (float): Target smoothing parameter
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-        """
-        super(KLDistillationLoss, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
-        """
-        Args:
-            inputs (Tensor): Predictive distribution
-            targets (Tensor): Target distribution (ensemble marginal)
-            temp (float): Temperature scaling coefficient for predictive distribution
-
-        Returns:
-            loss (Tensor): Loss value
-        """
-        # Assert input sizes
-        assert inputs.dim() == 2                  # Observations, predictive distribution
-        assert targets.dim() == 2                # Label for each observation
-        assert targets.size(0) == inputs.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if targets.size(-1) != inputs.size(-1):
-            name_error = f'Target dimension {targets.size(-1)} is not the same as the prediction dimension '
-            name_error += f'{inputs.size(-1)}.'
-            raise NameError(name_error)
-
-        # Remove observations to be ignored in loss calculation
-        inputs = torch.log(torch.softmax(inputs / temp, -1))
-        ids = torch.where(targets[:, 0] != self.ignore_index)[0]
-        inputs = inputs[ids]
-        targets = targets[ids]
-
-        # Target smoothing
-        targets = ((1 - self.lamb) * targets) + (self.lamb / targets.size(-1))
-
-        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BinaryKLDistillationLoss(KLDistillationLoss):
-    """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
-
-    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
-        """
-        Args:
-            lamb (float): Target smoothing parameter
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-        """
-        super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index)
-
-    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
-        """
-        Args:
-            inputs (Tensor): Predictive distribution
-            targets (Tensor): Target distribution (ensemble marginal)
-            temp (float): Temperature scaling coefficient for predictive distribution
-
-        Returns:
-            loss (Tensor): Loss value
-        """
-        # Assert input sizes
-        assert inputs.dim() == 1                 # Observations, predictive distribution
-        assert targets.dim() == 1                # Label for each observation
-        assert targets.size(0) == inputs.size(0)  # Equal number of observation
-        
-        # Convert input and target to 2D binary distribution for KL divergence computation
-        inputs = torch.sigmoid(inputs / temp).unsqueeze(-1)
-        inputs = torch.log(torch.cat((1 - inputs, inputs), 1))
-
-        targets = targets.unsqueeze(-1)
-        targets = torch.cat((1 - targets, targets), -1)
-
-        return super().forward(input, targets, temp)
diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/loss/labelsmoothing.py
index 70cef73170c0f539c3eeab888c50ae571b36c882..8fcc60afd50603cb5c2c84fd698fd11ce7fb7415 100644
--- a/convlab/dst/setsumbt/loss/labelsmoothing.py
+++ b/convlab/dst/setsumbt/loss/labelsmoothing.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Label smoothing loss function"""
+"""Inhibited Softmax Activation and Loss Functions"""
 
 
 import torch
@@ -23,87 +23,66 @@ from torch.nn.functional import kl_div
 
 class LabelSmoothingLoss(Module):
     """
-    Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w)
-    and p_{prob. computed by model}(w).
+    With label smoothing,
+    KL-divergence between q_{smoothed ground truth prob.}(w)
+    and p_{prob. computed by model}(w) is minimized.
     """
-
-    def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module:
-        """
-        Args:
-            label_smoothing (float): Label smoothing constant
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-        """
+    def __init__(self, label_smoothing=0.05, ignore_index=-1):
         super(LabelSmoothingLoss, self).__init__()
 
         assert 0.0 < label_smoothing <= 1.0
         self.ignore_index = ignore_index
         self.label_smoothing = float(label_smoothing)
 
-    def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+    def forward(self, logits, targets):
         """
-        Args:
-            input (Tensor): Predictive distribution
-            labels (Tensor): Label indices
-
-        Returns:
-            loss (Tensor): Loss value
+        output (FloatTensor): batch_size x n_classes
+        target (LongTensor): batch_size
         """
-        # Assert input sizes
-        assert inputs.dim() == 2
-        assert labels.dim() == 1
-        assert self.label_smoothing <= ((inputs.size(-1) - 1) / inputs.size(-1))
-
-        # Confirm predictive distribution dimension
-        if labels.max() <= inputs.size(-1):
-            dimension = inputs.size(-1)
-        else:
-            raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.')
+        assert logits.dim() == 2
+        assert targets.dim() == 1
+        assert self.label_smoothing <= ((logits.size(-1) - 1) / logits.size(-1))
 
-        # Remove observations to be ignored in loss calculation
-        inputs = inputs[labels != self.ignore_index]
-        labels = labels[labels != self.ignore_index]
+        logits = logits[targets != self.ignore_index]
+        targets = targets[targets != self.ignore_index]
 
-        # Create target distribution
-        inputs = torch.log(torch.softmax(inputs, -1))
-        targets = torch.ones(inputs.size()).float().to(inputs.device)
-        targets *= self.label_smoothing / (dimension - 1)
-        targets[range(labels.size(0)), labels] = 1.0 - self.label_smoothing
+        logits = torch.log(torch.softmax(logits, -1))
+        labels = torch.ones(logits.size()).float().to(logits.device)
+        labels *= self.label_smoothing / (logits.size(-1) - 1)
+        labels[range(labels.size(0)), targets] = 1.0 - self.label_smoothing
 
-        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
+        kl = kl_div(logits, labels, reduction='none').sum(-1).mean()
+        del logits, targets, labels
+        return kl
 
 
-class BinaryLabelSmoothingLoss(LabelSmoothingLoss):
+class BinaryLabelSmoothingLoss(Module):
     """
-    Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w)
-    and p_{prob. computed by model}(w).
+    With label smoothing,
+    KL-divergence between q_{smoothed ground truth prob.}(w)
+    and p_{prob. computed by model}(w) is minimized.
     """
+    def __init__(self, label_smoothing=0.05):
+        super(BinaryLabelSmoothingLoss, self).__init__()
 
-    def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module:
-        """
-        Args:
-            label_smoothing (float): Label smoothing constant
-            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
-        """
-        super(BinaryLabelSmoothingLoss, self).__init__(label_smoothing, ignore_index)
+        assert 0.0 < label_smoothing <= 1.0
+        self.label_smoothing = float(label_smoothing)
 
-    def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+    def forward(self, logits, targets):
         """
-        Args:
-            input (Tensor): Predictive distribution
-            labels (Tensor): Label indices
-
-        Returns:
-            loss (Tensor): Loss value
+        output (FloatTensor): batch_size x n_classes
+        target (LongTensor): batch_size
         """
-        # Assert input sizes
-        assert inputs.dim() == 1
-        assert labels.dim() == 1
+        assert logits.dim() == 1
+        assert targets.dim() == 1
         assert self.label_smoothing <= 0.5
 
-        inputs = torch.sigmoid(inputs).reshape(-1, 1)
-        inputs = torch.log(torch.cat((1 - inputs, inputs), 1))
-        targets = torch.ones(inputs.size()).float().to(inputs.device)
-        targets *= self.label_smoothing
-        targets[range(labels.size(0)), labels.long()] = 1.0 - self.label_smoothing
+        logits = torch.sigmoid(logits).reshape(-1, 1)
+        logits = torch.log(torch.cat((1 - logits, logits), 1))
+        labels = torch.ones(logits.size()).float().to(logits.device)
+        labels *= self.label_smoothing
+        labels[range(labels.size(0)), targets.long()] = 1.0 - self.label_smoothing
 
-        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
+        kl = kl_div(logits, labels, reduction='none').sum(-1).mean()
+        del logits, targets
+        return kl
diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py
index 59f1439948421ac365e4602b7800c94d3b8b32dd..011a1a774e2d1a22e46e242d1812549895f2246b 100644
--- a/convlab/dst/setsumbt/modeling/__init__.py
+++ b/convlab/dst/setsumbt/modeling/__init__.py
@@ -1,5 +1,3 @@
 from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
 from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT
-
-from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler
+from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT
diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py
index 6762fb3891b4720c3889d8c0809b8791f3bf7633..8b402b6be09684b27bb73acf17e578bc0e3b4bbd 100644
--- a/convlab/dst/setsumbt/modeling/bert_nbt.py
+++ b/convlab/dst/setsumbt/modeling/bert_nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,10 +16,11 @@
 """BERT SetSUMBT"""
 
 import torch
+import transformers
 from torch.autograd import Variable
 from transformers import BertModel, BertPreTrainedModel
 
-from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
+from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward
 
 
 class BertSetSUMBT(BertPreTrainedModel):
@@ -34,37 +35,59 @@ class BertSetSUMBT(BertPreTrainedModel):
             for p in self.bert.parameters():
                 p.requires_grad = False
 
-        self.setsumbt = SetSUMBTHead(config)
-        self.add_slot_candidates = self.setsumbt.add_slot_candidates
-        self.add_value_candidates = self.setsumbt.add_value_candidates
-
-    def forward(self,
-                input_ids: torch.Tensor,
-                attention_mask: torch.Tensor,
-                token_type_ids: torch.Tensor = None,
-                hidden_state: torch.Tensor = None,
-                state_labels: torch.Tensor = None,
-                request_labels: torch.Tensor = None,
-                active_domain_labels: torch.Tensor = None,
-                general_act_labels: torch.Tensor = None,
-                get_turn_pooled_representation: bool = False,
-                calculate_state_mutual_info: bool = False):
-        """
-        Args:
-            input_ids: Input token ids
-            attention_mask: Input padding mask
-            token_type_ids: Token type indicator
-            hidden_state: Latent internal dialogue belief state
-            state_labels: Dialogue state labels
-            request_labels: User request action labels
-            active_domain_labels: Current active domain labels
-            general_act_labels: General user action labels
-            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
-            calculate_state_mutual_info: Return mutual information in the dialogue state
-
-        Returns:
-            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
-        """
+        _initialise(self, config)
+
+    # Add new slot candidates to the model
+    def add_slot_candidates(self, slot_candidates):
+        """slot_candidates is a list of tuples for each slot.
+        - The tuples contains the slot embedding, informable value embeddings and a request indicator.
+        - If the informable value embeddings is None the slot is not informable
+        - If the request indicator is false the slot is not requestable"""
+        if self.slot_embeddings.size(0) != 0:
+            embeddings = self.slot_embeddings.detach()
+        else:
+            embeddings = torch.zeros(0)
+
+        for slot in slot_candidates:
+            if slot in self.slot_ids:
+                index = self.slot_ids[slot]
+                embeddings[index, :] = slot_candidates[slot][0]
+            else:
+                index = embeddings.size(0)
+                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
+                embeddings = torch.cat((embeddings, emb), 0)
+                self.slot_ids[slot] = index
+                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
+            # Add slot to relevant requestable and informable slot lists
+            if slot_candidates[slot][2]:
+                self.requestable_slot_ids[slot] = index
+            if slot_candidates[slot][1] is not None:
+                self.informable_slot_ids[slot] = index
+            
+            domain = slot.split('-', 1)[0]
+            if domain not in self.domain_ids:
+                self.domain_ids[domain] = []
+            self.domain_ids[domain].append(index)
+            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
+        
+        self.slot_embeddings = Variable(embeddings, requires_grad=False)
+
+
+    # Add new value candidates to the model
+    def add_value_candidates(self, slot, value_candidates, replace=False):
+        embeddings = getattr(self, slot + '_value_embeddings')
+
+        if embeddings.size(0) == 0 or replace:
+            embeddings = value_candidates
+        else:
+            embeddings = torch.cat((embeddings, value_candidates), 0)
+        
+        setattr(self, slot + '_value_embeddings', embeddings)
+
+    
+    def forward(self, input_ids, token_type_ids, attention_mask, hidden_state=None, inform_labels=None,
+                request_labels=None, domain_labels=None, goodbye_labels=None,
+                get_turn_pooled_representation=False, calculate_inform_mutual_info=False):
 
         # Encode Dialogues
         batch_size, dialogue_size, turn_size = input_ids.size()
@@ -80,10 +103,9 @@ class BertSetSUMBT(BertPreTrainedModel):
         turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
 
         if get_turn_pooled_representation:
-            return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
-                                 batch_size, dialogue_size, hidden_state, state_labels,
-                                 request_labels, active_domain_labels, general_act_labels,
-                                 calculate_state_mutual_info) + (bert_output.pooler_output,)
-        return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
-                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
-                             general_act_labels, calculate_state_mutual_info)
+            return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
+                                dialogue_size, turn_size, hidden_state, inform_labels, request_labels, domain_labels,
+                                goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,)
+        return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size,
+                            turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
+                            calculate_inform_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/calibration_utils.py b/convlab/dst/setsumbt/modeling/calibration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8514ac8d259162c5bcc55607bb8356de6d4b47c7
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/calibration_utils.py
@@ -0,0 +1,134 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Discriminative models calibration"""
+
+import random
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+
+# Load logger and tensorboard summary writer
+def set_logger(logger_, tb_writer_):
+    global logger, tb_writer
+    logger = logger_
+    tb_writer = tb_writer_
+
+
+# Set seeds
+def set_seed(args):
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+    logger.info('Seed set to %d.' % args.seed)
+
+
+def get_predictions(args, model, device, dataloader):
+    logger.info("  Num Batches = %d", len(dataloader))
+
+    model.eval()
+    if args.dropout_iterations > 1:
+        model.train()
+    
+    belief_states = {slot: [] for slot in model.informable_slot_ids}
+    request_belief = {slot: [] for slot in model.requestable_slot_ids}
+    domain_belief = {dom: [] for dom in model.domain_ids}
+    greeting_belief = []
+    labels = {slot: [] for slot in model.informable_slot_ids}
+    request_labels = {slot: [] for slot in model.requestable_slot_ids}
+    domain_labels = {dom: [] for dom in model.domain_ids}
+    greeting_labels = []
+    epoch_iterator = tqdm(dataloader, desc="Iteration")
+    for step, batch in enumerate(epoch_iterator):
+        with torch.no_grad():    
+            input_ids = batch['input_ids'].to(device)
+            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+
+            if args.dropout_iterations > 1:
+                p = {slot: [] for slot in model.informable_slot_ids}
+                for _ in range(args.dropout_iterations):
+                    p_, p_req_, p_dom_, p_bye_, _ = model(input_ids=input_ids,
+                                                        token_type_ids=token_type_ids,
+                                                        attention_mask=attention_mask)
+                    for slot in model.informable_slot_ids:
+                        p[slot].append(p_[slot].unsqueeze(0))
+                
+                mu = {slot: torch.cat(p[slot], 0).mean(0) for slot in model.informable_slot_ids}
+                sig = {slot: torch.cat(p[slot], 0).var(0) for slot in model.informable_slot_ids}
+                p = {slot: mu[slot] / torch.sqrt(1 + sig[slot]) for slot in model.informable_slot_ids}
+                p = {slot: normalise(p[slot]) for slot in model.informable_slot_ids}
+            else:
+                p, p_req, p_dom, p_bye, _ = model(input_ids=input_ids,
+                                                token_type_ids=token_type_ids,
+                                                attention_mask=attention_mask)
+            
+            for slot in model.informable_slot_ids:
+                p_ = p[slot]
+                labs = batch['labels-' + slot].to(device)
+                
+                belief_states[slot].append(p_)
+                labels[slot].append(labs)
+            
+            if p_req is not None:
+                for slot in model.requestable_slot_ids:
+                    p_ = p_req[slot]
+                    labs = batch['request-' + slot].to(device)
+
+                    request_belief[slot].append(p_)
+                    request_labels[slot].append(labs)
+                
+                for domain in model.domain_ids:
+                    p_ = p_dom[domain]
+                    labs = batch['active-' + domain].to(device)
+
+                    domain_belief[domain].append(p_)
+                    domain_labels[domain].append(labs)
+                
+                greeting_belief.append(p_bye)
+                greeting_labels.append(batch['goodbye'].to(device))
+    
+    for slot in belief_states:
+        belief_states[slot] = torch.cat(belief_states[slot], 0)
+        labels[slot] = torch.cat(labels[slot], 0)
+    if p_req is not None:
+        for slot in request_belief:
+            request_belief[slot] = torch.cat(request_belief[slot], 0)
+            request_labels[slot] = torch.cat(request_labels[slot], 0)
+        for domain in domain_belief:
+            domain_belief[domain] = torch.cat(domain_belief[domain], 0)
+            domain_labels[domain] = torch.cat(domain_labels[domain], 0)
+        greeting_belief = torch.cat(greeting_belief, 0)
+        greeting_labels = torch.cat(greeting_labels, 0)
+    else:
+        request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6
+
+    return belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels
+
+
+def normalise(p):
+    p_shape = p.size()
+
+    p = p.reshape(-1, p_shape[-1]) + 1e-10
+    p_sum = p.sum(-1).unsqueeze(1).repeat((1, p_shape[-1]))
+    p /= p_sum
+
+    p = p.reshape(p_shape)
+
+    return p
diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
index 0ae03a60b1a754c7305028158adece48d71fee03..9f101d128c6b8c9093c4959834cf5aac35b322a2 100644
--- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py
+++ b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
@@ -18,6 +18,7 @@
 import os
 
 import torch
+import transformers
 from torch.nn import Module
 from transformers import RobertaConfig, BertConfig
 
@@ -28,13 +29,8 @@ MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
 
 
 class EnsembleSetSUMBT(Module):
-    """Ensemble SetSUMBT Model for joint ensemble prediction"""
 
     def __init__(self, config):
-        """
-        Args:
-            config (configuration): Model configuration class
-        """
         super(EnsembleSetSUMBT, self).__init__()
         self.config = config
 
@@ -42,98 +38,75 @@ class EnsembleSetSUMBT(Module):
         model_cls = MODELS[self.config.model_type]
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             setattr(self, attr, model_cls(config))
+    
 
-    def _load(self, path: str):
-        """
-        Load parameters
-        Args:
-            path: Location of model parameters
-        """
+    # Load all ensemble memeber parameters
+    def load(self, path, config=None):
+        if config is None:
+            config = self.config
+        
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             idx = attr.split('_', 1)[-1]
             state_dict = torch.load(os.path.join(path, f'pytorch_model_{idx}.bin'))
             getattr(self, attr).load_state_dict(state_dict)
+    
 
-    def add_slot_candidates(self, slot_candidates: tuple):
-        """
-        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
-        and a request indicator, if the informable value embeddings is None the slot is not informable and if
-        the request indicator is false the slot is not requestable.
-
-        Args:
-            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
-        """
+    # Add new slot candidates to the ensemble members
+    def add_slot_candidates(self, slot_candidates):
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             getattr(self, attr).add_slot_candidates(slot_candidates)
-        self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids
-        self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids
-        self.domain_ids = self.setsumbt.model_0.domain_ids
-
-    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
-        """
-        Add value candidates for a slot
-
-        Args:
-            slot: Slot name
-            value_candidates: Value candidate embeddings
-            replace: If true existing value candidates are replaced
-        """
+        self.requestable_slot_ids = self.model_0.requestable_slot_ids
+        self.informable_slot_ids = self.model_0.informable_slot_ids
+        self.domain_ids = self.model_0.domain_ids
+
+
+    # Add new value candidates to the ensemble members
+    def add_value_candidates(self, slot, value_candidates, replace=False):
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             getattr(self, attr).add_value_candidates(slot, value_candidates, replace)
+        
 
-    def forward(self,
-                input_ids: torch.Tensor,
-                attention_mask: torch.Tensor,
-                token_type_ids: torch.Tensor = None,
-                reduction: str = 'mean') -> tuple:
-        """
-        Args:
-            input_ids: Input token ids
-            attention_mask: Input padding mask
-            token_type_ids: Token type indicator
-            reduction: Reduction of ensemble member predictive distributions (mean, none)
-
-        Returns:
-
-        """
-        belief_state_probs = {slot: [] for slot in self.model_0.informable_slot_ids}
-        request_probs = {slot: [] for slot in self.model_0.requestable_slot_ids}
-        active_domain_probs = {dom: [] for dom in self.model_0.domain_ids}
-        general_act_probs = []
+    # Forward pass of full ensemble
+    def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'):
+        logits, request_logits, domain_logits, goodbye_scores = [], [], [], []
+        logits = {slot: [] for slot in self.model_0.informable_slot_ids}
+        request_logits = {slot: [] for slot in self.model_0.requestable_slot_ids}
+        domain_logits = {dom: [] for dom in self.model_0.domain_ids}
+        goodbye_scores = []
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             # Prediction from each ensemble member
-            b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
+            l, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
                                                 token_type_ids=token_type_ids,
                                                 attention_mask=attention_mask)
-            for slot in belief_state_probs:
-                belief_state_probs[slot].append(b[slot].unsqueeze(-2))
+            for slot in logits:
+                logits[slot].append(l[slot].unsqueeze(-2))
             if self.config.predict_intents:
-                for slot in request_probs:
-                    request_probs[slot].append(r[slot].unsqueeze(-1))
-                for dom in active_domain_probs:
-                    active_domain_probs[dom].append(d[dom].unsqueeze(-1))
-                general_act_probs.append(g.unsqueeze(-2))
+                for slot in request_logits:
+                    request_logits[slot].append(r[slot].unsqueeze(-1))
+                for dom in domain_logits:
+                    domain_logits[dom].append(d[dom].unsqueeze(-1))
+                goodbye_scores.append(g.unsqueeze(-2))
         
-        belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
+        logits = {slot: torch.cat(l, -2) for slot, l in logits.items()}
         if self.config.predict_intents:
-            request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()}
-            active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()}
-            general_act_probs = torch.cat(general_act_probs, -2)
+            request_logits = {slot: torch.cat(l, -1) for slot, l in request_logits.items()}
+            domain_logits = {dom: torch.cat(l, -1) for dom, l in domain_logits.items()}
+            goodbye_scores = torch.cat(goodbye_scores, -2)
         else:
-            request_probs = {}
-            active_domain_probs = {}
-            general_act_probs = torch.tensor(0.0)
+            request_logits = {}
+            domain_logits = {}
+            goodbye_scores = torch.tensor(0.0)
 
         # Apply reduction of ensemble to single posterior
         if reduction == 'mean':
-            belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()}
-            request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()}
-            active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()}
-            general_act_probs = general_act_probs.mean(-2)
+            logits = {slot: l.mean(-2) for slot, l in logits.items()}
+            request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()}
+            domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()}
+            goodbye_scores = goodbye_scores.mean(-2)
         elif reduction != 'none':
             raise(NameError('Not Implemented!'))
 
-        return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _
+        return logits, request_logits, domain_logits, goodbye_scores, _
     
 
     @classmethod
@@ -152,3 +125,88 @@ class EnsembleSetSUMBT(Module):
         model.load(path)
 
         return model
+
+
+class DropoutEnsembleSetSUMBT(Module):
+
+    def __init__(self, config):
+        super(DropoutEnsembleBeliefTracker, self).__init__()
+        self.config = config
+
+        model_cls = MODELS[self.config.model_type]
+        self.model = model_cls(config)
+        self.model.train()
+    
+
+    def load(self, path, config=None):
+        if config is None:
+            config = self.config
+        state_dict = torch.load(os.path.join(path, f'pytorch_model.bin'))
+        self.model.load_state_dict(state_dict)
+    
+
+    # Add new slot candidates to the model
+    def add_slot_candidates(self, slot_candidates):
+        self.model.add_slot_candidates(slot_candidates)
+        self.requestable_slot_ids = self.model.requestable_slot_ids
+        self.informable_slot_ids = self.model.informable_slot_ids
+        self.domain_ids = self.model.domain_ids
+
+
+    # Add new value candidates to the model
+    def add_value_candidates(self, slot, value_candidates, replace=False):
+        self.model.add_value_candidates(slot, value_candidates, replace)
+        
+    
+    def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'):
+
+        input_ids = input_ids.unsqueeze(0).repeat((self.config.ensemble_size, 1, 1, 1))
+        input_ids = input_ids.reshape(-1, input_ids.size(-2), input_ids.size(-1))
+        if attention_mask is not None:
+            attention_mask = attention_mask.unsqueeze(0).repeat((10, 1, 1, 1))
+            attention_mask = attention_mask.reshape(-1, attention_mask.size(-2), attention_mask.size(-1))
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids.unsqueeze(0).repeat((10, 1, 1, 1))
+            token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-2), token_type_ids.size(-1))
+        
+        self.model.train()
+        logits, request_logits, domain_logits, goodbye_scores, _ = self.model(input_ids=input_ids,
+                                                                            attention_mask=attention_mask,
+                                                                            token_type_ids=token_type_ids)
+        
+        logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-2), l.size(-1)).transpose(0, 1).transpose(1, 2)
+                for s, l in logits.items()}
+        request_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2)
+                        for s, l in request_logits.items()}
+        domain_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2)
+                        for s, l in domain_logits.items()}
+        goodbye_scores = goodbye_scores.reshape(self.config.ensemble_size, -1, goodbye_scores.size(-2), goodbye_scores.size(-1))
+        goodbye_scores = goodbye_scores.transpose(0, 1).transpose(1, 2)
+
+        if reduction == 'mean':
+            logits = {slot: l.mean(-2) for slot, l in logits.items()}
+            request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()}
+            domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()}
+            goodbye_scores = goodbye_scores.mean(-2)
+        elif reduction != 'none':
+            raise(NameError('Not Implemented!'))
+
+        return logits, request_logits, domain_logits, goodbye_scores, _
+    
+
+    @classmethod
+    def from_pretrained(cls, path):
+        if not os.path.exists(os.path.join(path, 'config.json')):
+            raise(NameError('Could not find config.json in model path.'))
+        if not os.path.exists(os.path.join(path, 'pytorch_model.bin')):
+            raise(NameError('Could not find a model binary in the model path.'))
+        
+        try:
+            config = RobertaConfig.from_pretrained(path)
+        except:
+            config = BertConfig.from_pretrained(path)
+        
+        model = cls(config)
+        model.load(path)
+
+        return model
diff --git a/convlab/dst/setsumbt/modeling/ensemble_utils.py b/convlab/dst/setsumbt/modeling/ensemble_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f6abf81a4070b9498310adfab93d50f5a692f5
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/ensemble_utils.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Discriminative models calibration"""
+
+import random
+import os
+
+import torch
+import numpy as np
+from torch.distributions import Categorical
+from torch.nn.functional import kl_div
+from torch.nn import Module
+from tqdm import tqdm
+
+
+# Load logger and tensorboard summary writer
+def set_logger(logger_, tb_writer_):
+    global logger, tb_writer
+    logger = logger_
+    tb_writer = tb_writer_
+
+
+# Set seeds
+def set_seed(args):
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+    logger.info('Seed set to %d.' % args.seed)
+
+
+def build_train_loaders(args, tokenizer, dataset):
+    dataloaders = [dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
+                                            args.max_turn_len, resampled_size=args.data_sampling_size)
+                        for _ in range(args.ensemble_size)]
+    return dataloaders
diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py
deleted file mode 100644
index c73d4b6d32a485a2cf2b5948dbd6a9a4d7f346cb..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/evaluation_utils.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Evaluation Utilities"""
-
-import random
-
-import torch
-import numpy as np
-from tqdm import tqdm
-
-
-def set_seed(args):
-    """
-    Set random seeds
-
-    Args:
-        args (Arguments class): Arguments class containing seed and number of gpus to use
-    """
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple:
-    """
-    Get model predictions
-
-    Args:
-        args: Runtime arguments
-        model: SetSUMBT Model
-        device: Torch device
-        dataloader: Dataloader containing eval data
-    """
-    model.eval()
-    
-    belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids}
-    request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
-    active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids}
-    general_act_probs = []
-    state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids}
-    request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
-    active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids}
-    general_act_labels = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration")
-    for step, batch in enumerate(epoch_iterator):
-        with torch.no_grad():    
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-            p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids,
-                                              attention_mask=attention_mask)
-
-            for slot in belief_states:
-                p_ = p[slot]
-                labs = batch['state_labels-' + slot].to(device)
-                
-                belief_states[slot].append(p_)
-                state_labels[slot].append(labs)
-            
-            if p_req is not None:
-                for slot in request_probs:
-                    p_ = p_req[slot]
-                    labs = batch['request_labels-' + slot].to(device)
-
-                    request_probs[slot].append(p_)
-                    request_labels[slot].append(labs)
-                
-                for domain in active_domain_probs:
-                    p_ = p_dom[domain]
-                    labs = batch['active_domain_labels-' + domain].to(device)
-
-                    active_domain_probs[domain].append(p_)
-                    active_domain_labels[domain].append(labs)
-                
-                general_act_probs.append(p_gen)
-                general_act_labels.append(batch['general_act_labels'].to(device))
-    
-    for slot in belief_states:
-        belief_states[slot] = torch.cat(belief_states[slot], 0)
-        state_labels[slot] = torch.cat(state_labels[slot], 0)
-    if p_req is not None:
-        for slot in request_probs:
-            request_probs[slot] = torch.cat(request_probs[slot], 0)
-            request_labels[slot] = torch.cat(request_labels[slot], 0)
-        for domain in active_domain_probs:
-            active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0)
-            active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0)
-        general_act_probs = torch.cat(general_act_probs, 0)
-        general_act_labels = torch.cat(general_act_labels, 0)
-    else:
-        request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4
-        general_act_probs, general_act_labels = [None] * 2
-
-    out = (belief_states, state_labels, request_probs, request_labels)
-    out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels)
-    return out
diff --git a/convlab/dst/setsumbt/modeling/functional.py b/convlab/dst/setsumbt/modeling/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dd083d0da080ca089e81d9ae01e5f0954243f61
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/functional.py
@@ -0,0 +1,456 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SetSUMBT functionals"""
+
+import torch
+import transformers
+from torch.autograd import Variable
+from torch.nn import (MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
+                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
+                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
+from torch.nn.init import (xavier_normal_, constant_)
+
+from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss, BinaryBayesianMatchingLoss, dirichlet
+from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
+from convlab.dst.setsumbt.loss.distillation import DistillationKL, BinaryDistillationKL
+from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info
+
+
+# Default belief tracker model intialisation function
+def _initialise(self, config):
+    # Slot Utterance matching attention
+    self.slot_attention = MultiheadAttention(
+        config.hidden_size, config.slot_attention_heads)
+
+    # Latent context tracker
+    # Initial state prediction
+    if not config.rnn_zero_init and config.nbt_type in ['gru', 'lstm']:
+        self.belief_init = Sequential(Linear(config.hidden_size, config.nbt_hidden_size),
+                                      ReLU(), Dropout(config.dropout_rate))
+
+    # Recurrent context tracker setup
+    if config.nbt_type == 'gru':
+        self.nbt = GRU(input_size=config.hidden_size,
+                       hidden_size=config.nbt_hidden_size,
+                       num_layers=config.nbt_layers,
+                       dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate,
+                       batch_first=True)
+        # Initialise Parameters
+        xavier_normal_(self.nbt.weight_ih_l0)
+        xavier_normal_(self.nbt.weight_hh_l0)
+        constant_(self.nbt.bias_ih_l0, 0.0)
+        constant_(self.nbt.bias_hh_l0, 0.0)
+    elif config.nbt_type == 'lstm':
+        self.nbt = LSTM(input_size=config.hidden_size,
+                        hidden_size=config.nbt_hidden_size,
+                        num_layers=config.nbt_layers,
+                        dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate,
+                        batch_first=True)
+        # Initialise Parameters
+        xavier_normal_(self.nbt.weight_ih_l0)
+        xavier_normal_(self.nbt.weight_hh_l0)
+        constant_(self.nbt.bias_ih_l0, 0.0)
+        constant_(self.nbt.bias_hh_l0, 0.0)
+    else:
+        raise NameError('Not Implemented')
+
+    # Feature decoder and layer norm
+    self.intermediate = Linear(config.nbt_hidden_size, config.hidden_size)
+    self.layer_norm = LayerNorm(config.hidden_size)
+
+    # Dropout
+    self.dropout = Dropout(config.dropout_rate)
+
+    # Set pooler for set similarity model
+    if self.config.set_similarity:
+        # 1D convolutional set pooler
+        if self.config.set_pooling == 'cnn':
+            self.conv_pooler = Conv1d(
+                self.config.hidden_size, self.config.hidden_size, 3)
+        # Deep averaging network set pooler
+        elif self.config.set_pooling == 'dan':
+            self.avg_net = Sequential(Linear(self.config.hidden_size, 2 * self.config.hidden_size), GELU(),
+                                      Linear(2 * self.config.hidden_size, self.config.hidden_size))
+
+    # Model ontology placeholders
+    self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
+    self.slot_ids = dict()
+    self.requestable_slot_ids = dict()
+    self.informable_slot_ids = dict()
+    self.domain_ids = dict()
+
+    # Matching network similarity measure
+    if config.distance_measure == 'cosine':
+        self.distance = CosineSimilarity(dim=-1, eps=1e-8)
+    elif config.distance_measure == 'euclidean':
+        self.distance = PairwiseDistance(p=2.0, eps=1e-06, keepdim=False)
+    else:
+        raise NameError('NotImplemented')
+
+    # Belief state loss function
+    if config.loss_function == 'crossentropy':
+        self.loss = CrossEntropyLoss(ignore_index=-1)
+    elif config.loss_function == 'bayesianmatching':
+        self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+    elif config.loss_function == 'labelsmoothing':
+        self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
+    elif config.loss_function == 'distillation':
+        self.loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
+        self.temp = 1.0
+    elif config.loss_function == 'distribution_distillation':
+        self.loss = rkl_dirichlet_mediator_loss
+        self.temp = 1.0
+    else:
+        raise NameError('NotImplemented')
+
+    # Intent and domain prediction heads
+    if config.predict_actions:
+        self.request_gate = Linear(config.hidden_size, 1)
+        self.goodbye_gate = Linear(config.hidden_size, 3)
+        self.domain_gate = Linear(config.hidden_size, 1)
+
+        # Intent and domain loss function
+        self.request_weight = float(self.config.user_request_loss_weight)
+        self.goodbye_weight = float(self.config.user_general_act_loss_weight)
+        self.domain_weight = float(self.config.active_domain_loss_weight)
+        if config.loss_function == 'crossentropy':
+            self.request_loss = BCEWithLogitsLoss()
+            self.goodbye_loss = CrossEntropyLoss(ignore_index=-1)
+            self.domain_loss = BCEWithLogitsLoss()
+        elif config.loss_function == 'labelsmoothing':
+            self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
+            self.goodbye_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
+            self.domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
+        elif config.loss_function == 'bayesianmatching':
+            self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+            self.goodbye_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+            self.domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+        elif config.loss_function == 'distillation':
+            self.request_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
+            self.goodbye_loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
+            self.domain_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
+
+
+# Default belief tracker forward pass.
+def _nbt_forward(self, turn_embeddings,
+                 turn_pooled_representation,
+                 attention_mask,
+                 batch_size,
+                 dialogue_size,
+                 turn_size,
+                 hidden_state,
+                 inform_labels,
+                 request_labels,
+                 domain_labels,
+                 goodbye_labels,
+                 calculate_inform_mutual_info):
+    hidden_size = turn_embeddings.size(-1)
+    # Initialise loss
+    loss = 0.0
+
+    # Goodbye predictions
+    goodbye_probs = None
+    if self.config.predict_actions:
+        # General action prediction
+        goodbye_scores = self.goodbye_gate(
+            turn_pooled_representation.reshape(batch_size * dialogue_size, hidden_size))
+
+        # Compute loss for general action predictions (weighted loss)
+        if goodbye_labels is not None:
+            if self.config.loss_function == 'distillation':
+                goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-1))
+                loss += self.goodbye_loss(goodbye_scores, goodbye_labels, self.temp) * self.goodbye_weight
+            elif self.config.loss_function == 'distribution_distillation':
+                goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-2), goodbye_labels.size(-1))
+                loss += self.loss(goodbye_scores, goodbye_labels, 1.0, 1.0)[0] * self.goodbye_weight
+            else:
+                goodbye_labels = goodbye_labels.reshape(-1)
+                loss += self.goodbye_loss(goodbye_scores, goodbye_labels) * self.request_weight
+
+        # Compute general action probabilities
+        if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'distillation', 'distribution_distillation']:
+            goodbye_probs = torch.softmax(goodbye_scores, -1).reshape(batch_size, dialogue_size, -1)
+        elif self.config.loss_function in ['bayesianmatching']:
+            goodbye_probs = dirichlet(goodbye_scores.reshape(batch_size, dialogue_size, -1))
+
+    # Slot utterance matching
+    num_slots = self.slot_embeddings.size(0)
+    slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
+    slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)).to(turn_embeddings.device)
+
+    if self.config.set_similarity:
+        # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
+        slot_mask = (slot_embeddings != 0.0).float()
+
+    # Turn embeddings shape [turn_size, batch_size * dialogue_size, 768]
+    turn_embeddings = turn_embeddings.transpose(0, 1)
+    # Compute key padding mask
+    key_padding_mask = (attention_mask[:, :, 0] == 0.0)
+    key_padding_mask[key_padding_mask[:, 0] == True, :] = False
+    # Multi head attention of slot over tokens
+    hidden, _ = self.slot_attention(query=slot_embeddings,
+                                    key=turn_embeddings,
+                                    value=turn_embeddings,
+                                    key_padding_mask=key_padding_mask)  # [num_slots, batch_size * dialogue_size, 768]
+
+    # Set embeddings for all masked tokens to 0
+    attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
+    hidden = hidden * attention_mask
+    if self.config.set_similarity:
+        hidden = hidden * slot_mask
+    # Hidden layer shape [num_dials, num_slots, num_turns, 768]
+    hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
+
+    # Latent context tracking
+    # [batch_size * num_slots, dialogue_size, 768]
+    hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
+
+    if self.config.nbt_type == 'gru':
+        self.nbt.flatten_parameters()
+        if hidden_state is None:
+            if self.config.rnn_zero_init:
+                context = torch.zeros(self.config.nbt_layers, batch_size * slot_embeddings.size(0),
+                                      self.config.nbt_hidden_size)
+                context = context.to(turn_embeddings.device)
+            else:
+                context = self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1))
+        else:
+            context = hidden_state.to(hidden.device)
+
+        # [batch_size, dialogue_size, nbt_hidden_size]
+        belief_embedding, context = self.nbt(hidden, context)
+    elif self.config.nbt_type == 'lstm':
+        self.nbt.flatten_parameters()
+        if self.config.rnn_zero_init:
+            context = (torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size),
+                       torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size))
+            context = (context[0].to(turn_embeddings.device),
+                       context[1].to(turn_embeddings.device))
+        else:
+            context = (self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)),
+                       torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size))
+            context = (context[0], context[1].to(turn_embeddings.device))
+
+        # [batch_size, dialogue_size, nbt_hidden_size]
+        belief_embedding, context = self.nbt(hidden, context)
+
+    # Decode features
+    belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), dialogue_size, -1).transpose(1, 2)
+    if self.config.set_similarity:
+        belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
+                                                    self.config.nbt_hidden_size)
+    # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
+    # Normalisation and regularisation
+    belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
+    belief_embedding = self.dropout(belief_embedding)
+
+    # Pooling of the set of latent context representation
+    if self.config.set_similarity:
+        slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
+        belief_embedding = belief_embedding * slot_mask
+
+        # Apply pooler to latent context sequence
+        if self.config.set_pooling == 'mean':
+            belief_embedding = belief_embedding.sum(-2) / slot_mask.sum(-2)
+            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
+        elif self.config.set_pooling == 'cnn':
+            belief_embedding = belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size).transpose(1, 2)
+            belief_embedding = self.conv_pooler(belief_embedding)
+            # Mean pooling after CNN
+            belief_embedding = belief_embedding.mean(-1).reshape(batch_size, dialogue_size, num_slots, -1)
+        elif self.config.set_pooling == 'dan':
+            # sqrt N reduction
+            belief_embedding = belief_embedding.sum(-2) / torch.sqrt(torch.tensor(slot_mask.sum(-2)))
+            # Deep averaging feature extractor
+            belief_embedding = self.avg_net(belief_embedding)
+            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
+
+    # Perform classification
+    if self.config.predict_actions:
+        # User request prediction
+        request_probs = dict()
+        for slot, slot_id in self.requestable_slot_ids.items():
+            request_scores = self.request_gate(belief_embedding[:, :, slot_id, :])
+
+            # Store output probabilities
+            request_scores = request_scores.reshape(batch_size, dialogue_size)
+            mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
+            batches, dialogues = torch.where(mask == 0.0)
+            # Set request scores to 0.0 for padded turns
+            request_scores[batches, dialogues] = 0.0
+            if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching',
+                                             'distillation', 'distribution_distillation']:
+                request_probs[slot] = torch.sigmoid(request_scores)
+
+            if request_labels is not None:
+                # Compute request gate loss
+                request_scores = request_scores.reshape(-1)
+                if self.config.loss_function == 'distillation':
+                    loss += self.request_loss(request_scores, request_labels[slot].reshape(-1),
+                                              self.temp) * self.request_weight
+                elif self.config.loss_function == 'distribution_distillation':
+                    scores, labs = convert_probs_to_logits(request_scores, request_labels[slot])
+                    loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight
+                else:
+                    labs = request_labels[slot].reshape(-1)
+                    request_scores = request_scores[labs != -1]
+                    labs = labs[labs != -1].float()
+                    loss += self.request_loss(request_scores, labs) * self.request_weight
+
+        # Active domain prediction
+        domain_probs = dict()
+        for domain, slot_ids in self.domain_ids.items():
+            belief = belief_embedding[:, :, slot_ids, :]
+            if len(slot_ids) > 1:
+                # SqrtN reduction across all slots within a domain
+                belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
+            domain_scores = self.domain_gate(belief)
+
+            # Store output probabilities
+            domain_scores = domain_scores.reshape(batch_size, dialogue_size)
+            mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
+            batches, dialogues = torch.where(mask == 0.0)
+            domain_scores[batches, dialogues] = 0.0
+            if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', 'distillation',
+                                             'distribution_distillation']:
+                domain_probs[domain] = torch.sigmoid(domain_scores)
+
+            if domain_labels is not None:
+                # Compute domain prediction loss
+                domain_scores = domain_scores.reshape(-1)
+                if self.config.loss_function == 'distillation':
+                    loss += self.domain_loss(domain_scores, domain_labels[domain].reshape(-1),
+                                             self.temp) * self.domain_weight
+                elif self.config.loss_function == 'distribution_distillation':
+                    scores, labs = convert_probs_to_logits(domain_scores, domain_labels[domain])
+                    loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight
+                else:
+                    labs = domain_labels[domain].reshape(-1)
+                    domain_scores = domain_scores[labs != -1]
+                    labs = labs[labs != -1].float()
+                    loss += self.domain_loss(domain_scores, labs) * self.domain_weight
+    else:
+        request_probs, domain_probs = None, None
+
+    # Informable slot predictions
+    inform_probs = dict()
+    out_dict = dict()
+    mutual_info = dict()
+    stats = dict()
+    for slot, slot_id in self.informable_slot_ids.items():
+        # Get slot belief embedding and value candidates
+        candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
+        belief = belief_embedding[:, :, slot_id, :]
+        slot_size = candidate_embeddings.size(0)
+
+        # Use similaroty matching to produce belief state
+        if self.config.distance_measure in ['cosine', 'euclidean']:
+            belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
+            belief = belief.reshape(-1, self.config.hidden_size)
+
+            # Pooling of set of value candidate description representation
+            if self.config.set_similarity and self.config.set_pooling == 'mean':
+                candidate_mask = (candidate_embeddings != 0.0).float()
+                candidate_embeddings = candidate_embeddings.sum(1) / candidate_mask.sum(1)
+            elif self.config.set_similarity and self.config.set_pooling == 'cnn':
+                candidate_embeddings = candidate_embeddings.transpose(1, 2)
+                candidate_embeddings = self.conv_pooler(candidate_embeddings).mean(-1)
+            elif self.config.set_similarity and self.config.set_pooling == 'dan':
+                candidate_mask = (candidate_embeddings != 0.0).float()
+                candidate_embeddings = candidate_embeddings.sum(1) / torch.sqrt(torch.tensor(candidate_mask.sum(1)))
+                candidate_embeddings = self.avg_net(candidate_embeddings)
+
+            candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
+                                                                                          dialogue_size, 1, 1))
+            candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
+
+        # Score value candidates
+        if self.config.distance_measure == 'cosine':
+            scores = self.distance(belief, candidate_embeddings)
+            # *27 here rescales the cosine similarity for better learning
+            scores = scores.reshape(batch_size * dialogue_size, -1) * 27.0
+        elif self.config.distance_measure == 'euclidean':
+            scores = -1.0 * self.distance(belief, candidate_embeddings)
+            scores = scores.reshape(batch_size * dialogue_size, -1)
+
+        # Calculate belief state
+        if self.config.loss_function in ['crossentropy', 'inhibitedce',
+                                         'labelsmoothing', 'distillation', 'distribution_distillation']:
+            probs_ = torch.softmax(scores.reshape(batch_size, dialogue_size, -1), -1)
+        elif self.config.loss_function in ['bayesianmatching']:
+            probs_ = dirichlet(scores.reshape(batch_size, dialogue_size, -1))
+
+        # Compute knowledge uncertainty in the beleif states
+        if calculate_inform_mutual_info and self.config.loss_function == 'distribution_distillation':
+            mutual_info[slot] = logits_to_mutual_info(scores).reshape(batch_size, dialogue_size)
+
+        # Set padded turn probabilities to zero
+        mask = attention_mask[self.slot_ids[slot],:, 0].reshape(batch_size, dialogue_size)
+        batches, dialogues = torch.where(mask == 0.0)
+        probs_[batches, dialogues, :] = 0.0
+        inform_probs[slot] = probs_
+
+        # Calculate belief state loss
+        if inform_labels is not None and slot in inform_labels:
+            if self.config.loss_function == 'bayesianmatching':
+                prior = torch.ones(scores.size(-1)).float().to(scores.device)
+                prior = prior * self.config.prior_constant
+                prior = prior.unsqueeze(0).repeat((scores.size(0), 1))
+
+                loss += self.loss(scores, inform_labels[slot].reshape(-1), prior=prior)
+            elif self.config.loss_function == 'distillation':
+                labels = inform_labels[slot]
+                labels = labels.reshape(-1, labels.size(-1))
+                loss += self.loss(scores, labels, self.temp)
+            elif self.config.loss_function == 'distribution_distillation':
+                labels = inform_labels[slot]
+                labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
+                loss_, model_stats, ensemble_stats = self.loss(scores, labels, 1.0, 1.0)
+                loss += loss_
+
+                # Calculate stats regarding model precisions
+                precision = model_stats['precision']
+                ensemble_precision = ensemble_stats['precision']
+                stats[slot] = {'model_precision_min': precision.min(),
+                               'model_precision_max': precision.max(),
+                               'model_precision_mean': precision.mean(),
+                               'ensemble_precision_min': ensemble_precision.min(),
+                               'ensemble_precision_max': ensemble_precision.max(),
+                               'ensemble_precision_mean': ensemble_precision.mean()}
+            else:
+                loss += self.loss(scores, inform_labels[slot].reshape(-1))
+
+    # Return model outputs
+    out = inform_probs, request_probs, domain_probs, goodbye_probs, context
+    if inform_labels is not None or request_labels is not None or domain_labels is not None or goodbye_labels is not None:
+        out = (loss,) + out + (stats,)
+    if calculate_inform_mutual_info:
+        out = out + (mutual_info,)
+    return out
+
+
+# Convert binary scores and labels to 2 class classification problem for distribution distillation
+def convert_probs_to_logits(scores, labels):
+    # Convert single target probability p to distribution [1-p, p]
+    labels = labels.reshape(-1, labels.size(-1), 1)
+    labels = torch.cat([1 - labels, labels], -1)
+
+    # Convert input scores into predictive distribution [1-z, z]
+    scores = torch.sigmoid(scores).unsqueeze(1)
+    scores = torch.cat((1 - scores, scores), 1)
+    scores = -1.0 * torch.log((1 / (scores + 1e-8)) - 1)  # Inverse sigmoid
+
+    return scores, labels
diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py
index f72d17fafa50553434b6d4dcd20b8e53d143892f..36920c5ca550a3295f31aaca53c2ebed8c22be37 100644
--- a/convlab/dst/setsumbt/modeling/roberta_nbt.py
+++ b/convlab/dst/setsumbt/modeling/roberta_nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,19 +16,16 @@
 """RoBERTa SetSUMBT"""
 
 import torch
+import transformers
+from torch.autograd import Variable
 from transformers import RobertaModel, RobertaPreTrainedModel
 
-from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
+from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward
 
 
 class RobertaSetSUMBT(RobertaPreTrainedModel):
-    """Roberta based SetSUMBT model"""
 
     def __init__(self, config):
-        """
-        Args:
-            config (configuration): Model configuration class
-        """
         super(RobertaSetSUMBT, self).__init__(config)
         self.config = config
 
@@ -38,37 +35,60 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
             for p in self.roberta.parameters():
                 p.requires_grad = False
 
-        self.setsumbt = SetSUMBTHead(config)
-        self.add_slot_candidates = self.setsumbt.add_slot_candidates
-        self.add_value_candidates = self.setsumbt.add_value_candidates
+        _initialise(self, config)
     
-    def forward(self,
-                input_ids: torch.Tensor,
-                attention_mask: torch.Tensor,
-                token_type_ids: torch.Tensor = None,
-                hidden_state: torch.Tensor = None,
-                state_labels: torch.Tensor = None,
-                request_labels: torch.Tensor = None,
-                active_domain_labels: torch.Tensor = None,
-                general_act_labels: torch.Tensor = None,
-                get_turn_pooled_representation: bool = False,
-                calculate_state_mutual_info: bool = False):
-        """
-        Args:
-            input_ids: Input token ids
-            attention_mask: Input padding mask
-            token_type_ids: Token type indicator
-            hidden_state: Latent internal dialogue belief state
-            state_labels: Dialogue state labels
-            request_labels: User request action labels
-            active_domain_labels: Current active domain labels
-            general_act_labels: General user action labels
-            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
-            calculate_state_mutual_info: Return mutual information in the dialogue state
 
-        Returns:
-            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
-        """
+    # Add new slot candidates to the model
+    def add_slot_candidates(self, slot_candidates):
+        """slot_candidates is a list of tuples for each slot.
+        - The tuples contains the slot embedding, informable value embeddings and a request indicator.
+        - If the informable value embeddings is None the slot is not informable
+        - If the request indicator is false the slot is not requestable"""
+        if self.slot_embeddings.size(0) != 0:
+            embeddings = self.slot_embeddings.detach()
+        else:
+            embeddings = torch.zeros(0)
+
+        for slot in slot_candidates:
+            if slot in self.slot_ids:
+                index = self.slot_ids[slot]
+                embeddings[index, :] = slot_candidates[slot][0]
+            else:
+                index = embeddings.size(0)
+                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
+                embeddings = torch.cat((embeddings, emb), 0)
+                self.slot_ids[slot] = index
+                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
+            # Add slot to relevant requestable and informable slot lists
+            if slot_candidates[slot][2]:
+                self.requestable_slot_ids[slot] = index
+            if slot_candidates[slot][1] is not None:
+                self.informable_slot_ids[slot] = index
+            
+            domain = slot.split('-', 1)[0]
+            if domain not in self.domain_ids:
+                self.domain_ids[domain] = []
+            self.domain_ids[domain].append(index)
+            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
+        
+        self.slot_embeddings = Variable(embeddings, requires_grad=False)
+
+
+    # Add new value candidates to the model
+    def add_value_candidates(self, slot, value_candidates, replace=False):
+        embeddings = getattr(self, slot + '_value_embeddings')
+
+        if embeddings.size(0) == 0 or replace:
+            embeddings = value_candidates
+        else:
+            embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
+        
+        setattr(self, slot + '_value_embeddings', embeddings)
+        
+    
+    def forward(self, input_ids, attention_mask, token_type_ids=None, hidden_state=None, inform_labels=None,
+                request_labels=None, domain_labels=None, goodbye_labels=None,
+                get_turn_pooled_representation=False, calculate_inform_mutual_info=False):
         if token_type_ids is not None:
             token_type_ids = None
 
@@ -86,10 +106,9 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
         turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
         
         if get_turn_pooled_representation:
-            return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
-                                 batch_size, dialogue_size, hidden_state, state_labels,
-                                 request_labels, active_domain_labels, general_act_labels,
-                                 calculate_state_mutual_info) + (roberta_output.pooler_output,)
-        return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
-                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
-                             general_act_labels, calculate_state_mutual_info)
+            return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
+                                turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
+                                calculate_inform_mutual_info) + (roberta_output.pooler_output,)
+        return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
+                            turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
+                            calculate_inform_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py
deleted file mode 100644
index 0249649f0840d66b0cec8a65c91aded906f62f85..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/setsumbt.py
+++ /dev/null
@@ -1,564 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""SetSUMBT Prediction Head"""
-
-import torch
-from torch.autograd import Variable
-from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
-                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
-                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
-from torch.nn.init import (xavier_normal_, constant_)
-
-from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss,
-                                       KLDistillationLoss, BinaryKLDistillationLoss,
-                                       LabelSmoothingLoss, BinaryLabelSmoothingLoss,
-                                       RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss)
-
-
-class SlotUtteranceMatching(Module):
-    """Slot Utterance matching attention based information extractor"""
-
-    def __init__(self, hidden_size: int = 768, attention_heads: int = 12):
-        """
-        Args:
-            hidden_size (int): Dimension of token embeddings
-            attention_heads (int): Number of attention heads to use in attention module
-        """
-        super(SlotUtteranceMatching, self).__init__()
-
-        self.attention = MultiheadAttention(hidden_size, attention_heads)
-
-    def forward(self,
-                turn_embeddings: torch.Tensor,
-                attention_mask: torch.Tensor,
-                slot_embeddings: torch.Tensor) -> torch.Tensor:
-        """
-        Args:
-            turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size]
-            attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size]
-            slot_embeddings: Embeddings for each token in the slot descriptions
-
-        Returns:
-            hidden: Information extracted from turn related to slot descriptions
-        """
-        turn_embeddings = turn_embeddings.transpose(0, 1)
-
-        key_padding_mask = (attention_mask[:, :, 0] == 0.0)
-        key_padding_mask[key_padding_mask[:, 0], :] = False
-
-        hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
-                                   key_padding_mask=key_padding_mask)
-
-        attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
-        hidden = hidden * attention_mask
-
-        return hidden
-
-
-class RecurrentNeuralBeliefTracker(Module):
-    """Recurrent latent neural belief tracking module"""
-
-    def __init__(self,
-                 nbt_type: str = 'gru',
-                 rnn_zero_init: bool = False,
-                 input_size: int = 768,
-                 hidden_size: int = 300,
-                 hidden_layers: int = 1,
-                 dropout_rate: float = 0.3):
-        """
-        Args:
-            nbt_type: Type of recurrent neural network (gru/lstm)
-            rnn_zero_init: Use zero initialised state for the RNN
-            input_size: Embedding size of the inputs
-            hidden_size: Hidden size of the RNN
-            hidden_layers: Number of RNN Layers
-            dropout_rate: Dropout rate
-        """
-        super(RecurrentNeuralBeliefTracker, self).__init__()
-
-        if rnn_zero_init:
-            self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate))
-        else:
-            self.belief_init = None
-
-        self.nbt_type = nbt_type
-        self.hidden_layers = hidden_layers
-        self.hidden_size = hidden_size
-        if nbt_type == 'gru':
-            self.nbt = GRU(input_size=input_size,
-                           hidden_size=hidden_size,
-                           num_layers=hidden_layers,
-                           dropout=0.0 if hidden_layers == 1 else dropout_rate,
-                           batch_first=True)
-        elif nbt_type == 'lstm':
-            self.nbt = LSTM(input_size=input_size,
-                            hidden_size=hidden_size,
-                            num_layers=hidden_layers,
-                            dropout=0.0 if hidden_layers == 1 else dropout_rate,
-                            batch_first=True)
-        else:
-            raise NameError('Not Implemented')
-
-        # Initialise Parameters
-        xavier_normal_(self.nbt.weight_ih_l0)
-        xavier_normal_(self.nbt.weight_hh_l0)
-        constant_(self.nbt.bias_ih_l0, 0.0)
-        constant_(self.nbt.bias_hh_l0, 0.0)
-
-        # Intermediate feature mapping and layer normalisation
-        self.intermediate = Linear(hidden_size, input_size)
-        self.layer_norm = LayerNorm(input_size)
-        self.dropout = Dropout(dropout_rate)
-
-    def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor:
-        """
-        Args:
-            inputs: Latent turn level information
-            hidden_state: Latent internal belief state
-
-        Returns:
-            belief_embedding: Belief state embeddings
-            context: Latent internal belief state
-        """
-        self.nbt.flatten_parameters()
-        if hidden_state is None:
-            if self.belief_init is None:
-                context = torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device)
-            else:
-                context = self.belief_init(inputs[:, 0, :]).unsqueeze(0).repeat((self.hidden_layers, 1, 1))
-            if self.nbt_type == "lstm":
-                context = (context, torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device))
-        else:
-            context = hidden_state.to(inputs.device)
-
-        # [batch_size, dialogue_size, nbt_hidden_size]
-        belief_embedding, context = self.nbt(inputs, context)
-
-        # Normalisation and regularisation
-        belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
-        belief_embedding = self.dropout(belief_embedding)
-
-        return belief_embedding, context
-
-
-class SetPooler(Module):
-    """Token set pooler"""
-
-    def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768):
-        """
-        Args:
-            pooling_strategy: Type of set pooler (cnn/dan/mean)
-            hidden_size: Token embedding size
-        """
-        super(SetPooler, self).__init__()
-
-        self.pooling_strategy = pooling_strategy
-        if pooling_strategy == 'cnn':
-            self.cnn_filter_size = 3
-            self.pooler = Conv1d(hidden_size, hidden_size, self.cnn_filter_size)
-        elif pooling_strategy == 'dan':
-            self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size))
-
-    def forward(self, inputs, attention_mask):
-        """
-        Args:
-            inputs: Token set embeddings
-            attention_mask: Padding mask for the set of tokens
-
-        Returns:
-
-        """
-        if self.pooling_strategy == "mean":
-            hidden = inputs.sum(1) / attention_mask.sum(1)
-        elif self.pooling_strategy == "cnn":
-            hidden = self.pooler(inputs.transpose(1, 2)).mean(-1)
-        elif self.pooling_strategy == 'dan':
-            hidden = inputs.sum(1) / torch.sqrt(torch.tensor(attention_mask.sum(1)))
-            hidden = self.pooler(hidden)
-
-        return hidden
-
-
-class SetSUMBTHead(Module):
-    """SetSUMBT Prediction Head for Language Models"""
-
-    def __init__(self, config):
-        """
-        Args:
-            config (configuration): Model configuration class
-        """
-        super(SetSUMBTHead, self).__init__()
-        self.config = config
-        # Slot Utterance matching attention
-        self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads)
-
-        # Latent context tracker
-        self.nbt = RecurrentNeuralBeliefTracker(config.nbt_type, config.rnn_zero_init, config.hidden_size,
-                                                config.nbt_hidden_size, config.nbt_layers, config.dropout_rate)
-
-        # Set pooler for set similarity model
-        if self.config.set_similarity:
-            self.set_pooler = SetPooler(config.set_pooling, config.hidden_size)
-
-        # Model ontology placeholders
-        self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
-        self.slot_ids = dict()
-        self.requestable_slot_ids = dict()
-        self.informable_slot_ids = dict()
-        self.domain_ids = dict()
-
-        # Matching network similarity measure
-        if config.distance_measure == 'cosine':
-            self.distance = CosineSimilarity(dim=-1, eps=1e-8)
-        elif config.distance_measure == 'euclidean':
-            self.distance = PairwiseDistance(p=2.0, eps=1e-6, keepdim=False)
-        else:
-            raise NameError('NotImplemented')
-
-        # User goal prediction loss function
-        if config.loss_function == 'crossentropy':
-            self.loss = CrossEntropyLoss(ignore_index=-1)
-        elif config.loss_function == 'bayesianmatching':
-            self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-        elif config.loss_function == 'labelsmoothing':
-            self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-        elif config.loss_function == 'distillation':
-            self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-            self.temp = 1.0
-        elif config.loss_function == 'distribution_distillation':
-            self.loss = RKLDirichletMediatorLoss(ignore_index=-1)
-        else:
-            raise NameError('NotImplemented')
-
-        # Intent and domain prediction heads
-        if config.predict_actions:
-            self.request_gate = Linear(config.hidden_size, 1)
-            self.general_act_gate = Linear(config.hidden_size, 3)
-            self.active_domain_gate = Linear(config.hidden_size, 1)
-
-            # Intent and domain loss function
-            self.request_weight = float(self.config.user_request_loss_weight)
-            self.general_act_weight = float(self.config.user_general_act_loss_weight)
-            self.active_domain_weight = float(self.config.active_domain_loss_weight)
-            if config.loss_function == 'crossentropy':
-                self.request_loss = BCEWithLogitsLoss()
-                self.general_act_loss = CrossEntropyLoss(ignore_index=-1)
-                self.active_domain_loss = BCEWithLogitsLoss()
-            elif config.loss_function == 'labelsmoothing':
-                self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-                self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-                self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-            elif config.loss_function == 'bayesianmatching':
-                self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-                self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-                self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-            elif config.loss_function == 'distillation':
-                self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-                self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-                self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
-            elif config.loss_function == 'distribution_distillation':
-                self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
-                self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1)
-                self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
-
-    def add_slot_candidates(self, slot_candidates: tuple):
-        """
-        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
-        and a request indicator, if the informable value embeddings is None the slot is not informable and if
-        the request indicator is false the slot is not requestable.
-
-        Args:
-            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
-        """
-        if self.slot_embeddings.size(0) != 0:
-            embeddings = self.slot_embeddings.detach()
-        else:
-            embeddings = torch.zeros(0)
-
-        for slot in slot_candidates:
-            if slot in self.slot_ids:
-                index = self.slot_ids[slot]
-                embeddings[index, :] = slot_candidates[slot][0]
-            else:
-                index = embeddings.size(0)
-                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
-                embeddings = torch.cat((embeddings, emb), 0)
-                self.slot_ids[slot] = index
-                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
-            # Add slot to relevant requestable and informable slot lists
-            if slot_candidates[slot][2]:
-                self.requestable_slot_ids[slot] = index
-            if slot_candidates[slot][1] is not None:
-                self.informable_slot_ids[slot] = index
-
-            domain = slot.split('-', 1)[0]
-            if domain not in self.domain_ids:
-                self.domain_ids[domain] = []
-            self.domain_ids[domain].append(index)
-            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
-
-        self.slot_embeddings = Variable(embeddings, requires_grad=False)
-
-    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
-        """
-        Add value candidates for a slot
-
-        Args:
-            slot: Slot name
-            value_candidates: Value candidate embeddings
-            replace: If true existing value candidates are replaced
-        """
-        embeddings = getattr(self, slot + '_value_embeddings')
-
-        if embeddings.size(0) == 0 or replace:
-            embeddings = value_candidates
-        else:
-            embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
-
-        setattr(self, slot + '_value_embeddings', embeddings)
-
-    def forward(self,
-                turn_embeddings: torch.Tensor,
-                turn_pooled_representation: torch.Tensor,
-                attention_mask: torch.Tensor,
-                batch_size: int,
-                dialogue_size: int,
-                hidden_state: torch.Tensor = None,
-                state_labels: torch.Tensor = None,
-                request_labels: torch.Tensor = None,
-                active_domain_labels: torch.Tensor = None,
-                general_act_labels: torch.Tensor = None,
-                calculate_state_mutual_info: bool = False):
-        """
-        Args:
-            turn_embeddings: Token embeddings in the current turn
-            turn_pooled_representation: Pooled representation of the current dialogue turn
-            attention_mask: Padding mask for the current dialogue turn
-            batch_size: Number of dialogues in the batch
-            dialogue_size: Number of turns in each dialogue
-            hidden_state: Latent internal dialogue belief state
-            state_labels: Dialogue state labels
-            request_labels: User request action labels
-            active_domain_labels: Current active domain labels
-            general_act_labels: General user action labels
-            calculate_state_mutual_info: Return mutual information in the dialogue state
-
-        Returns:
-            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
-        """
-        hidden_size = turn_embeddings.size(-1)
-        # Initialise loss
-        loss = 0.0
-
-        # General Action predictions
-        general_act_probs = None
-        if self.config.predict_actions:
-            # General action prediction
-            general_act_logits = self.general_act_gate(turn_pooled_representation.reshape(batch_size * dialogue_size,
-                                                                                          hidden_size))
-
-            # Compute loss for general action predictions (weighted loss)
-            if general_act_labels is not None:
-                if self.config.loss_function == 'distillation':
-                    general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-1))
-                    loss += self.general_act_loss(general_act_logits, general_act_labels,
-                                                  self.temp) * self.general_act_weight
-                elif self.config.loss_function == 'distribution_distillation':
-                    general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-2),
-                                                                    general_act_labels.size(-1))
-                    loss += self.general_act_loss(general_act_logits, general_act_labels)[0] * self.general_act_weight
-                else:
-                    general_act_labels = general_act_labels.reshape(-1)
-                    loss += self.general_act_loss(general_act_logits, general_act_labels) * self.general_act_weight
-
-            # Compute general action probabilities
-            general_act_probs = torch.softmax(general_act_logits, -1).reshape(batch_size, dialogue_size, -1)
-
-        # Slot utterance matching
-        num_slots = self.slot_embeddings.size(0)
-        slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
-        slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1))
-        slot_embeddings = slot_embeddings.to(turn_embeddings.device)
-
-        if self.config.set_similarity:
-            # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
-            slot_mask = (slot_embeddings != 0.0).float()
-
-        hidden = self.slot_utterance_matching(turn_embeddings, attention_mask, slot_embeddings)
-
-        if self.config.set_similarity:
-            hidden = hidden * slot_mask
-        # Hidden layer shape [num_dials, num_slots, num_turns, 768]
-        hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
-
-        # Latent context tracking
-        # [batch_size * num_slots, dialogue_size, 768]
-        hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
-        belief_embedding, hidden_state = self.nbt(hidden, hidden_state)
-
-        belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0),
-                                                    dialogue_size, -1).transpose(1, 2)
-        if self.config.set_similarity:
-            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
-                                                        self.config.hidden_size)
-        # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
-
-        # Pooling of the set of latent context representation
-        if self.config.set_similarity:
-            slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
-            belief_embedding = belief_embedding * slot_mask
-
-            belief_embedding = self.set_pooler(belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size),
-                                               slot_mask.reshape(-1, slot_mask.size(-2), hidden_size))
-            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
-
-        # Perform classification
-        # Get padded batch, dialogue idx pairs
-        batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0)
-        
-        if self.config.predict_actions:
-            # User request prediction
-            request_probs = dict()
-            for slot, slot_id in self.requestable_slot_ids.items():
-                request_logits = self.request_gate(belief_embedding[:, :, slot_id, :])
-
-                # Store output probabilities
-                request_logits = request_logits.reshape(batch_size, dialogue_size)
-                # Set request scores to 0.0 for padded turns
-                request_logits[batches, dialogues] = 0.0
-                request_probs[slot] = torch.sigmoid(request_logits)
-
-                if request_labels is not None:
-                    # Compute request gate loss
-                    request_logits = request_logits.reshape(-1)
-                    if self.config.loss_function == 'distillation':
-                        loss += self.request_loss(request_logits, request_labels[slot].reshape(-1),
-                                                  self.temp) * self.request_weight
-                    elif self.config.loss_function == 'distribution_distillation':
-                        loss += self.request_loss(request_logits, request_labels[slot])[0] * self.request_weight
-                    else:
-                        labs = request_labels[slot].reshape(-1)
-                        request_logits = request_logits[labs != -1]
-                        labs = labs[labs != -1].float()
-                        loss += self.request_loss(request_logits, labs) * self.request_weight
-
-            # Active domain prediction
-            active_domain_probs = dict()
-            for domain, slot_ids in self.domain_ids.items():
-                belief = belief_embedding[:, :, slot_ids, :]
-                if len(slot_ids) > 1:
-                    # SqrtN reduction across all slots within a domain
-                    belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
-                active_domain_logits = self.active_domain_gate(belief)
-
-                # Store output probabilities
-                active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size)
-                active_domain_logits[batches, dialogues] = 0.0
-                active_domain_probs[domain] = torch.sigmoid(active_domain_logits)
-
-                if active_domain_labels is not None and domain in active_domain_labels:
-                    # Compute domain prediction loss
-                    active_domain_logits = active_domain_logits.reshape(-1)
-                    if self.config.loss_function == 'distillation':
-                        loss += self.active_domain_loss(active_domain_logits, active_domain_labels[domain].reshape(-1),
-                                                        self.temp) * self.active_domain_weight
-                    elif self.config.loss_function == 'distribution_distillation':
-                        loss += self.active_domain_loss(active_domain_logits,
-                                                        active_domain_labels[domain])[0] * self.active_domain_weight
-                    else:
-                        labs = active_domain_labels[domain].reshape(-1)
-                        active_domain_logits = active_domain_logits[labs != -1]
-                        labs = labs[labs != -1].float()
-                        loss += self.active_domain_loss(active_domain_logits, labs) * self.active_domain_weight
-        else:
-            request_probs, active_domain_probs = None, None
-
-        # Dialogue state predictions
-        belief_state_probs = dict()
-        belief_state_mutual_info = dict()
-        belief_state_stats = dict()
-        for slot, slot_id in self.informable_slot_ids.items():
-            # Get slot belief embedding and value candidates
-            candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
-            belief = belief_embedding[:, :, slot_id, :]
-            slot_size = candidate_embeddings.size(0)
-
-            belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
-            belief = belief.reshape(-1, self.config.hidden_size)
-
-            if self.config.set_similarity:
-                candidate_embeddings = self.set_pooler(candidate_embeddings, (candidate_embeddings != 0.0).float())
-            candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
-                                                                                          dialogue_size, 1, 1))
-            candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
-
-            # Score value candidates
-            if self.config.distance_measure == 'cosine':
-                logits = self.distance(belief, candidate_embeddings)
-                # *27 here rescales the cosine similarity for better learning
-                logits = logits.reshape(batch_size * dialogue_size, -1) * 27.0
-            elif self.config.distance_measure == 'euclidean':
-                logits = -1.0 * self.distance(belief, candidate_embeddings)
-                logits = logits.reshape(batch_size * dialogue_size, -1)
-
-            # Calculate belief state
-            probs_ = torch.softmax(logits.reshape(batch_size, dialogue_size, -1), -1)
-
-            # Compute knowledge uncertainty in the beleif states
-            if calculate_state_mutual_info and self.config.loss_function == 'distribution_distillation':
-                belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size)
-
-            # Set padded turn probabilities to zero
-            probs_[batches, dialogues, :] = 0.0
-            belief_state_probs[slot] = probs_
-
-            # Calculate belief state loss
-            if state_labels is not None and slot in state_labels:
-                if self.config.loss_function == 'bayesianmatching':
-                    prior = torch.ones(logits.size(-1)).float().to(logits.device)
-                    prior = prior * self.config.prior_constant
-                    prior = prior.unsqueeze(0).repeat((logits.size(0), 1))
-
-                    loss += self.loss(logits, state_labels[slot].reshape(-1), prior=prior)
-                elif self.config.loss_function == 'distillation':
-                    labels = state_labels[slot]
-                    labels = labels.reshape(-1, labels.size(-1))
-                    loss += self.loss(logits, labels, self.temp)
-                elif self.config.loss_function == 'distribution_distillation':
-                    labels = state_labels[slot]
-                    labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
-                    loss_, model_stats, ensemble_stats = self.loss(logits, labels)
-                    loss += loss_
-
-                    # Calculate stats regarding model precisions
-                    precision = model_stats['precision']
-                    ensemble_precision = ensemble_stats['precision']
-                    belief_state_stats[slot] = {'model_precision_min': precision.min(),
-                                                'model_precision_max': precision.max(),
-                                                'model_precision_mean': precision.mean(),
-                                                'ensemble_precision_min': ensemble_precision.min(),
-                                                'ensemble_precision_max': ensemble_precision.max(),
-                                                'ensemble_precision_mean': ensemble_precision.mean()}
-                else:
-                    loss += self.loss(logits, state_labels[slot].reshape(-1))
-
-        # Return model outputs
-        out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state
-        if state_labels is not None or request_labels is not None:
-            out = (loss,) + out + (belief_state_stats,)
-        if calculate_state_mutual_info:
-            out = out + (belief_state_mutual_info,)
-        return out
diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
index 654e83c5d1ad9dc908213cca8967a84893395b04..fab205befe3350c9beb9d81566c813cf00b55cf2 100644
--- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py
+++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,70 +13,50 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Linear Temperature Scheduler Class"""
-
+"""Temperature Scheduler Class"""
+import torch
 
 # Temp scheduler class for ensemble distillation
-class LinearTemperatureScheduler:
-    """
-    Temperature scheduler object used for distribution temperature scheduling in distillation
+class TemperatureScheduler:
 
-    Attributes:
-        state (dict): Internal state of scheduler
-    """
-    def __init__(self,
-                 total_steps: int,
-                 base_temp: float = 2.5,
-                 cycle_len: float = 0.1):
-        """
-        Args:
-            total_steps (int): Total number of training steps
-            base_temp (float): Starting temperature
-            cycle_len (float): Fraction of total steps used for scheduling cycle
-        """
-        self.state = dict()
+    def __init__(self, total_steps, base_temp=2.5, cycle_len=0.1):
+        self.state = {}
         self.state['total_steps'] = total_steps
         self.state['current_step'] = 0
         self.state['base_temp'] = base_temp
         self.state['current_temp'] = base_temp
         self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)]
-        self.state['rate'] = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
     
     def step(self):
-        """
-        Update temperature based on the schedule
-        """
         self.state['current_step'] += 1
         assert self.state['current_step'] <= self.state['total_steps']
         if self.state['current_step'] > self.state['cycles'][0]:
             if self.state['current_step'] < self.state['cycles'][1]:
-                self.state['current_temp'] -= self.state['rate']
+                rate = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
+                self.state['current_temp'] -= rate
             else:
                 self.state['current_temp'] = 1.0
     
     def temp(self):
-        """
-        Get current temperature
-
-        Returns:
-            temp (float): Current temperature for distribution scaling
-        """
         return float(self.state['current_temp'])
     
     def state_dict(self):
-        """
-        Return scheduler state
-
-        Returns:
-            state (dict): Dictionary format state of the scheduler
-        """
         return self.state
     
-    def load_state_dict(self, state_dict: dict):
-        """
-        Load scheduler state from dictionary
+    def load_state_dict(self, sd):
+        self.state = sd
+
+
+# if __name__ == "__main__":
+#     temp_scheduler = TemperatureScheduler(100)
+#     print(temp_scheduler.state_dict())
+
+#     temp = []
+#     for i in range(100):
+#         temp.append(temp_scheduler.temp())
+#         temp_scheduler.step()
+    
+#     temp_scheduler.load_state_dict(temp_scheduler.state_dict())
+#     print(temp_scheduler.state_dict())
 
-        Args:
-            state_dict (dict): Dictionary format state of the scheduler
-        """
-        self.state = state_dict
+#     print(temp)
diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py
index a898a42a9f7321942d940e246d24b25f7b15eedc..259c6e1da061ad6800f92e9237324181678459cb 100644
--- a/convlab/dst/setsumbt/modeling/training.py
+++ b/convlab/dst/setsumbt/modeling/training.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Training and evaluation utils"""
+"""Training utils"""
 
 import random
 import os
 import logging
-from copy import deepcopy
 
 import torch
 from torch.nn import DataParallel
 from torch.distributions import Categorical
 import numpy as np
-from transformers import get_linear_schedule_with_warmup
-from torch.optim import AdamW
+from transformers import AdamW, get_linear_schedule_with_warmup
 from tqdm import tqdm, trange
 try:
     from apex import amp
@@ -33,7 +31,7 @@ except:
     print('Apex not used')
 
 from convlab.dst.setsumbt.utils import clear_checkpoints
-from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler
+from convlab.dst.setsumbt.modeling.temperature_scheduler import TemperatureScheduler
 
 
 # Load logger and tensorboard summary writer
@@ -61,124 +59,18 @@ def set_ontology_embeddings(model, slots, load_slots=True):
     if load_slots:
         slots = {slot: embs for slot, embs in slots.items()}
         model.add_slot_candidates(slots)
-    for slot in model.setsumbt.informable_slot_ids:
+    for slot in model.informable_slot_ids:
         model.add_value_candidates(slot, values[slot], replace=True)
 
 
-def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None):
-    """
-    Log training statistics.
-
-    Args:
-        global_step: Number of global training steps completed
-        loss: Training loss
-        jg_acc: Joint goal accuracy
-        sl_acc: Slot accuracy
-        req_f1: Request prediction F1 score
-        dom_f1: Active domain prediction F1 score
-        gen_f1: General action prediction F1 score
-        stats: Uncertainty measure statistics of model
-    """
-    info = f"{global_step} steps complete, " if type(global_step) == int else ""
-    if global_step == 'training_complete':
-        info += f"Training Complete"
-        info += f"Loss since last update: {loss}. Validation set stats: "
-    if global_step == 'dev':
-        info += f"Validation set stats: Loss: {loss}, "
-    if global_step == 'test':
-        info += f"Test set stats: Loss: {loss}, "
-    info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, "
-    if req_f1 is not None:
-        info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, "
-        info += f"General Action F1 Score: {gen_f1}"
-    logger.info(info)
-
-    if type(global_step) == int:
-        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
-        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
-        if req_f1 is not None:
-            tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
-            tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step)
-            tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step)
-        tb_writer.add_scalar('Loss/Dev', loss, global_step)
-
-        if stats:
-            for slot, stats_slot in stats.items():
-                for key, item in stats_slot.items():
-                    tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
-
-
-def get_input_dict(batch: dict,
-                   predict_actions: bool,
-                   model_informable_slot_ids: list,
-                   model_requestable_slot_ids: list = None,
-                   model_domain_ids: list = None,
-                   device = 'cpu') -> dict:
-    """
-    Produce model input arguments
-
-    Args:
-        batch: Batch of data from the dataloader
-        predict_actions: Model should predict user actions if set true
-        model_informable_slot_ids: List of model dialogue state slots
-        model_requestable_slot_ids: List of model requestable slots
-        model_domain_ids: List of model domains
-        device: Current torch device in use
-
-    Returns:
-        input_dict: Dictrionary containing model inputs for the batch
-    """
-    input_dict = dict()
-
-    input_dict['input_ids'] = batch['input_ids'].to(device)
-    input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-    input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-    if 'belief_state' in batch:
-        input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device) for slot in model_informable_slot_ids
-                                      if ('belief_state-' + slot) in batch}
-        if predict_actions:
-            input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device)
-                                            for slot in model_requestable_slot_ids
-                                            if ('request_probs-' + slot) in batch}
-            input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device)
-                                                  for domain in model_domain_ids
-                                                  if ('active_domain_probs-' + domain) in batch}
-            input_dict['general_act_labels'] = batch['general_act_probs'].to(device)
-    else:
-        input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device)
-                                      for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch}
-        if predict_actions:
-            input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device)
-                                            for slot in model_requestable_slot_ids
-                                            if ('request_labels-' + slot) in batch}
-            input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device)
-                                                  for domain in model_domain_ids
-                                                  if ('active_domain_labels-' + domain) in batch}
-            input_dict['general_act_labels'] = batch['general_act_labels'].to(device)
-
-    return input_dict
-
-
-def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict):
-    """
-    Train the SetSUMBT model.
-
-    Args:
-        args: Runtime arguments
-        model: SetSUMBT Model instance to train
-        device: Torch device to use during training
-        train_dataloader: Dataloader containing the training data
-        dev_dataloader: Dataloader containing the validation set data
-        slots: Model ontology used for training
-        slots_dev: Model ontology used for evaluating on the validation set
-    """
+def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev, embeddings=None, tokenizer=None):
+    """Train model!"""
 
     # Calculate the total number of training steps to be performed
     if args.max_training_steps > 0:
         t_total = args.max_training_steps
-        args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1
-        args.num_train_epochs = args.max_training_steps // args.num_train_epochs
+        args.num_train_epochs = args.max_training_steps // (
+            (len(train_dataloader) // args.gradient_accumulation_steps) + 1)
     else:
         t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs
 
@@ -196,12 +88,12 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
         {
             "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0,
-            "lr": args.learning_rate
+            "lr":args.learning_rate
         },
     ]
 
     # Initialise the optimizer
-    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
 
     # Initialise linear lr scheduler
     num_warmup_steps = int(t_total * args.warmup_proportion)
@@ -217,7 +109,8 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
 
     # Set up fp16 and multi gpu usage
     if args.fp16:
-        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
+        model, optimizer = amp.initialize(
+            model, optimizer, opt_level=args.fp16_opt_level)
     if args.n_gpu > 1:
         model = DataParallel(model)
 
@@ -225,7 +118,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
     best_model = {'joint goal accuracy': 0.0,
                   'request f1 score': 0.0,
                   'active domain f1 score': 0.0,
-                  'general act f1 score': 0.0,
+                  'goodbye act f1 score': 0.0,
                   'train loss': np.inf}
     if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')):
         logger.info("Optimizer loaded from previous run.")
@@ -243,27 +136,27 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
             model.eval()
             set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-            jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True)
+            jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
 
             # Set model back to training mode
             model.train()
             model.zero_grad()
             set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
         else:
-            jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0
+            jg_acc, req_f1, dom_f1, bye_f1 = 0.0, 0.0, 0.0, 0.0
 
         best_model['joint goal accuracy'] = jg_acc
         best_model['request f1 score'] = req_f1
         best_model['active domain f1 score'] = dom_f1
-        best_model['general act f1 score'] = gen_f1
+        best_model['goodbye act f1 score'] = bye_f1
 
     # Log training set up
-    logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}")
+    logger.info("Device: %s, Number of GPUs: %s, FP16 training: %s" % (device, args.n_gpu, args.fp16))
     logger.info("***** Running training *****")
-    logger.info(f"  Num Batches = {len(train_dataloader)}")
-    logger.info(f"  Num Epochs = {args.num_train_epochs}")
-    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
-    logger.info(f"  Total optimization steps = {t_total}")
+    logger.info("  Num Batches = %d" % len(train_dataloader))
+    logger.info("  Num Epochs = %d" % args.num_train_epochs)
+    logger.info("  Gradient Accumulation steps = %d" % args.gradient_accumulation_steps)
+    logger.info("  Total optimization steps = %d" % t_total)
 
     # Initialise training parameters
     global_step = 0
@@ -280,11 +173,11 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
             steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
 
             logger.info("  Continuing training from checkpoint, will skip to saved global_step")
-            logger.info(f"  Continuing training from epoch {epochs_trained}")
-            logger.info(f"  Continuing training from global step {global_step}")
-            logger.info(f"  Will skip the first {steps_trained_in_current_epoch} steps in the first epoch")
+            logger.info("  Continuing training from epoch %d" % epochs_trained)
+            logger.info("  Continuing training from global step %d" % global_step)
+            logger.info("  Will skip the first %d steps in the first epoch" % steps_trained_in_current_epoch)
         except ValueError:
-            logger.info(f"  Starting fine-tuning.")
+            logger.info("  Starting fine-tuning.")
 
     # Prepare model for training
     tr_loss, logging_loss = 0.0, 0.0
@@ -303,15 +196,43 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
                 continue
 
             # Extract all label dictionaries from the batch
-            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
-                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
+            if 'goodbye_belief' in batch:
+                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
+                          if ('belief-' + slot) in batch}
+                request_labels = {slot: batch['request_belief-' + slot].to(device)
+                                  for slot in model.requestable_slot_ids
+                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
+                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
+                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
+                goodbye_labels = batch['goodbye_belief'].to(
+                    device) if args.predict_actions else None
+            else:
+                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
+                          if ('labels-' + slot) in batch}
+                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
+                                  if ('request-' + slot) in batch} if args.predict_actions else None
+                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
+                                 if ('active-' + domain) in batch} if args.predict_actions else None
+                goodbye_labels = batch['goodbye'].to(
+                    device) if args.predict_actions else None
+
+            # Extract all model inputs from batch
+            input_ids = batch['input_ids'].to(device)
+            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
 
             # Set up temperature scaling for the model
             if temp_scheduler is not None:
-                model.setsumbt.temp = temp_scheduler.temp()
+                model.temp = temp_scheduler.temp()
 
             # Forward pass to obtain loss
-            loss, _, _, _, _, _, stats = model(**input_dict)
+            loss, _, _, _, _, _, stats = model(input_ids=input_ids,
+                                               token_type_ids=token_type_ids,
+                                               attention_mask=attention_mask,
+                                               inform_labels=labels,
+                                               request_labels=request_labels,
+                                               domain_labels=domain_labels,
+                                               goodbye_labels=goodbye_labels)
 
             if args.n_gpu > 1:
                 loss = loss.mean()
@@ -337,6 +258,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
                 tb_writer.add_scalar('LearningRate', lr, global_step)
 
                 if stats:
+                    # print(stats.keys())
                     for slot, stats_slot in stats.items():
                         for key, item in stats_slot.items():
                             tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step)
@@ -351,6 +273,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
 
                 tr_loss += loss.float().item()
                 epoch_iterator.set_postfix(loss=loss.float().item())
+                loss = 0.0
                 global_step += 1
 
             # Save model checkpoint
@@ -363,34 +286,52 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
                     model.eval()
                     set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-                    jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
-                                                                                   is_train=True)
+                    jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
                     # Log model eval information
-                    log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats)
+                    if req_f1 is not None:
+                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f'
+                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
+                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
+                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
+                        tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
+                        tb_writer.add_scalar('DomainF1Score/Dev', dom_f1, global_step)
+                        tb_writer.add_scalar('GoodbyeF1Score/Dev', bye_f1, global_step)
+                    else:
+                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f'
+                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc))
+                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
+                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
+                    tb_writer.add_scalar('Loss/Dev', loss, global_step)
+                    if stats:
+                        for slot, stats_slot in stats.items():
+                            for key, item in stats_slot.items():
+                                tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
 
                     # Set model back to training mode
                     model.train()
                     model.zero_grad()
                     set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
                 else:
-                    log_info(global_step, logging_loss / args.save_steps)
+                    jg_acc, req_f1 = 0.0, None
+                    logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps))
 
                 logging_loss = tr_loss
 
                 # Compute the score of the best model
                 try:
-                    best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
-                    best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
-                    best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
+                    best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \
+                        (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \
+                        (best_model['goodbye act f1 score'] *
+                         model.config.user_general_act_loss_weight)
                 except AttributeError:
                     best_score = 0.0
                 best_score += best_model['joint goal accuracy']
 
                 # Compute the score of the current model
                 try:
-                    current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
-                    current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
-                    current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
+                    current_score = (req_f1 * model.config.user_request_loss_weight) + \
+                        (dom_f1 * model.config.active_domain_loss_weight) + \
+                        (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0
                 except AttributeError:
                     current_score = 0.0
                 current_score += jg_acc
@@ -412,10 +353,10 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
                     if req_f1:
                         best_model['request f1 score'] = req_f1
                         best_model['active domain f1 score'] = dom_f1
-                        best_model['general act f1 score'] = gen_f1
+                        best_model['goodbye act f1 score'] = bye_f1
                     best_model['train loss'] = tr_loss / global_step
 
-                    output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                     if not os.path.exists(output_dir):
                         os.makedirs(output_dir, exist_ok=True)
 
@@ -445,14 +386,14 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
                 epoch_iterator.close()
                 break
 
-        logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}')
+        logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step))
 
         if args.max_training_steps > 0 and global_step > args.max_training_steps:
             train_iterator.close()
             break
         if args.patience > 0 and steps_since_last_update >= args.patience:
             train_iterator.close()
-            logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!')
+            logger.info('Model has not improved for at least %i steps. Training stopped!' % args.patience)
             break
 
     # Evaluate final model
@@ -460,25 +401,30 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
         model.eval()
         set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
-                                                                       is_train=True)
-
-        log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
+        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
+        if req_f1 is not None:
+            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f'
+                        % (tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
+        else:
+            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f'
+                        % (tr_loss / global_step, jg_acc, sl_acc))
     else:
+        jg_acc = 0.0
         logger.info('Training complete!')
 
     # Store final model
     try:
-        best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
-        best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
-        best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
+        best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \
+            (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \
+            (best_model['goodbye act f1 score'] *
+             model.config.user_general_act_loss_weight)
     except AttributeError:
         best_score = 0.0
     best_score += best_model['joint goal accuracy']
     try:
-        current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
-        current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
-        current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
+        current_score = (req_f1 * model.config.user_request_loss_weight) + \
+                        (dom_f1 * model.config.active_domain_loss_weight) + \
+                        (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0
     except AttributeError:
         current_score = 0.0
     current_score += jg_acc
@@ -510,85 +456,225 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl
             torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
         clear_checkpoints(args.output_dir)
     else:
-        logger.info('Final model not saved, as it is not the best performing model.')
+        logger.info(
+            'Final model not saved, since it is not the best performing model.')
+
+
+# Function for validation
+def train_eval(args, model, device, dev_dataloader):
+    """Evaluate Model during training!"""
+    accuracy_jg = []
+    accuracy_sl = []
+    accuracy_req = []
+    truepos_req, falsepos_req, falseneg_req = [], [], []
+    truepos_dom, falsepos_dom, falseneg_dom = [], [], []
+    truepos_bye, falsepos_bye, falseneg_bye = [], [], []
+    accuracy_dom = []
+    accuracy_bye = []
+    turns = []
+    for batch in dev_dataloader:
+        # Perform with no gradients stored
+        with torch.no_grad():
+            if 'goodbye_belief' in batch:
+                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
+                          if ('belief-' + slot) in batch}
+                request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids
+                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
+                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
+                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
+                goodbye_labels = batch['goodbye_belief'].to(
+                    device) if args.predict_actions else None
+            else:
+                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
+                          if ('labels-' + slot) in batch}
+                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
+                                  if ('request-' + slot) in batch} if args.predict_actions else None
+                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
+                                 if ('active-' + domain) in batch} if args.predict_actions else None
+                goodbye_labels = batch['goodbye'].to(
+                    device) if args.predict_actions else None
+
+            input_ids = batch['input_ids'].to(device)
+            token_type_ids = batch['token_type_ids'].to(
+                device) if 'token_type_ids' in batch else None
+            attention_mask = batch['attention_mask'].to(
+                device) if 'attention_mask' in batch else None
+
+            loss, p, p_req, p_dom, p_bye, _, stats = model(input_ids=input_ids,
+                                                           token_type_ids=token_type_ids,
+                                                           attention_mask=attention_mask,
+                                                           inform_labels=labels,
+                                                           request_labels=request_labels,
+                                                           domain_labels=domain_labels,
+                                                           goodbye_labels=goodbye_labels)
+
+        jg_acc = 0.0
+        req_acc = 0.0
+        req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
+        dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
+        dom_acc = 0.0
+        for slot in model.informable_slot_ids:
+            labels = batch['labels-' + slot].to(device)
+            p_ = p[slot]
+
+            acc = (p_.argmax(-1) == labels).reshape(-1).float()
+            jg_acc += acc
+
+        if model.config.predict_actions:
+            for slot in model.requestable_slot_ids:
+                p_req_ = p_req[slot]
+                request_labels = batch['request-' + slot].to(device)
+
+                acc = (p_req_.round().int() == request_labels).reshape(-1).float()
+                tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
+                fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float()
+                fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float()
+                req_acc += acc
+                req_tp += tp
+                req_fp += fp
+                req_fn += fn
+
+            for domain in model.domain_ids:
+                p_dom_ = p_dom[domain]
+                domain_labels = batch['active-' + domain].to(device)
 
+                acc = (p_dom_.round().int() == domain_labels).reshape(-1).float()
+                tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float()
+                fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float()
+                fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float()
+                dom_acc += acc
+                dom_tp += tp
+                dom_fp += fp
+                dom_fn += fn
 
-def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False):
-    """
-    Evaluate model
+            goodbye_labels = batch['goodbye'].to(device)
+            bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum()
+            bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
+            bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum()
+            bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
+        else:
+            req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0)
+            req_tp, req_fp, req_fn = None, None, None
+            dom_tp, dom_fp, dom_fn = None, None, None
+            bye_tp, bye_fp, bye_fn = torch.tensor(
+                0.0), torch.tensor(0.0), torch.tensor(0.0)
+
+        sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float()
+        jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float()
+        req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0)
+        req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
+        req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
+        req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
+        dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
+        dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
+        dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
+        dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0)
+        n_turns = (labels >= 0).reshape(-1).sum().float().item()
 
-    Args:
-        args: Runtime arguments
-        model: SetSUMBT model instance
-        device: Torch device in use
-        dataloader: Dataloader of data to evaluate on
-        return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format
-        is_train: If true model is training and no logging is performed
+        accuracy_jg.append(jg_acc.item())
+        accuracy_sl.append(sl_acc.item())
+        accuracy_req.append(req_acc.item())
+        truepos_req.append(req_tp.item())
+        falsepos_req.append(req_fp.item())
+        falseneg_req.append(req_fn.item())
+        accuracy_dom.append(dom_acc.item())
+        truepos_dom.append(dom_tp.item())
+        falsepos_dom.append(dom_fp.item())
+        falseneg_dom.append(dom_fn.item())
+        accuracy_bye.append(bye_acc.item())
+        truepos_bye.append(bye_tp.item())
+        falsepos_bye.append(bye_fp.item())
+        falseneg_bye.append(bye_fn.item())
+        turns.append(n_turns)
 
-    Returns:
-        out: Evaluted model statistics
-    """
-    return_eval_output = False if is_train else return_eval_output
-    if not is_train:
-        logger.info("***** Running evaluation *****")
-        logger.info("  Num Batches = %d", len(dataloader))
+    # Global accuracy reduction across batches
+    turns = sum(turns)
+    jg_acc = sum(accuracy_jg) / turns
+    sl_acc = sum(accuracy_sl) / turns
+    if model.config.predict_actions:
+        req_acc = sum(accuracy_req) / turns
+        req_tp = sum(truepos_req)
+        req_fp = sum(falsepos_req)
+        req_fn = sum(falseneg_req)
+        req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn))
+        dom_acc = sum(accuracy_dom) / turns
+        dom_tp = sum(truepos_dom)
+        dom_fp = sum(falsepos_dom)
+        dom_fn = sum(falseneg_dom)
+        dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn))
+        bye_tp = sum(truepos_bye)
+        bye_fp = sum(falsepos_bye)
+        bye_fn = sum(falseneg_bye)
+        bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn))
+        bye_acc = sum(accuracy_bye) / turns
+    else:
+        req_acc, dom_acc, bye_acc = None, None, None
+        req_f1, dom_f1, bye_f1 = None, None, None
+
+    return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats
+
+
+def evaluate(args, model, device, dataloader):
+    """Evaluate Model!"""
+    # Evaluate!
+    logger.info("***** Running evaluation *****")
+    logger.info("  Num Batches = %d", len(dataloader))
 
     tr_loss = 0.0
     model.eval()
-    if return_eval_output:
-        ontology = dataloader.dataset.ontology
 
+    # logits = {slot: [] for slot in model.informable_slot_ids}
     accuracy_jg = []
     accuracy_sl = []
+    accuracy_req = []
     truepos_req, falsepos_req, falseneg_req = [], [], []
     truepos_dom, falsepos_dom, falseneg_dom = [], [], []
-    truepos_gen, falsepos_gen, falseneg_gen = [], [], []
+    truepos_bye, falsepos_bye, falseneg_bye = [], [], []
+    accuracy_dom = []
+    accuracy_bye = []
     turns = []
-    if return_eval_output:
-        evaluation_output = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader
+    epoch_iterator = tqdm(dataloader, desc="Iteration")
     for batch in epoch_iterator:
         with torch.no_grad():
-            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
-                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
-
-            loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict)
+            if 'goodbye_belief' in batch:
+                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
+                          if ('belief-' + slot) in batch}
+                request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids
+                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
+                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
+                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
+                goodbye_labels = batch['goodbye_belief'].to(
+                    device) if args.predict_actions else None
+            else:
+                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
+                          if ('labels-' + slot) in batch}
+                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
+                                  if ('request-' + slot) in batch} if args.predict_actions else None
+                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
+                                 if ('active-' + domain) in batch} if args.predict_actions else None
+                goodbye_labels = batch['goodbye'].to(
+                    device) if args.predict_actions else None
+
+            input_ids = batch['input_ids'].to(device)
+            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+
+            loss, p, p_req, p_dom, p_bye, _, _ = model(input_ids=input_ids,
+                                                       token_type_ids=token_type_ids,
+                                                       attention_mask=attention_mask,
+                                                       inform_labels=labels,
+                                                       request_labels=request_labels,
+                                                       domain_labels=domain_labels,
+                                                       goodbye_labels=goodbye_labels)
 
         jg_acc = 0.0
         req_acc = 0.0
         req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
         dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
         dom_acc = 0.0
-
-        if return_eval_output:
-            eval_output_batch = []
-            for dial_id, dial in enumerate(input_dict['input_ids']):
-                for turn_id, turn in enumerate(dial):
-                    if turn.sum() != 0:
-                        eval_output_batch.append({'dial_idx': dial_id,
-                                                  'utt_idx': turn_id,
-                                                  'state': {domain: {slot: '' for slot in substate}
-                                                            for domain, substate in ontology.items()},
-                                                  'predictions': {'state': {domain: {slot: '' for slot in substate}
-                                                                            for domain, substate in ontology.items()}}
-                                                  })
-
-        for slot in model.setsumbt.informable_slot_ids:
+        for slot in model.informable_slot_ids:
             p_ = p[slot]
-            state_labels = batch['state_labels-' + slot].to(device)
-
-            if return_eval_output:
-                prediction = p_.argmax(-1)
-
-                for sample in eval_output_batch:
-                    dom, slt = slot.split('-', 1)
-                    pred = prediction[sample['dial_idx']][sample['utt_idx']].item()
-                    pred = ontology[dom][slt]['possible_values'][pred]
-                    lab = state_labels[sample['dial_idx']][sample['utt_idx']].item()
-                    lab = ontology[dom][slt]['possible_values'][lab]
-
-                    sample['state'][dom][slt] = lab if lab != 'none' else ''
-                    sample['predictions']['state'][dom][slt] = pred if pred != 'none' else ''
+            labels = batch['labels-' + slot].to(device)
 
             if args.temp_scaling > 0.0:
                 p_ = torch.log(p_ + 1e-10) / args.temp_scaling
@@ -597,18 +683,28 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
                 p_ = torch.log(p_ + 1e-10) / 1.0
                 p_ = torch.softmax(p_, -1)
 
-            acc = (p_.argmax(-1) == state_labels).reshape(-1).float()
+            # logits[slot].append(p_)
+
+            if args.accuracy_samples > 0:
+                dist = Categorical(probs=p_.reshape(-1, p_.size(-1)))
+                lab_sample = dist.sample((args.accuracy_samples,))
+                lab_sample = lab_sample.transpose(0, 1)
+                acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)]
+                acc = torch.tensor(acc).float()
+            elif args.accuracy_topn > 0:
+                labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
+                labs = labs[:, :args.accuracy_topn]
+                acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)]
+                acc = torch.tensor(acc).float()
+            else:
+                acc = (p_.argmax(-1) == labels).reshape(-1).float()
 
             jg_acc += acc
 
-        if return_eval_output:
-            evaluation_output += deepcopy(eval_output_batch)
-            eval_output_batch = []
-
         if model.config.predict_actions:
-            for slot in model.setsumbt.requestable_slot_ids:
+            for slot in model.requestable_slot_ids:
                 p_req_ = p_req[slot]
-                request_labels = batch['request_labels-' + slot].to(device)
+                request_labels = batch['request-' + slot].to(device)
 
                 acc = (p_req_.round().int() == request_labels).reshape(-1).float()
                 tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
@@ -619,88 +715,85 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
                 req_fp += fp
                 req_fn += fn
 
-            domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch]
-            for domain in domains:
+            for domain in model.domain_ids:
                 p_dom_ = p_dom[domain]
-                active_domain_labels = batch['active_domain_labels-' + domain].to(device)
+                domain_labels = batch['active-' + domain].to(device)
 
-                acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float()
-                tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float()
-                fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float()
+                acc = (p_dom_.round().int() == domain_labels).reshape(-1).float()
+                tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float()
+                fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float()
+                fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float()
                 dom_acc += acc
                 dom_tp += tp
                 dom_fp += fp
                 dom_fn += fn
 
-            general_act_labels = batch['general_act_labels'].to(device)
-            gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum()
-            gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum()
-            gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum()
+            goodbye_labels = batch['goodbye'].to(device)
+            bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum()
+            bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
+            bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum()
+            bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
         else:
+            req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0)
             req_tp, req_fp, req_fn = None, None, None
             dom_tp, dom_fp, dom_fn = None, None, None
-            gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
-
-        sl_acc = sum(jg_acc / len(model.setsumbt.informable_slot_ids)).float()
-        jg_acc = sum((jg_acc == len(model.setsumbt.informable_slot_ids)).int()).float()
-        req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
-        req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
-        req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
-        dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
-        dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
-        dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
-        n_turns = (state_labels >= 0).reshape(-1).sum().float().item()
+            bye_tp, bye_fp, bye_fn = torch.tensor(
+                0.0), torch.tensor(0.0), torch.tensor(0.0)
+
+        sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float()
+        jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float()
+        req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0)
+        req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
+        req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
+        req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
+        dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
+        dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
+        dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
+        dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0)
+        n_turns = (labels >= 0).reshape(-1).sum().float().item()
 
         accuracy_jg.append(jg_acc.item())
         accuracy_sl.append(sl_acc.item())
+        accuracy_req.append(req_acc.item())
         truepos_req.append(req_tp.item())
         falsepos_req.append(req_fp.item())
         falseneg_req.append(req_fn.item())
+        accuracy_dom.append(dom_acc.item())
         truepos_dom.append(dom_tp.item())
         falsepos_dom.append(dom_fp.item())
         falseneg_dom.append(dom_fn.item())
-        truepos_gen.append(gen_tp.item())
-        falsepos_gen.append(gen_fp.item())
-        falseneg_gen.append(gen_fn.item())
+        accuracy_bye.append(bye_acc.item())
+        truepos_bye.append(bye_tp.item())
+        falsepos_bye.append(bye_fp.item())
+        falseneg_bye.append(bye_fn.item())
         turns.append(n_turns)
         tr_loss += loss.item()
 
+    # for slot in logits:
+    #     logits[slot] = torch.cat(logits[slot], 0)
+
     # Global accuracy reduction across batches
     turns = sum(turns)
     jg_acc = sum(accuracy_jg) / turns
     sl_acc = sum(accuracy_sl) / turns
     if model.config.predict_actions:
+        req_acc = sum(accuracy_req) / turns
         req_tp = sum(truepos_req)
         req_fp = sum(falsepos_req)
         req_fn = sum(falseneg_req)
-        req_f1 = req_tp + 0.5 * (req_fp + req_fn)
-        req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0
+        req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn))
+        dom_acc = sum(accuracy_dom) / turns
         dom_tp = sum(truepos_dom)
         dom_fp = sum(falsepos_dom)
         dom_fn = sum(falseneg_dom)
-        dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn)
-        dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0
-        gen_tp = sum(truepos_gen)
-        gen_fp = sum(falsepos_gen)
-        gen_fn = sum(falseneg_gen)
-        gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn)
-        gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0
+        dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn))
+        bye_tp = sum(truepos_bye)
+        bye_fp = sum(falsepos_bye)
+        bye_fn = sum(falseneg_bye)
+        bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn))
+        bye_acc = sum(accuracy_bye) / turns
     else:
-        req_f1, dom_f1, gen_f1 = None, None, None
-
-    if return_eval_output:
-        dial_idx = 0
-        for sample in evaluation_output:
-            if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0:
-                dial_idx = 0
-            elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0:
-                dial_idx += 1
-            elif sample['utt_idx'] == 0:
-                dial_idx += 1
-            sample['dial_idx'] = dial_idx
-
-        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output
-    if is_train:
-        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats
-    return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader)
+        req_acc, dom_acc, bye_acc = None, None, None
+        req_f1, dom_f1, bye_f1 = None, None, None
+
+    return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, tr_loss / len(dataloader)
diff --git a/convlab/dst/setsumbt/multiwoz/Tracker.py b/convlab/dst/setsumbt/multiwoz/Tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed1a1a62334e9a721c1f5b7de442b17bfd14b76
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/Tracker.py
@@ -0,0 +1,455 @@
+import os
+import json
+import copy
+import logging
+
+import torch
+import transformers
+from transformers import (BertModel, BertConfig, BertTokenizer,
+                          RobertaModel, RobertaConfig, RobertaTokenizer)
+from convlab.dst.setsumbt.modeling import (RobertaSetSUMBT,
+                                            BertSetSUMBT)
+
+from convlab.dst.dst import DST
+from convlab.util.multiwoz.state import default_state
+from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
+from convlab.dst.rule.multiwoz import normalize_value
+from convlab.util.custom_util import model_downloader
+
+USE_CUDA = torch.cuda.is_available()
+
+# Map from SetSUMBT slot names to Convlab slot names
+SLOT_MAP = {'arrive by': 'arriveBy',
+            'leave at': 'leaveAt',
+            'price range': 'pricerange',
+            'trainid': 'trainID',
+            'reference': 'Ref',
+            'taxi types': 'car type'}
+
+
+class SetSUMBTTracker(DST):
+
+    def __init__(self, model_path="", model_type="roberta",
+                 get_turn_pooled_representation=False,
+                 get_confidence_scores=False,
+                 threshold='auto',
+                 return_entropy=False,
+                 return_mutual_info=False,
+                 store_full_belief_state=False):
+        super(SetSUMBTTracker, self).__init__()
+
+        self.model_type = model_type
+        self.model_path = model_path
+        self.get_turn_pooled_representation = get_turn_pooled_representation
+        self.get_confidence_scores = get_confidence_scores
+        self.threshold = threshold
+        self.return_entropy = return_entropy
+        self.return_mutual_info = return_mutual_info
+        self.store_full_belief_state = store_full_belief_state
+        if self.store_full_belief_state:
+            self.full_belief_state = {}
+        self.info_dict = {}
+
+        # Download model if needed
+        if not os.path.exists(self.model_path):
+            # Get path /.../convlab/dst/setsumbt/multiwoz/models
+            download_path = os.path.dirname(os.path.abspath(__file__))
+            download_path = os.path.join(download_path, 'models')
+            if not os.path.exists(download_path):
+                os.mkdir(download_path)
+            model_downloader(download_path, self.model_path)
+            # Downloadable model path format http://.../setsumbt_model_name.zip
+            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
+            self.model_path = os.path.join(download_path, self.model_path)
+
+        # Select model type based on the encoder
+        if model_type == "roberta":
+            self.config = RobertaConfig.from_pretrained(self.model_path)
+            self.tokenizer = RobertaTokenizer
+            self.model = RobertaSetSUMBT
+        elif model_type == "bert":
+            self.config = BertConfig.from_pretrained(self.model_path)
+            self.tokenizer = BertTokenizer
+            self.model = BertSetSUMBT
+        else:
+            logging.debug("Name Error: Not Implemented")
+
+        self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
+
+        # Value dict for value normalisation
+        path = os.path.dirname(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+        path = os.path.join(path, 'data/multiwoz/value_dict.json')
+        self.value_dict = json.load(open(path))
+
+        self.load_weights()
+
+    def load_weights(self):
+        # Load tokenizer and model checkpoints
+        logging.info('Loading SetSUMBT pretrained model.')
+        self.tokenizer = self.tokenizer.from_pretrained(
+            self.config.tokenizer_name)
+        logging.info(
+            f'Model tokenizer loaded from {self.config.tokenizer_name}.')
+        self.model = self.model.from_pretrained(
+            self.model_path, config=self.config)
+        logging.info(f'Model loaded from {self.model_path}.')
+
+        # Transfer model to compute device and setup eval environment
+        self.model = self.model.to(self.device)
+        self.model.eval()
+        logging.info(f'Model transferred to device: {self.device}')
+
+        logging.info('Loading model ontology')
+        f = open(os.path.join(self.model_path, 'ontology.json'), 'r')
+        self.ontology = json.load(f)
+        f.close()
+
+        db = torch.load(os.path.join(self.model_path, 'ontology.db'))
+        # Get slot and value embeddings
+        slots = {slot: db[slot] for slot in db}
+        values = {slot: db[slot][1] for slot in db}
+        del db
+
+        # Load model ontology
+        self.model.add_slot_candidates(slots)
+        for slot in values:
+            self.model.add_value_candidates(slot, values[slot], replace=True)
+
+        if self.get_confidence_scores:
+            logging.info('Model will output action and state confidence scores.')
+        if self.get_confidence_scores:
+            self.get_thresholds(self.threshold)
+            logging.info('Uncertain Querying set up and thresholds set up at:')
+            logging.info(self.thresholds)
+        if self.return_entropy:
+            logging.info('Model will output state distribution entropy.')
+        if self.return_mutual_info:
+            logging.info('Model will output state distribution mutual information.')
+        logging.info('Ontology loaded successfully.')
+
+        self.det_dic = {}
+        for domain, dic in REF_USR_DA.items():
+            for key, value in dic.items():
+                assert '-' not in key
+                self.det_dic[key.lower()] = key + '-' + domain
+                self.det_dic[value.lower()] = key + '-' + domain
+
+    def get_thresholds(self, threshold='auto'):
+        self.thresholds = {}
+        for slot, value_candidates in self.ontology.items():
+            domain, slot = slot.split('-', 1)
+            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
+            slot = slot.strip().split()[1] if 'book ' in slot else slot
+            slot = SLOT_MAP.get(slot, slot)
+
+            # Auto thresholds are set based on the number of value candidates per slot
+            if domain not in self.thresholds:
+                self.thresholds[domain] = {}
+            if threshold == 'auto':
+                thres = 1.0 / (float(len(value_candidates)) - 2.1)
+                self.thresholds[domain][slot] = max(0.05, thres)
+            else:
+                self.thresholds[domain][slot] = max(0.05, threshold)
+
+        return self.thresholds
+
+    def init_session(self):
+        self.state = default_state()
+        self.active_domains = {}
+        self.hidden_states = None
+        self.info_dict = {}
+
+    def update(self, user_act=''):
+        prev_state = self.state
+
+        # Convert dialogs into transformer input features (token_ids, masks, etc)
+        features = self.get_features(user_act)
+        # Model forward pass
+        pred_states, active_domains, user_acts, turn_pooled_representation, belief_state, entropy_, mutual_info_ = self.predict(
+            features)
+
+        if entropy_ is not None:
+            entropy = {}
+            for slot, e in entropy_.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in entropy:
+                    entropy[domain] = {}
+                if 'book' in slot:
+                    assert slot.startswith('book ')
+                    slot = slot.strip().split()[1]
+                slot = SLOT_MAP.get(slot, slot)
+                entropy[domain][slot] = e
+            del entropy_
+        else:
+            entropy = None
+
+        if mutual_info_ is not None:
+            mutual_info = {}
+            for slot, mi in mutual_info_.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in mutual_info:
+                    mutual_info[domain] = {}
+                if 'book' in slot:
+                    assert slot.startswith('book ')
+                    slot = slot.strip().split()[1]
+                slot = SLOT_MAP.get(slot, slot)
+                mutual_info[domain][slot] = mi[0, 0]
+        else:
+            mutual_info = None
+
+        if belief_state is not None:
+            bs_probs = {}
+            belief_state, request_dist, domain_dist, greeting_dist = belief_state
+            for slot, p in belief_state.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in bs_probs:
+                    bs_probs[domain] = {}
+                if 'book' in slot:
+                    assert slot.startswith('book ')
+                    slot = slot.strip().split()[1]
+                slot = SLOT_MAP.get(slot, slot)
+                if slot not in bs_probs[domain]:
+                    bs_probs[domain][slot] = {}
+                bs_probs[domain][slot]['inform'] = p
+
+            for slot, p in request_dist.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in bs_probs:
+                    bs_probs[domain] = {}
+                slot = SLOT_MAP.get(slot, slot)
+                if slot not in bs_probs[domain]:
+                    bs_probs[domain][slot] = {}
+                bs_probs[domain][slot]['request'] = p
+
+            for domain, p in domain_dist.items():
+                if domain not in bs_probs:
+                    bs_probs[domain] = {}
+                bs_probs[domain]['none'] = {'inform': p}
+
+            if 'general' not in bs_probs:
+                bs_probs['general'] = {}
+            bs_probs['general']['none'] = greeting_dist
+
+        new_domains = [d for d, active in active_domains.items() if active]
+        new_domains = [
+            d for d in new_domains if not self.active_domains.get(d, False)]
+        self.active_domains = active_domains
+
+        for domain in new_domains:
+            user_acts.append(['Inform', domain.capitalize(), 'none', 'none'])
+
+        new_belief_state = copy.deepcopy(prev_state['belief_state'])
+        # user_acts = []
+        for state, value in pred_states.items():
+            domain, slot = state.split('-', 1)
+            value = '' if value == 'none' else value
+            value = 'dontcare' if value == 'do not care' else value
+            value = 'guesthouse' if value == 'guest house' else value
+            if slot not in ['name', 'book']:
+                if domain not in new_belief_state:
+                    if domain == 'bus':
+                        continue
+                    else:
+                        logging.debug(
+                            'Error: domain <{}> not in belief state'.format(domain))
+            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
+            assert 'semi' in new_belief_state[domain]
+            assert 'book' in new_belief_state[domain]
+            if 'book' in slot:
+                assert slot.startswith('book ')
+                slot = slot.strip().split()[1]
+            slot = SLOT_MAP.get(slot, slot)
+
+            # Uncertainty clipping of state
+            if belief_state is not None:
+                if bs_probs[domain][slot].get('inform', 1.0) < self.thresholds[domain][slot]:
+                    value = ''
+
+            domain_dic = new_belief_state[domain]
+            value = normalize_value(self.value_dict, domain, slot, value)
+            if slot in domain_dic['semi']:
+                new_belief_state[domain]['semi'][slot] = value
+                if prev_state['belief_state'][domain]['semi'][slot] != value:
+                    user_acts.append(['Inform', domain.capitalize(
+                    ), REF_USR_DA[domain.capitalize()].get(slot, slot), value])
+            elif slot in domain_dic['book']:
+                new_belief_state[domain]['book'][slot] = value
+                if prev_state['belief_state'][domain]['book'][slot] != value:
+                    user_acts.append(['Inform', domain.capitalize(
+                    ), REF_USR_DA[domain.capitalize()].get(slot, slot), value])
+            elif slot.lower() in domain_dic['book']:
+                new_belief_state[domain]['book'][slot.lower()] = value
+                if prev_state['belief_state'][domain]['book'][slot.lower()] != value:
+                    user_acts.append(['Inform', domain.capitalize(
+                    ), REF_USR_DA[domain.capitalize()].get(slot.lower(), slot.lower()), value])
+            else:
+                logging.debug(
+                    'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(
+                        slot, value, domain, state)
+                )
+
+        new_state = copy.deepcopy(dict(prev_state))
+        new_state['belief_state'] = new_belief_state
+        new_state['active_domains'] = self.active_domains
+        if belief_state is not None:
+            new_state['belief_state_probs'] = bs_probs
+        if entropy is not None:
+            new_state['entropy'] = entropy
+        if mutual_info is not None:
+            new_state['mutual_information'] = mutual_info
+
+        new_state['user_action'] = user_acts
+
+        user_requests = [[a, d, s, v]
+                         for a, d, s, v in user_acts if a == 'Request']
+        for act, domain, slot, value in user_requests:
+            k = REF_SYS_DA[domain].get(slot, slot)
+            domain = domain.lower()
+            if domain not in new_state['request_state']:
+                new_state['request_state'][domain] = {}
+            if k not in new_state['request_state'][domain]:
+                new_state['request_state'][domain][k] = 0
+
+        if turn_pooled_representation is not None:
+            new_state['turn_pooled_representation'] = turn_pooled_representation
+
+        self.state = new_state
+        self.info_dict = copy.deepcopy(dict(new_state))
+
+        return self.state
+
+    # Model prediction function
+
+    def predict(self, features):
+        # Forward Pass
+        mutual_info = None
+        with torch.no_grad():
+            turn_pooled_representation = None
+            if self.get_turn_pooled_representation:
+                belief_state, request, domain, goodbye, self.hidden_states, turn_pooled_representation = self.model(input_ids=features['input_ids'],
+                                                                                                                    token_type_ids=features[
+                                                                                                                        'token_type_ids'],
+                                                                                                                    attention_mask=features[
+                                                                                                                        'attention_mask'],
+                                                                                                                    hidden_state=self.hidden_states,
+                                                                                                                    get_turn_pooled_representation=True)
+            elif self.return_mutual_info:
+                belief_state, request, domain, goodbye, self.hidden_states, mutual_info = self.model(input_ids=features['input_ids'],
+                                                                                                     token_type_ids=features[
+                                                                                                         'token_type_ids'],
+                                                                                                     attention_mask=features[
+                                                                                                         'attention_mask'],
+                                                                                                     hidden_state=self.hidden_states,
+                                                                                                     get_turn_pooled_representation=False,
+                                                                                                     calculate_inform_mutual_info=True)
+            else:
+                belief_state, request, domain, goodbye, self.hidden_states = self.model(input_ids=features['input_ids'],
+                                                                                        token_type_ids=features['token_type_ids'],
+                                                                                        attention_mask=features['attention_mask'],
+                                                                                        hidden_state=self.hidden_states,
+                                                                                        get_turn_pooled_representation=False)
+
+        # Convert belief state into dialog state
+        predictions = {slot: state[0, 0, :].argmax().item()
+                       for slot, state in belief_state.items()}
+        predictions = {slot: self.ontology[slot][idx]
+                       for slot, idx in predictions.items()}
+        predictions = {s: v for s, v in predictions.items() if v != 'none'}
+
+        if self.store_full_belief_state:
+            self.full_belief_state = belief_state
+
+        # Obtain model output probabilities
+        if self.get_confidence_scores:
+            entropy = None
+            if self.return_entropy:
+                entropy = {slot: state[0, 0, :]
+                           for slot, state in belief_state.items()}
+                entropy = {slot: self.relative_entropy(
+                    p).item() for slot, p in entropy.items()}
+
+            # Confidence score is the max probability across all not "none" values candidates.
+            belief_state = {slot: state[0, 0, 1:].max().item()
+                            for slot, state in belief_state.items()}
+            request_dist = {SLOT_MAP.get(
+                slot, slot): p[0, 0].item() for slot, p in request.items()}
+            domain_dist = {domain: p[0, 0].item()
+                           for domain, p in domain.items()}
+            greeting_dist = {'bye': goodbye[0, 0, 1].item(
+            ), 'thank': goodbye[0, 0, 2].item()}
+            belief_state = (belief_state, request_dist,
+                            domain_dist, greeting_dist)
+        else:
+            belief_state = None
+            entropy = None
+
+        # Construct request action prediction
+        request = [slot for slot, p in request.items() if p[0, 0].item() > 0.5]
+        request = [slot.split('-', 1) for slot in request]
+        request = [[domain, SLOT_MAP.get(slot, slot)]
+                   for domain, slot in request]
+        request = [['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(
+            slot, slot), '?'] for domain, slot in request]
+
+        # Construct active domain set
+        domain = {domain: p[0, 0].item() > 0.5 for domain, p in domain.items()}
+
+        # Construct general domain action
+        goodbye = goodbye[0, 0, :].argmax(-1).item()
+        goodbye = [[], ['bye'], ['thank']][goodbye]
+        goodbye = [[act, 'general', 'none', 'none'] for act in goodbye]
+
+        user_acts = request + goodbye
+
+        return predictions, domain, user_acts, turn_pooled_representation, belief_state, entropy, mutual_info
+
+    def relative_entropy(self, probs):
+        entropy = probs * torch.log(probs + 1e-8)
+        entropy = -entropy.sum()
+        # Maximum entropy of a K dimentional distribution is ln(K)
+        entropy /= torch.log(torch.tensor(probs.size(-1)).float())
+
+        return entropy
+
+    # Convert dialog turns into model features
+    def get_features(self, user_act):
+        # Extract system utterance from dialog history
+        context = self.state['history']
+        if context:
+            if context[-1][0] != 'sys':
+                system_act = ''
+            else:
+                system_act = context[-1][-1]
+        else:
+            system_act = ''
+
+        # Tokenize dialog
+        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, max_length=self.config.max_turn_len,
+                                              padding='max_length', truncation='longest_first')
+
+        input_ids = torch.tensor(features['input_ids']).reshape(
+            1, 1, -1).to(self.device) if 'input_ids' in features else None
+        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
+            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
+        attention_mask = torch.tensor(features['attention_mask']).reshape(
+            1, 1, -1).to(self.device) if 'attention_mask' in features else None
+        features = {'input_ids': input_ids,
+                    'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
+
+        return features
+
+
+# if __name__ == "__main__":
+#     tracker = SetSUMBTTracker(model_type='roberta', model_path='/gpfs/project/niekerk/results/nbt/convlab_setsumbt_acts')
+#                         # nlu_path='/gpfs/project/niekerk/data/bert_multiwoz_all_context.zip')
+#     tracker.init_session()
+#     state = tracker.update('hey. I need a cheap restaurant.')
+#     # tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
+#     # tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
+#     # state = tracker.update('If you have something Asian that would be great.')
+#     # tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
+#     # tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
+#     # state = tracker.update('Great. Where are they located?')
+#     # tracker.state['history'].append(['usr', 'Great. Where are they located?'])
+#     print(tracker.state)
diff --git a/convlab/dst/setsumbt/multiwoz/__init__.py b/convlab/dst/setsumbt/multiwoz/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1f1fb894a0430545c62b78f9a6f4786c4e328a8
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/__init__.py
@@ -0,0 +1,2 @@
+from convlab.dst.setsumbt.multiwoz.dataset import multiwoz21, ontology
+from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
new file mode 100644
index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
@@ -0,0 +1,83 @@
+it's	it is
+don't	do not
+doesn't	does not
+didn't	did not
+you'd	you would
+you're	you are
+you'll	you will
+i'm	i am
+they're	they are
+that's	that is
+what's	what is
+couldn't	could not
+i've	i have
+we've	we have
+can't	cannot
+i'd	i would
+i'd	i would
+aren't	are not
+isn't	is not
+wasn't	was not
+weren't	were not
+won't	will not
+there's	there is
+there're	there are
+. .	.
+restaurants	restaurant -s
+hotels	hotel -s
+laptops	laptop -s
+cheaper	cheap -er
+dinners	dinner -s
+lunches	lunch -s
+breakfasts	breakfast -s
+expensively	expensive -ly
+moderately	moderate -ly
+cheaply	cheap -ly
+prices	price -s
+places	place -s
+venues	venue -s
+ranges	range -s
+meals	meal -s
+locations	location -s
+areas	area -s
+policies	policy -s
+children	child -s
+kids	kid -s
+kidfriendly	kid friendly
+cards	card -s
+upmarket	expensive
+inpricey	cheap
+inches	inch -s
+uses	use -s
+dimensions	dimension -s
+driverange	drive range
+includes	include -s
+computers	computer -s
+machines	machine -s
+families	family -s
+ratings	rating -s
+constraints	constraint -s
+pricerange	price range
+batteryrating	battery rating
+requirements	requirement -s
+drives	drive -s
+specifications	specification -s
+weightrange	weight range
+harddrive	hard drive
+batterylife	battery life
+businesses	business -s
+hours	hour -s
+one	1
+two	2
+three	3
+four	4
+five	5
+six	6
+seven	7
+eight	8
+nine	9
+ten	10
+eleven	11
+twelve	12
+anywhere	any where
+good bye	goodbye
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c8e98f35429ce5194d82d07a1cf0de8fee54515
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
@@ -0,0 +1,502 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MultiWOZ 2.1/2.3 Dialogue Dataset"""
+
+import os
+import json
+import requests
+import zipfile
+import io
+from shutil import copy2 as copy
+
+import torch
+from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
+from tqdm import tqdm
+
+from convlab.dst.setsumbt.multiwoz.dataset.utils import (clean_text, ACTIVE_DOMAINS, get_domains, set_util_domains,
+                                                        fix_delexicalisation, extract_dialogue, PRICERANGE,
+                                                        BOOLEAN, DAYS, QUANTITIES, TIME, VALUE_MAP, map_values)
+
+
+# Set up global data_directory
+def set_datadir(dir):
+    global DATA_DIR
+    DATA_DIR = dir
+
+
+def set_active_domains(domains):
+    global ACTIVE_DOMAINS
+    ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
+    set_util_domains(ACTIVE_DOMAINS)
+
+
+# MultiWOZ2.1 download link
+URL = 'https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip'
+def set_url(url):
+    global URL
+    URL = url
+
+
+# Create Dialogue examples from the dataset
+def create_examples(max_utt_len, get_requestable_slots=False, force_processing=False):
+
+    # Load or download Raw Data
+    if not os.path.exists(DATA_DIR):
+        os.mkdir(DATA_DIR)
+    if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')):
+        # Download data archive and extract
+        archive = _download()
+        data = _extract(archive)
+
+        writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w')
+        json.dump(data, writer, indent = 2)
+        del archive, writer
+    else:
+        reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r')
+        data = json.load(reader)
+
+    if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')):
+        # Preprocess all dialogues
+        data_processed = _process(data['data'], data['system_acts'])
+        # Format data and split train, test and devlopment sets
+        train, dev, test = _split_data(data_processed, data['testListFile'],
+                                                            data['valListFile'], max_utt_len)
+
+        # Write data
+        writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w')
+        json.dump(train, writer, indent = 2)
+        writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w')
+        json.dump(test, writer, indent = 2)
+        writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w')
+        json.dump(dev, writer, indent = 2)
+        writer.flush()
+        writer.close()
+        del writer
+
+        # Extract slots and slot value candidates from the dataset
+        for set_type in ['train', 'dev', 'test']:
+            _get_ontology(set_type, get_requestable_slots)
+        
+        script_path = os.path.abspath(__file__).replace('/multiwoz21.py', '')
+        file_name = 'mwoz21_ont_request.json' if get_requestable_slots else 'mwoz21_ont.json'
+        copy(os.path.join(script_path, file_name), os.path.join(DATA_DIR, 'ontology_test.json'))
+        copy(os.path.join(script_path, 'mwoz21_slot_descriptions.json'), os.path.join(DATA_DIR, 'slot_descriptions.json'))
+
+
+# Extract slots and slot value candidates from the dataset
+def _get_ontology(set_type, get_requestable_slots=False):
+
+    datasets = ['train']
+    if set_type in ['test', 'dev']:
+        datasets.append('dev')
+        datasets.append('test')
+
+    # Load examples
+    data = []
+    for dataset in datasets:
+        reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r')
+        data += json.load(reader)
+
+    ontology = dict()
+    for dial in data:
+        for turn in dial['dialogue']:
+            for state in turn['dialogue_state']:
+                slot, value = state
+                value = map_values(value)
+                if slot not in ontology:
+                    ontology[slot] = [value]
+                else:
+                    ontology[slot].append(value)
+
+    requestable_slots = []
+    if get_requestable_slots:
+        for dial in data:
+            for turn in dial['dialogue']:
+                for act, dom, slot, val in turn['user_acts']:
+                    if act == 'request':
+                        requestable_slots.append(f'{dom}-{slot}')
+        requestable_slots = list(set(requestable_slots))
+
+    for slot in ontology:
+        if 'price' in slot:
+            ontology[slot] = PRICERANGE
+        if 'parking' in slot or 'internet' in slot:
+            ontology[slot] = BOOLEAN
+        if 'day' in slot:
+            ontology[slot] = DAYS
+        if 'people' in slot or 'duration' in slot or 'stay' in slot:
+            ontology[slot] = QUANTITIES
+        if 'time' in slot or 'leave' in slot or 'arrive' in slot:
+            ontology[slot] = TIME
+        if 'stars' in slot:
+            ontology[slot] += [str(i) for i in range(5)]
+
+    # Sort slot values and add none and dontcare values
+    for slot in ontology:
+        ontology[slot] = list(set(ontology[slot]))
+        ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']])
+    for slot in requestable_slots:
+        if slot in ontology:
+            ontology[slot].append('request')
+        else:
+            ontology[slot] = ['request']
+
+    writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w')
+    json.dump(ontology, writer, indent=2)
+    writer.close()
+
+
+# Convert dialogue examples to model input features and labels
+def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64):
+
+    features = dict()
+
+    # Load examples
+    reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r')
+    data = json.load(reader)
+
+    # Get encoder input for system, user utterance pairs
+    input_feats = []
+    for dial in data:
+        dial_feats = []
+        for turn in dial['dialogue']:
+            if len(turn['system_transcript']) == 0:
+                usr = turn['transcript']
+                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens = True,
+                                                        max_length = max_seq_len, padding='max_length',
+                                                        truncation = 'longest_first'))
+            else:
+                usr = turn['transcript']
+                sys = turn['system_transcript']
+                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens = True,
+                                                        max_length = max_seq_len, padding='max_length',
+                                                        truncation = 'longest_first'))
+            if len(dial_feats) >= max_turns:
+                break
+        input_feats.append(dial_feats)
+    del dial_feats
+
+    # Perform turn level padding
+    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
+    if 'token_type_ids' in input_feats[0][0]:
+        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
+    else:
+        token_type_ids = None
+    if 'attention_mask' in input_feats[0][0]:
+        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
+    else:
+        attention_mask = None
+    del input_feats
+
+    # Create torch data tensors
+    features['input_ids'] = torch.tensor(input_ids)
+    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
+    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
+    del input_ids, token_type_ids, attention_mask
+
+    # Load ontology
+    reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
+    ontology = json.load(reader)
+    reader.close()
+
+    informable_slots = [slot for slot, values in ontology.items() if values != ['request']]
+    requestable_slots = [slot for slot, values in ontology.items() if 'request' in values]
+    for slot in requestable_slots:
+        ontology[slot].remove('request')
+    
+    domains = list(set(informable_slots + requestable_slots))
+    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
+
+    # Create slot labels
+    for slot in informable_slots:
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial['dialogue']:
+                slots_active = [s for s, v in turn['dialogue_state']]
+                if slot in slots_active:
+                    value = [v for s, v in turn['dialogue_state'] if s == slot][0]
+                else:
+                    value = 'none'
+                if value in ontology[slot]:
+                    value = ontology[slot].index(value)
+                else:
+                    value = map_values(value)
+                    if value in ontology[slot]:
+                        value = ontology[slot].index(value)
+                    else:
+                        value = -1 # If value is not in ontology then we do not penalise the model
+                labs.append(value)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['labels-' + slot] = labels
+
+    for slot in requestable_slots:
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial['dialogue']:
+                slots_active = [[d, s] for i, d, s, v in turn['user_acts']]
+                if slot.split('-', 1) in slots_active:
+                    act_ = [i for i, d, s, v in turn['user_acts'] if f"{d}-{s}" == slot][0]
+                    if act_ == 'request':
+                        labs.append(1)
+                    else:
+                        labs.append(0)
+                else:
+                    labs.append(0)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+        
+        labels = torch.tensor(labels)
+        features['request-' + slot] = labels
+    
+    # Greeting act labels (0-no greeting, 1-goodbye, 2-thank you)
+    labels = []
+    for dial in data:
+        labs = []
+        for turn in dial['dialogue']:
+            greeting_active = [i for i, d, s, v in turn['user_acts'] if i in ['bye', 'thank']]
+            if greeting_active:
+                if 'bye' in greeting_active:
+                    labs.append(1)
+                else :
+                    labs.append(2)
+            else:
+                labs.append(0)
+            if len(labs) >= max_turns:
+                break
+        labs = labs + [-1] * (max_turns - len(labs))
+        labels.append(labs)
+    
+    labels = torch.tensor(labels)
+    features['goodbye'] = labels
+
+    for domain in domains:
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial['dialogue']:
+                if domain == turn['domain']:
+                    labs.append(1)
+                else:
+                    labs.append(0)
+                if len(labs) >= max_turns:
+                        break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+        
+        labels = torch.tensor(labels)
+        features['active-' + domain] = labels
+
+    del labels
+
+    return features
+
+
+# MultiWOZ2.1 Dataset object
+class MultiWoz21(Dataset):
+
+    def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64):
+        self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len)
+
+    def __getitem__(self, index):
+        return {label: self.features[label][index] for label in self.features
+                if self.features[label] is not None}
+
+    def __len__(self):
+        return self.features['input_ids'].size(0)
+
+    def resample(self, size=None):
+        n_dialogues = self.__len__()
+        if not size:
+            size = n_dialogues
+
+        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
+        self.features = {label: self.features[label][dialogues] for label in self.features
+                        if self.features[label] is not None}
+        
+        return self
+
+    def to(self, device):
+        self.device = device
+        self.features = {label: self.features[label].to(device) for label in self.features
+                         if self.features[label] is not None}
+
+
+# MultiWOZ2.1 Dataset object
+class EnsembleMultiWoz21(Dataset):
+    def __init__(self, data):
+        self.features = data
+
+    def __getitem__(self, index):
+        return {label: self.features[label][index] for label in self.features
+                if self.features[label] is not None}
+
+    def __len__(self):
+        return self.features['input_ids'].size(0)
+
+    def resample(self, size=None):
+        n_dialogues = self.__len__()
+        if not size:
+            size = n_dialogues
+
+        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
+        self.features = {label: self.features[label][dialogues] for label in self.features
+                        if self.features[label] is not None}
+
+    def to(self, device):
+        self.device = device
+        self.features = {label: self.features[label].to(device) for label in self.features
+                         if self.features[label] is not None}
+
+
+# Module to create torch dataloaders
+def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None):
+    data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len)
+    data.to('cpu')
+
+    if resampled_size:
+        data.resample(resampled_size)
+
+    if set_type in ['test', 'dev']:
+        sampler = SequentialSampler(data)
+    else:
+        sampler = RandomSampler(data)
+    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
+
+    return loader
+
+
+def _download(chunk_size=1048576):
+    """Download data archive.
+
+    Parameters:
+        chunk_size (int): Download chunk size. (default=1048576)
+    Returns:
+        archive: ZipFile archive object.
+    """
+    # Download the archive byte string
+    req = requests.get(URL, stream=True)
+    archive = b''
+    for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'):
+        if chunk:
+            archive += chunk
+
+    # Convert the bytestring into a zipfile object
+    archive = io.BytesIO(archive)
+    archive = zipfile.ZipFile(archive)
+
+    return archive
+
+
+def _extract(archive):
+    """Extract the json dictionaries from the archive.
+
+    Parameters:
+        archive: ZipFile archive object.
+    Returns:
+        data: Data dictionary.
+    """
+    files = [file for file in archive.filelist if ('.json' in file.filename or '.txt' in file.filename)
+            and 'MACOSX' not in file.filename]
+    objects = []
+    for file in tqdm(files, desc='File'):
+        data = archive.open(file).read()
+        # Get data objects from the files
+        try:
+            data = json.loads(data)
+        except json.decoder.JSONDecodeError:
+            data = data.decode().split('\n')
+        objects.append(data)
+
+    files = [file.filename.split('/')[-1].split('.')[0] for file in files]
+
+    data = {file: data for file, data in zip(files, objects)}
+    return data
+
+
+# Process files
+def _process(dialogue_data, acts_data):
+    print('Processing Dialogues')
+    out = {}
+    for dial_name in tqdm(dialogue_data):
+        dialogue = dialogue_data[dial_name]
+
+        prev_dom = ''
+        for turn_id, turn in enumerate(dialogue['log']):
+            dialogue['log'][turn_id]['text'] = clean_text(turn['text'])
+            if len(turn['metadata']) != 0:
+                crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom)
+                prev_dom = crnt_dom
+                dialogue['log'][turn_id - 1]['domain'] = crnt_dom
+
+            dialogue['log'][turn_id] = fix_delexicalisation(turn)
+
+        out[dial_name] = dialogue
+
+    return out
+
+
+# Split data (train, dev, test)
+def _split_data(dial_data, test, dev, max_utt_len):
+    train_dials, test_dials, dev_dials = [], [], []
+    print('Formatting and Splitting Data')
+    for name in tqdm(dial_data):
+        dialogue = dial_data[name]
+        domains = []
+
+        dial = extract_dialogue(dialogue, max_utt_len)
+        if dial:
+            dialogue = dict()
+            dialogue['dialogue_idx'] = name
+            dialogue['domains'] = []
+            dialogue['dialogue'] = []
+
+            for turn_id, turn in enumerate(dial):
+                turn_dialog = dict()
+                turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else ''
+                turn_dialog['turn_idx'] = turn_id
+                turn_dialog['dialogue_state'] = turn['ds']
+                turn_dialog['transcript'] = turn['usr']
+                # turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else []
+                turn_dialog['user_acts'] = turn['usr_a']
+                turn_dialog['domain'] = turn['domain']
+                dialogue['domains'].append(turn['domain'])
+                dialogue['dialogue'].append(turn_dialog)
+
+            dialogue['domains'] = [d for d in list(set(dialogue['domains'])) if d != '']
+            if True in [dom not in ACTIVE_DOMAINS for dom in dialogue['domains']]:
+                dialogue['domains'] = []
+            dialogue['domains'] = [dom for dom in dialogue['domains'] if dom in ACTIVE_DOMAINS]
+
+            if dialogue['domains']:
+                if name in test:
+                    test_dials.append(dialogue)
+                elif name in dev:
+                    dev_dials.append(dialogue)
+                else:
+                    train_dials.append(dialogue)
+
+    print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials)))
+
+    return train_dials, dev_dials, test_dials
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
new file mode 100644
index 0000000000000000000000000000000000000000..b703793dc747535132b748d8b69838b4c151d8d5
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
@@ -0,0 +1,2990 @@
+{
+  "hotel-price range": [
+    "none",
+    "do not care",
+    "cheap",
+    "expensive",
+    "moderate"
+  ],
+  "hotel-type": [
+    "none",
+    "do not care",
+    "bed and breakfast",
+    "guest house",
+    "hotel"
+  ],
+  "hotel-parking": [
+    "none",
+    "do not care",
+    "no",
+    "yes"
+  ],
+  "hotel-book day": [
+    "none",
+    "do not care",
+    "friday",
+    "monday",
+    "saterday",
+    "sunday",
+    "thursday",
+    "tuesday",
+    "wednesday"
+  ],
+  "hotel-book people": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "hotel-book stay": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "train-destination": [
+    "none",
+    "do not care",
+    "bishops stortford",
+    "kings lynn",
+    "london liverpool street",
+    "centre",
+    "bishop stortford",
+    "liverpool",
+    "leicester",
+    "broxbourne",
+    "gourmet burger kitchen",
+    "copper kettle",
+    "bournemouth",
+    "stevenage",
+    "liverpool street",
+    "norwich",
+    "huntingdon marriott hotel",
+    "city centre north",
+    "taj tandoori",
+    "the copper kettle",
+    "peterborough",
+    "ely",
+    "lecester",
+    "london",
+    "willi",
+    "stansted airport",
+    "huntington marriott",
+    "cambridge",
+    "gonv",
+    "glastonbury",
+    "hol",
+    "north",
+    "birmingham new street",
+    "norway",
+    "petersborough",
+    "london kings cross",
+    "curry prince",
+    "bishops storford"
+  ],
+  "train-arrive by": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55"
+  ],
+  "train-departure": [
+    "none",
+    "do not care",
+    "bishops stortford",
+    "kings lynn",
+    "brookshite",
+    "london liverpool street",
+    "cam",
+    "liverpool",
+    "bro",
+    "leicester",
+    "broxbourne",
+    "norwhich",
+    "saint johns",
+    "stevenage",
+    "stansted",
+    "london liverpool",
+    "cambrid",
+    "city hall",
+    "rosas bed and breakfast",
+    "alpha-milton",
+    "wandlebury country park",
+    "norwich",
+    "liecester",
+    "stratford",
+    "peterborough",
+    "duxford",
+    "ely",
+    "london",
+    "stansted airport",
+    "lon",
+    "cambridge",
+    "panahar",
+    "cineworld",
+    "leicaster",
+    "birmingham",
+    "cafe uno",
+    "camboats",
+    "huntingdon",
+    "birmingham new street",
+    "arbu",
+    "alpha milton",
+    "east london",
+    "london kings cross",
+    "hamilton lodge",
+    "aylesbray lodge guest",
+    "el shaddai"
+  ],
+  "train-day": [
+    "none",
+    "do not care",
+    "friday",
+    "monday",
+    "saterday",
+    "sunday",
+    "thursday",
+    "tuesday",
+    "wednesday"
+  ],
+  "train-book people": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "hotel-stars": [
+    "none",
+    "do not care",
+    "0",
+    "1",
+    "2",
+    "3",
+    "4",
+    "5"
+  ],
+  "hotel-internet": [
+    "none",
+    "do not care",
+    "no",
+    "yes"
+  ],
+  "hotel-name": [
+    "a and b guest house",
+    "city roomz",
+    "carolina bed and breakfast",
+    "limehouse",
+    "anatolia",
+    "hamilton lodge",
+    "the lensfield hotel",
+    "rosa's bed and breakfast",
+    "gall",
+    "aylesbray lodge",
+    "kirkwood",
+    "cambridge belfry",
+    "warkworth house",
+    "gonville",
+    "belfy hotel",
+    "nus",
+    "alexander",
+    "super 5",
+    "aylesbray lodge guest house",
+    "the gonvile hotel",
+    "allenbell",
+    "nothamilton lodge",
+    "ashley hotel",
+    "autumn house",
+    "hobsons house",
+    "hotel",
+    "ashely hotel",
+    "caridge belfrey",
+    "el shaddia guest house",
+    "avalon",
+    "cote",
+    "city centre north bed and breakfast",
+    "the cambridge belfry",
+    "home from home",
+    "wandlebury coutn",
+    "wankworth house",
+    "city stop rest",
+    "the worth house",
+    "cityroomz",
+    "huntingdon marriottt hotel",
+    "none",
+    "lensfield",
+    "rosas bed and breakfast",
+    "leverton house",
+    "gonville hotel",
+    "holiday inn cambridge",
+    "do not care",
+    "archway house",
+    "lan hon",
+    "levert",
+    "acorn guest house",
+    "cambridge",
+    "the ashley hotel",
+    "el shaddai",
+    "sleeperz",
+    "alpha milton guest house",
+    "doubletree by hilton cambridge",
+    "tandoori palace",
+    "express by",
+    "express by holiday inn cambridge",
+    "north bed and breakfast",
+    "holiday inn",
+    "arbury lodge guest house",
+    "alexander bed and breakfast",
+    "huntingdon marriott hotel",
+    "royal spice",
+    "sou",
+    "finches bed and breakfast",
+    "the alpha milton",
+    "bridge guest house",
+    "the acorn guest house",
+    "kirkwood house",
+    "eraina",
+    "la margherit",
+    "lensfield hotel",
+    "marriott hotel",
+    "nusha",
+    "city centre bed and breakfast",
+    "the allenbell",
+    "university arms hotel",
+    "clare",
+    "cherr",
+    "wartworth",
+    "acorn place",
+    "lovell lodge",
+    "whale"
+  ],
+  "train-leave at": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55"
+  ],
+  "restaurant-price range": [
+    "none",
+    "do not care",
+    "cheap",
+    "expensive",
+    "moderate"
+  ],
+  "restaurant-food": [
+    "british food",
+    "steakhouse",
+    "turkish",
+    "sushi",
+    "north american",
+    "scottish",
+    "french",
+    "austrian",
+    "korean",
+    "eastern european",
+    "swedish",
+    "gastro pub",
+    "modern eclectic",
+    "afternoon tea",
+    "welsh",
+    "christmas",
+    "tuscan",
+    "gastropub",
+    "sri lankan",
+    "molecular gastronomy",
+    "traditional american",
+    "italian",
+    "pizza",
+    "thai",
+    "south african",
+    "creative",
+    "english",
+    "asian",
+    "lebanese",
+    "hungarian",
+    "halal",
+    "portugese",
+    "modern english",
+    "african",
+    "light bites",
+    "malaysian",
+    "venetian",
+    "traditional",
+    "chinese",
+    "vegetarian",
+    "persian",
+    "thai and chinese",
+    "scandinavian",
+    "catalan",
+    "polynesian",
+    "crossover",
+    "canapes",
+    "cantonese",
+    "north african",
+    "seafood",
+    "brazilian",
+    "south indian",
+    "australasian",
+    "belgian",
+    "barbeque",
+    "the americas",
+    "indonesian",
+    "singaporean",
+    "irish",
+    "middle eastern",
+    "dojo noodle bar",
+    "caribbean",
+    "vietnamese",
+    "modern european",
+    "russian",
+    "none",
+    "german",
+    "world",
+    "japanese",
+    "moroccan",
+    "modern global",
+    "do not care",
+    "indian",
+    "british",
+    "american",
+    "danish",
+    "panasian",
+    "swiss",
+    "basque",
+    "north indian",
+    "modern american",
+    "australian",
+    "european",
+    "corsica",
+    "greek",
+    "northern european",
+    "mediterranean",
+    "portuguese",
+    "romanian",
+    "jamaican",
+    "polish",
+    "international",
+    "unusual",
+    "latin american",
+    "asian oriental",
+    "mexican",
+    "bistro",
+    "cuban",
+    "fusion",
+    "new zealand",
+    "spanish",
+    "eritrean",
+    "afghan",
+    "kosher"
+  ],
+  "attraction-name": [
+    "downing college",
+    "fitzwilliam",
+    "clare college",
+    "ruskin gallery",
+    "sidney sussex college",
+    "great saint mary's church",
+    "cherry hinton water play park",
+    "wandlebury country park",
+    "cafe uno",
+    "place",
+    "broughton",
+    "cineworld cinema",
+    "jesus college",
+    "vue cinema",
+    "history of science museum",
+    "mumford theatre",
+    "whale of time",
+    "fitzbillies",
+    "christs church",
+    "churchill college",
+    "museum of classical archaeology",
+    "gonville and caius college",
+    "pizza",
+    "kirkwood",
+    "saint catharines college",
+    "kings college",
+    "parkside",
+    "by",
+    "st catharines college",
+    "saint john's college",
+    "cherry hinton water park",
+    "st christs college",
+    "christ's college",
+    "bangkok city",
+    "scudamores punti co",
+    "free",
+    "great saint marys church",
+    "milton country park",
+    "the fez club",
+    "soultree",
+    "autu",
+    "whipple museum of the history of science",
+    "aylesbray lodge guest house",
+    "broughton house gallery",
+    "peoples portraits exhibition",
+    "primavera",
+    "kettles yard",
+    "all saint's church",
+    "cinema cinema",
+    "regency gallery",
+    "corpus christi",
+    "corn cambridge exchange",
+    "da vinci pizzeria",
+    "school",
+    "hobsons house",
+    "cambride and country folk museum",
+    "north",
+    "da v",
+    "cambridge corn exchange",
+    "soul tree nightclub",
+    "cambridge arts theater",
+    "saint catharine's college",
+    "byard art",
+    "cambridge punter",
+    "cambridge university botanic gardens",
+    "castle galleries",
+    "museum of archaelogy and anthropogy",
+    "no specific location",
+    "cherry hinton hall",
+    "gallery at 12 a high street",
+    "parkside pools",
+    "queen's college",
+    "little saint mary's church",
+    "gallery",
+    "home from home",
+    "tenpin",
+    "the wandlebury",
+    "county folk museum",
+    "swimming pool",
+    "christs college",
+    "cafe jello museum",
+    "scott polar",
+    "christ college",
+    "cambridge museum of technology",
+    "abbey pool and astroturf pitch",
+    "king hedges learner pool",
+    "the cambridge arts theatre",
+    "the castle galleries",
+    "cambridge and country folk museum",
+    "kohinoor",
+    "scudamores punting co",
+    "sidney sussex",
+    "the man on the moon",
+    "little saint marys church",
+    "queens",
+    "the place",
+    "old school",
+    "churchill",
+    "churchills college",
+    "hughes hall",
+    "churchhill college",
+    "riverboat georgina",
+    "none",
+    "belf",
+    "cambridge temporary art",
+    "abc theatre",
+    "cambridge contemporary art museum",
+    "man on the moon",
+    "the junction",
+    "cherry hinton water play",
+    "adc theatre",
+    "gonville hotel",
+    "magdalene college",
+    "peoples portraits exhibition at girton college",
+    "boat",
+    "centre",
+    "sheep's green and lammas land park fen causeway",
+    "do not care",
+    "the mumford theatre",
+    "archway house",
+    "queens' college",
+    "williams art and antiques",
+    "funky fun house",
+    "cherry hinton village centre",
+    "camboats",
+    "cambridge",
+    "old schools",
+    "kettle's yard",
+    "whale of a time",
+    "the churchill college",
+    "cafe jello gallery",
+    "aut",
+    "salsa",
+    "city",
+    "clare hall",
+    "boating",
+    "pembroke college",
+    "kings hedges learner pool",
+    "caffe uno",
+    "lammas land park",
+    "museum",
+    "the fitzwilliam museum",
+    "the cherry hinton village centre",
+    "the cambridge corn exchange",
+    "fitzwilliam museum",
+    "museum of archaelogy and anthropology",
+    "fez club",
+    "the cambridge punter",
+    "saint johns college",
+    "emmanuel college",
+    "cambridge belf",
+    "scudamore",
+    "lynne strover gallery",
+    "king's college",
+    "whippple museum",
+    "trinity college",
+    "college in the north",
+    "sheep's green",
+    "kambar",
+    "museum of archaelogy",
+    "adc",
+    "garde",
+    "club salsa",
+    "people's portraits exhibition at girton college",
+    "botanic gardens",
+    "carol",
+    "college",
+    "gallery at twelve a high street",
+    "abbey pool and astroturf",
+    "cambridge book and print gallery",
+    "jesus green outdoor pool",
+    "scott polar museum",
+    "saint barnabas press gallery",
+    "cambridge artworks",
+    "older churches",
+    "cambridge contemporary art",
+    "cherry hinton hall and grounds",
+    "univ",
+    "jesus green",
+    "ballare",
+    "abbey pool",
+    "cambridge botanic gardens",
+    "nusha",
+    "worth house",
+    "thanh",
+    "university arms hotel",
+    "cambridge arts theatre",
+    "cafe jello",
+    "cambridge and county folk museum",
+    "the cambridge artworks",
+    "all saints church",
+    "holy trinity church",
+    "contemporary art museum",
+    "architectural churches",
+    "queens college",
+    "trinity street college"
+  ],
+  "restaurant-name": [
+    "none",
+    "do not care",
+    "hotel du vin and bistro",
+    "ask",
+    "gourmet formal kitchen",
+    "the meze bar",
+    "lan hong house",
+    "cow pizza",
+    "one seven",
+    "prezzo",
+    "maharajah tandoori restaurant",
+    "alex",
+    "shanghai",
+    "golden wok",
+    "restaurant",
+    "fitzbillies",
+    "nil",
+    "copper kettle",
+    "meghna",
+    "hk fusion",
+    "bangkok city",
+    "hobsons house",
+    "tang chinese",
+    "anatolia",
+    "ugly duckling",
+    "anatolia and efes restaurant",
+    "sitar tandoori",
+    "city stop",
+    "ashley",
+    "pizza express fen ditton",
+    "molecular gastronomy",
+    "autumn house",
+    "el shaddia guesthouse",
+    "the grafton hotel",
+    "limehouse",
+    "gardenia",
+    "not metioned",
+    "hakka",
+    "michaelhouse cafe",
+    "pipasha",
+    "meze bar",
+    "archway",
+    "molecular gastonomy",
+    "yipee noodle bar",
+    "the peking",
+    "curry prince",
+    "midsummer house restaurant",
+    "pizza hut cherry hinton",
+    "the lucky star",
+    "stazione restaurant and coffee bar",
+    "shanghi family restaurant",
+    "good luck",
+    "j restaurant",
+    "bedouin",
+    "cott",
+    "little seoul",
+    "south",
+    "thanh binh",
+    "el",
+    "efes restaurant",
+    "kohinoor",
+    "clowns",
+    "india",
+    "the slug and lettuce",
+    "shiraz",
+    "barbakan",
+    "zizzi cambridge",
+    "restaurant one seven",
+    "slug and lettuce",
+    "travellers rest",
+    "binh",
+    "worth house",
+    "broughton house gallery",
+    "chiquito",
+    "the river bar steakhouse and grill",
+    "ros",
+    "golden house",
+    "india west",
+    "cam",
+    "panahar",
+    "restaurant 22",
+    "adden",
+    "indian",
+    "hu",
+    "jinling noodle bar",
+    "darrys cookhouse and wine shop",
+    "hobson house",
+    "cambridge be",
+    "el shaddai",
+    "ac",
+    "nandos",
+    "cambridge lodge",
+    "the cow pizza kitchen and bar",
+    "charlie",
+    "rajmahal",
+    "kymmoy",
+    "cambri",
+    "backstreet bistro",
+    "galleria",
+    "restaurant 2 two",
+    "chiquito restaurant bar",
+    "royal standard",
+    "lucky star",
+    "curry king",
+    "grafton hotel restaurant",
+    "mahal of cambridge",
+    "the bedouin",
+    "nus",
+    "the kohinoor",
+    "pizza hut fenditton",
+    "camboats",
+    "the gardenia",
+    "de luca cucina and bar",
+    "nusha",
+    "european",
+    "taj tandoori",
+    "tandoori palace",
+    "golden curry",
+    "efes",
+    "loch fyne",
+    "the maharajah tandoor",
+    "lovel",
+    "restaurant 17",
+    "clowns cafe",
+    "cambridge punter",
+    "bloomsbury restaurant",
+    "la mimosa",
+    "the cambridge chop house",
+    "funky",
+    "cotto",
+    "oak bistro",
+    "restaurant two two",
+    "pipasha restaurant",
+    "river bar steakhouse and grill",
+    "royal spice",
+    "the copper kettle",
+    "graffiti",
+    "nandos city centre",
+    "saffron brasserie",
+    "cambridge chop house",
+    "sitar",
+    "kitchen and bar",
+    "the good luck chinese food takeaway",
+    "clu",
+    "la tasca",
+    "cafe uno",
+    "cote",
+    "the varsity restaurant",
+    "bri",
+    "eraina",
+    "bridge",
+    "fin",
+    "cambridge lodge restaurant",
+    "grafton",
+    "hotpot",
+    "sala thong",
+    "margherita",
+    "wise buddha",
+    "the missing sock",
+    "seasame restaurant and bar",
+    "the dojo noodle bar",
+    "restaurant alimentum",
+    "gastropub",
+    "saigon city",
+    "la margherita",
+    "pizza hut",
+    "curry garden",
+    "ashley hotel",
+    "eraina and michaelhouse cafe",
+    "the golden curry",
+    "curry queen",
+    "cow pizza kitchen and bar",
+    "the peking restaurant:",
+    "hamilton lodge",
+    "alimentum",
+    "yippee noodle bar",
+    "2 two and cote",
+    "shanghai family restaurant",
+    "grafton hotel",
+    "yes",
+    "ali baba",
+    "dif",
+    "fitzbillies restaurant",
+    "peking restaurant",
+    "lev",
+    "nirala",
+    "the alex",
+    "tandoori",
+    "city stop restaurant",
+    "rice house",
+    "cityr",
+    "yu garden",
+    "meze bar restaurant",
+    "the",
+    "don pasquale pizzeria",
+    "rice boat",
+    "the hotpot",
+    "old school",
+    "the oak bistro",
+    "sesame restaurant and bar",
+    "pizza express",
+    "the gandhi",
+    "pizza hut fen ditton",
+    "charlie chan",
+    "da vinci pizzeria",
+    "dojo noodle bar",
+    "gourmet burger kitchen",
+    "the golden house",
+    "india house",
+    "hobso",
+    "missing sock",
+    "pizza hut city centre",
+    "parkside pools",
+    "riverside brasserie",
+    "caffe uno",
+    "primavera",
+    "the nirala",
+    "wagamama",
+    "au",
+    "ian hong house",
+    "frankie and bennys",
+    "4 kings parade city centre",
+    "shiraz restaurant",
+    "scudamores punt",
+    "mahal",
+    "saint johns chop house",
+    "de luca cucina and bar riverside brasserie",
+    "cocum",
+    "la raza"
+  ],
+  "attraction-type": [
+    "none",
+    "do not care",
+    "architecture",
+    "boat",
+    "boating",
+    "camboats",
+    "church",
+    "churchills college",
+    "cinema",
+    "college",
+    "concert",
+    "concerthall",
+    "entertainment",
+    "gallery",
+    "gastropub",
+    "hiking",
+    "hotel",
+    "multiple sports",
+    "museum",
+    "museum kettles yard",
+    "night club",
+    "outdoor",
+    "park",
+    "pool",
+    "special",
+    "sports",
+    "swimming pool",
+    "theater",
+    "theatre",
+    "concert hall",
+    "local site",
+    "nightclub",
+    "hotspot"
+  ],
+  "taxi-leave at": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55"
+  ],
+  "taxi-destination": [
+    "none",
+    "do not care",
+    "a and b guest house",
+    "abbey pool and astroturf pitch",
+    "acorn guest house",
+    "adc theatre",
+    "addenbrookes hospital",
+    "alexander bed and breakfast",
+    "ali baba",
+    "all saints church",
+    "allenbell",
+    "alpha milton guest house",
+    "anatolia",
+    "arbury lodge guesthouse",
+    "archway house",
+    "ashley hotel",
+    "ask",
+    "attraction",
+    "autumn house",
+    "avalon",
+    "aylesbray lodge guest house",
+    "backstreet bistro",
+    "ballare",
+    "bangkok city",
+    "bedouin",
+    "birmingham new street train station",
+    "bishops stortford train station",
+    "bloomsbury restaurant",
+    "bridge guest house",
+    "broughton house gallery",
+    "broxbourne train station",
+    "byard art",
+    "cafe jello gallery",
+    "cafe uno",
+    "camboats",
+    "cambridge",
+    "cambridge and county folk museum",
+    "cambridge arts theatre",
+    "cambridge artworks",
+    "cambridge belfry",
+    "cambridge book and print gallery",
+    "cambridge chop house",
+    "cambridge contemporary art",
+    "cambridge county fair next to the city tourist museum",
+    "cambridge lodge restaurant",
+    "cambridge museum of technology",
+    "cambridge punter",
+    "cambridge road church of christ",
+    "cambridge train station",
+    "cambridge university botanic gardens",
+    "carolina bed and breakfast",
+    "castle galleries",
+    "charlie chan",
+    "cherry hinton hall and grounds",
+    "cherry hinton village centre",
+    "cherry hinton water park",
+    "cherry hinton water play",
+    "chiquito restaurant bar",
+    "christ college",
+    "churchills college",
+    "cineworld cinema",
+    "city centre north bed and breakfast",
+    "city stop restaurant",
+    "cityroomz",
+    "clare college",
+    "clare hall",
+    "clowns cafe",
+    "club salsa",
+    "cocum",
+    "copper kettle",
+    "corpus christi",
+    "cote",
+    "cotto",
+    "cow pizza kitchen and bar",
+    "curry garden",
+    "curry king",
+    "curry prince",
+    "da vinci pizzeria",
+    "darrys cookhouse and wine shop",
+    "de luca cucina and bar",
+    "dojo noodle bar",
+    "don pasquale pizzeria",
+    "downing college",
+    "efes restaurant",
+    "el shaddia guesthouse",
+    "ely train station",
+    "emmanuel college",
+    "eraina",
+    "express by holiday inn cambridge",
+    "finches bed and breakfast",
+    "finders corner newmarket road",
+    "fitzbillies restaurant",
+    "fitzwilliam museum",
+    "frankie and bennys",
+    "funky fun house",
+    "galleria",
+    "gallery at 12 a high street",
+    "gastropub",
+    "golden curry",
+    "golden house",
+    "golden wok",
+    "gonville and caius college",
+    "gonville hotel",
+    "good luck",
+    "gourmet burger kitchen",
+    "graffiti",
+    "grafton hotel restaurant",
+    "great saint marys church",
+    "hakka",
+    "hamilton lodge",
+    "hk fusion",
+    "hobsons house",
+    "holy trinity church",
+    "home from home",
+    "hotel du vin and bistro",
+    "hughes hall",
+    "huntingdon marriott hotel",
+    "ian hong",
+    "india house",
+    "j restaurant",
+    "jesus college",
+    "jesus green outdoor pool",
+    "jinling noodle bar",
+    "kambar",
+    "kettles yard",
+    "kings college",
+    "kings hedges learner pool",
+    "kirkwood house",
+    "kohinoor",
+    "kymmoy",
+    "la margherita",
+    "la mimosa",
+    "la raza",
+    "la tasca",
+    "lan hong house",
+    "leicester train station",
+    "lensfield hotel",
+    "limehouse",
+    "little saint marys church",
+    "little seoul",
+    "loch fyne",
+    "london kings cross train station",
+    "london liverpool street train station",
+    "lovell lodge",
+    "lynne strover gallery",
+    "magdalene college",
+    "mahal of cambridge",
+    "maharajah tandoori restaurant",
+    "meghna",
+    "meze bar",
+    "michaelhouse cafe",
+    "midsummer house restaurant",
+    "milton country park",
+    "mumford theatre",
+    "museum of archaelogy and anthropology",
+    "museum of classical archaeology",
+    "nandos",
+    "nandos city centre",
+    "nil",
+    "nirala",
+    "norwich train station",
+    "nusha",
+    "old schools",
+    "panahar",
+    "parkside police station",
+    "parkside pools",
+    "peking restaurant",
+    "pembroke college",
+    "peoples portraits exhibition at girton college",
+    "peterborough train station",
+    "pipasha restaurant",
+    "pizza express",
+    "pizza hut cherry hinton",
+    "pizza hut city centre",
+    "pizza hut fenditton",
+    "prezzo",
+    "primavera",
+    "queens college",
+    "rajmahal",
+    "regency gallery",
+    "restaurant 17",
+    "restaurant 2 two",
+    "restaurant alimentum",
+    "rice boat",
+    "rice house",
+    "riverboat georgina",
+    "riverside brasserie",
+    "rosas bed and breakfast",
+    "royal spice",
+    "royal standard",
+    "ruskin gallery",
+    "saffron brasserie",
+    "saigon city",
+    "saint barnabas",
+    "saint barnabas press gallery",
+    "saint catharines college",
+    "saint johns chop house",
+    "saint johns college",
+    "sala thong",
+    "scott polar museum",
+    "scudamores punting co",
+    "sesame restaurant and bar",
+    "shanghai family restaurant",
+    "sheeps green and lammas land park fen causeway",
+    "shiraz",
+    "sidney sussex college",
+    "sitar tandoori",
+    "sleeperz hotel",
+    "soul tree nightclub",
+    "st johns chop house",
+    "stansted airport train station",
+    "station road",
+    "stazione restaurant and coffee bar",
+    "stevenage train station",
+    "taj tandoori",
+    "tall monument",
+    "tandoori palace",
+    "tang chinese",
+    "tenpin",
+    "thanh binh",
+    "the anatolia",
+    "the cambridge corn exchange",
+    "the cambridge shop",
+    "the fez club",
+    "the gandhi",
+    "the gardenia",
+    "the hotpot",
+    "the junction",
+    "the lucky star",
+    "the man on the moon",
+    "the missing sock",
+    "the oak bistro",
+    "the place",
+    "the regent street city center",
+    "the river bar steakhouse and grill",
+    "the slug and lettuce",
+    "the varsity restaurant",
+    "travellers rest",
+    "trinity college",
+    "ugly duckling",
+    "university arms hotel",
+    "vue cinema",
+    "wagamama",
+    "wandlebury country park",
+    "wankworth hotel",
+    "warkworth house",
+    "whale of a time",
+    "whipple museum of the history of science",
+    "williams art and antiques",
+    "worth house",
+    "yippee noodle bar",
+    "yu garden",
+    "zizzi cambridge",
+    "leverton house",
+    "the cambridge chop house",
+    "saint john's college",
+    "churchill college",
+    "the nirala",
+    "the cow pizza kitchen and bar",
+    "christ's college",
+    "el shaddai",
+    "saint catharine's college",
+    "camb",
+    "the golden curry",
+    "little saint mary's church",
+    "country folk museum",
+    "meze bar restaurant",
+    "the cambridge belfry",
+    "the fitzwilliam museum",
+    "the lensfield hotel",
+    "pizza express fen ditton",
+    "the cambridge punter",
+    "king's college",
+    "the cherry hinton village centre",
+    "shiraz restaurant",
+    "sheep's green and lammas land park fen causeway",
+    "caffe uno",
+    "the ghandi",
+    "the copper kettle",
+    "man on the moon concert hall",
+    "alpha-milton guest house",
+    "queen's college",
+    "restaurant one seven",
+    "restaurant two two",
+    "city centre north b and b",
+    "rosa's bed and breakfast",
+    "the good luck chinese food takeaway",
+    "not museum of archaeology and anthropologymentioned",
+    "tandori in cambridge",
+    "kettle's yard",
+    "megna",
+    "grou",
+    "gallery at twelve a high street",
+    "maharajah tandoori restaurant",
+    "pizza hut fen ditton",
+    "gandhi",
+    "tranh binh",
+    "kambur",
+    "people's portraits exhibition at girton college",
+    "hotel",
+    "restaurant",
+    "the galleria",
+    "queens' college",
+    "great saint mary's church",
+    "theathre",
+    "cambridge artworks",
+    "acorn house",
+    "shiraz",
+    "riverboat georginawd",
+    "mic",
+    "the gallery at twelve",
+    "the soul tree",
+    "finches"
+  ],
+  "taxi-departure": [
+    "none",
+    "do not care",
+    "172 chestertown road",
+    "4455 woodbridge road",
+    "a and b guest house",
+    "abbey pool and astroturf pitch",
+    "acorn guest house",
+    "adc theatre",
+    "addenbrookes hospital",
+    "alexander bed and breakfast",
+    "ali baba",
+    "all saints church",
+    "allenbell",
+    "alpha milton guest house",
+    "alyesbray lodge hotel",
+    "ambridge",
+    "anatolia",
+    "arbury lodge guesthouse",
+    "archway house",
+    "ashley hotel",
+    "ask",
+    "autumn house",
+    "avalon",
+    "aylesbray lodge guest house",
+    "backstreet bistro",
+    "ballare",
+    "bangkok city",
+    "bedouin",
+    "birmingham new street train station",
+    "bishops stortford train station",
+    "bloomsbury restaurant",
+    "bridge guest house",
+    "broughton house gallery",
+    "broxbourne train station",
+    "byard art",
+    "cafe jello gallery",
+    "cafe uno",
+    "caffee uno",
+    "camboats",
+    "cambridge",
+    "cambridge and county folk museum",
+    "cambridge arts theatre",
+    "cambridge artworks",
+    "cambridge belfry",
+    "cambridge book and print gallery",
+    "cambridge chop house",
+    "cambridge contemporary art",
+    "cambridge lodge restaurant",
+    "cambridge museum of technology",
+    "cambridge punter",
+    "cambridge towninfo centre",
+    "cambridge train station",
+    "cambridge university botanic gardens",
+    "carolina bed and breakfast",
+    "castle galleries",
+    "centre of town at my hotel",
+    "charlie chan",
+    "cherry hinton hall and grounds",
+    "cherry hinton village center",
+    "cherry hinton village centre",
+    "cherry hinton water play",
+    "chiquito restaurant bar",
+    "christ college",
+    "churchills college",
+    "cineworld cinema",
+    "citiroomz",
+    "city centre north bed and breakfast",
+    "city stop restaurant",
+    "cityroomz",
+    "clair hall",
+    "clare college",
+    "clare hall",
+    "clowns cafe",
+    "club salsa",
+    "cocum",
+    "copper kettle",
+    "corpus christi",
+    "cote",
+    "cotto",
+    "cow pizza kitchen and bar",
+    "curry garden",
+    "curry king",
+    "curry prince",
+    "curry queen",
+    "da vinci pizzeria",
+    "darrys cookhouse and wine shop",
+    "de luca cucina and bar",
+    "dojo noodle bar",
+    "don pasquale pizzeria",
+    "downing college",
+    "downing street",
+    "el shaddia guesthouse",
+    "ely",
+    "ely train station",
+    "emmanuel college",
+    "eraina",
+    "express by holiday inn cambridge",
+    "finches bed and breakfast",
+    "fitzbillies restaurant",
+    "fitzwilliam museum",
+    "frankie and bennys",
+    "funky fun house",
+    "galleria",
+    "gallery at 12 a high street",
+    "girton college",
+    "golden curry",
+    "golden house",
+    "golden wok",
+    "gonville and caius college",
+    "gonville hotel",
+    "good luck",
+    "gourmet burger kitchen",
+    "graffiti",
+    "grafton hotel restaurant",
+    "great saint marys church",
+    "hakka",
+    "hamilton lodge",
+    "hobsons house",
+    "holy trinity church",
+    "home",
+    "home from home",
+    "hotel",
+    "hotel du vin and bistro",
+    "hughes hall",
+    "huntingdon marriott hotel",
+    "india house",
+    "j restaurant",
+    "jesus college",
+    "jesus green outdoor pool",
+    "jinling noodle bar",
+    "junction theatre",
+    "kambar",
+    "kettles yard",
+    "kings college",
+    "kings hedges learner pool",
+    "kings lynn train station",
+    "kirkwood house",
+    "kohinoor",
+    "kymmoy",
+    "la margherita",
+    "la mimosa",
+    "la raza",
+    "la tasca",
+    "lan hong house",
+    "lensfield hotel",
+    "leverton house",
+    "limehouse",
+    "little saint marys church",
+    "little seoul",
+    "loch fyne",
+    "london kings cross train station",
+    "london liverpool street",
+    "london liverpool street train station",
+    "lovell lodge",
+    "lynne strover gallery",
+    "magdalene college",
+    "mahal of cambridge",
+    "maharajah tandoori restaurant",
+    "meghna",
+    "meze bar",
+    "michaelhouse cafe",
+    "milton country park",
+    "mumford theatre",
+    "museum",
+    "museum of archaelogy and anthropology",
+    "museum of classical archaeology",
+    "nandos",
+    "nandos city centre",
+    "new england",
+    "nirala",
+    "norwich train station",
+    "nstaot mentioned",
+    "nusha",
+    "old schools",
+    "panahar",
+    "parkside police station",
+    "parkside pools",
+    "peking restaurant",
+    "pembroke college",
+    "peoples portraits exhibition at girton college",
+    "peterborough train station",
+    "pizza express",
+    "pizza hut cherry hinton",
+    "pizza hut city centre",
+    "pizza hut fenditton",
+    "prezzo",
+    "primavera",
+    "queens college",
+    "rajmahal",
+    "regency gallery",
+    "restaurant 17",
+    "restaurant 2 two",
+    "restaurant alimentum",
+    "rice boat",
+    "rice house",
+    "riverboat georgina",
+    "riverside brasserie",
+    "rosas bed and breakfast",
+    "royal spice",
+    "royal standard",
+    "ruskin gallery",
+    "saffron brasserie",
+    "saigon city",
+    "saint barnabas press gallery",
+    "saint catharines college",
+    "saint johns chop house",
+    "saint johns college",
+    "sala thong",
+    "scott polar museum",
+    "scudamores punting co",
+    "sesame restaurant and bar",
+    "sheeps green and lammas land park",
+    "sheeps green and lammas land park fen causeway",
+    "shiraz",
+    "sidney sussex college",
+    "sitar tandoori",
+    "soul tree nightclub",
+    "st johns college",
+    "stazione restaurant and coffee bar",
+    "stevenage train station",
+    "taj tandoori",
+    "tandoori palace",
+    "tang chinese",
+    "tenpin",
+    "thanh binh",
+    "the cambridge corn exchange",
+    "the fez club",
+    "the gallery at 12",
+    "the gandhi",
+    "the gardenia",
+    "the hotpot",
+    "the junction",
+    "the lucky star",
+    "the man on the moon",
+    "the missing sock",
+    "the oak bistro",
+    "the place",
+    "the river bar steakhouse and grill",
+    "the slug and lettuce",
+    "the varsity restaurant",
+    "travellers rest",
+    "trinity college",
+    "ugly duckling",
+    "university arms hotel",
+    "vue cinema",
+    "wagamama",
+    "wandlebury country park",
+    "warkworth house",
+    "whale of a time",
+    "whipple museum of the history of science",
+    "williams art and antiques",
+    "worth house",
+    "yippee noodle bar",
+    "yu garden",
+    "zizzi cambridge",
+    "christ's college",
+    "city centre north b and b",
+    "the lensfield hotel",
+    "alpha-milton guest house",
+    "el shaddai",
+    "churchill college",
+    "the cambridge belfry",
+    "king's college",
+    "great saint mary's church",
+    "restaurant two two",
+    "queens' college",
+    "little saint mary's church",
+    "chinese city centre",
+    "kettle's yard",
+    "pizza hut",
+    "the golden curry",
+    "rosa's bed and breakfast",
+    "the cambridge punter",
+    "the byard art museum",
+    "saint catharine's college",
+    "meze bar restaurant",
+    "the good luck chinese food takeaway",
+    "restaurant one seven",
+    "pizza hut fen ditton",
+    "the nirala",
+    "the fitzwilliam museum",
+    "st. john's college",
+    "gallery at twelve a high street",
+    "sheep's green and lammas land park fen causeway",
+    "the cherry hinton village centre",
+    "pizza express fen ditton",
+    "corpus cristi",
+    "cas",
+    "acorn house",
+    "lens",
+    "the cambridge chop house",
+    "the copper kettle",
+    "the avalon",
+    "saint john's college",
+    "aylesbray lodge",
+    "the alexander bed and breakfast",
+    "cambridge belfy",
+    "people's portraits exhibition at girton college",
+    "gonville",
+    "caffe uno",
+    "the cow pizza kitchen and bar",
+    "lovell ldoge",
+    "cinema",
+    "shiraz restaurant",
+    "park",
+    "the allenbell"
+  ],
+  "restaurant-book day": [
+    "none",
+    "do not care",
+    "friday",
+    "monday",
+    "saterday",
+    "sunday",
+    "thursday",
+    "tuesday",
+    "wednesday"
+  ],
+  "restaurant-book people": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "restaurant-book time": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55"
+  ],
+  "taxi-arrive by": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55"
+  ],
+  "restaurant-area": [
+    "none",
+    "do not care",
+    "centre",
+    "east",
+    "north",
+    "south",
+    "west"
+  ],
+  "hotel-area": [
+    "none",
+    "do not care",
+    "centre",
+    "east",
+    "north",
+    "south",
+    "west"
+  ],
+  "attraction-area": [
+    "none",
+    "do not care",
+    "centre",
+    "east",
+    "north",
+    "south",
+    "west"
+  ]
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
new file mode 100644
index 0000000000000000000000000000000000000000..b0dd00fdf6dc2b824f1f50a44e776c63ce72f14b
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
@@ -0,0 +1,3128 @@
+{
+  "hotel-price range": [
+    "none",
+    "do not care",
+    "cheap",
+    "expensive",
+    "moderate",
+    "request"
+  ],
+  "hotel-type": [
+    "none",
+    "do not care",
+    "bed and breakfast",
+    "guest house",
+    "hotel",
+    "request"
+  ],
+  "hotel-parking": [
+    "none",
+    "do not care",
+    "no",
+    "yes",
+    "request"
+  ],
+  "hotel-book day": [
+    "none",
+    "do not care",
+    "friday",
+    "monday",
+    "saterday",
+    "sunday",
+    "thursday",
+    "tuesday",
+    "wednesday"
+  ],
+  "hotel-book people": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "hotel-book stay": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "train-destination": [
+    "none",
+    "do not care",
+    "bishops stortford",
+    "kings lynn",
+    "london liverpool street",
+    "centre",
+    "bishop stortford",
+    "liverpool",
+    "leicester",
+    "broxbourne",
+    "gourmet burger kitchen",
+    "copper kettle",
+    "bournemouth",
+    "stevenage",
+    "liverpool street",
+    "norwich",
+    "huntingdon marriott hotel",
+    "city centre north",
+    "taj tandoori",
+    "the copper kettle",
+    "peterborough",
+    "ely",
+    "lecester",
+    "london",
+    "willi",
+    "stansted airport",
+    "huntington marriott",
+    "cambridge",
+    "gonv",
+    "glastonbury",
+    "hol",
+    "north",
+    "birmingham new street",
+    "norway",
+    "petersborough",
+    "london kings cross",
+    "curry prince",
+    "bishops storford"
+  ],
+  "train-arrive by": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55",
+    "request"
+  ],
+  "train-departure": [
+    "none",
+    "do not care",
+    "bishops stortford",
+    "kings lynn",
+    "brookshite",
+    "london liverpool street",
+    "cam",
+    "liverpool",
+    "bro",
+    "leicester",
+    "broxbourne",
+    "norwhich",
+    "saint johns",
+    "stevenage",
+    "stansted",
+    "london liverpool",
+    "cambrid",
+    "city hall",
+    "rosas bed and breakfast",
+    "alpha-milton",
+    "wandlebury country park",
+    "norwich",
+    "liecester",
+    "stratford",
+    "peterborough",
+    "duxford",
+    "ely",
+    "london",
+    "stansted airport",
+    "lon",
+    "cambridge",
+    "panahar",
+    "cineworld",
+    "leicaster",
+    "birmingham",
+    "cafe uno",
+    "camboats",
+    "huntingdon",
+    "birmingham new street",
+    "arbu",
+    "alpha milton",
+    "east london",
+    "london kings cross",
+    "hamilton lodge",
+    "aylesbray lodge guest",
+    "el shaddai"
+  ],
+  "train-day": [
+    "none",
+    "do not care",
+    "friday",
+    "monday",
+    "saterday",
+    "sunday",
+    "thursday",
+    "tuesday",
+    "wednesday"
+  ],
+  "train-book people": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "hotel-stars": [
+    "none",
+    "do not care",
+    "0",
+    "1",
+    "2",
+    "3",
+    "4",
+    "5",
+    "request"
+  ],
+  "hotel-internet": [
+    "none",
+    "do not care",
+    "no",
+    "yes",
+    "request"
+  ],
+  "hotel-name": [
+    "none",
+    "do not care",
+    "a and b guest house",
+    "city roomz",
+    "carolina bed and breakfast",
+    "limehouse",
+    "anatolia",
+    "hamilton lodge",
+    "the lensfield hotel",
+    "rosa's bed and breakfast",
+    "gall",
+    "aylesbray lodge",
+    "kirkwood",
+    "cambridge belfry",
+    "warkworth house",
+    "gonville",
+    "belfy hotel",
+    "nus",
+    "alexander",
+    "super 5",
+    "aylesbray lodge guest house",
+    "the gonvile hotel",
+    "allenbell",
+    "nothamilton lodge",
+    "ashley hotel",
+    "autumn house",
+    "hobsons house",
+    "hotel",
+    "ashely hotel",
+    "caridge belfrey",
+    "el shaddia guest house",
+    "avalon",
+    "cote",
+    "city centre north bed and breakfast",
+    "the cambridge belfry",
+    "home from home",
+    "wandlebury coutn",
+    "wankworth house",
+    "city stop rest",
+    "the worth house",
+    "cityroomz",
+    "huntingdon marriottt hotel",
+    "lensfield",
+    "rosas bed and breakfast",
+    "leverton house",
+    "gonville hotel",
+    "holiday inn cambridge",
+    "archway house",
+    "lan hon",
+    "levert",
+    "acorn guest house",
+    "cambridge",
+    "the ashley hotel",
+    "el shaddai",
+    "sleeperz",
+    "alpha milton guest house",
+    "doubletree by hilton cambridge",
+    "tandoori palace",
+    "express by",
+    "express by holiday inn cambridge",
+    "north bed and breakfast",
+    "holiday inn",
+    "arbury lodge guest house",
+    "alexander bed and breakfast",
+    "huntingdon marriott hotel",
+    "royal spice",
+    "sou",
+    "finches bed and breakfast",
+    "the alpha milton",
+    "bridge guest house",
+    "the acorn guest house",
+    "kirkwood house",
+    "eraina",
+    "la margherit",
+    "lensfield hotel",
+    "marriott hotel",
+    "nusha",
+    "city centre bed and breakfast",
+    "the allenbell",
+    "university arms hotel",
+    "clare",
+    "cherr",
+    "wartworth",
+    "acorn place",
+    "lovell lodge",
+    "whale"
+  ],
+  "train-leave at": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55",
+    "request"
+  ],
+  "restaurant-price range": [
+    "none",
+    "do not care",
+    "cheap",
+    "expensive",
+    "moderate",
+    "request"
+  ],
+  "restaurant-food": [
+    "none",
+    "do not care",
+    "british food",
+    "steakhouse",
+    "turkish",
+    "sushi",
+    "north american",
+    "scottish",
+    "french",
+    "austrian",
+    "korean",
+    "eastern european",
+    "swedish",
+    "gastro pub",
+    "modern eclectic",
+    "afternoon tea",
+    "welsh",
+    "christmas",
+    "tuscan",
+    "gastropub",
+    "sri lankan",
+    "molecular gastronomy",
+    "traditional american",
+    "italian",
+    "pizza",
+    "thai",
+    "south african",
+    "creative",
+    "english",
+    "asian",
+    "lebanese",
+    "hungarian",
+    "halal",
+    "portugese",
+    "modern english",
+    "african",
+    "light bites",
+    "malaysian",
+    "venetian",
+    "traditional",
+    "chinese",
+    "vegetarian",
+    "persian",
+    "thai and chinese",
+    "scandinavian",
+    "catalan",
+    "polynesian",
+    "crossover",
+    "canapes",
+    "cantonese",
+    "north african",
+    "seafood",
+    "brazilian",
+    "south indian",
+    "australasian",
+    "belgian",
+    "barbeque",
+    "the americas",
+    "indonesian",
+    "singaporean",
+    "irish",
+    "middle eastern",
+    "dojo noodle bar",
+    "caribbean",
+    "vietnamese",
+    "modern european",
+    "russian",
+    "german",
+    "world",
+    "japanese",
+    "moroccan",
+    "modern global",
+    "indian",
+    "british",
+    "american",
+    "danish",
+    "panasian",
+    "swiss",
+    "basque",
+    "north indian",
+    "modern american",
+    "australian",
+    "european",
+    "corsica",
+    "greek",
+    "northern european",
+    "mediterranean",
+    "portuguese",
+    "romanian",
+    "jamaican",
+    "polish",
+    "international",
+    "unusual",
+    "latin american",
+    "asian oriental",
+    "mexican",
+    "bistro",
+    "cuban",
+    "fusion",
+    "new zealand",
+    "spanish",
+    "eritrean",
+    "afghan",
+    "kosher",
+    "request"
+  ],
+  "attraction-name": [
+    "none",
+    "do not care",
+    "downing college",
+    "fitzwilliam",
+    "clare college",
+    "ruskin gallery",
+    "sidney sussex college",
+    "great saint mary's church",
+    "cherry hinton water play park",
+    "wandlebury country park",
+    "cafe uno",
+    "place",
+    "broughton",
+    "cineworld cinema",
+    "jesus college",
+    "vue cinema",
+    "history of science museum",
+    "mumford theatre",
+    "whale of time",
+    "fitzbillies",
+    "christs church",
+    "churchill college",
+    "museum of classical archaeology",
+    "gonville and caius college",
+    "pizza",
+    "kirkwood",
+    "saint catharines college",
+    "kings college",
+    "parkside",
+    "by",
+    "st catharines college",
+    "saint john's college",
+    "cherry hinton water park",
+    "st christs college",
+    "christ's college",
+    "bangkok city",
+    "scudamores punti co",
+    "free",
+    "great saint marys church",
+    "milton country park",
+    "the fez club",
+    "soultree",
+    "autu",
+    "whipple museum of the history of science",
+    "aylesbray lodge guest house",
+    "broughton house gallery",
+    "peoples portraits exhibition",
+    "primavera",
+    "kettles yard",
+    "all saint's church",
+    "cinema cinema",
+    "regency gallery",
+    "corpus christi",
+    "corn cambridge exchange",
+    "da vinci pizzeria",
+    "school",
+    "hobsons house",
+    "cambride and country folk museum",
+    "north",
+    "da v",
+    "cambridge corn exchange",
+    "soul tree nightclub",
+    "cambridge arts theater",
+    "saint catharine's college",
+    "byard art",
+    "cambridge punter",
+    "cambridge university botanic gardens",
+    "castle galleries",
+    "museum of archaelogy and anthropogy",
+    "no specific location",
+    "cherry hinton hall",
+    "gallery at 12 a high street",
+    "parkside pools",
+    "queen's college",
+    "little saint mary's church",
+    "gallery",
+    "home from home",
+    "tenpin",
+    "the wandlebury",
+    "county folk museum",
+    "swimming pool",
+    "christs college",
+    "cafe jello museum",
+    "scott polar",
+    "christ college",
+    "cambridge museum of technology",
+    "abbey pool and astroturf pitch",
+    "king hedges learner pool",
+    "the cambridge arts theatre",
+    "the castle galleries",
+    "cambridge and country folk museum",
+    "kohinoor",
+    "scudamores punting co",
+    "sidney sussex",
+    "the man on the moon",
+    "little saint marys church",
+    "queens",
+    "the place",
+    "old school",
+    "churchill",
+    "churchills college",
+    "hughes hall",
+    "churchhill college",
+    "riverboat georgina",
+    "belf",
+    "cambridge temporary art",
+    "abc theatre",
+    "cambridge contemporary art museum",
+    "man on the moon",
+    "the junction",
+    "cherry hinton water play",
+    "adc theatre",
+    "gonville hotel",
+    "magdalene college",
+    "peoples portraits exhibition at girton college",
+    "boat",
+    "centre",
+    "sheep's green and lammas land park fen causeway",
+    "the mumford theatre",
+    "archway house",
+    "queens' college",
+    "williams art and antiques",
+    "funky fun house",
+    "cherry hinton village centre",
+    "camboats",
+    "cambridge",
+    "old schools",
+    "kettle's yard",
+    "whale of a time",
+    "the churchill college",
+    "cafe jello gallery",
+    "aut",
+    "salsa",
+    "city",
+    "clare hall",
+    "boating",
+    "pembroke college",
+    "kings hedges learner pool",
+    "caffe uno",
+    "lammas land park",
+    "museum",
+    "the fitzwilliam museum",
+    "the cherry hinton village centre",
+    "the cambridge corn exchange",
+    "fitzwilliam museum",
+    "museum of archaelogy and anthropology",
+    "fez club",
+    "the cambridge punter",
+    "saint johns college",
+    "emmanuel college",
+    "cambridge belf",
+    "scudamore",
+    "lynne strover gallery",
+    "king's college",
+    "whippple museum",
+    "trinity college",
+    "college in the north",
+    "sheep's green",
+    "kambar",
+    "museum of archaelogy",
+    "adc",
+    "garde",
+    "club salsa",
+    "people's portraits exhibition at girton college",
+    "botanic gardens",
+    "carol",
+    "college",
+    "gallery at twelve a high street",
+    "abbey pool and astroturf",
+    "cambridge book and print gallery",
+    "jesus green outdoor pool",
+    "scott polar museum",
+    "saint barnabas press gallery",
+    "cambridge artworks",
+    "older churches",
+    "cambridge contemporary art",
+    "cherry hinton hall and grounds",
+    "univ",
+    "jesus green",
+    "ballare",
+    "abbey pool",
+    "cambridge botanic gardens",
+    "nusha",
+    "worth house",
+    "thanh",
+    "university arms hotel",
+    "cambridge arts theatre",
+    "cafe jello",
+    "cambridge and county folk museum",
+    "the cambridge artworks",
+    "all saints church",
+    "holy trinity church",
+    "contemporary art museum",
+    "architectural churches",
+    "queens college",
+    "trinity street college"
+  ],
+  "restaurant-name": [
+    "none",
+    "do not care",
+    "hotel du vin and bistro",
+    "ask",
+    "gourmet formal kitchen",
+    "the meze bar",
+    "lan hong house",
+    "cow pizza",
+    "one seven",
+    "prezzo",
+    "maharajah tandoori restaurant",
+    "alex",
+    "shanghai",
+    "golden wok",
+    "restaurant",
+    "fitzbillies",
+    "nil",
+    "copper kettle",
+    "meghna",
+    "hk fusion",
+    "bangkok city",
+    "hobsons house",
+    "tang chinese",
+    "anatolia",
+    "ugly duckling",
+    "anatolia and efes restaurant",
+    "sitar tandoori",
+    "city stop",
+    "ashley",
+    "pizza express fen ditton",
+    "molecular gastronomy",
+    "autumn house",
+    "el shaddia guesthouse",
+    "the grafton hotel",
+    "limehouse",
+    "gardenia",
+    "not metioned",
+    "hakka",
+    "michaelhouse cafe",
+    "pipasha",
+    "meze bar",
+    "archway",
+    "molecular gastonomy",
+    "yipee noodle bar",
+    "the peking",
+    "curry prince",
+    "midsummer house restaurant",
+    "pizza hut cherry hinton",
+    "the lucky star",
+    "stazione restaurant and coffee bar",
+    "shanghi family restaurant",
+    "good luck",
+    "j restaurant",
+    "bedouin",
+    "cott",
+    "little seoul",
+    "south",
+    "thanh binh",
+    "el",
+    "efes restaurant",
+    "kohinoor",
+    "clowns",
+    "india",
+    "the slug and lettuce",
+    "shiraz",
+    "barbakan",
+    "zizzi cambridge",
+    "restaurant one seven",
+    "slug and lettuce",
+    "travellers rest",
+    "binh",
+    "worth house",
+    "broughton house gallery",
+    "chiquito",
+    "the river bar steakhouse and grill",
+    "ros",
+    "golden house",
+    "india west",
+    "cam",
+    "panahar",
+    "restaurant 22",
+    "adden",
+    "indian",
+    "hu",
+    "jinling noodle bar",
+    "darrys cookhouse and wine shop",
+    "hobson house",
+    "cambridge be",
+    "el shaddai",
+    "ac",
+    "nandos",
+    "cambridge lodge",
+    "the cow pizza kitchen and bar",
+    "charlie",
+    "rajmahal",
+    "kymmoy",
+    "cambri",
+    "backstreet bistro",
+    "galleria",
+    "restaurant 2 two",
+    "chiquito restaurant bar",
+    "royal standard",
+    "lucky star",
+    "curry king",
+    "grafton hotel restaurant",
+    "mahal of cambridge",
+    "the bedouin",
+    "nus",
+    "the kohinoor",
+    "pizza hut fenditton",
+    "camboats",
+    "the gardenia",
+    "de luca cucina and bar",
+    "nusha",
+    "european",
+    "taj tandoori",
+    "tandoori palace",
+    "golden curry",
+    "efes",
+    "loch fyne",
+    "the maharajah tandoor",
+    "lovel",
+    "restaurant 17",
+    "clowns cafe",
+    "cambridge punter",
+    "bloomsbury restaurant",
+    "la mimosa",
+    "the cambridge chop house",
+    "funky",
+    "cotto",
+    "oak bistro",
+    "restaurant two two",
+    "pipasha restaurant",
+    "river bar steakhouse and grill",
+    "royal spice",
+    "the copper kettle",
+    "graffiti",
+    "nandos city centre",
+    "saffron brasserie",
+    "cambridge chop house",
+    "sitar",
+    "kitchen and bar",
+    "the good luck chinese food takeaway",
+    "clu",
+    "la tasca",
+    "cafe uno",
+    "cote",
+    "the varsity restaurant",
+    "bri",
+    "eraina",
+    "bridge",
+    "fin",
+    "cambridge lodge restaurant",
+    "grafton",
+    "hotpot",
+    "sala thong",
+    "margherita",
+    "wise buddha",
+    "the missing sock",
+    "seasame restaurant and bar",
+    "the dojo noodle bar",
+    "restaurant alimentum",
+    "gastropub",
+    "saigon city",
+    "la margherita",
+    "pizza hut",
+    "curry garden",
+    "ashley hotel",
+    "eraina and michaelhouse cafe",
+    "the golden curry",
+    "curry queen",
+    "cow pizza kitchen and bar",
+    "the peking restaurant:",
+    "hamilton lodge",
+    "alimentum",
+    "yippee noodle bar",
+    "2 two and cote",
+    "shanghai family restaurant",
+    "grafton hotel",
+    "yes",
+    "ali baba",
+    "dif",
+    "fitzbillies restaurant",
+    "peking restaurant",
+    "lev",
+    "nirala",
+    "the alex",
+    "tandoori",
+    "city stop restaurant",
+    "rice house",
+    "cityr",
+    "yu garden",
+    "meze bar restaurant",
+    "the",
+    "don pasquale pizzeria",
+    "rice boat",
+    "the hotpot",
+    "old school",
+    "the oak bistro",
+    "sesame restaurant and bar",
+    "pizza express",
+    "the gandhi",
+    "pizza hut fen ditton",
+    "charlie chan",
+    "da vinci pizzeria",
+    "dojo noodle bar",
+    "gourmet burger kitchen",
+    "the golden house",
+    "india house",
+    "hobso",
+    "missing sock",
+    "pizza hut city centre",
+    "parkside pools",
+    "riverside brasserie",
+    "caffe uno",
+    "primavera",
+    "the nirala",
+    "wagamama",
+    "au",
+    "ian hong house",
+    "frankie and bennys",
+    "4 kings parade city centre",
+    "shiraz restaurant",
+    "scudamores punt",
+    "mahal",
+    "saint johns chop house",
+    "de luca cucina and bar riverside brasserie",
+    "cocum",
+    "la raza"
+  ],
+  "attraction-type": [
+    "none",
+    "do not care",
+    "architecture",
+    "boat",
+    "boating",
+    "camboats",
+    "church",
+    "churchills college",
+    "cinema",
+    "college",
+    "concert",
+    "concerthall",
+    "entertainment",
+    "gallery",
+    "gastropub",
+    "hiking",
+    "hotel",
+    "multiple sports",
+    "museum",
+    "museum kettles yard",
+    "night club",
+    "outdoor",
+    "park",
+    "pool",
+    "special",
+    "sports",
+    "swimming pool",
+    "theater",
+    "theatre",
+    "concert hall",
+    "local site",
+    "nightclub",
+    "hotspot",
+    "request"
+  ],
+  "taxi-leave at": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55",
+    "request"
+  ],
+  "taxi-destination": [
+    "none",
+    "do not care",
+    "a and b guest house",
+    "abbey pool and astroturf pitch",
+    "acorn guest house",
+    "adc theatre",
+    "addenbrookes hospital",
+    "alexander bed and breakfast",
+    "ali baba",
+    "all saints church",
+    "allenbell",
+    "alpha milton guest house",
+    "anatolia",
+    "arbury lodge guesthouse",
+    "archway house",
+    "ashley hotel",
+    "ask",
+    "attraction",
+    "autumn house",
+    "avalon",
+    "aylesbray lodge guest house",
+    "backstreet bistro",
+    "ballare",
+    "bangkok city",
+    "bedouin",
+    "birmingham new street train station",
+    "bishops stortford train station",
+    "bloomsbury restaurant",
+    "bridge guest house",
+    "broughton house gallery",
+    "broxbourne train station",
+    "byard art",
+    "cafe jello gallery",
+    "cafe uno",
+    "camboats",
+    "cambridge",
+    "cambridge and county folk museum",
+    "cambridge arts theatre",
+    "cambridge artworks",
+    "cambridge belfry",
+    "cambridge book and print gallery",
+    "cambridge chop house",
+    "cambridge contemporary art",
+    "cambridge county fair next to the city tourist museum",
+    "cambridge lodge restaurant",
+    "cambridge museum of technology",
+    "cambridge punter",
+    "cambridge road church of christ",
+    "cambridge train station",
+    "cambridge university botanic gardens",
+    "carolina bed and breakfast",
+    "castle galleries",
+    "charlie chan",
+    "cherry hinton hall and grounds",
+    "cherry hinton village centre",
+    "cherry hinton water park",
+    "cherry hinton water play",
+    "chiquito restaurant bar",
+    "christ college",
+    "churchills college",
+    "cineworld cinema",
+    "city centre north bed and breakfast",
+    "city stop restaurant",
+    "cityroomz",
+    "clare college",
+    "clare hall",
+    "clowns cafe",
+    "club salsa",
+    "cocum",
+    "copper kettle",
+    "corpus christi",
+    "cote",
+    "cotto",
+    "cow pizza kitchen and bar",
+    "curry garden",
+    "curry king",
+    "curry prince",
+    "da vinci pizzeria",
+    "darrys cookhouse and wine shop",
+    "de luca cucina and bar",
+    "dojo noodle bar",
+    "don pasquale pizzeria",
+    "downing college",
+    "efes restaurant",
+    "el shaddia guesthouse",
+    "ely train station",
+    "emmanuel college",
+    "eraina",
+    "express by holiday inn cambridge",
+    "finches bed and breakfast",
+    "finders corner newmarket road",
+    "fitzbillies restaurant",
+    "fitzwilliam museum",
+    "frankie and bennys",
+    "funky fun house",
+    "galleria",
+    "gallery at 12 a high street",
+    "gastropub",
+    "golden curry",
+    "golden house",
+    "golden wok",
+    "gonville and caius college",
+    "gonville hotel",
+    "good luck",
+    "gourmet burger kitchen",
+    "graffiti",
+    "grafton hotel restaurant",
+    "great saint marys church",
+    "hakka",
+    "hamilton lodge",
+    "hk fusion",
+    "hobsons house",
+    "holy trinity church",
+    "home from home",
+    "hotel du vin and bistro",
+    "hughes hall",
+    "huntingdon marriott hotel",
+    "ian hong",
+    "india house",
+    "j restaurant",
+    "jesus college",
+    "jesus green outdoor pool",
+    "jinling noodle bar",
+    "kambar",
+    "kettles yard",
+    "kings college",
+    "kings hedges learner pool",
+    "kirkwood house",
+    "kohinoor",
+    "kymmoy",
+    "la margherita",
+    "la mimosa",
+    "la raza",
+    "la tasca",
+    "lan hong house",
+    "leicester train station",
+    "lensfield hotel",
+    "limehouse",
+    "little saint marys church",
+    "little seoul",
+    "loch fyne",
+    "london kings cross train station",
+    "london liverpool street train station",
+    "lovell lodge",
+    "lynne strover gallery",
+    "magdalene college",
+    "mahal of cambridge",
+    "maharajah tandoori restaurant",
+    "meghna",
+    "meze bar",
+    "michaelhouse cafe",
+    "midsummer house restaurant",
+    "milton country park",
+    "mumford theatre",
+    "museum of archaelogy and anthropology",
+    "museum of classical archaeology",
+    "nandos",
+    "nandos city centre",
+    "nil",
+    "nirala",
+    "norwich train station",
+    "nusha",
+    "old schools",
+    "panahar",
+    "parkside police station",
+    "parkside pools",
+    "peking restaurant",
+    "pembroke college",
+    "peoples portraits exhibition at girton college",
+    "peterborough train station",
+    "pipasha restaurant",
+    "pizza express",
+    "pizza hut cherry hinton",
+    "pizza hut city centre",
+    "pizza hut fenditton",
+    "prezzo",
+    "primavera",
+    "queens college",
+    "rajmahal",
+    "regency gallery",
+    "restaurant 17",
+    "restaurant 2 two",
+    "restaurant alimentum",
+    "rice boat",
+    "rice house",
+    "riverboat georgina",
+    "riverside brasserie",
+    "rosas bed and breakfast",
+    "royal spice",
+    "royal standard",
+    "ruskin gallery",
+    "saffron brasserie",
+    "saigon city",
+    "saint barnabas",
+    "saint barnabas press gallery",
+    "saint catharines college",
+    "saint johns chop house",
+    "saint johns college",
+    "sala thong",
+    "scott polar museum",
+    "scudamores punting co",
+    "sesame restaurant and bar",
+    "shanghai family restaurant",
+    "sheeps green and lammas land park fen causeway",
+    "shiraz",
+    "sidney sussex college",
+    "sitar tandoori",
+    "sleeperz hotel",
+    "soul tree nightclub",
+    "st johns chop house",
+    "stansted airport train station",
+    "station road",
+    "stazione restaurant and coffee bar",
+    "stevenage train station",
+    "taj tandoori",
+    "tall monument",
+    "tandoori palace",
+    "tang chinese",
+    "tenpin",
+    "thanh binh",
+    "the anatolia",
+    "the cambridge corn exchange",
+    "the cambridge shop",
+    "the fez club",
+    "the gandhi",
+    "the gardenia",
+    "the hotpot",
+    "the junction",
+    "the lucky star",
+    "the man on the moon",
+    "the missing sock",
+    "the oak bistro",
+    "the place",
+    "the regent street city center",
+    "the river bar steakhouse and grill",
+    "the slug and lettuce",
+    "the varsity restaurant",
+    "travellers rest",
+    "trinity college",
+    "ugly duckling",
+    "university arms hotel",
+    "vue cinema",
+    "wagamama",
+    "wandlebury country park",
+    "wankworth hotel",
+    "warkworth house",
+    "whale of a time",
+    "whipple museum of the history of science",
+    "williams art and antiques",
+    "worth house",
+    "yippee noodle bar",
+    "yu garden",
+    "zizzi cambridge",
+    "leverton house",
+    "the cambridge chop house",
+    "saint john's college",
+    "churchill college",
+    "the nirala",
+    "the cow pizza kitchen and bar",
+    "christ's college",
+    "el shaddai",
+    "saint catharine's college",
+    "camb",
+    "the golden curry",
+    "little saint mary's church",
+    "country folk museum",
+    "meze bar restaurant",
+    "the cambridge belfry",
+    "the fitzwilliam museum",
+    "the lensfield hotel",
+    "pizza express fen ditton",
+    "the cambridge punter",
+    "king's college",
+    "the cherry hinton village centre",
+    "shiraz restaurant",
+    "sheep's green and lammas land park fen causeway",
+    "caffe uno",
+    "the ghandi",
+    "the copper kettle",
+    "man on the moon concert hall",
+    "alpha-milton guest house",
+    "queen's college",
+    "restaurant one seven",
+    "restaurant two two",
+    "city centre north b and b",
+    "rosa's bed and breakfast",
+    "the good luck chinese food takeaway",
+    "not museum of archaeology and anthropologymentioned",
+    "tandori in cambridge",
+    "kettle's yard",
+    "megna",
+    "grou",
+    "gallery at twelve a high street",
+    "maharajah tandoori restaurant",
+    "pizza hut fen ditton",
+    "gandhi",
+    "tranh binh",
+    "kambur",
+    "people's portraits exhibition at girton college",
+    "hotel",
+    "restaurant",
+    "the galleria",
+    "queens' college",
+    "great saint mary's church",
+    "theathre",
+    "cambridge artworks",
+    "acorn house",
+    "shiraz",
+    "riverboat georginawd",
+    "mic",
+    "the gallery at twelve",
+    "the soul tree",
+    "finches"
+  ],
+  "taxi-departure": [
+    "none",
+    "do not care",
+    "172 chestertown road",
+    "4455 woodbridge road",
+    "a and b guest house",
+    "abbey pool and astroturf pitch",
+    "acorn guest house",
+    "adc theatre",
+    "addenbrookes hospital",
+    "alexander bed and breakfast",
+    "ali baba",
+    "all saints church",
+    "allenbell",
+    "alpha milton guest house",
+    "alyesbray lodge hotel",
+    "ambridge",
+    "anatolia",
+    "arbury lodge guesthouse",
+    "archway house",
+    "ashley hotel",
+    "ask",
+    "autumn house",
+    "avalon",
+    "aylesbray lodge guest house",
+    "backstreet bistro",
+    "ballare",
+    "bangkok city",
+    "bedouin",
+    "birmingham new street train station",
+    "bishops stortford train station",
+    "bloomsbury restaurant",
+    "bridge guest house",
+    "broughton house gallery",
+    "broxbourne train station",
+    "byard art",
+    "cafe jello gallery",
+    "cafe uno",
+    "caffee uno",
+    "camboats",
+    "cambridge",
+    "cambridge and county folk museum",
+    "cambridge arts theatre",
+    "cambridge artworks",
+    "cambridge belfry",
+    "cambridge book and print gallery",
+    "cambridge chop house",
+    "cambridge contemporary art",
+    "cambridge lodge restaurant",
+    "cambridge museum of technology",
+    "cambridge punter",
+    "cambridge towninfo centre",
+    "cambridge train station",
+    "cambridge university botanic gardens",
+    "carolina bed and breakfast",
+    "castle galleries",
+    "centre of town at my hotel",
+    "charlie chan",
+    "cherry hinton hall and grounds",
+    "cherry hinton village center",
+    "cherry hinton village centre",
+    "cherry hinton water play",
+    "chiquito restaurant bar",
+    "christ college",
+    "churchills college",
+    "cineworld cinema",
+    "citiroomz",
+    "city centre north bed and breakfast",
+    "city stop restaurant",
+    "cityroomz",
+    "clair hall",
+    "clare college",
+    "clare hall",
+    "clowns cafe",
+    "club salsa",
+    "cocum",
+    "copper kettle",
+    "corpus christi",
+    "cote",
+    "cotto",
+    "cow pizza kitchen and bar",
+    "curry garden",
+    "curry king",
+    "curry prince",
+    "curry queen",
+    "da vinci pizzeria",
+    "darrys cookhouse and wine shop",
+    "de luca cucina and bar",
+    "dojo noodle bar",
+    "don pasquale pizzeria",
+    "downing college",
+    "downing street",
+    "el shaddia guesthouse",
+    "ely",
+    "ely train station",
+    "emmanuel college",
+    "eraina",
+    "express by holiday inn cambridge",
+    "finches bed and breakfast",
+    "fitzbillies restaurant",
+    "fitzwilliam museum",
+    "frankie and bennys",
+    "funky fun house",
+    "galleria",
+    "gallery at 12 a high street",
+    "girton college",
+    "golden curry",
+    "golden house",
+    "golden wok",
+    "gonville and caius college",
+    "gonville hotel",
+    "good luck",
+    "gourmet burger kitchen",
+    "graffiti",
+    "grafton hotel restaurant",
+    "great saint marys church",
+    "hakka",
+    "hamilton lodge",
+    "hobsons house",
+    "holy trinity church",
+    "home",
+    "home from home",
+    "hotel",
+    "hotel du vin and bistro",
+    "hughes hall",
+    "huntingdon marriott hotel",
+    "india house",
+    "j restaurant",
+    "jesus college",
+    "jesus green outdoor pool",
+    "jinling noodle bar",
+    "junction theatre",
+    "kambar",
+    "kettles yard",
+    "kings college",
+    "kings hedges learner pool",
+    "kings lynn train station",
+    "kirkwood house",
+    "kohinoor",
+    "kymmoy",
+    "la margherita",
+    "la mimosa",
+    "la raza",
+    "la tasca",
+    "lan hong house",
+    "lensfield hotel",
+    "leverton house",
+    "limehouse",
+    "little saint marys church",
+    "little seoul",
+    "loch fyne",
+    "london kings cross train station",
+    "london liverpool street",
+    "london liverpool street train station",
+    "lovell lodge",
+    "lynne strover gallery",
+    "magdalene college",
+    "mahal of cambridge",
+    "maharajah tandoori restaurant",
+    "meghna",
+    "meze bar",
+    "michaelhouse cafe",
+    "milton country park",
+    "mumford theatre",
+    "museum",
+    "museum of archaelogy and anthropology",
+    "museum of classical archaeology",
+    "nandos",
+    "nandos city centre",
+    "new england",
+    "nirala",
+    "norwich train station",
+    "nstaot mentioned",
+    "nusha",
+    "old schools",
+    "panahar",
+    "parkside police station",
+    "parkside pools",
+    "peking restaurant",
+    "pembroke college",
+    "peoples portraits exhibition at girton college",
+    "peterborough train station",
+    "pizza express",
+    "pizza hut cherry hinton",
+    "pizza hut city centre",
+    "pizza hut fenditton",
+    "prezzo",
+    "primavera",
+    "queens college",
+    "rajmahal",
+    "regency gallery",
+    "restaurant 17",
+    "restaurant 2 two",
+    "restaurant alimentum",
+    "rice boat",
+    "rice house",
+    "riverboat georgina",
+    "riverside brasserie",
+    "rosas bed and breakfast",
+    "royal spice",
+    "royal standard",
+    "ruskin gallery",
+    "saffron brasserie",
+    "saigon city",
+    "saint barnabas press gallery",
+    "saint catharines college",
+    "saint johns chop house",
+    "saint johns college",
+    "sala thong",
+    "scott polar museum",
+    "scudamores punting co",
+    "sesame restaurant and bar",
+    "sheeps green and lammas land park",
+    "sheeps green and lammas land park fen causeway",
+    "shiraz",
+    "sidney sussex college",
+    "sitar tandoori",
+    "soul tree nightclub",
+    "st johns college",
+    "stazione restaurant and coffee bar",
+    "stevenage train station",
+    "taj tandoori",
+    "tandoori palace",
+    "tang chinese",
+    "tenpin",
+    "thanh binh",
+    "the cambridge corn exchange",
+    "the fez club",
+    "the gallery at 12",
+    "the gandhi",
+    "the gardenia",
+    "the hotpot",
+    "the junction",
+    "the lucky star",
+    "the man on the moon",
+    "the missing sock",
+    "the oak bistro",
+    "the place",
+    "the river bar steakhouse and grill",
+    "the slug and lettuce",
+    "the varsity restaurant",
+    "travellers rest",
+    "trinity college",
+    "ugly duckling",
+    "university arms hotel",
+    "vue cinema",
+    "wagamama",
+    "wandlebury country park",
+    "warkworth house",
+    "whale of a time",
+    "whipple museum of the history of science",
+    "williams art and antiques",
+    "worth house",
+    "yippee noodle bar",
+    "yu garden",
+    "zizzi cambridge",
+    "christ's college",
+    "city centre north b and b",
+    "the lensfield hotel",
+    "alpha-milton guest house",
+    "el shaddai",
+    "churchill college",
+    "the cambridge belfry",
+    "king's college",
+    "great saint mary's church",
+    "restaurant two two",
+    "queens' college",
+    "little saint mary's church",
+    "chinese city centre",
+    "kettle's yard",
+    "pizza hut",
+    "the golden curry",
+    "rosa's bed and breakfast",
+    "the cambridge punter",
+    "the byard art museum",
+    "saint catharine's college",
+    "meze bar restaurant",
+    "the good luck chinese food takeaway",
+    "restaurant one seven",
+    "pizza hut fen ditton",
+    "the nirala",
+    "the fitzwilliam museum",
+    "st. john's college",
+    "gallery at twelve a high street",
+    "sheep's green and lammas land park fen causeway",
+    "the cherry hinton village centre",
+    "pizza express fen ditton",
+    "corpus cristi",
+    "cas",
+    "acorn house",
+    "lens",
+    "the cambridge chop house",
+    "the copper kettle",
+    "the avalon",
+    "saint john's college",
+    "aylesbray lodge",
+    "the alexander bed and breakfast",
+    "cambridge belfy",
+    "people's portraits exhibition at girton college",
+    "gonville",
+    "caffe uno",
+    "the cow pizza kitchen and bar",
+    "lovell ldoge",
+    "cinema",
+    "shiraz restaurant",
+    "park",
+    "the allenbell"
+  ],
+  "restaurant-book day": [
+    "none",
+    "do not care",
+    "friday",
+    "monday",
+    "saterday",
+    "sunday",
+    "thursday",
+    "tuesday",
+    "wednesday"
+  ],
+  "restaurant-book people": [
+    "none",
+    "do not care",
+    "1",
+    "10 or more",
+    "2",
+    "3",
+    "4",
+    "5",
+    "6",
+    "7",
+    "8",
+    "9"
+  ],
+  "restaurant-book time": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55"
+  ],
+  "taxi-arrive by": [
+    "none",
+    "do not care",
+    "00:00",
+    "00:05",
+    "00:10",
+    "00:15",
+    "00:20",
+    "00:25",
+    "00:30",
+    "00:35",
+    "00:40",
+    "00:45",
+    "00:50",
+    "00:55",
+    "01:00",
+    "01:05",
+    "01:10",
+    "01:15",
+    "01:20",
+    "01:25",
+    "01:30",
+    "01:35",
+    "01:40",
+    "01:45",
+    "01:50",
+    "01:55",
+    "02:00",
+    "02:05",
+    "02:10",
+    "02:15",
+    "02:20",
+    "02:25",
+    "02:30",
+    "02:35",
+    "02:40",
+    "02:45",
+    "02:50",
+    "02:55",
+    "03:00",
+    "03:05",
+    "03:10",
+    "03:15",
+    "03:20",
+    "03:25",
+    "03:30",
+    "03:35",
+    "03:40",
+    "03:45",
+    "03:50",
+    "03:55",
+    "04:00",
+    "04:05",
+    "04:10",
+    "04:15",
+    "04:20",
+    "04:25",
+    "04:30",
+    "04:35",
+    "04:40",
+    "04:45",
+    "04:50",
+    "04:55",
+    "05:00",
+    "05:05",
+    "05:10",
+    "05:15",
+    "05:20",
+    "05:25",
+    "05:30",
+    "05:35",
+    "05:40",
+    "05:45",
+    "05:50",
+    "05:55",
+    "06:00",
+    "06:05",
+    "06:10",
+    "06:15",
+    "06:20",
+    "06:25",
+    "06:30",
+    "06:35",
+    "06:40",
+    "06:45",
+    "06:50",
+    "06:55",
+    "07:00",
+    "07:05",
+    "07:10",
+    "07:15",
+    "07:20",
+    "07:25",
+    "07:30",
+    "07:35",
+    "07:40",
+    "07:45",
+    "07:50",
+    "07:55",
+    "08:00",
+    "08:05",
+    "08:10",
+    "08:15",
+    "08:20",
+    "08:25",
+    "08:30",
+    "08:35",
+    "08:40",
+    "08:45",
+    "08:50",
+    "08:55",
+    "09:00",
+    "09:05",
+    "09:10",
+    "09:15",
+    "09:20",
+    "09:25",
+    "09:30",
+    "09:35",
+    "09:40",
+    "09:45",
+    "09:50",
+    "09:55",
+    "10:00",
+    "10:05",
+    "10:10",
+    "10:15",
+    "10:20",
+    "10:25",
+    "10:30",
+    "10:35",
+    "10:40",
+    "10:45",
+    "10:50",
+    "10:55",
+    "11:00",
+    "11:05",
+    "11:10",
+    "11:15",
+    "11:20",
+    "11:25",
+    "11:30",
+    "11:35",
+    "11:40",
+    "11:45",
+    "11:50",
+    "11:55",
+    "12:00",
+    "12:05",
+    "12:10",
+    "12:15",
+    "12:20",
+    "12:25",
+    "12:30",
+    "12:35",
+    "12:40",
+    "12:45",
+    "12:50",
+    "12:55",
+    "13:00",
+    "13:05",
+    "13:10",
+    "13:15",
+    "13:20",
+    "13:25",
+    "13:30",
+    "13:35",
+    "13:40",
+    "13:45",
+    "13:50",
+    "13:55",
+    "14:00",
+    "14:05",
+    "14:10",
+    "14:15",
+    "14:20",
+    "14:25",
+    "14:30",
+    "14:35",
+    "14:40",
+    "14:45",
+    "14:50",
+    "14:55",
+    "15:00",
+    "15:05",
+    "15:10",
+    "15:15",
+    "15:20",
+    "15:25",
+    "15:30",
+    "15:35",
+    "15:40",
+    "15:45",
+    "15:50",
+    "15:55",
+    "16:00",
+    "16:05",
+    "16:10",
+    "16:15",
+    "16:20",
+    "16:25",
+    "16:30",
+    "16:35",
+    "16:40",
+    "16:45",
+    "16:50",
+    "16:55",
+    "17:00",
+    "17:05",
+    "17:10",
+    "17:15",
+    "17:20",
+    "17:25",
+    "17:30",
+    "17:35",
+    "17:40",
+    "17:45",
+    "17:50",
+    "17:55",
+    "18:00",
+    "18:05",
+    "18:10",
+    "18:15",
+    "18:20",
+    "18:25",
+    "18:30",
+    "18:35",
+    "18:40",
+    "18:45",
+    "18:50",
+    "18:55",
+    "19:00",
+    "19:05",
+    "19:10",
+    "19:15",
+    "19:20",
+    "19:25",
+    "19:30",
+    "19:35",
+    "19:40",
+    "19:45",
+    "19:50",
+    "19:55",
+    "20:00",
+    "20:05",
+    "20:10",
+    "20:15",
+    "20:20",
+    "20:25",
+    "20:30",
+    "20:35",
+    "20:40",
+    "20:45",
+    "20:50",
+    "20:55",
+    "21:00",
+    "21:05",
+    "21:10",
+    "21:15",
+    "21:20",
+    "21:25",
+    "21:30",
+    "21:35",
+    "21:40",
+    "21:45",
+    "21:50",
+    "21:55",
+    "22:00",
+    "22:05",
+    "22:10",
+    "22:15",
+    "22:20",
+    "22:25",
+    "22:30",
+    "22:35",
+    "22:40",
+    "22:45",
+    "22:50",
+    "22:55",
+    "23:00",
+    "23:05",
+    "23:10",
+    "23:15",
+    "23:20",
+    "23:25",
+    "23:30",
+    "23:35",
+    "23:40",
+    "23:45",
+    "23:50",
+    "23:55",
+    "request"
+  ],
+  "restaurant-area": [
+    "none",
+    "do not care",
+    "centre",
+    "east",
+    "north",
+    "south",
+    "west",
+    "request"
+  ],
+  "hotel-area": [
+    "none",
+    "do not care",
+    "centre",
+    "east",
+    "north",
+    "south",
+    "west",
+    "request"
+  ],
+  "attraction-area": [
+    "none",
+    "do not care",
+    "centre",
+    "east",
+    "north",
+    "south",
+    "west",
+    "request"
+  ],
+  "hospital-department": [
+    "none",
+    "do not care",
+    "acute medical assessment unit",
+    "acute medicine for the elderly",
+    "antenatal",
+    "cambridge eye unit",
+    "cardiology",
+    "cardiology and coronary care unit",
+    "childrens oncology and haematology",
+    "childrens surgical and medicine",
+    "clinical decisions unit",
+    "clinical research facility",
+    "coronary care unit",
+    "diabetes and endocrinology",
+    "emergency department",
+    "gastroenterology",
+    "gynaecology",
+    "haematology",
+    "haematology and haematological oncology",
+    "haematology day unit",
+    "hepatobillary and gastrointestinal surgery regional referral centre",
+    "hepatology",
+    "infectious diseases",
+    "infusion services",
+    "inpatient occupational therapy",
+    "intermediate dependancy area",
+    "john farman intensive care unit",
+    "medical decisions unit",
+    "medicine for the elderly",
+    "neonatal unit",
+    "neurology",
+    "neurology neurosurgery",
+    "neurosciences",
+    "neurosciences critical care unit",
+    "oncology",
+    "oral and maxillofacial surgery and ent",
+    "paediatric clinic",
+    "paediatric day unit",
+    "paediatric intensive care unit",
+    "plastic and vascular surgery plastics",
+    "psychiatry",
+    "respiratory medicine",
+    "surgery",
+    "teenage cancer trust unit",
+    "transitional care",
+    "transplant high dependency unit",
+    "trauma and orthopaedics",
+    "trauma high dependency unit",
+    "urology"
+  ],
+  "police-postcode": [
+    "request"
+  ],
+  "restaurant-postcode": [
+    "request"
+  ],
+  "train-duration": [
+    "request"
+  ],
+  "train-trainid": [
+    "request"
+  ],
+  "hospital-address": [
+    "request"
+  ],
+  "restaurant-phone": [
+    "request"
+  ],
+  "hotel-phone": [
+    "request"
+  ],
+  "restaurant-address": [
+    "request"
+  ],
+  "hotel-postcode": [
+    "request"
+  ],
+  "attraction-phone": [
+    "request"
+  ],
+  "attraction-entrance fee": [
+    "request"
+  ],
+  "hotel-reference": [
+    "request"
+  ],
+  "taxi-taxi types": [
+    "request"
+  ],
+  "attraction-address": [
+    "request"
+  ],
+  "hospital-phone": [
+    "request"
+  ],
+  "attraction-postcode": [
+    "request"
+  ],
+  "police-address": [
+    "request"
+  ],
+  "taxi-taxi phone": [
+    "request"
+  ],
+  "train-price": [
+    "request"
+  ],
+  "hospital-postcode": [
+    "request"
+  ],
+  "police-phone": [
+    "request"
+  ],
+  "hotel-address": [
+    "request"
+  ],
+  "restaurant-reference": [
+    "request"
+  ],
+  "train-reference": [
+    "request"
+  ]
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
new file mode 100644
index 0000000000000000000000000000000000000000..87e315363ad3dfd8cadd5bd10cfd7e5047450160
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
@@ -0,0 +1,57 @@
+{
+  "hotel-price range": "preferred cost or price of the hotel",
+  "hotel-type": "what is the type of the hotel",
+  "hotel-parking": "does the hotel have parking",
+  "hotel-book stay": "number of nights for the hotel reservation",
+  "hotel-book day": "starting day of the hotel booking",
+  "hotel-book people": "number of people for the hotel booking",
+  "hotel-area": "area or place of the hotel",
+  "hotel-stars": "star rating of the hotel",
+  "hotel-internet": "does the hotel have internet or wifi",
+  "hotel-name": "name of the hotel",
+  "hotel-phone": "phone number of the hotel",
+  "hotel-postcode": "postcode of the hotel",
+  "hotel-reference": "booking reference of the hotel booking",
+  "hotel-address": "street address of the hotel",
+  "train-destination": "train station you want to travel to",
+  "train-day": "day of the train booking",
+  "train-departure": "train station you want to leave from",
+  "train-arrive by": "arrival time of the train",
+  "train-book people": "number of people for the train booking",
+  "train-leave at": "departure time for the train",
+  "train-duration": "duration of the train journey",
+  "train-trainid": "train identifier or number",
+  "train-price": "how much does the train trip cost",
+  "train-reference": "booking reference of the train booking",
+  "attraction-type": "type of attraction or point of interest",
+  "attraction-area": "area or place of the attraction",
+  "attraction-name": "name of the attraction",
+  "attraction-phone": "phone number of the attraction",
+  "attraction-entrance fee": "entrace fee at the attraction",
+  "attraction-address": "street address of the attraction",
+  "attraction-postcode": "postcode of the attraction",
+  "restaurant-book people": "number of people for the restaurant booking",
+  "restaurant-book day": "weekday for the restaurant booking",
+  "restaurant-book time": "time of the restaurant booking",
+  "restaurant-food": "type of food served at the restaurant",
+  "restaurant-price range": "preferred cost or price of the restaurant",
+  "restaurant-name": "name of the restaurant",
+  "restaurant-area": "area or place of the restaurant",
+  "restaurant-postcode": "postcode of the restaurant",
+  "restaurant-phone": "phone number of the restaurant",
+  "restaurant-address": "street address of the restaurant",
+  "restaurant-reference": "booking reference of the hotel booking",
+  "taxi-leave at": "what time you want the taxi to leave by",
+  "taxi-destination": "where you want the taxi to drop you off",
+  "taxi-departure": "where you want the taxi to pick you up",
+  "taxi-arrive by": "what time you to arrive at your destination",
+  "taxi-taxi types": "vehicle type of the taxi",
+  "taxi-taxi phone": "phone number of the taxi",
+  "hospital-department": "name of hospital department",
+  "hospital-address": "street address of the hospital",
+  "hospital-phone": "phone number of the hospital",
+  "hospital-postcode": "postcode of the hospital",
+  "police-postcode": "postcode of the police station",
+  "police-address": "street address of the police station",
+  "police-phone": "phone number of the police station"
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b9c3365764eb8f1c1cab5d68674dd85b39ce2a
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py
@@ -0,0 +1,168 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Create Ontology Embeddings"""
+
+import json
+import os
+import random
+
+import torch
+import numpy as np
+
+
+# Slot mapping table for description extractions
+# SLOT_NAME_MAPPINGS = {
+#     'arrive at': 'arriveAt',
+#     'arrive by': 'arriveBy',
+#     'leave at': 'leaveAt',
+#     'leave by': 'leaveBy',
+#     'arriveby': 'arriveBy',
+#     'arriveat': 'arriveAt',
+#     'leaveat': 'leaveAt',
+#     'leaveby': 'leaveBy',
+#     'price range': 'pricerange'
+# }
+
+# Set up global data directory
+def set_datadir(dir):
+    global DATA_DIR
+    DATA_DIR = dir
+
+
+# Set seeds
+def set_seed(args):
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+
+# Get embeddings for slots and candidates
+def get_slot_candidate_embeddings(set_type, args, tokenizer, embedding_model, save_to_file=True):
+    # Get set alots and candidates
+    reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
+    ontology = json.load(reader)
+    reader.close()
+
+    reader = open(os.path.join(DATA_DIR, 'slot_descriptions.json'), 'r')
+    slot_descriptions = json.load(reader)
+    reader.close()
+
+    embedding_model.eval()
+
+    slots = dict()
+    for slot in ontology:
+        if args.use_descriptions:
+            # d, s = slot.split('-', 1)
+            # s = SLOT_NAME_MAPPINGS[s] if s in SLOT_NAME_MAPPINGS else s
+            # s = d + '-' + s
+            # if slot in slot_descriptions:
+            desc = slot_descriptions[slot]
+            # elif slot.lower() in slot_descriptions:
+            #     desc = slot_descriptions[s.lower()]
+            # else:
+            #     desc = slot.replace('-', ' ')
+        else:
+            desc = slot
+
+        # Tokenize slot and get embeddings
+        feats = tokenizer.encode_plus(desc, add_special_tokens = True,
+                                            max_length = args.max_slot_len, padding='max_length',
+                                            truncation = 'longest_first')
+
+        with torch.no_grad():
+            input_ids = torch.tensor([feats['input_ids']]).to(embedding_model.device) # [1, max_slot_len]
+            if 'token_type_ids' in feats:
+                token_type_ids = torch.tensor([feats['token_type_ids']]).to(embedding_model.device) # [1, max_slot_len]
+                if 'attention_mask' in feats:
+                    attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) # [1, max_slot_len]
+                    feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
+                                            attention_mask=attention_mask)
+                    attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
+                    feats = feats * attention_mask # [1, max_slot_len, hidden_dim]
+                else:
+                    feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids)
+            else:
+                if 'attention_mask' in feats:
+                    attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device)
+                    feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
+                    attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
+                    feats = feats * attention_mask # [1, max_slot_len, hidden_dim]
+                else:
+                    feats, pooled_feats = embedding_model(input_ids=input_ids) # [1, max_slot_len, hidden_dim]
+        
+        if args.set_similarity:
+            slot_emb = feats[0, :, :].detach().cpu() # [seq_len, hidden_dim]
+        else:
+            if args.candidate_pooling == 'cls' and pooled_feats is not None:
+                slot_emb = pooled_feats[0, :].detach().cpu() # [hidden_dim]
+            elif args.candidate_pooling == 'mean':
+                feats = feats.sum(1)
+                feats = torch.nn.functional.layer_norm(feats, feats.size())
+                slot_emb = feats[0, :].detach().cpu() # [hidden_dim]
+
+        # Tokenize value candidates and get embeddings
+        values = ontology[slot]
+        is_requestable = False
+        if 'request' in values:
+            is_requestable = True
+            values.remove('request')
+        if values:
+            feats = [tokenizer.encode_plus(val, add_special_tokens = True,
+                                                max_length = args.max_candidate_len, padding='max_length',
+                                                truncation = 'longest_first')
+                    for val in values]
+            with torch.no_grad():
+                input_ids = torch.tensor([f['input_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
+                if 'token_type_ids' in feats[0]:
+                    token_type_ids = torch.tensor([f['token_type_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
+                    if 'attention_mask' in feats[0]:
+                        attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
+                        feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
+                                                attention_mask=attention_mask)
+                        attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
+                        feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
+                    else:
+                        feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) # [num_candidates, max_candidate_len, hidden_dim]
+                else:
+                    if 'attention_mask' in feats[0]:
+                        attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device)
+                        feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
+                        attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
+                        feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
+                    else:
+                        feats, pooled_feats = embedding_model(input_ids=input_ids) # [num_candidates, max_candidate_len, hidden_dim]
+            
+            if args.set_similarity:
+                feats = feats.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim]
+            else:
+                if args.candidate_pooling == 'cls' and pooled_feats is not None:
+                    feats = pooled_feats.detach().cpu()
+                elif args.candidate_pooling == "mean":
+                    feats = feats.sum(1)
+                    feats = torch.nn.functional.layer_norm(feats, feats.size())
+                    feats = feats.detach().cpu()
+        else:
+            feats = None
+        slots[slot] = (slot_emb, feats, is_requestable)
+
+    # Dump tensors for use in training
+    if save_to_file:
+        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
+        torch.save(slots, writer)
+    
+    return slots
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/utils.py b/convlab/dst/setsumbt/multiwoz/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..485dee643148ad1d37069064ebc4e2e4553b3dac
--- /dev/null
+++ b/convlab/dst/setsumbt/multiwoz/dataset/utils.py
@@ -0,0 +1,446 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Code adapted from the TRADE preprocessing code (https://github.com/jasonwu0731/trade-dst)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MultiWOZ2.1/3 data processing utilities"""
+
+import re
+import os
+
+from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
+from convlab.dst.rule.multiwoz import normalize_value
+
+# ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train']
+ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train', 'hospital', 'police']
+def set_util_domains(domains):
+    global ACTIVE_DOMAINS
+    ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
+
+MAPPING_PATH = os.path.abspath(__file__).replace('utils.py', 'mapping.pair')
+# Read replacement pairs from the mapping.pair file
+REPLACEMENTS = []
+for line in open(MAPPING_PATH).readlines():
+    tok_from, tok_to = line.replace('\n', '').split('\t')
+    REPLACEMENTS.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
+
+# Extract belief state from mturk annotations
+def build_dialoguestate(metadata, get_domains=False):
+    domains_list = [dom for dom in ACTIVE_DOMAINS if dom in metadata]
+    dialogue_state, domains = [], []
+    for domain in domains_list:
+        active = False
+        # Extract booking information
+        booking = []
+        for slot in sorted(metadata[domain]['book'].keys()):
+            if slot != 'booked':
+                if metadata[domain]['book'][slot] == 'not mentioned':
+                    continue
+                if metadata[domain]['book'][slot] != '':
+                    val = ['%s-book %s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['book'][slot])]
+                    dialogue_state.append(val)
+                    active = True
+
+        for slot in metadata[domain]['semi']:
+            if metadata[domain]['semi'][slot] == 'not mentioned':
+                continue
+            elif metadata[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", 'don not care',
+                                                    'do not care', 'does not care']:
+                dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), 'do not care'])
+                active = True
+            elif metadata[domain]['semi'][slot]:
+                dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['semi'][slot])])
+                active = True
+
+        if active:
+            domains.append(domain)
+
+    if get_domains:
+        return domains
+    return clean_dialoguestate(dialogue_state)
+
+
+PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive']
+BOOLEAN = ['do not care', 'yes', 'no']
+DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday',
+        'friday', 'saterday', 'sunday']
+QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
+TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
+TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l]
+
+VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
+            'cityroomz': 'city roomz', '  ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
+            'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
+            'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
+            'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
+            'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
+
+def map_values(value):
+    for old, new in VALUE_MAP.items():
+        value = value.replace(old, new)
+    return value
+
+def clean_dialoguestate(states, is_acts=False):
+    # path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
+    # path = os.path.join(path, 'data/multiwoz/value_dict.json')
+    # value_dict = json.load(open(path))
+    clean_state = []
+    for slot, value in states:
+        if 'pricerange' in slot:
+            d, s = slot.split('-', 1)
+            s = 'price range'
+            slot = f'{d}-{s}'
+            if value in PRICERANGE:
+                clean_state.append([slot, value])
+            elif True in [v in value for v in PRICERANGE]:
+                value = [v for v in PRICERANGE if v in value][0]
+                clean_state.append([slot, value])
+            elif value == '?' and is_acts:
+                clean_state.append([slot, value])
+            else:
+                continue
+        elif 'parking' in slot or 'internet' in slot:
+            if value in BOOLEAN:
+                clean_state.append([slot, value])
+            if value == 'free':
+                value = 'yes'
+                clean_state.append([slot, value])
+            elif True in [v in value for v in BOOLEAN]:
+                value = [v for v in BOOLEAN if v in value][0]
+                clean_state.append([slot, value])
+            elif value == '?' and is_acts:
+                clean_state.append([slot, value])
+            else:
+                continue
+        elif 'day' in slot:
+            if value in DAYS:
+                clean_state.append([slot, value])
+            elif True in [v in value for v in DAYS]:
+                value = [v for v in DAYS if v in value][0]
+                clean_state.append([slot, value])
+            else:
+                continue
+        elif 'people' in slot or 'duration' in slot or 'stay' in slot:
+            if value in QUANTITIES:
+                clean_state.append([slot, value])
+            elif True in [v in value for v in QUANTITIES]:
+                value = [v for v in QUANTITIES if v in value][0]
+                clean_state.append([slot, value])
+            elif value == '?' and is_acts:
+                clean_state.append([slot, value])
+            else:
+                try:
+                    value = int(value)
+                    if value >= 10:
+                        value = '10 or more'
+                        clean_state.append([slot, value])
+                    else:
+                        continue
+                except:
+                    continue
+        elif 'time' in slot or 'leaveat' in slot or 'arriveby' in slot:
+            if 'leaveat' in slot:
+                d, s = slot.split('-', 1)
+                s = 'leave at'
+                slot = f'{d}-{s}'
+            if 'arriveby' in slot:
+                d, s = slot.split('-', 1)
+                s = 'arrive by'
+                slot = f'{d}-{s}'
+            if value in TIME:
+                if value == 'do not care':
+                    clean_state.append([slot, value])
+                else:
+                    h, m = value.split(':')
+                    if int(m) % 5 == 0:
+                        clean_state.append([slot, value])
+                    else:
+                        m = round(int(m) / 5) * 5
+                        h = int(h)
+                        if m == 60:
+                            m = 0
+                            h += 1
+                        if h >= 24:
+                            h -= 24
+                        value = '%02i:%02i' % (h, m)
+                        clean_state.append([slot, value])
+            elif True in [v in value for v in TIME]:
+                value = [v for v in TIME if v in value][0]
+                h, m = value.split(':')
+                if int(m) % 5 == 0:
+                    clean_state.append([slot, value])
+                else:
+                    m = round(int(m) / 5) * 5
+                    h = int(h)
+                    if m == 60:
+                        m = 0
+                        h += 1
+                    if h >= 24:
+                        h -= 24
+                    value = '%02i:%02i' % (h, m)
+                    clean_state.append([slot, value])
+            elif value == '?' and is_acts:
+                clean_state.append([slot, value])
+            else:
+                continue
+        elif 'stars' in slot:
+            if len(value) == 1 or value == 'do not care':
+                clean_state.append([slot, value])
+            elif value == '?' and is_acts:
+                clean_state.append([slot, value])
+            elif len(value) > 1:
+                try:
+                    value = int(value[0])
+                    value = str(value)
+                    clean_state.append([slot, value])
+                except:
+                    continue
+        elif 'area' in slot:
+            if '|' in value:
+                value = value.split('|', 1)[0]
+            clean_state.append([slot, value])
+        else:
+            if '|' in value:
+                value = value.split('|', 1)[0]
+                value = map_values(value)
+                # d, s = slot.split('-', 1)
+                # value = normalize_value(value_dict, d, s, value)
+            clean_state.append([slot, value])
+    
+    return clean_state
+
+
+# Module to process a dialogue and check its validity
+def process_dialogue(dialogue, max_utt_len=128):
+    if len(dialogue['log']) % 2 != 0:
+        return None
+
+    # Extract user and system utterances
+    usr_utts, sys_utts = [], []
+    avg_len = sum(len(utt['text'].split(' ')) for utt in dialogue['log'])
+    avg_len = avg_len / len(dialogue['log'])
+    if avg_len > max_utt_len:
+        return None
+
+    # If the first term is a system turn then ignore dialogue
+    if dialogue['log'][0]['metadata']:
+        return None
+
+    usr, sys = None, None
+    for turn in dialogue['log']:
+        if not is_ascii(turn['text']):
+            return None
+
+        if not usr or not sys:
+            if len(turn['metadata']) == 0:
+                usr = turn
+            else:
+                sys = turn
+        
+        if usr and sys:
+            states = build_dialoguestate(sys['metadata'], get_domains = False)
+            sys['dialogue_states'] = states
+
+            usr_utts.append(usr)
+            sys_utts.append(sys)
+            usr, sys = None, None
+
+    dial_clean = dict()
+    dial_clean['usr_log'] = usr_utts
+    dial_clean['sys_log'] = sys_utts
+    return dial_clean
+
+
+# Get new domains
+def get_act_domains(prev, crnt):
+    diff = {}
+    if not prev or not crnt:
+        return diff
+
+    for ((prev_dom, prev_val), (crnt_dom, crnt_val)) in zip(prev.items(), crnt.items()):
+        assert prev_dom == crnt_dom
+        if prev_val != crnt_val:
+            diff[crnt_dom] = crnt_val
+    return diff
+
+
+# Get current domains
+def get_domains(dial_log, turn_id, prev_domain):
+    if turn_id == 1:
+        active = build_dialoguestate(dial_log[turn_id]['metadata'], get_domains=True)
+        acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
+        acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
+        active += acts
+        crnt = active[0] if active else ''
+    else:
+        active = get_act_domains(dial_log[turn_id - 2]['metadata'], dial_log[turn_id]['metadata'])
+        active = list(active.keys())
+        acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
+        acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
+        active += acts
+        crnt = [prev_domain] if not active else active
+        crnt = crnt[0]
+
+    return crnt
+
+
+# Function to extract dialogue info from data
+def extract_dialogue(dialogue, max_utt_len=50):
+    dialogue = process_dialogue(dialogue, max_utt_len)
+    if not dialogue:
+        return None
+
+    usr_utts = [turn['text'] for turn in dialogue['usr_log']]
+    sys_utts = [turn['text'] for turn in dialogue['sys_log']]
+    # sys_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['sys_log']]
+    usr_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['usr_log']]
+    dialogue_states = [turn['dialogue_states'] for turn in dialogue['sys_log']]
+    domains = [turn['domain'] for turn in dialogue['usr_log']]
+
+    # dial = [{'usr': u,'sys': s, 'usr_a': ua, 'sys_a': a, 'domain': d, 'ds': v}
+    #         for u, s, ua, a, d, v in zip(usr_utts, sys_utts, usr_acts, sys_acts, domains, dialogue_states)]
+    dial = [{'usr': u,'sys': s, 'usr_a': ua, 'domain': d, 'ds': v}
+            for u, s, ua, d, v in zip(usr_utts, sys_utts, usr_acts, domains, dialogue_states)]    
+    return dial
+
+
+def format_acts(acts):
+    new_acts = []
+    for key, item in acts.items():
+        domain, intent = key.split('-', 1)
+        if domain.lower() in ACTIVE_DOMAINS + ['general']:
+            state = []
+            for slot, value in item:
+                slot = str(REF_SYS_DA[domain].get(slot, slot)).lower() if domain in REF_SYS_DA else slot
+                value = clean_text(value)
+                slot = slot.replace('_', ' ').replace('ref', 'reference')
+                state.append([f'{domain.lower()}-{slot}', value])
+            state = clean_dialoguestate(state, is_acts=True)
+            if domain == 'general':
+                if intent in ['thank', 'bye']:
+                    state = [['general-none', 'none']]
+                else:
+                    state = []
+            for slot, value in state:
+                if slot not in ['train-people']:
+                    slot = slot.split('-', 1)[-1]
+                    new_acts.append([intent.lower(), domain.lower(), slot, value])
+    
+    return new_acts
+                
+
+# Fix act labels
+def fix_delexicalisation(turn):
+    if 'dialog_act' in turn:
+        for dom, act in turn['dialog_act'].items():
+            if 'Attraction' in dom:
+                if 'restaurant_' in turn['text']:
+                    turn['text'] = turn['text'].replace("restaurant", "attraction")
+                if 'hotel_' in turn['text']:
+                    turn['text'] = turn['text'].replace("hotel", "attraction")
+            if 'Hotel' in dom:
+                if 'attraction_' in turn['text']:
+                    turn['text'] = turn['text'].replace("attraction", "hotel")
+                if 'restaurant_' in turn['text']:
+                    turn['text'] = turn['text'].replace("restaurant", "hotel")
+            if 'Restaurant' in dom:
+                if 'attraction_' in turn['text']:
+                    turn['text'] = turn['text'].replace("attraction", "restaurant")
+                if 'hotel_' in turn['text']:
+                    turn['text'] = turn['text'].replace("hotel", "restaurant")
+
+    return turn
+
+
+# Check if a character is an ascii character
+def is_ascii(s):
+    return all(ord(c) < 128 for c in s)
+
+
+# Insert white space
+def separate_token(token, text):
+    sidx = 0
+    while True:
+        # Find next instance of token
+        sidx = text.find(token, sidx)
+        if sidx == -1:
+            break
+        # If the token is already seperated continue to next
+        if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
+                re.match('[0-9]', text[sidx + 1]):
+            sidx += 1
+            continue
+        # Create white space separation around token
+        if text[sidx - 1] != ' ':
+            text = text[:sidx] + ' ' + text[sidx:]
+            sidx += 1
+        if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
+            text = text[:sidx + 1] + ' ' + text[sidx + 1:]
+        sidx += 1
+    return text
+
+
+def clean_text(text):
+    # Replace white spaces in front and end
+    text = re.sub(r'^\s*|\s*$', '', text.strip().lower())
+
+    # Replace b&v or 'b and b' with 'bed and breakfast'
+    text = re.sub(r"b&b", "bed and breakfast", text)
+    text = re.sub(r"b and b", "bed and breakfast", text)
+
+    # Fix apostrophies
+    text = re.sub(u"(\u2018|\u2019)", "'", text)
+
+    # Correct punctuation
+    text = text.replace(';', ',')
+    text = re.sub('$\/', '', text)
+    text = text.replace('/', ' and ')
+
+    # Replace special characters
+    text = text.replace('-', ' ')
+    text = re.sub('[\"\<>@\(\)]', '', text)
+
+    # Insert white space around special tokens:
+    for token in ['?', '.', ',', '!']:
+        text = separate_token(token, text)
+
+    # insert white space for 's
+    text = separate_token('\'s', text)
+
+    # replace it's, does't, you'd ... etc
+    text = re.sub('^\'', '', text)
+    text = re.sub('\'$', '', text)
+    text = re.sub('\'\s', ' ', text)
+    text = re.sub('\s\'', ' ', text)
+
+    # Perform pair replacements listed in the mapping.pair file
+    for fromx, tox in REPLACEMENTS:
+        text = ' ' + text + ' '
+        text = text.replace(fromx, tox)[1:-1]
+
+    # Remove multiple spaces
+    text = re.sub(' +', ' ', text)
+
+    # Concatenate numbers eg '1 3' -> '13'
+    tokens = text.split()
+    i = 1
+    while i < len(tokens):
+        if re.match(u'^\d+$', tokens[i]) and \
+                re.match(u'\d+$', tokens[i - 1]):
+            tokens[i - 1] += tokens[i]
+            del tokens[i]
+        else:
+            i += 1
+    text = ' '.join(tokens)
+
+    return text
diff --git a/convlab/dst/setsumbt/process_mwoz_data.py b/convlab/dst/setsumbt/process_mwoz_data.py
new file mode 100755
index 0000000000000000000000000000000000000000..701a523613961d83a5188fa9ab0cf786b19a5a7e
--- /dev/null
+++ b/convlab/dst/setsumbt/process_mwoz_data.py
@@ -0,0 +1,99 @@
+import os
+import json
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import torch
+from tqdm import tqdm
+
+from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
+from convlab.util.multiwoz.lexicalize import deflat_da, flat_da
+
+
+def load_data(path):
+    with open(path, 'r') as reader:
+        data = json.load(reader)
+        reader.close()
+    
+    return data
+
+
+def load_tracker(model_checkpoint):
+    model = SetSUMBTTracker(model_path=model_checkpoint)
+    model.init_session()
+
+    return model
+
+
+def process_dialogue(dial, model, get_full_belief_state):
+    model.store_full_belief_state = get_full_belief_state
+    model.init_session()
+
+    model.state['history'].append(['sys', ''])
+    processed_dial = []
+    belief_state = {}
+    for turn in dial:
+        if not turn['metadata']:
+            state = model.update(turn['text'])
+            model.state['history'].append(['usr', turn['text']])
+            
+            acts = model.state['user_action']
+            acts = [[val.replace('-', ' ') for val in act] for act in acts]
+            acts = flat_da(acts)
+            acts = deflat_da(acts)
+            turn['dialog_act'] = acts
+        else:
+            model.state['history'].append(['sys', turn['text']])
+            turn['metadata'] = model.state['belief_state']
+        
+        if get_full_belief_state:
+            for slot, probs in model.full_belief_state.items():
+                if slot not in belief_state:
+                    belief_state[slot] = [probs[0]]
+                else:
+                    belief_state[slot].append(probs[0])
+        
+        processed_dial.append(turn)
+    
+    if get_full_belief_state:
+        belief_state = {slot: torch.cat(probs, 0).cpu() for slot, probs in belief_state.items()}
+
+    return processed_dial, belief_state
+
+
+def process_dialogues(data, model, get_full_belief_state=False):
+    processed_data = {}
+    belief_states = {}
+    for dial_id, dial in tqdm(data.items()):
+        dial['log'], bs = process_dialogue(dial['log'], model, get_full_belief_state)
+        processed_data[dial_id] = dial
+        if get_full_belief_state:
+            belief_states[dial_id] = bs
+
+    return processed_data, belief_states
+
+
+def get_arguments():
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+
+    parser.add_argument('--model_path')
+    parser.add_argument('--data_path')
+    parser.add_argument('--get_full_belief_state', action='store_true')
+
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_arguments()
+
+    print('Loading data and model...')
+    data = load_data(os.path.join(args.data_path, 'data.json'))
+    model = load_tracker(args.model_path)
+
+    print('Processing data...\n')
+    data, belief_states = process_dialogues(data, model, get_full_belief_state=args.get_full_belief_state)
+    
+    print('Saving results...\n')
+    torch.save(belief_states, os.path.join(args.data_path, 'setsumbt_belief_states.bin'))
+    with open(os.path.join(args.data_path, 'setsumbt_data.json'), 'w') as writer:
+        json.dump(data, writer, indent=2)
+        writer.close()
diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py
index e45bf129f0c9f2c5c1fba01d4b5eb80e29a5a1f0..b9c9a75b86d47cd5db733b4755d6af11f08b827d 100644
--- a/convlab/dst/setsumbt/run.py
+++ b/convlab/dst/setsumbt/run.py
@@ -33,8 +33,8 @@ def main():
     if args.run_nbt:
         from convlab.dst.setsumbt.do.nbt import main
         main(args, config)
-    if args.run_evaluation:
-        from convlab.dst.setsumbt.do.evaluate import main
+    if args.run_calibration:
+        from convlab.dst.setsumbt.do.calibration import main
         main(args, config)
 
 
diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py
deleted file mode 100644
index b58fc5bd17fdf486e544fe06c4f8c2bc0b83f8b3..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/tracker.py
+++ /dev/null
@@ -1,442 +0,0 @@
-import os
-import json
-import copy
-import logging
-
-import torch
-import transformers
-from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer
-
-from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT
-from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings
-from convlab.dst.dst import DST
-from convlab.util.custom_util import model_downloader
-
-USE_CUDA = torch.cuda.is_available()
-
-
-class SetSUMBTTracker(DST):
-    """SetSUMBT Tracker object for Convlab dialogue system"""
-
-    def __init__(self,
-                 model_path: str = "",
-                 model_type: str = "roberta",
-                 return_turn_pooled_representation: bool = False,
-                 return_confidence_scores: bool = False,
-                 confidence_threshold='auto',
-                 return_belief_state_entropy: bool = False,
-                 return_belief_state_mutual_info: bool = False,
-                 store_full_belief_state: bool = False):
-        """
-        Args:
-            model_path: Model path or download URL
-            model_type: Transformer type (roberta/bert)
-            return_turn_pooled_representation: If true a turn level pooled representation is returned
-            return_confidence_scores: If true act confidence scores are included in the state
-            confidence_threshold: Confidence threshold value for constraints or option auto
-            return_belief_state_entropy: If true belief state distribution entropies are included in the state
-            return_belief_state_mutual_info: If true belief state distribution mutual infos are included in the state
-            store_full_belief_state: If true full belief state is stored within tracker object
-        """
-        super(SetSUMBTTracker, self).__init__()
-
-        self.model_type = model_type
-        self.model_path = model_path
-        self.return_turn_pooled_representation = return_turn_pooled_representation
-        self.return_confidence_scores = return_confidence_scores
-        self.confidence_threshold = confidence_threshold
-        self.return_belief_state_entropy = return_belief_state_entropy
-        self.return_belief_state_mutual_info = return_belief_state_mutual_info
-        self.store_full_belief_state = store_full_belief_state
-        if self.store_full_belief_state:
-            self.full_belief_state = {}
-        self.info_dict = {}
-
-        # Download model if needed
-        if not os.path.exists(self.model_path):
-            # Get path /.../convlab/dst/setsumbt/multiwoz/models
-            download_path = os.path.dirname(os.path.abspath(__file__))
-            download_path = os.path.join(download_path, 'models')
-            if not os.path.exists(download_path):
-                os.mkdir(download_path)
-            model_downloader(download_path, self.model_path)
-            # Downloadable model path format http://.../setsumbt_model_name.zip
-            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
-            self.model_path = os.path.join(download_path, self.model_path)
-
-        # Select model type based on the encoder
-        if model_type == "roberta":
-            self.config = RobertaConfig.from_pretrained(self.model_path)
-            self.tokenizer = RobertaTokenizer
-            self.model = RobertaSetSUMBT
-        elif model_type == "bert":
-            self.config = BertConfig.from_pretrained(self.model_path)
-            self.tokenizer = BertTokenizer
-            self.model = BertSetSUMBT
-        else:
-            logging.debug("Name Error: Not Implemented")
-
-        self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
-
-        self.load_weights()
-
-    def load_weights(self):
-        """Load model weights and model ontology"""
-        logging.info('Loading SetSUMBT pretrained model.')
-        self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name)
-        logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.')
-        self.model = self.model.from_pretrained(self.model_path, config=self.config)
-        logging.info(f'Model loaded from {self.model_path}.')
-
-        # Transfer model to compute device and setup eval environment
-        self.model = self.model.to(self.device)
-        self.model.eval()
-        logging.info(f'Model transferred to device: {self.device}')
-
-        logging.info('Loading model ontology')
-        f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r')
-        self.ontology = json.load(f)
-        f.close()
-
-        db = torch.load(os.path.join(self.model_path, 'database', 'test.db'))
-        set_ontology_embeddings(self.model, db)
-
-        if self.return_confidence_scores:
-            logging.info('Model returns user action and belief state confidence scores.')
-            self.get_thresholds(self.confidence_threshold)
-            logging.info('Uncertain Querying set up and thresholds set up at:')
-            logging.info(self.confidence_thresholds)
-        if self.return_belief_state_entropy:
-            logging.info('Model returns belief state distribution entropy scores (Total uncertainty).')
-        if self.return_belief_state_mutual_info:
-            logging.info('Model returns belief state distribution mutual information scores (Knowledge uncertainty).')
-        logging.info('Ontology loaded successfully.')
-
-    def get_thresholds(self, threshold='auto') -> dict:
-        """
-        Setup dictionary of domain specific confidence thresholds
-
-        Args:
-            threshold: Threshold value or option auto
-
-        Returns:
-            confidence_thresholds: Domain specific confidence thresholds
-        """
-        self.confidence_thresholds = dict()
-        for domain, substate in self.ontology.items():
-            for slot, slot_info in substate.items():
-                # Auto thresholds are set based on the number of value candidates per slot
-                if domain not in self.confidence_thresholds:
-                    self.confidence_thresholds[domain] = dict()
-                if threshold == 'auto':
-                    thres = 1.0 / (float(len(slot_info['possible_values'])) - 2.1)
-                    self.confidence_thresholds[domain][slot] = max(0.05, thres)
-                else:
-                    self.confidence_thresholds[domain][slot] = max(0.05, threshold)
-
-        return self.confidence_thresholds
-
-    def init_session(self):
-        self.state = dict()
-        self.state['belief_state'] = dict()
-        for domain, substate in self.ontology.items():
-            self.state['belief_state'][domain] = dict()
-            for slot, slot_info in substate.items():
-                if slot_info['possible_values'] and slot_info['possible_values'] != ['?']:
-                    self.state['belief_state'][domain][slot] = ''
-        self.state['history'] = []
-        self.state['system_action'] = []
-        self.state['user_action'] = []
-        self.active_domains = {}
-        self.hidden_states = None
-        self.info_dict = {}
-
-    def update(self, user_act: str = '') -> dict:
-        """
-        Update user actions and dialogue and belief states.
-
-        Args:
-            user_act:
-
-        Returns:
-
-        """
-        prev_state = self.state
-        _output = self.predict(self.get_features(user_act))
-
-        # Format state entropy
-        if _output[5] is not None:
-            state_entropy = dict()
-            for slot, e in _output[5].items():
-                domain, slot = slot.split('-', 1)
-                if domain not in state_entropy:
-                    state_entropy[domain] = dict()
-                state_entropy[domain][slot] = e
-        else:
-            state_entropy = None
-
-        # Format state mutual information
-        if _output[6] is not None:
-            state_mutual_info = dict()
-            for slot, mi in _output[5].items():
-                domain, slot = slot.split('-', 1)
-                if domain not in state_mutual_info:
-                    state_mutual_info[domain] = dict()
-                state_mutual_info[domain][slot] = mi[0, 0]
-        else:
-            state_mutual_info = None
-
-        # Format all confidence scores
-        belief_state_confidence = None
-        if _output[4] is not None:
-            belief_state_confidence = dict()
-            belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4]
-            for slot, p in belief_state_conf.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in belief_state_confidence:
-                    belief_state_confidence[domain] = dict()
-                if slot not in belief_state_confidence[domain]:
-                    belief_state_confidence[domain][slot] = dict()
-                belief_state_confidence[domain][slot]['inform'] = p
-
-            for slot, p in request_probs.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in belief_state_confidence:
-                    belief_state_confidence[domain] = dict()
-                if slot not in belief_state_confidence[domain]:
-                    belief_state_confidence[domain][slot] = dict()
-                belief_state_confidence[domain][slot]['request'] = p
-
-            for domain, p in active_domain_probs.items():
-                if domain not in belief_state_confidence:
-                    belief_state_confidence[domain] = dict()
-                belief_state_confidence[domain]['none'] = {'inform': p}
-
-            if 'general' not in belief_state_confidence:
-                belief_state_confidence['general'] = dict()
-            belief_state_confidence['general']['none'] = general_act_probs
-
-        # Get new domain activation actions
-        new_domains = [d for d, active in _output[1].items() if active]
-        new_domains = [d for d in new_domains if not self.active_domains.get(d, False)]
-        self.active_domains = _output[1]
-
-        user_acts = _output[2]
-        for domain in new_domains:
-            user_acts.append(['inform', domain, 'none', 'none'])
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        for domain, substate in _output[0].items():
-            for slot, value in substate.items():
-                value = '' if value == 'none' else value
-                value = 'dontcare' if value == 'do not care' else value
-                value = 'guesthouse' if value == 'guest house' else value
-
-                if domain not in new_belief_state:
-                    if domain == 'bus':
-                        continue
-                    else:
-                        logging.debug('Error: domain <{}> not in belief state'.format(domain))
-
-                # Uncertainty clipping of state
-                if belief_state_confidence is not None:
-                    threshold = self.confidence_thresholds[domain][slot]
-                    if belief_state_confidence[domain][slot].get('inform', 1.0) < threshold:
-                        value = ''
-
-                new_belief_state[domain][slot] = value
-                if prev_state['belief_state'][domain][slot] != value:
-                    user_acts.append(['inform', domain, slot, value])
-                else:
-                    bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>'
-                    logging.debug(bug)
-
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state
-        new_state['active_domains'] = self.active_domains
-        if belief_state_confidence is not None:
-            new_state['belief_state_probs'] = belief_state_confidence
-        if state_entropy is not None:
-            new_state['entropy'] = state_entropy
-        if state_mutual_info is not None:
-            new_state['mutual_information'] = state_mutual_info
-
-        user_acts = [act for act in user_acts if act not in new_state['system_action']]
-        new_state['user_action'] = user_acts
-
-        if _output[3] is not None:
-            new_state['turn_pooled_representation'] = _output[3]
-
-        self.state = new_state
-        self.info_dict = copy.deepcopy(dict(new_state))
-
-        return self.state
-
-    def predict(self, features: dict) -> tuple:
-        """
-        Model forward pass and prediction post processing.
-
-        Args:
-            features: Dictionary of model input features
-
-        Returns:
-            out: Model predictions and uncertainty features
-        """
-        state_mutual_info = None
-        with torch.no_grad():
-            turn_pooled_representation = None
-            if self.return_turn_pooled_representation:
-                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
-                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
-                                      get_turn_pooled_representation=True)
-                belief_state = _outputs[0]
-                request_probs = _outputs[1]
-                active_domain_probs = _outputs[2]
-                general_act_probs = _outputs[3]
-                self.hidden_states = _outputs[4]
-                turn_pooled_representation = _outputs[5]
-            elif self.return_belief_state_mutual_info:
-                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
-                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
-                                      get_turn_pooled_representation=True, calculate_inform_mutual_info=True)
-                belief_state = _outputs[0]
-                request_probs = _outputs[1]
-                active_domain_probs = _outputs[2]
-                general_act_probs = _outputs[3]
-                self.hidden_states = _outputs[4]
-                state_mutual_info = _outputs[5]
-            else:
-                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
-                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
-                                      get_turn_pooled_representation=False)
-                belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs
-
-        # Convert belief state into dialog state
-        dialogue_state = dict()
-        for slot, probs in belief_state.items():
-            dom, slot = slot.split('-', 1)
-            if dom not in dialogue_state:
-                dialogue_state[dom] = dict()
-            val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()]
-            if val != 'none':
-                dialogue_state[dom][slot] = val
-
-        if self.store_full_belief_state:
-            self.full_belief_state = belief_state
-
-        # Obtain model output probabilities
-        if self.return_confidence_scores:
-            state_entropy = None
-            if self.return_belief_state_entropy:
-                state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()}
-                state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()}
-
-            # Confidence score is the max probability across all not "none" values candidates.
-            belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()}
-            _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()}
-            _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()}
-            _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()}
-            confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs)
-        else:
-            confidence_scores = None
-            state_entropy = None
-
-        # Construct request action prediction
-        request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
-        request_acts = [slot.split('-', 1) for slot in request_acts]
-        request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
-
-        # Construct active domain set
-        active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
-
-        # Construct general domain action
-        general_acts = general_act_probs[0, 0, :].argmax(-1).item()
-        general_acts = [[], ['bye'], ['thank']][general_acts]
-        general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
-
-        user_acts = request_acts + general_acts
-
-        out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores)
-        out += (state_entropy, state_mutual_info)
-        return out
-
-    def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor:
-        """
-        Compute relative entrop for a probability distribution
-
-        Args:
-            probs: Probability distributions
-
-        Returns:
-            entropy: Relative entropy
-        """
-        entropy = probs * torch.log(probs + 1e-8)
-        entropy = -entropy.sum()
-        # Maximum entropy of a K dimentional distribution is ln(K)
-        entropy /= torch.log(torch.tensor(probs.size(-1)).float())
-
-        return entropy
-
-    def get_features(self, user_act: str) -> dict:
-        """
-        Tokenize utterances and construct model input features
-
-        Args:
-            user_act: User action string
-
-        Returns:
-            features: Model input features
-        """
-        # Extract system utterance from dialog history
-        context = self.state['history']
-        if context:
-            if context[-1][0] != 'sys':
-                system_act = ''
-            else:
-                system_act = context[-1][-1]
-        else:
-            system_act = ''
-
-        # Tokenize dialog
-        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True,
-                                              max_length=self.config.max_turn_len, padding='max_length',
-                                              truncation='longest_first')
-
-        input_ids = torch.tensor(features['input_ids']).reshape(
-            1, 1, -1).to(self.device) if 'input_ids' in features else None
-        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
-            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
-        attention_mask = torch.tensor(features['attention_mask']).reshape(
-            1, 1, -1).to(self.device) if 'attention_mask' in features else None
-        features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
-
-        return features
-
-
-# if __name__ == "__main__":
-#     from convlab.policy.vector.vector_uncertainty import VectorUncertainty
-#     # from convlab.policy.vector.vector_binary import VectorBinary
-#     tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
-#                               return_confidence_scores=True, confidence_threshold='auto',
-#                               return_belief_state_entropy=True)
-#     vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds,
-#                                use_masking=True)
-#     # vector = VectorBinary()
-#     tracker.init_session()
-#
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     state = tracker.update('If you have something Asian that would be great.')
-#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'],
-#                                       ['inform', 'restaurant', 'name', 'the golden wok']]
-#     state = tracker.update('Great. Where are they located?')
-#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     state = tracker.state
-#     state['terminated'] = False
-#     state['booked'] = {}
-#
-#     print(state)
-#     print(vector.state_vectorize(state))
diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils.py
index 23cec10ddd06872118af77bfac9ea856f3947677..75a6a1febd7510f1a6d152676b82d709abf177f0 100644
--- a/convlab/dst/setsumbt/utils.py
+++ b/convlab/dst/setsumbt/utils.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,41 +15,57 @@
 # limitations under the License.
 """SetSUMBT utils"""
 
+import re
 import os
 import shutil
 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+from glob import glob
 from datetime import datetime
 
+from google.cloud import storage
 
-def get_args(base_models: dict):
+
+def get_args(MODELS):
     # Get arguments
     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
 
     # Optional
-    parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='')
+    parser.add_argument('--tensorboard_path',
+                        help='Path to tensorboard', default='')
     parser.add_argument('--logging_path', help='Path for log file', default='')
-    parser.add_argument('--seed', help='Seed value for reproducibility', default=0, type=int)
+    parser.add_argument(
+        '--seed', help='Seed value for reproducability', default=0, type=int)
 
     # DATASET (Optional)
-    parser.add_argument('--dataset', help='Dataset Name (See Convlab 3 unified format for possible datasets',
-                        default='multiwoz21')
-    parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0,
-                        type=float)
-    parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
-    parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
-    parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
-    parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12,
-                        type=int)
-    parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.')
-    parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int)
-    parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings',
+    parser.add_argument(
+        '--dataset', help='Dataset Name: multiwoz21/simr', default='multiwoz21')
+    parser.add_argument('--shrink_active_domains', help='Shrink active domains to only well represented test set domains',
+                        action='store_true')
+    parser.add_argument(
+        '--data_dir', help='Data storage directory', default=None)
+    parser.add_argument(
+        '--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
+    parser.add_argument(
+        '--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
+    parser.add_argument(
+        '--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
+    parser.add_argument('--max_candidate_len',
+                        help='Maximum number of tokens per value candidate', default=12, type=int)
+    parser.add_argument('--force_processing', action='store_true',
+                        help='Force preprocessing of data.')
+    parser.add_argument('--data_sampling_size',
+                        help='Resampled dataset size', default=-1, type=int)
+    parser.add_argument('--use_descriptions', help='Use slot descriptions rather than slot names for embeddings',
                         action='store_true')
 
     # MODEL
     # Environment
-    parser.add_argument('--output_dir', help='Output storage directory', default=None)
-    parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta')
-    parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None)
+    parser.add_argument(
+        '--output_dir', help='Output storage directory', default=None)
+    parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta',
+                        default='roberta')
+    parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.',
+                        default=None)
     parser.add_argument('--candidate_embedding_model_name', default=None,
                         help='Name of the pretrained candidate embedding model.')
 
@@ -58,73 +74,92 @@ def get_args(base_models: dict):
                         action='store_true')
     parser.add_argument('--slot_attention_heads', help='Number of attention heads for slot conditioning',
                         default=12, type=int)
-    parser.add_argument('--dropout_rate', help='Dropout Rate', default=0.3, type=float)
-    parser.add_argument('--nbt_type', help='Belief Tracker type: gru/lstm', default='gru')
+    parser.add_argument('--dropout_rate', help='Dropout Rate',
+                        default=0.3, type=float)
+    parser.add_argument(
+        '--nbt_type', help='Belief Tracker type: gru/lstm', default='gru')
     parser.add_argument('--nbt_hidden_size', help='Hidden embedding size for the Neural Belief Tracker',
                         default=300, type=int)
-    parser.add_argument('--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int)
-    parser.add_argument('--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true')
+    parser.add_argument(
+        '--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int)
+    parser.add_argument(
+        '--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true')
     parser.add_argument('--distance_measure', default='cosine',
                         help='Similarity measure for candidate scoring: cosine/euclidean')
-    parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
-    parser.add_argument('--no_set_similarity', action='store_true', help='Set True to not use set similarity')
-    parser.add_argument('--set_pooling',
-                        help='Set pooling method for set similarity model using single embedding distances',
+    parser.add_argument(
+        '--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
+    parser.add_argument('--set_similarity', action='store_true',
+                        help='Set True to not use set similarity (Model tracks latent belief state as sequence and performs semantic similarity of sets)')
+    parser.add_argument('--set_pooling', help='Set pooling method for set similarity model using single embedding distances',
                         default='cnn')
-    parser.add_argument('--candidate_pooling',
-                        help='Pooling approach for non set based candidate representations: cls/mean',
+    parser.add_argument('--candidate_pooling', help='Pooling approach for non set based candidate representations: cls/mean',
                         default='mean')
-    parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain',
+    parser.add_argument('--predict_actions', help='Model predicts user actions and active domain',
                         action='store_true')
 
     # Loss
-    parser.add_argument('--loss_function',
-                        help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/...',
+    parser.add_argument('--loss_function', help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/distillation/distribution_distillation',
                         default='labelsmoothing')
     parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss',
                         type=float)
     parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss',
                         type=float)
-    parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float)
-    parser.add_argument('--annealing_base_temp', help='Ensemble Distribution distillation temp annealing base temp',
+    parser.add_argument('--ensemble_smoothing',
+                        help='Ensemble distribution smoothing constant', type=float)
+    parser.add_argument('--annealing_base_temp', help='Ensemble Distribution destillation temp annealing base temp',
                         type=float)
-    parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution distillation temp annealing cycle length',
+    parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution destillation temp annealing cycle length',
                         type=float)
-    parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float)
-    parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0',
-                        type=float)
-    parser.add_argument('--user_request_loss_weight',
-                        help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument('--user_general_act_loss_weight',
-                        help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument('--active_domain_loss_weight',
-                        help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--inhibiting_factor',
+                        help='Inhibiting factor for Inhibited Softmax CE', type=float)
+    parser.add_argument('--label_smoothing',
+                        help='Label smoothing coefficient.', type=float)
+    parser.add_argument(
+        '--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument(
+        '--user_request_loss_weight', help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument(
+        '--user_general_act_loss_weight', help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument(
+        '--active_domain_loss_weight', help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float)
 
     # TRAINING
-    parser.add_argument('--train_batch_size', help='Training Set Batch Size', default=8, type=int)
-    parser.add_argument('--max_training_steps', help='Maximum number of training update steps', default=-1, type=int)
+    parser.add_argument('--train_batch_size',
+                        help='Training Set Batch Size', default=4, type=int)
+    parser.add_argument('--max_training_steps', help='Maximum number of training update steps',
+                        default=-1, type=int)
     parser.add_argument('--gradient_accumulation_steps', default=1, type=int,
                         help='Number of batches accumulated for one update step')
-    parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int)
+    parser.add_argument('--num_train_epochs',
+                        help='Number of training epochs', default=50, type=int)
     parser.add_argument('--patience', help='Number of training steps without improving model before stopping.',
-                        default=20, type=int)
-    parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float)
-    parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float)
-    parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float)
-    parser.add_argument('--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float)
-    parser.add_argument('--save_steps', help='Number of update steps between saving model', default=-1, type=int)
-    parser.add_argument('--keep_models', help='How many model checkpoints should be kept during training',
-                        default=1, type=int)
+                        default=25, type=int)
+    parser.add_argument(
+        '--weight_decay', help='Weight decay rate', default=0.01, type=float)
+    parser.add_argument('--learning_rate',
+                        help='Initial Learning Rate', default=5e-5, type=float)
+    parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler',
+                        default=0.2, type=float)
+    parser.add_argument(
+        '--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float)
+    parser.add_argument(
+        '--save_steps', help='Number of update steps between saving model', default=-1, type=int)
+    parser.add_argument(
+        '--keep_models', help='How many model checkpoints should be kept during training', default=1, type=int)
 
     # CALIBRATION
-    parser.add_argument('--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float)
+    parser.add_argument(
+        '--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float)
 
     # EVALUATION
-    parser.add_argument('--dev_batch_size', help='Dev Set Batch Size', default=16, type=int)
-    parser.add_argument('--test_batch_size', help='Test Set Batch Size', default=16, type=int)
+    parser.add_argument('--dev_batch_size',
+                        help='Dev Set Batch Size', default=16, type=int)
+    parser.add_argument('--test_batch_size',
+                        help='Test Set Batch Size', default=16, type=int)
 
     # COMPUTING
-    parser.add_argument('--n_gpu', help='Number of GPUs to use', default=1, type=int)
+    parser.add_argument(
+        '--n_gpu', help='Number of GPUs to use', default=1, type=int)
     parser.add_argument('--fp16', action='store_true',
                         help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
     parser.add_argument('--fp16_opt_level', type=str, default='O1',
@@ -132,29 +167,32 @@ def get_args(base_models: dict):
                              "See details at https://nvidia.github.io/apex/amp.html")
 
     # ACTIONS
-    parser.add_argument('--run_nbt', help='Run NBT script',action='store_true')
-    parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true')
+    parser.add_argument('--run_nbt', help='Run NBT script',
+                        action='store_true')
+    parser.add_argument('--run_calibration',
+                        help='Run calibration', action='store_true')
 
     # RUN_NBT ACTIONS
-    parser.add_argument('--do_train', help='Perform training', action='store_true')
-    parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true')
-    parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
+    parser.add_argument(
+        '--do_train', help='Perform training', action='store_true')
+    parser.add_argument(
+        '--do_eval', help='Perform model evaluation during training', action='store_true')
+    parser.add_argument(
+        '--do_test', help='Evaulate model on test data', action='store_true')
     args = parser.parse_args()
 
-    # Simplify args
-    args.set_similarity = not args.no_set_similarity
-    args.use_descriptions = not args.no_descriptions
-    args.predict_actions = not args.no_action_prediction
-
     # Setup default directories
+    if not args.data_dir:
+        args.data_dir = os.path.dirname(os.path.abspath(__file__))
+        args.data_dir = os.path.join(args.data_dir, 'data')
+        os.makedirs(args.data_dir, exist_ok=True)
+
     if not args.output_dir:
         args.output_dir = os.path.dirname(os.path.abspath(__file__))
         args.output_dir = os.path.join(args.output_dir, 'models')
 
-        name = 'SetSUMBT' if args.set_similarity else 'SUMBT'
-        name += '+ActPrediction' if args.predict_actions else ''
-        name += '-' + args.dataset
-        name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else ''
+        name = 'SetSUMBT'
+        name += '-Acts' if args.predict_actions else ''
         name += '-' + args.model_type
         name += '-' + args.nbt_type
         name += '-' + args.distance_measure
@@ -170,6 +208,9 @@ def get_args(base_models: dict):
             args.kl_scaling_factor = 0.001
         if not args.prior_constant:
             args.prior_constant = 1.0
+    if args.loss_function == 'inhibitedce':
+        if not args.inhibiting_factor:
+            args.inhibiting_factor = 1.0
     if args.loss_function == 'labelsmoothing':
         if not args.label_smoothing:
             args.label_smoothing = 0.05
@@ -192,8 +233,10 @@ def get_args(base_models: dict):
         if not args.active_domain_loss_weight:
             args.active_domain_loss_weight = 0.2
 
-    args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(args.output_dir, 'tb_logs')
-    args.logging_path = args.logging_path if args.logging_path else os.path.join(args.output_dir, 'run.log')
+    args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(
+        args.output_dir, 'tb_logs')
+    args.logging_path = args.logging_path if args.logging_path else os.path.join(
+        args.output_dir, 'run.log')
 
     # Default model_name's
     if not args.model_name_or_path:
@@ -207,22 +250,28 @@ def get_args(base_models: dict):
     if not args.candidate_embedding_model_name:
         args.candidate_embedding_model_name = args.model_name_or_path
 
-    if args.model_type in base_models:
-        config_class = base_models[args.model_type][-2]
+    if args.model_type in MODELS:
+        configClass = MODELS[args.model_type][-2]
     else:
         raise NameError('NotImplemented')
-    config = build_config(config_class, args)
+    config = build_config(configClass, args)
     return args, config
 
 
-def build_config(config_class, args):
-    config = config_class.from_pretrained(args.model_name_or_path)
-    if not os.path.exists(args.model_name_or_path):
+def build_config(configClass, args):
+    if args.model_type == 'fasttext':
+        config = configClass.from_pretrained('bert-base-uncased')
+        config.model_type == 'fasttext'
+        config.fasttext_path = args.model_name_or_path
+        config.vocab_size = None
+    elif not os.path.exists(args.model_name_or_path):
+        config = configClass.from_pretrained(args.model_name_or_path)
         config.tokenizer_name = args.model_name_or_path
-    try:
-        config.tokenizer_name = config.tokenizer_name
-    except AttributeError:
+    elif 'tod-bert' in args.model_name_or_path.lower():
+        config = configClass.from_pretrained(args.model_name_or_path)
         config.tokenizer_name = args.model_name_or_path
+    else:
+        config = configClass.from_pretrained(args.model_name_or_path)
     if args.candidate_embedding_model_name:
         config.candidate_embedding_model_name = args.candidate_embedding_model_name
     config.max_dialogue_len = args.max_dialogue_len
diff --git a/convlab/dst/trippy/__init__.py b/convlab/dst/trippy/__init__.py
deleted file mode 100755
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/convlab/dst/trippy/multiwoz/__init__.py b/convlab/dst/trippy/multiwoz/__init__.py
deleted file mode 100644
index d432596abfaf8c0410b2b520cd59725b92de8932..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from convlab.dst.trippy.multiwoz.trippy import TRIPPY
diff --git a/convlab/dst/trippy/multiwoz/__init__.py~ b/convlab/dst/trippy/multiwoz/__init__.py~
deleted file mode 100644
index fed1a0ba2130022910438c88074a0791f5c795fe..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/__init__.py~
+++ /dev/null
@@ -1 +0,0 @@
-from convlab2.dst.trippy.multiwoz.trippy import TRIPPY
diff --git a/convlab/dst/trippy/multiwoz/modeling_bert_dst.prev.py b/convlab/dst/trippy/multiwoz/modeling_bert_dst.prev.py
deleted file mode 100644
index 1b39f7010516ddf72e326f07864434502cea389d..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/modeling_bert_dst.prev.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-#from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
-#from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
-from transformers import (BertModel, BertConfig, BertPreTrainedModel)
-
-#@add_start_docstrings(
-#    """BERT Model with a classification heads for the DST task. """,
-#    BERT_START_DOCSTRING,
-#)
-class BertForDST(BertPreTrainedModel):
-    def __init__(self, config):
-        super(BertForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.bert = BertModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        # Head for aux task
-        if hasattr(config, "aux_task_def"):
-            self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
-
-        self.init_weights()
-
-    #@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None,
-                aux_task_def=None):
-        outputs = self.bert(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        if aux_task_def is not None:
-            if aux_task_def['task_type'] == "classification":
-                aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
-                aux_logits = self.dropout_heads(aux_logits)
-                aux_loss_fct = CrossEntropyLoss()
-                aux_loss = aux_loss_fct(aux_logits, class_label_id)
-                # add hidden states and attention if they are here
-                return (aux_loss,) + outputs[2:]
-            elif aux_task_def['task_type'] == "span":
-                aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
-                aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
-                aux_start_logits = self.dropout_heads(aux_start_logits)
-                aux_end_logits = self.dropout_heads(aux_end_logits)
-                aux_start_logits = aux_start_logits.squeeze(-1)
-                aux_end_logits = aux_end_logits.squeeze(-1)
-
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos.size()) > 1:
-                    start_pos = start_pos.squeeze(-1)
-                if len(end_pos.size()) > 1:
-                    end_pos = end_pos.squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = aux_start_logits.size(1) # This is a single index
-                start_pos.clamp_(0, ignored_index)
-                end_pos.clamp_(0, ignored_index)
-            
-                aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
-                aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
-                aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
-                aux_loss = (aux_start_loss + aux_end_loss) / 2.0
-                return (aux_loss,) + outputs[2:]
-            else:
-                raise Exception("Unknown task_type")
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                start_loss = token_loss_fct(start_logits, start_pos[slot])
-                end_loss = token_loss_fct(end_logits, end_pos[slot])
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py b/convlab/dst/trippy/multiwoz/modeling_bert_dst.py
deleted file mode 100644
index 8dc899344dc7e17883096997782ea2f7bf85d0f6..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-from transformers import (BertModel, BertPreTrainedModel)
-
-
-class BertForDST(BertPreTrainedModel):
-    def __init__(self, config):
-        super(BertForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.stack_token_logits = config.dst_stack_token_logits
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.bert = BertModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        self.init_weights()
-
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None):
-        outputs = self.bert(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                if not self.stack_token_logits:
-                    start_loss = token_loss_fct(start_logits, start_pos[slot])
-                    end_loss = token_loss_fct(end_logits, end_pos[slot])
-                else:
-                    start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot])
-                    end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot])
-
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py~ b/convlab/dst/trippy/multiwoz/modeling_bert_dst.py~
deleted file mode 100644
index 809c26c5b449d2c800d8407e6dc2c561ff253d33..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py~
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-from transformers import (BertModel, BertPreTrainedModel)
-
-
-class BertForDST(BertPreTrainedModel):
-    def __init__(self, config):
-        super(BertForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.stack_token_logits = config.dst_stack_token_logits
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.bert = BertModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        self.init_weights()
-
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None):
-        outputs = self.bert(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                if not self.stack_token_logits:
-                    start_loss = token_loss_fct(start_logits, start_pos[slot])
-                    end_loss = token_loss_fct(end_logits, end_pos[slot])
-                else:
-                    start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot])
-                    end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot])
-
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.prev.py b/convlab/dst/trippy/multiwoz/modeling_roberta_dst.prev.py
deleted file mode 100644
index 5faf311318714007421242467d2579dfed99bad9..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.prev.py
+++ /dev/null
@@ -1,237 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-#from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
-#from transformers.modeling_utils import (PreTrainedModel)
-#from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
-#                                           ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm)
-from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel)
-
-
-#class RobertaPreTrainedModel(PreTrainedModel):
-#    """ An abstract class to handle weights initialization and
-#        a simple interface for dowloading and loading pretrained models.
-#    """
-#    config_class = RobertaConfig
-#    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
-#    base_model_prefix = "roberta"
-#    
-#    def _init_weights(self, module):
-#        """ Initialize the weights """
-#        if isinstance(module, (nn.Linear, nn.Embedding)):
-#            # Slightly different from the TF version which uses truncated_normal for initialization
-#            # cf https://github.com/pytorch/pytorch/pull/5617
-#            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-#        elif isinstance(module, BertLayerNorm):
-#            module.bias.data.zero_()
-#            module.weight.data.fill_(1.0)
-#        if isinstance(module, nn.Linear) and module.bias is not None:
-#            module.bias.data.zero_()
-
-
-#@add_start_docstrings(
-#    """RoBERTa Model with classification heads for the DST task. """,
-#    ROBERTA_START_DOCSTRING,
-#)
-class RobertaForDST(RobertaPreTrainedModel):
-    def __init__(self, config):
-        super(RobertaForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.roberta = RobertaModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        # Head for aux task
-        if hasattr(config, "aux_task_def"):
-            self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
-
-        self.init_weights()
-
-    #@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None,
-                aux_task_def=None):
-        outputs = self.roberta(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        if aux_task_def is not None:
-            if aux_task_def['task_type'] == "classification":
-                aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
-                aux_logits = self.dropout_heads(aux_logits)
-                aux_loss_fct = CrossEntropyLoss()
-                aux_loss = aux_loss_fct(aux_logits, class_label_id)
-                # add hidden states and attention if they are here
-                return (aux_loss,) + outputs[2:]
-            elif aux_task_def['task_type'] == "span":
-                aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
-                aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
-                aux_start_logits = self.dropout_heads(aux_start_logits)
-                aux_end_logits = self.dropout_heads(aux_end_logits)
-                aux_start_logits = aux_start_logits.squeeze(-1)
-                aux_end_logits = aux_end_logits.squeeze(-1)
-
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos.size()) > 1:
-                    start_pos = start_pos.squeeze(-1)
-                if len(end_pos.size()) > 1:
-                    end_pos = end_pos.squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = aux_start_logits.size(1) # This is a single index
-                start_pos.clamp_(0, ignored_index)
-                end_pos.clamp_(0, ignored_index)
-            
-                aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
-                aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
-                aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
-                aux_loss = (aux_start_loss + aux_end_loss) / 2.0
-                return (aux_loss,) + outputs[2:]
-            else:
-                raise Exception("Unknown task_type")
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                start_loss = token_loss_fct(start_logits, start_pos[slot])
-                end_loss = token_loss_fct(end_logits, end_pos[slot])
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py b/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py
deleted file mode 100644
index f4c2d773f80551965d0181c97c5fba64bbfb06b1..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel)
-
-
-class RobertaForDST(RobertaPreTrainedModel):
-    def __init__(self, config):
-        super(RobertaForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.stack_token_logits = config.dst_stack_token_logits
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.roberta = RobertaModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        self.init_weights()
-
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None):
-        outputs = self.roberta(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                if not self.stack_token_logits:
-                    start_loss = token_loss_fct(start_logits, start_pos[slot])
-                    end_loss = token_loss_fct(end_logits, end_pos[slot])
-                else:
-                    start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot])
-                    end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot])
-
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py~ b/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py~
deleted file mode 100644
index cdc2996f727c30bfa7b66e1b8bea2af28514a5a5..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py~
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel)
-
-
-class RobertaForDST(RobertaPreTrainedModel):
-    def __init__(self, config):
-        super(RobertaForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.stack_token_logits = config.dst_stack_token_logits
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.roberta = RobertaModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        self.init_weights()
-
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None):
-        outputs = self.roberta(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                if not self.stack_token_logits:
-                    start_loss = token_loss_fct(start_logits, start_pos[slot])
-                    end_loss = token_loss_fct(end_logits, end_pos[slot])
-                else:
-                    start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot])
-                    end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot])
-
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippy/multiwoz/trippy.bck.py b/convlab/dst/trippy/multiwoz/trippy.bck.py
deleted file mode 100644
index 0a41812a3718115faee90099aafdfcc25378bd50..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/trippy.bck.py
+++ /dev/null
@@ -1,444 +0,0 @@
-# Copyright 2021 Heinrich Heine University Duesseldorf
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import re
-import copy
-
-import torch
-from transformers import (BertConfig, BertTokenizer,
-                          RobertaConfig, RobertaTokenizer)
-
-from convlab2.dst.trippy.multiwoz.modeling_bert_dst import (BertForDST)
-from convlab2.dst.trippy.multiwoz.modeling_roberta_dst import (RobertaForDST)
-
-from convlab2.dst.dst import DST
-from convlab2.util.multiwoz.state import default_state
-from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab2.nlu.jointBERT.multiwoz import BERTNLU
-
-import pdb
-
-
-MODEL_CLASSES = {
-    'bert': (BertConfig, BertForDST, BertTokenizer),
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
-}
-
-
-TEMPLATE_STATE = {
-    "attraction": {
-        "type": "",
-        "name": "",
-        "area": ""
-    },
-    "hotel": {
-        "name": "",
-        "area": "",
-        "parking": "",
-        "price range": "",
-        "stars": "",
-        "internet": "",
-        "type": "",
-        "book stay": "",
-        "book day": "",
-        "book people": ""
-    },
-    "restaurant": {
-        "food": "",
-        "price range": "",
-        "name": "",
-        "area": "",
-        "book time": "",
-        "book day": "",
-        "book people": ""
-    },
-    "taxi": {
-        "leave at": "",
-        "destination": "",
-        "departure": "",
-        "arrive by": ""
-    },
-    "train": {
-        "leave at": "",
-        "destination": "",
-        "day": "",
-        "arrive by": "",
-        "departure": "",
-        "book people": ""
-    },
-    "hospital": {
-        "department": ""
-    }
-}
-
-
-SLOT_MAP_TRIPPY_TO_UDF = {
-    'hotel': {
-        'pricerange': 'price range',
-        'book_stay': 'book stay',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode'
-    },
-    'restaurant': {
-        'pricerange': 'price range',
-        'book_time': 'book time',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode'
-    },
-    'taxi': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'car': 'type',
-        'car type': 'type',
-        'depart': 'departure',
-        'dest': 'destination'
-    },
-    'train': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'book_people': 'book people',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'depart': 'departure',
-        'dest': 'destination',
-        'id': 'train id',
-        'people': 'book people',
-        'time': 'duration',
-        'ticket': 'price',
-        'trainid': 'train id'
-    },
-    'attraction': {
-        'post': 'postcode',
-        'addr': 'address',
-        'fee': 'entrance fee',
-        'price': 'price range'
-    },
-    'general': {},
-    'hospital': {
-        'post': 'postcode',
-        'addr': 'address'
-    },
-    'police': {
-        'post': 'postcode',
-        'addr': 'address'
-    }
-}
-
-
-class TRIPPY(DST):
-    def print_header(self):
-        print(" _________  ________  ___  ________  ________  ___    ___ ")
-        print("|\___   ___\\\   __  \|\  \|\   __  \|\   __  \|\  \  /  /|")
-        print("\|___ \  \_\ \  \|\  \ \  \ \  \|\  \ \  \|\  \ \  \/  / /")
-        print("     \ \  \ \ \   _  _\ \  \ \   ____\ \   ____\ \    / / ")
-        print("      \ \  \ \ \  \\\  \\\ \  \ \  \___|\ \  \___|\/  /  /  ")
-        print("       \ \__\ \ \__\\\ _\\\ \__\ \__\    \ \__\ __/  / /    ")
-        print("        \|__|  \|__|\|__|\|__|\|__|     \|__||\___/ /     ")
-        print("          (c) 2022 Heinrich Heine University \|___|/      ")
-        print()
-        
-    def __init__(self, model_type="roberta", model_name="roberta-base", model_path="", nlu_path=""):
-        super(TRIPPY, self).__init__()
-
-        self.print_header()
-
-        self.model_type = model_type.lower()
-        self.model_name = model_name.lower()
-        self.model_path = model_path
-        self.nlu_path = nlu_path
-
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-        self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type]
-        self.config = self.config_class.from_pretrained(self.model_path)
-        # TODO: update config (parameters)
-
-        self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-
-        self.load_weights()
-    
-    def load_weights(self):
-        self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name) # TODO: do_lower_case=args.do_lower_case ?
-        self.model = self.model_class.from_pretrained(self.model_path, config=self.config)
-        self.model.to(self.device)
-        self.model.eval()
-        self.nlu = BERTNLU(model_file=self.nlu_path) # TODO: remove, once TripPy takes over its task
-
-    def init_session(self):
-        self.state = default_state() # Initialise as empty state
-        self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE)
-        # TODO: define internal variables here as well that are tracked but not forwarded
-
-    def update(self, user_act=''):
-        prev_state = self.state
-
-        # TODO: add asserts to check format. if wrong, print suggested config
-        #print("--")
-
-        # --- Get inform memory and auxiliary features ---
-
-        # If system_action is plain text, get acts using NLU
-        if isinstance(prev_state['system_action'], str):
-            acts, _ = self.get_acts(prev_state['system_action'])
-        elif isinstance(prev_state['system_action'], list):
-            acts = prev_state['system_action']
-        else:
-            raise Exception('Unknown format for system action:', prev_state['system_action'])
-        inform_aux, inform_mem = self.get_inform_aux(acts)
-        printed_inform_mem = False
-        for s in inform_mem:
-            if inform_mem[s] != 'none':
-                if not printed_inform_mem:
-                    print("DST: inform_mem:")
-                print(s, ':', inform_mem[s])
-                printed_inform_mem = True
-
-        # --- Tokenize dialogue context and feed DST model ---
-
-        features = self.get_features(self.state['history'], ds_aux=self.ds_aux, inform_aux=inform_aux)
-        pred_states, cls_representation = self.predict(features, inform_mem)
-
-        # --- Update ConvLab-style dialogue state ---
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        user_acts = []
-        for state, value in pred_states.items():
-            if value == 'none':
-                continue
-            domain, slot = state.split('-', 1)
-            # TODO: value normalizations?
-            if domain == 'hotel' and slot == 'type':
-                value = "hotel" if value == "yes" else "guesthouse"
-            # TODO: needed?
-            if domain not in new_belief_state:
-                if domain == 'bus':
-                    continue
-                else:
-                    raise Exception('Domain <{}> not in belief state'.format(domain))
-            slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot)
-            if slot in new_belief_state[domain]:
-                new_belief_state[domain][slot] = value # TODO: value normalization?
-                user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value]) # TODO: value normalization?
-            else:
-                raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-        # TODO: only for debugging for now!
-        if re.search("not.*book.*", prev_state['user_action']) is not None:
-            user_acts.append(['inform', 'train', 'notbook', 'none'])
-
-        # Update request_state
-        #new_request_state = copy.deepcopy(prev_state['request_state'])
-
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state
-        #new_state['request_state'] = new_request_state
-
-        # Get requestable slots from NLU, until implemented in DST.
-        nlu_user_acts, nlu_system_acts = self.get_acts(user_act)
-        for e in nlu_user_acts:
-            nlu_a, nlu_d, nlu_s, nlu_v = e
-            nlu_a = nlu_a.lower()
-            nlu_d = nlu_d.lower()
-            nlu_s = nlu_s.lower()
-            nlu_v = nlu_v.lower()
-            if nlu_a != 'inform':
-                user_acts.append([nlu_a, nlu_d, SLOT_MAP_TRIPPY_TO_UDF[nlu_d].get(nlu_s, nlu_s), nlu_v])
-        new_state['system_action'] = nlu_system_acts # Empty when DST for user -> needed?
-        new_state['user_action'] = user_acts
-
-        #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu?
-
-        self.state = new_state
-
-        # (Re)set internal states
-        if self.state['terminated']:
-            #print("!!! RESET DS_AUX")
-            self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        else:
-            self.ds_aux = self.update_ds_aux(self.state['belief_state'])
-        #print("ds:", [self.ds_aux[s][0].item() for s in self.ds_aux])
-
-        #print("--")
-        return self.state
-    
-    def predict(self, features, inform_mem):
-        with torch.no_grad():
-            outputs = self.model(input_ids=features['input_ids'],
-                                 input_mask=features['attention_mask'],
-                                 inform_slot_id=features['inform_slot_id'],
-                                 diag_state=features['diag_state'])
-
-        input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked!
-
-        #total_loss = outputs[0]
-        #per_slot_per_example_loss = outputs[1]
-        per_slot_class_logits = outputs[2]
-        per_slot_start_logits = outputs[3]
-        per_slot_end_logits = outputs[4]
-        per_slot_refer_logits = outputs[5]
-
-        cls_representation = outputs[6]
-            
-        # TODO: maybe add assert to check that batch=1
-        
-        predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            start_logits = per_slot_start_logits[slot][0]
-            end_logits = per_slot_end_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            start_prediction = int(start_logits.argmax())
-            end_prediction = int(end_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if class_prediction == self.config.dst_class_types.index('dontcare'):
-                predictions[slot] = 'dontcare'
-            elif class_prediction == self.config.dst_class_types.index('copy_value'):
-                predictions[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
-                predictions[slot] = re.sub("(^| )##", "", predictions[slot])
-                if "\u0120" in predictions[slot]:
-                    predictions[slot] = re.sub(" ", "", predictions[slot])
-                    predictions[slot] = re.sub("\u0120", " ", predictions[slot])
-                    predictions[slot] = predictions[slot].strip()
-            elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'):
-                predictions[slot] = "yes" # 'true'
-            elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'):
-                predictions[slot] = "no" # 'false'
-            elif class_prediction == self.config.dst_class_types.index('inform'):
-                #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot])
-                predictions[slot] = inform_mem[slot]
-            # Referral case is handled below
-
-        # Referral case. All other slot values need to be seen first in order
-        # to be able to do this correctly.
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'):
-                # Only slots that have been mentioned before can be referred to.
-                # One can think of a situation where one slot is referred to in the same utterance.
-                # This phenomenon is however currently not properly covered in the training data
-                # label generation process.
-                predictions[slot] = predictions[self.config.dst_slot_list[refer_prediction - 1]]
-
-            if class_prediction > 0:
-                print("  ", slot, "->", class_prediction, ",", predictions[slot])
-
-        return predictions, cls_representation
-
-    def get_features(self, context, ds_aux=None, inform_aux=None):
-        assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models
-        input_tokens = ['<s>']
-        for e_itr, e in enumerate(reversed(context)):
-            #input_tokens.append(e[1].lower() if e[1] != 'null' else ' ') # TODO: normalise text
-            input_tokens.append(e[1] if e[1] != 'null' else ' ') # TODO: normalise text
-            if e_itr < 2:
-                input_tokens.append('</s> </s>')
-        if e_itr == 0:
-            input_tokens.append('</s> </s>')
-        input_tokens.append('</s>')
-        input_tokens = ' '.join(input_tokens)
-
-        # TODO: delex sys utt somehow, or refrain from using delex for sys utts?
-        features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length)
-
-        input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device)
-        attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device)
-        features = {'input_ids': input_ids,
-                    'attention_mask': attention_mask},
-                    'inform_slot_id': inform_aux,
-                    'diag_state': ds_aux}
-
-        return features
-
-    def update_ds_aux(self, state, terminated=False):
-        ds_aux = copy.deepcopy(self.ds_aux) # TODO: deepcopy necessary? just update class variable?
-        for slot in self.config.dst_slot_list:
-            d, s = slot.split('-')
-            ds_aux[slot][0] = int(state[d][SLOT_MAP_TRIPPY_TO_UDF[d].get(s, s)] != '')
-        return ds_aux
-
-    # TODO: consider "booked" values?
-    def get_inform_aux(self, state):
-        inform_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        inform_mem = {slot: 'none' for slot in self.config.dst_slot_list}
-        for e in state:
-            #print(e)
-            #pdb.set_trace()
-            a, d, s, v = e
-            if a in ['inform', 'recommend', 'select', 'book', 'offerbook']:
-                #ds_d = d.lower()
-                #if s in REF_SYS_DA[d]:
-                #    ds_s = REF_SYS_DA[d][s]
-                #elif s in REF_SYS_DA['Booking']:
-                #    ds_s = "book_" + REF_SYS_DA['Booking'][s]
-                #else:
-                #    ds_s = s.lower()
-                #    #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d))
-                slot = "%s-%s" % (d, s)
-                if slot in inform_aux:
-                    inform_aux[slot][0] = 1
-                    inform_mem[slot] = v
-        return inform_aux, inform_mem
-
-    # TODO: fix, still a mess...
-    def get_acts(self, user_act):
-        context = self.state['history']
-        if context:
-            if context[-1][0] != 'sys':
-                system_act = ''
-                context = [t for s,t in context]
-            else:
-                system_act = context[-1][-1]
-                context = [t for s,t in context[:-1]]
-        else:
-            system_act = ''
-            context = ['']
-
-        #print("  SYS:", system_act, context)
-        system_acts = self.nlu.predict(system_act, context=context)
-
-        context.append(system_act)
-        #print("  USR:", user_act, context)
-        user_acts = self.nlu.predict(user_act, context=context)
-        
-        return user_acts, system_acts
-
-
-# if __name__ == "__main__":
-#     tracker = TRIPPY(model_type='roberta', model_path='/path/to/model',
-#                         nlu_path='/path/to/nlu')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     state = tracker.update('If you have something Asian that would be great.')
-#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     state = tracker.update('Great. Where are they located?')
-#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/dst/trippy/multiwoz/trippy.py b/convlab/dst/trippy/multiwoz/trippy.py
deleted file mode 100644
index 6bebf35341e4284ca6c911a7f797149b6c265069..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/trippy.py
+++ /dev/null
@@ -1,653 +0,0 @@
-# Copyright 2021 Heinrich Heine University Duesseldorf
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import json
-import copy
-
-import torch
-from transformers import (BertConfig, BertTokenizer,
-                          RobertaConfig, RobertaTokenizer)
-
-from convlab.dst.trippy.multiwoz.modeling_bert_dst import (BertForDST)
-from convlab.dst.trippy.multiwoz.modeling_roberta_dst import (RobertaForDST)
-
-from convlab.dst.dst import DST
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab.nlu.jointBERT.multiwoz import BERTNLU
-from convlab.nlg.template.multiwoz import TemplateNLG
-from convlab.dst.rule.multiwoz.dst_util import normalize_value
-
-from convlab.util import relative_import_module_from_unified_datasets
-ONTOLOGY = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'ontology')
-TEMPLATE_STATE = ONTOLOGY['state']
-
-import pdb
-import time
-
-
-MODEL_CLASSES = {
-    'bert': (BertConfig, BertForDST, BertTokenizer),
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
-}
-
-
-SLOT_MAP_TRIPPY_TO_UDF = {
-    'hotel': {
-        'pricerange': 'price range',
-        'book_stay': 'book stay',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode',
-        'price': 'price range',
-        'people': 'book people'
-    },
-    'restaurant': {
-        'pricerange': 'price range',
-        'book_time': 'book time',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode',
-        'price': 'price range',
-        'people': 'book people'
-    },
-    'taxi': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'car': 'type',
-        'car type': 'type',
-        'depart': 'departure',
-        'dest': 'destination'
-    },
-    'train': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'book_people': 'book people',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'depart': 'departure',
-        'dest': 'destination',
-        'id': 'train id',
-        'people': 'book people',
-        'time': 'duration',
-        'ticket': 'price',
-        'trainid': 'train id'
-    },
-    'attraction': {
-        'post': 'postcode',
-        'addr': 'address',
-        'fee': 'entrance fee',
-        'price': 'entrance fee'
-    },
-    'general': {},
-    'hospital': {
-        'post': 'postcode',
-        'addr': 'address'
-    },
-    'police': {
-        'post': 'postcode',
-        'addr': 'address'
-    }
-}
-
-
-class TRIPPY(DST):
-    def print_header(self):
-        print(" _________  ________  ___  ________  ________  ___    ___ ")
-        print("|\___   ___\\\   __  \|\  \|\   __  \|\   __  \|\  \  /  /|")
-        print("\|___ \  \_\ \  \|\  \ \  \ \  \|\  \ \  \|\  \ \  \/  / /")
-        print("     \ \  \ \ \   _  _\ \  \ \   ____\ \   ____\ \    / / ")
-        print("      \ \  \ \ \  \\\  \\\ \  \ \  \___|\ \  \___|\/  /  /  ")
-        print("       \ \__\ \ \__\\\ _\\\ \__\ \__\    \ \__\ __/  / /    ")
-        print("        \|__|  \|__|\|__|\|__|\|__|     \|__||\___/ /     ")
-        print("          (c) 2022 Heinrich Heine University \|___|/      ")
-        print()
-
-    def print_dialog(self, hst):
-        #print("Dialogue %s, turn %s:" % (self.global_diag_cnt, int(len(hst) / 2) - 1))
-        print("Dialogue %s, turn %s:" % (self.global_diag_cnt, self.global_turn_cnt))
-        for utt in hst[:-2]:
-            print("  \033[92m%s\033[0m" % (utt))
-        if len(hst) > 1:
-            print(" ", hst[-2])
-            print(" ", hst[-1])
-
-    def print_inform_memory(self, inform_mem):
-        print("Inform memory:")
-        is_all_none = True
-        for s in inform_mem:
-            if inform_mem[s] != 'none':
-                print("  %s = %s" % (s, inform_mem[s]))
-                is_all_none = False
-        if is_all_none:
-            print("  -")
-
-    def eval_user_acts(self, user_act, user_acts):
-        print("User acts:")
-        for ua in user_acts:
-            if ua not in user_act:
-                print("  \033[33m%s\033[0m" % (ua))
-            else:
-                print("  \033[92m%s\033[0m" % (ua))
-        for ua in user_act:
-            if ua not in user_acts:
-                print("  \033[91m%s\033[0m" % (ua))
-
-    def eval_dialog_state(self, state_updates, new_belief_state):
-        print("Dialogue state:")
-        for d in self.gt_belief_state:
-            print("  %s:" % (d))
-            for s in new_belief_state[d]:
-                is_printed = False
-                is_updated = False
-                if state_updates[d][s] > 0:
-                    is_updated = True
-                if is_updated:
-                    print("\033[3m", end='')
-                if new_belief_state[d][s] != self.gt_belief_state[d][s]:
-                    self.global_eval_stats[d][s]['FP'] += 1
-                    if self.gt_belief_state[d][s] == '':
-                        print("    \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='')
-                    else:
-                        print("    \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]), end='')
-                        self.global_eval_stats[d][s]['FN'] += 1
-                    is_printed = True
-                elif new_belief_state[d][s] != '':
-                    print("    \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='')
-                    self.global_eval_stats[d][s]['TP'] += 1
-                    is_printed = True
-                if is_updated:
-                    print(" (%s)" % (self.config.dst_class_types[state_updates[d][s]]))
-                elif is_printed:
-                    print()
-
-    def eval_print_stats(self):
-        print("Statistics:")
-        for d in self.global_eval_stats:
-            for s in self.global_eval_stats[d]:
-                TP = self.global_eval_stats[d][s]['TP']
-                FP = self.global_eval_stats[d][s]['FP']
-                FN = self.global_eval_stats[d][s]['FN']
-                prec = TP / ( TP + FP + 1e-8)
-                rec = TP / ( TP + FN + 1e-8)
-                f1 = 2 * ((prec * rec) / (prec + rec + 1e-8))
-                print("  %s %s Recall: %.2f, Precision: %.2f, F1: %.2f" % (d, s, rec, prec, f1))
-
-    def __init__(self, model_type="roberta",
-                 model_name="roberta-base",
-                 model_path="",
-                 nlu_path="",
-                 no_eval=False,
-                 no_history=False,
-                 no_normalize_value=False,
-                 gt_user_acts=False,
-                 gt_ds=False,
-                 gt_request_acts=False):
-        super(TRIPPY, self).__init__()
-
-        self.print_header()
-
-        self.model_type = model_type.lower()
-        self.model_name = model_name.lower()
-        self.model_path = model_path
-        self.nlu_path = nlu_path
-        self.no_eval = no_eval
-        self.no_history = no_history
-        self.no_normalize_value = no_normalize_value
-        self.gt_user_acts = gt_user_acts
-        self.gt_ds = gt_ds
-        self.gt_request_acts = gt_request_acts
-
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-        self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type]
-        self.config = self.config_class.from_pretrained(self.model_path, local_files_only=True) # TODO: parameterize
-        # TODO: update config (parameters)
-
-        # For debugging only
-        path = os.path.dirname(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        path = os.path.join(path, 'data/multiwoz/value_dict.json')
-        self.value_dict = json.load(open(path))
-
-        self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-
-        self.global_eval_stats = copy.deepcopy(TEMPLATE_STATE)
-        for d in self.global_eval_stats:
-            for s in self.global_eval_stats[d]:
-                self.global_eval_stats[d][s] = {'TP': 0, 'FP': 0, 'FN': 0}
-        self.global_diag_cnt = -3
-        self.global_turn_cnt = -1
-
-        self.load_weights()
-    
-    def load_weights(self):
-        self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name, local_files_only=True) # TODO: do_lower_case=args.do_lower_case ? # TODO: parameterize
-        self.model = self.model_class.from_pretrained(self.model_path, config=self.config, local_files_only=True) # TODO: parameterize
-        self.model.to(self.device)
-        self.model.eval()
-        self.nlu = BERTNLU(model_file=self.nlu_path) # This is used for internal evaluation
-        self.nlg_usr = TemplateNLG(is_user=True)
-        self.nlg_sys = TemplateNLG(is_user=False)
-        
-    def init_session(self):
-        self.state = default_state() # Initialise as empty state
-        self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE)
-        self.nlg_history = []
-        self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        self.gt_belief_state = copy.deepcopy(TEMPLATE_STATE)
-        self.global_diag_cnt += 1
-        self.global_turn_cnt = -1
-
-    def update_gt_belief_state(self, user_act):
-        for intent, domain, slot, value in user_act:
-            if domain == 'police':
-                continue
-            if intent == 'inform':
-                if slot == 'none' or slot == '':
-                    continue
-                domain_dic = self.gt_belief_state[domain]
-                if slot in domain_dic:
-                    #nvalue = normalize_value(self.value_dict, domain, slot, value)
-                    self.gt_belief_state[domain][slot] = value # nvalue
-                #elif slot != 'none' or slot != '':
-                #    raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-    # TODO: receive semantic, convert semantic -> text -> semantic for sanity check
-    # For TripPy: receive semantic, convert semantic -> text (with context) as input to DST
-    # - allows for accuracy estimates
-    # - allows isolating inform prediction from request prediction (as can be taken from input for sanity check)
-    def update(self, user_act=''):
-        def normalize_values(text):
-            text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"}
-            #text = re.sub("^(\d{2}) : (\d{2})$", r"\1:\2", text) # Times
-            #text = re.sub(" ?' ?s", "s", text) # Genitive
-            text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters
-            text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe
-            if text in text_to_num:
-                text = text_to_num[text]
-            return text
-
-        prev_state = self.state
-
-        if not self.no_eval:
-            print("-" * 40)
-
-        #nlg_history = []
-        ##for h in prev_state['history'][-2:]: # TODO: make this an option?
-        #for h in prev_state['history']:
-        #    nlg_history.append([h[0], self.get_text(h[1], is_user=(h[0]=='user'))])
-        ## Special case: at the beginning of the dialog the history might be empty (depending on policy)
-        #if len(nlg_history) == 0:
-        #    nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False)])
-        #    nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True)])
-        if self.no_history:
-            self.nlg_history = []
-        self.nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False, normalize=True)])
-        self.nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True, normalize=True)])
-        self.global_turn_cnt += 1
-        if not self.no_eval:
-            self.print_dialog(self.nlg_history)
-
-        # --- Get inform memory and auxiliary features ---
-
-        # If system_action is plain text, get acts using NLU
-        if isinstance(prev_state['user_action'], str):
-            u_acts, s_acts = self.get_acts()
-        elif isinstance(prev_state['user_action'], list):
-            u_acts = user_act # same as prev_state['user_action']
-            s_acts = prev_state['system_action']
-        else:
-            raise Exception('Unknown format for user action:', prev_state['user_action'])
-        inform_aux, inform_mem = self.get_inform_aux(s_acts)
-        if not self.no_eval:
-            self.print_inform_memory(inform_mem)
-
-        # --- Tokenize dialogue context and feed DST model ---
-
-        ##features = self.get_features(self.state['history'], ds_aux=self.ds_aux, inform_aux=inform_aux)
-        used_ds_aux = None if not self.config.dst_class_aux_feats_ds else self.ds_aux
-        used_inform_aux = None if not self.config.dst_class_aux_feats_inform else inform_aux
-        features = self.get_features(self.nlg_history, ds_aux=used_ds_aux, inform_aux=used_inform_aux)
-        pred_states, pred_classes, cls_representation = self.predict(features, inform_mem)
-
-        # --- Update ConvLab-style dialogue state ---
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        user_acts = []
-        for state, value in pred_states.items():
-            value = normalize_values(value)
-            if value == 'none':
-                continue
-            domain, slot = state.split('-', 1)
-            # Value normalization # TODO: according to trippy rules?
-            if domain == 'hotel' and slot == 'type':
-                value = "hotel" if value == "yes" else "guesthouse"
-            if not self.no_normalize_value:
-                value = normalize_value(self.value_dict, domain, slot, value)
-            slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot)
-            if slot in new_belief_state[domain]:
-                new_belief_state[domain][slot] = value
-                user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value])
-            else:
-                raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-        self.update_gt_belief_state(u_acts) # For evaluation
-
-        # BELIEF STATE UPDATE
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state # TripPy
-        if self.gt_ds:
-            new_state['belief_state'] = self.gt_belief_state # Rule
-
-        state_updates = {}
-        for cl in pred_classes:
-            cl_d, cl_s = cl.split('-')
-            # Some reformatting for the evaluation further down
-            if cl_d not in state_updates:
-                state_updates[cl_d] = {}
-            state_updates[cl_d][SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s)] = pred_classes[cl]
-            # We care only about the requestable slots here
-            if self.config.dst_class_types[pred_classes[cl]] != 'request':
-                continue
-            if cl_d != 'general' and cl_s == 'none':
-                user_acts.append(['inform', cl_d, '', ''])
-            elif cl_d == 'general':
-                user_acts.append([SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), 'general', '', ''])
-                #user_acts.append(['bye', 'general', '', '']) # Map "thank" to "bye"? Mind "hello" as well!
-            elif not self.gt_request_acts:
-                user_acts.append(['request', cl_d, SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), ''])
-
-        # TODO: For debugging -> doesn't make a difference
-        #for e in user_act:
-        #    nlu_a, nlu_d, nlu_s, nlu_v = e
-        #    nlu_a = nlu_a.lower()
-        #    nlu_d = nlu_d.lower()
-        #    nlu_s = nlu_s.lower()
-        #    nlu_v = nlu_v.lower()
-        #    # Mostly requestables
-        #    if nlu_a == 'inform' and nlu_d == 'train' and nlu_s == 'notbook':
-        #        user_acts.append([nlu_a, nlu_d, 'NotBook', 'none'])
-
-        # TODO: fix # TODO: still needed?
-        if 0:
-            domain = ''
-            is_inform = False
-            is_request = False
-            is_notbook = False
-            for act in user_act:
-                _, _, slot, _ = act
-                if slot == "NotBook":
-                    is_notbook = True
-            for act in user_acts:
-                intent, domain, slot, value = act
-                if intent == 'inform':
-                    is_inform = True
-                if intent == 'request':
-                    is_request = True
-            if is_inform and not is_request and not is_notbook and domain != '' and domain != "general":
-                user_acts = [['inform', domain, '', '']] + user_acts
-
-        # USER ACTS UPDATE
-        new_state['user_action'] = user_acts # TripPy
-        # ONLY FOR DEBUGGING
-        if self.gt_user_acts:
-            new_state['user_action'] = u_acts # Rule
-        elif self.gt_request_acts:
-            for e in u_acts:
-                ea, _, _, _ = e
-                if ea == 'request':
-                    user_acts.append(e)
-
-        if not self.no_eval:
-            self.eval_user_acts(u_acts, user_acts)
-            self.eval_dialog_state(state_updates, new_belief_state)
-
-        #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu?
-
-        self.state = new_state
-
-        # Print eval statistics
-        if self.state['terminated'] and not self.no_eval:
-            print("Booked:", self.state['booked'])
-            self.eval_print_stats()
-            print("=" * 10, "End of the dialogue", "=" * 10)
-            #self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        #else:
-        self.ds_aux = self.update_ds_aux(self.state['belief_state'], pred_states)
-        #print("ds:", [self.ds_aux[s][0].item() for s in self.ds_aux])
-
-        return self.state
-    
-    def predict(self, features, inform_mem):
-        #aaa_time = time.time()
-        with torch.no_grad():
-            outputs = self.model(input_ids=features['input_ids'],
-                                 input_mask=features['attention_mask'],
-                                 inform_slot_id=features['inform_slot_id'],
-                                 diag_state=features['diag_state'])
-        #bbb_time = time.time()
-
-        input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked!
-
-        #total_loss = outputs[0]
-        #per_slot_per_example_loss = outputs[1]
-        per_slot_class_logits = outputs[2]
-        per_slot_start_logits = outputs[3]
-        per_slot_end_logits = outputs[4]
-        per_slot_refer_logits = outputs[5]
-
-        cls_representation = outputs[6]
-            
-        # TODO: maybe add assert to check that batch=1
-        
-        predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-        class_predictions = {slot: 0 for slot in self.config.dst_slot_list}
-
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            start_logits = per_slot_start_logits[slot][0]
-            end_logits = per_slot_end_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            start_prediction = int(start_logits.argmax())
-            end_prediction = int(end_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if class_prediction == self.config.dst_class_types.index('dontcare'):
-                predictions[slot] = 'dontcare'
-            elif class_prediction == self.config.dst_class_types.index('copy_value'):
-                predictions[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
-                predictions[slot] = re.sub("(^| )##", "", predictions[slot])
-                if "\u0120" in predictions[slot]:
-                    predictions[slot] = re.sub(" ", "", predictions[slot])
-                    predictions[slot] = re.sub("\u0120", " ", predictions[slot])
-                    predictions[slot] = predictions[slot].strip()
-            elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'):
-                predictions[slot] = "yes" # 'true'
-            elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'):
-                predictions[slot] = "no" # 'false'
-            elif class_prediction == self.config.dst_class_types.index('inform'):
-                #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot])
-                predictions[slot] = inform_mem[slot]
-            # Referral case is handled below
-
-        # Referral case. All other slot values need to be seen first in order
-        # to be able to do this correctly.
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'):
-                # Only slots that have been mentioned before can be referred to.
-                # First try to resolve a reference within the same turn. (One can think of a situation
-                # where one slot is referred to in the same utterance. This phenomenon is however
-                # currently not properly covered in the training data label generation process)
-                # Then try to resolve a reference given the current dialogue state.
-                predictions[slot] = predictions[self.config.dst_slot_list[refer_prediction - 1]]
-                if predictions[slot] == 'none':
-                    referred_slot = self.config.dst_slot_list[refer_prediction - 1]
-                    referred_slot_d, referred_slot_s = referred_slot.split('-')
-                    referred_slot_s = SLOT_MAP_TRIPPY_TO_UDF[referred_slot_d].get(referred_slot_s, referred_slot_s)
-                    if self.state['belief_state'][referred_slot_d][referred_slot_s] != '':
-                        predictions[slot] = self.state['belief_state'][referred_slot_d][referred_slot_s]
-                if predictions[slot] == 'none':
-                    ref_slot = self.config.dst_slot_list[refer_prediction - 1]
-                    if ref_slot == 'hotel-name':
-                        predictions[slot] = 'the hotel'
-                    elif ref_slot == 'restaurant-name':
-                        predictions[slot] = 'the restaurant'
-                    elif ref_slot == 'attraction-name':
-                        predictions[slot] = 'the attraction'
-                    elif ref_slot == 'hotel-area':
-                        predictions[slot] = 'same area as the hotel'
-                    elif ref_slot == 'restaurant-area':
-                        predictions[slot] = 'same area as the restaurant'
-                    elif ref_slot == 'attraction-area':
-                        predictions[slot] = 'same area as the attraction'
-                    elif ref_slot == 'hotel-pricerange':
-                        predictions[slot] = 'in the same price range as the hotel'
-                    elif ref_slot == 'restaurant-pricerange':
-                        predictions[slot] = 'in the same price range as the restaurant'
-
-            class_predictions[slot] = class_prediction
-            #if class_prediction > 0:
-            #    print("  ", slot, "->", class_prediction, ",", predictions[slot])
-        #ccc_time = time.time()
-        #print("TIME:", bbb_time - aaa_time, ccc_time - bbb_time)
-
-        return predictions, class_predictions, cls_representation
-
-    def get_features(self, context, ds_aux=None, inform_aux=None):
-        assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models
-        input_tokens = ['<s>']
-        e_itr = 0
-        for e_itr, e in enumerate(reversed(context)):
-            #input_tokens.append(e[1].lower() if e[1] != 'null' else ' ') # TODO: normalise text
-            input_tokens.append(e[1] if e[1] != 'null' else ' ') # TODO: normalise text
-            if e_itr < 2:
-                input_tokens.append('</s> </s>')
-        if e_itr == 0:
-            input_tokens.append('</s> </s>')
-        input_tokens.append('</s>')
-        input_tokens = ' '.join(input_tokens)
-
-        # TODO: delex sys utt somehow, or refrain from using delex for sys utts?
-        features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length)
-
-        input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device)
-        attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device)
-        features = {'input_ids': input_ids,
-                    'attention_mask': attention_mask,
-                    'inform_slot_id': inform_aux,
-                    'diag_state': ds_aux}
-
-        return features
-
-    def update_ds_aux(self, state, pred_states, terminated=False):
-        ds_aux = copy.deepcopy(self.ds_aux) # TODO: deepcopy necessary? just update class variable?
-        for slot in self.config.dst_slot_list:
-            d, s = slot.split('-')
-            if d in state and s in state[d]:
-                ds_aux[slot][0] = int(state[d][SLOT_MAP_TRIPPY_TO_UDF[d].get(s, s)] != '')
-            else:
-                # Requestable slots are not found in the DS
-                ds_aux[slot][0] = int(pred_states[slot] != 'none')
-        return ds_aux
-
-    # TODO: consider "booked" values?
-    def get_inform_aux(self, state):
-        inform_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        inform_mem = {slot: 'none' for slot in self.config.dst_slot_list}
-        for e in state:
-            #print(e)
-            #pdb.set_trace()
-            a, d, s, v = e
-            if a in ['inform', 'recommend', 'select', 'book', 'offerbook']:
-                #ds_d = d.lower()
-                #if s in REF_SYS_DA[d]:
-                #    ds_s = REF_SYS_DA[d][s]
-                #elif s in REF_SYS_DA['Booking']:
-                #    ds_s = "book_" + REF_SYS_DA['Booking'][s]
-                #else:
-                #    ds_s = s.lower()
-                #    #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d))
-                slot = "%s-%s" % (d, s)
-                if slot in inform_aux:
-                    inform_aux[slot][0] = 1
-                    inform_mem[slot] = v
-        return inform_aux, inform_mem
-
-    def get_acts(self):
-        context = self.state['history']
-        if context[-1][0] != 'user':
-            raise Exception("Wrong order of utterances, check your input.")
-        system_act = context[-2][-1]
-        user_act = context[-1][-1]
-        system_context = [t for s,t in context[:-2]]
-        user_context = [t for s,t in context[:-1]]
-
-        #print("  SYS:", system_act, system_context)
-        system_acts = self.nlu.predict(system_act, context=system_context)
-
-        #print("  USR:", user_act, user_context)
-        user_acts = self.nlu.predict(user_act, context=user_context)
-        
-        return user_acts, system_acts
-
-    def get_text(self, act, is_user=False, normalize=False):
-        if act == 'null':
-            return 'null'
-        if not isinstance(act, list):
-            result = act
-        elif is_user:
-            result = self.nlg_usr.generate(act)
-        else:
-            result = self.nlg_sys.generate(act)
-        if normalize:
-            return self.normalize_text(result)
-        else:
-            return result
-
-    def normalize_text(self, text):
-        norm_text = text.lower()
-        #norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference
-        #norm_text = re.sub("ca not", "cannot", norm_text)
-        norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0])
-        return norm_text
-
-
-# if __name__ == "__main__":
-#     tracker = TRIPPY(model_type='roberta', model_path='/path/to/model',
-#                         nlu_path='/path/to/nlu')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     state = tracker.update('If you have something Asian that would be great.')
-#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     state = tracker.update('Great. Where are they located?')
-#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/dst/trippy/multiwoz/trippy.py~ b/convlab/dst/trippy/multiwoz/trippy.py~
deleted file mode 100644
index fd1317036efafbf9c0b9ed21ba1aa7ea15955934..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/trippy.py~
+++ /dev/null
@@ -1,547 +0,0 @@
-# Copyright 2021 Heinrich Heine University Duesseldorf
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import re
-import copy
-
-import torch
-from transformers import (BertConfig, BertTokenizer,
-                          RobertaConfig, RobertaTokenizer)
-
-from convlab.dst.trippy.multiwoz.modeling_bert_dst import (BertForDST)
-from convlab.dst.trippy.multiwoz.modeling_roberta_dst import (RobertaForDST)
-
-from convlab.dst.dst import DST
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab.nlu.jointBERT.multiwoz import BERTNLU
-from convlab.nlg.template.multiwoz import TemplateNLG
-
-from convlab.util import relative_import_module_from_unified_datasets
-ONTOLOGY = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'ontology')
-TEMPLATE_STATE = ONTOLOGY['state']
-
-# For debugging only
-from convlab.dst.rule.multiwoz.dst_util import normalize_value
-import os
-import json
-
-import pdb
-
-
-MODEL_CLASSES = {
-    'bert': (BertConfig, BertForDST, BertTokenizer),
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
-}
-
-
-SLOT_MAP_TRIPPY_TO_UDF = {
-    'hotel': {
-        'pricerange': 'price range',
-        'book_stay': 'book stay',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode',
-        'price': 'price range',
-        'people': 'book people'
-    },
-    'restaurant': {
-        'pricerange': 'price range',
-        'book_time': 'book time',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode',
-        'price': 'price range',
-        'people': 'book people'
-    },
-    'taxi': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'car': 'type',
-        'car type': 'type',
-        'depart': 'departure',
-        'dest': 'destination'
-    },
-    'train': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'book_people': 'book people',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'depart': 'departure',
-        'dest': 'destination',
-        'id': 'train id',
-        'people': 'book people',
-        'time': 'duration',
-        'ticket': 'price',
-        'trainid': 'train id'
-    },
-    'attraction': {
-        'post': 'postcode',
-        'addr': 'address',
-        'fee': 'entrance fee',
-        'price': 'entrance fee'
-    },
-    'general': {},
-    'hospital': {
-        'post': 'postcode',
-        'addr': 'address'
-    },
-    'police': {
-        'post': 'postcode',
-        'addr': 'address'
-    }
-}
-
-
-class TRIPPY(DST):
-    def print_header(self):
-        print(" _________  ________  ___  ________  ________  ___    ___ ")
-        print("|\___   ___\\\   __  \|\  \|\   __  \|\   __  \|\  \  /  /|")
-        print("\|___ \  \_\ \  \|\  \ \  \ \  \|\  \ \  \|\  \ \  \/  / /")
-        print("     \ \  \ \ \   _  _\ \  \ \   ____\ \   ____\ \    / / ")
-        print("      \ \  \ \ \  \\\  \\\ \  \ \  \___|\ \  \___|\/  /  /  ")
-        print("       \ \__\ \ \__\\\ _\\\ \__\ \__\    \ \__\ __/  / /    ")
-        print("        \|__|  \|__|\|__|\|__|\|__|     \|__||\___/ /     ")
-        print("          (c) 2022 Heinrich Heine University \|___|/      ")
-        print()
-
-    def print_dialog(self, hst):
-        print("Dialogue turn %s:" % (int(len(hst) / 2) - 1))
-        for utt in hst[:-2]:
-            print("  \033[92m%s\033[0m" % (utt))
-        if len(hst) > 1:
-            print(" ", hst[-2])
-            print(" ", hst[-1])
-
-    def print_inform_memory(self, inform_mem):
-        print("Inform memory:")
-        is_all_none = True
-        for s in inform_mem:
-            if inform_mem[s] != 'none':
-                print("  %s = %s" % (s, inform_mem[s]))
-                is_all_none = False
-        if is_all_none:
-            print("  -")
-
-    def eval_user_acts(self, user_act, user_acts):
-        print("User acts:")
-        for ua in user_acts:
-            if ua not in user_act:
-                print("  \033[33m%s\033[0m" % (ua))
-            else:
-                print("  \033[92m%s\033[0m" % (ua))
-        for ua in user_act:
-            if ua not in user_acts:
-                print("  \033[91m%s\033[0m" % (ua))
-
-    def eval_dialog_state(self, state_updates, new_belief_state):
-        print("Dialogue state:")
-        for d in self.gt_belief_state:
-            print("  %s:" % (d))
-            for s in new_belief_state[d]:
-                is_printed = False
-                is_updated = False
-                if state_updates[d][s] > 0:
-                    is_updated = True
-                if is_updated:
-                    print("\033[3m", end='')
-                if new_belief_state[d][s] != self.gt_belief_state[d][s]:
-                    if self.gt_belief_state[d][s] == '':
-                        print("    \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='')
-                    else:
-                        print("    \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]), end='')
-                    is_printed = True
-                elif new_belief_state[d][s] != '':
-                    print("    \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='')
-                    is_printed = True
-                if is_updated:
-                    print(" (%s)" % (self.config.dst_class_types[state_updates[d][s]]))
-                elif is_printed:
-                    print()
-
-    def __init__(self, model_type="roberta", model_name="roberta-base", model_path="", nlu_path=""):
-        super(TRIPPY, self).__init__()
-
-        self.print_header()
-
-        self.model_type = model_type.lower()
-        self.model_name = model_name.lower()
-        self.model_path = model_path
-        self.nlu_path = nlu_path
-
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-        self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type]
-        self.config = self.config_class.from_pretrained(self.model_path)
-        # TODO: update config (parameters)
-
-        # For debugging only
-        path = os.path.dirname(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        path = os.path.join(path, 'data/multiwoz/value_dict.json')
-        self.value_dict = json.load(open(path))
-
-        self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-
-        self.load_weights()
-    
-    def load_weights(self):
-        self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name) # TODO: do_lower_case=args.do_lower_case ?
-        self.model = self.model_class.from_pretrained(self.model_path, config=self.config)
-        self.model.to(self.device)
-        self.model.eval()
-        self.nlu = BERTNLU(model_file=self.nlu_path) # TODO: remove, once TripPy takes over its task
-        self.nlg_usr = TemplateNLG(is_user=True)
-        self.nlg_sys = TemplateNLG(is_user=False)
-        
-    def init_session(self):
-        self.state = default_state() # Initialise as empty state
-        self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE)
-        self.gt_belief_state = copy.deepcopy(TEMPLATE_STATE)
-
-    def update_gt_belief_state(self, user_act):
-        print(user_act)
-        for intent, domain, slot, value in user_act:
-            if domain == 'police':
-                continue
-            if intent == 'inform':
-                if slot == 'none' or slot == '':
-                    continue
-                domain_dic = self.gt_belief_state[domain]
-                if slot in domain_dic:
-                    #nvalue = normalize_value(self.value_dict, domain, slot, value)
-                    self.gt_belief_state[domain][slot] = value # nvalue
-                #elif slot != 'none' or slot != '':
-                #    raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-    # TODO: receive semantic, convert semantic -> text -> semantic for sanity check
-    # For TripPy: receive semantic, convert semantic -> text (with context) as input to DST
-    # - allows for accuracy estimates
-    # - allows isolating inform prediction from request prediction (as can be taken from input for sanity check)
-    def update(self, user_act=''):
-        prev_state = self.state
-
-        user_act2 = self.get_acts(user_act)
-        
-        print("-" * 40)
-
-        nlg_history = []
-        #for h in prev_state['history'][-2:]: # TODO: make this an option?
-        for h in prev_state['history']:
-            nlg_history.append([h[0], self.get_text(h[1], is_user=(h[0]=='user'))])
-        self.print_dialog(nlg_history)
-
-        # --- Get inform memory and auxiliary features ---
-
-        # If system_action is plain text, get acts using NLU
-        if isinstance(prev_state['system_action'], str):
-            acts, _ = self.get_acts(prev_state['system_action'])
-        elif isinstance(prev_state['system_action'], list):
-            acts = prev_state['system_action']
-        else:
-            raise Exception('Unknown format for system action:', prev_state['system_action'])
-        inform_aux, inform_mem = self.get_inform_aux(acts)
-        self.print_inform_memory(inform_mem)
-
-        # --- Tokenize dialogue context and feed DST model ---
-
-        ##features = self.get_features(self.state['history'], ds_aux=self.ds_aux, inform_aux=inform_aux)
-        used_ds_aux = None if not self.config.dst_class_aux_feats_ds else self.ds_aux
-        used_inform_aux = None if not self.config.dst_class_aux_feats_inform else inform_aux
-        features = self.get_features(nlg_history, ds_aux=used_ds_aux, inform_aux=used_inform_aux)
-        pred_states, class_preds, cls_representation = self.predict(features, inform_mem)
-
-        # --- Update ConvLab-style dialogue state ---
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        user_acts = []
-        for state, value in pred_states.items():
-            if value == 'none':
-                continue
-            domain, slot = state.split('-', 1)
-            # TODO: value normalizations?
-            if domain == 'hotel' and slot == 'type':
-                value = "hotel" if value == "yes" else "guesthouse"
-            orig_slot = slot
-            slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot)
-            if slot in new_belief_state[domain]:
-                new_belief_state[domain][slot] = value # TODO: value normalization?
-                user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value]) # TODO: value normalization?
-            else:
-                raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-        self.update_gt_belief_state(user_act2) # For evaluation
-
-        # BELIEF STATE UPDATE
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state # TripPy
-        #new_state['belief_state'] = self.gt_belief_state # Rule
-
-        state_updates = {}
-        for cl in class_preds:
-            cl_d, cl_s = cl.split('-')
-            # Some reformatting for the evaluation further down
-            if cl_d not in state_updates:
-                state_updates[cl_d] = {}
-            state_updates[cl_d][SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s)] = class_preds[cl]
-            # We care only about the requestable slots here
-            if self.config.dst_class_types[class_preds[cl]] != 'request':
-                continue
-            if cl_d != 'general' and cl_s == 'none':
-                user_acts.append(['inform', cl_d, '', ''])
-            elif cl_d == 'general':
-                user_acts.append([SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), 'general', '', ''])
-                #user_acts.append(['bye', 'general', '', '']) # Map "thank" to "bye"? Mind "hello" as well!
-            else:
-                user_acts.append(['request', cl_d, SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), ''])
-
-        # TODO: For debugging -> doesn't make a difference
-        #for e in user_act:
-        #    nlu_a, nlu_d, nlu_s, nlu_v = e
-        #    nlu_a = nlu_a.lower()
-        #    nlu_d = nlu_d.lower()
-        #    nlu_s = nlu_s.lower()
-        #    nlu_v = nlu_v.lower()
-        #    # Mostly requestables
-        #    if nlu_a == 'inform' and nlu_d == 'train' and nlu_s == 'notbook':
-        #        user_acts.append([nlu_a, nlu_d, 'NotBook', 'none'])
-
-        # TODO: fix # TODO: still needed?
-        if 0:
-            domain = ''
-            is_inform = False
-            is_request = False
-            is_notbook = False
-            for act in user_act:
-                _, _, slot, _ = act
-                if slot == "NotBook":
-                    is_notbook = True
-            for act in user_acts:
-                intent, domain, slot, value = act
-                if intent == 'inform':
-                    is_inform = True
-                if intent == 'request':
-                    is_request = True
-            if is_inform and not is_request and not is_notbook and domain != '' and domain != "general":
-                user_acts = [['inform', domain, '', '']] + user_acts
-
-        # USER ACTS UPDATE
-        new_state['user_action'] = user_acts # TripPy
-        #new_state['user_action'] = user_act # Rule
-
-        self.eval_user_acts(user_act2, user_acts)
-        self.eval_dialog_state(state_updates, new_belief_state)
-
-        #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu?
-
-        self.state = new_state
-
-        # (Re)set internal states
-        if self.state['terminated']:
-            print("=" * 15, "End of dialog", "=" * 15)
-            self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        else:
-            self.ds_aux = self.update_ds_aux(self.state['belief_state'], pred_states)
-        #print("ds:", [self.ds_aux[s][0].item() for s in self.ds_aux])
-
-        return self.state
-    
-    def predict(self, features, inform_mem):
-        with torch.no_grad():
-            outputs = self.model(input_ids=features['input_ids'],
-                                 input_mask=features['attention_mask'],
-                                 inform_slot_id=features['inform_slot_id'],
-                                 diag_state=features['diag_state'])
-
-        input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked!
-
-        #total_loss = outputs[0]
-        #per_slot_per_example_loss = outputs[1]
-        per_slot_class_logits = outputs[2]
-        per_slot_start_logits = outputs[3]
-        per_slot_end_logits = outputs[4]
-        per_slot_refer_logits = outputs[5]
-
-        cls_representation = outputs[6]
-            
-        # TODO: maybe add assert to check that batch=1
-        
-        predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-        class_predictions = {slot: 0 for slot in self.config.dst_slot_list}
-
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            start_logits = per_slot_start_logits[slot][0]
-            end_logits = per_slot_end_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            start_prediction = int(start_logits.argmax())
-            end_prediction = int(end_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if class_prediction == self.config.dst_class_types.index('dontcare'):
-                predictions[slot] = 'dontcare'
-            elif class_prediction == self.config.dst_class_types.index('copy_value'):
-                predictions[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
-                predictions[slot] = re.sub("(^| )##", "", predictions[slot])
-                if "\u0120" in predictions[slot]:
-                    predictions[slot] = re.sub(" ", "", predictions[slot])
-                    predictions[slot] = re.sub("\u0120", " ", predictions[slot])
-                    predictions[slot] = predictions[slot].strip()
-            elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'):
-                predictions[slot] = "yes" # 'true'
-            elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'):
-                predictions[slot] = "no" # 'false'
-            elif class_prediction == self.config.dst_class_types.index('inform'):
-                #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot])
-                predictions[slot] = inform_mem[slot]
-            # Referral case is handled below
-
-        # Referral case. All other slot values need to be seen first in order
-        # to be able to do this correctly.
-        # TODO: right now, resolution is only attempted within single turn. Consider previous state instead!
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'):
-                # Only slots that have been mentioned before can be referred to.
-                # One can think of a situation where one slot is referred to in the same utterance.
-                # This phenomenon is however currently not properly covered in the training data
-                # label generation process.
-                predictions[slot] = predictions[self.config.dst_slot_list[refer_prediction - 1]]
-
-            class_predictions[slot] = class_prediction
-            #if class_prediction > 0:
-            #    print("  ", slot, "->", class_prediction, ",", predictions[slot])
-
-        return predictions, class_predictions, cls_representation
-
-    def get_features(self, context, ds_aux=None, inform_aux=None):
-        assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models
-        input_tokens = ['<s>']
-        e_itr = 0
-        for e_itr, e in enumerate(reversed(context)):
-            #input_tokens.append(e[1].lower() if e[1] != 'null' else ' ') # TODO: normalise text
-            input_tokens.append(e[1] if e[1] != 'null' else ' ') # TODO: normalise text
-            if e_itr < 2:
-                input_tokens.append('</s> </s>')
-        if e_itr == 0:
-            input_tokens.append('</s> </s>')
-        input_tokens.append('</s>')
-        input_tokens = ' '.join(input_tokens)
-
-        # TODO: delex sys utt somehow, or refrain from using delex for sys utts?
-        features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length)
-
-        input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device)
-        attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device)
-        features = {'input_ids': input_ids,
-                    'attention_mask': attention_mask,
-                    'inform_slot_id': inform_aux,
-                    'diag_state': ds_aux}
-
-        return features
-
-    def update_ds_aux(self, state, pred_states, terminated=False):
-        ds_aux = copy.deepcopy(self.ds_aux) # TODO: deepcopy necessary? just update class variable?
-        for slot in self.config.dst_slot_list:
-            d, s = slot.split('-')
-            if d in state and s in state[d]:
-                ds_aux[slot][0] = int(state[d][SLOT_MAP_TRIPPY_TO_UDF[d].get(s, s)] != '')
-            else:
-                # Requestable slots are not found in the DS
-                ds_aux[slot][0] = int(pred_states[slot] != 'none')
-        return ds_aux
-
-    # TODO: consider "booked" values?
-    def get_inform_aux(self, state):
-        inform_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list}
-        inform_mem = {slot: 'none' for slot in self.config.dst_slot_list}
-        for e in state:
-            #print(e)
-            #pdb.set_trace()
-            a, d, s, v = e
-            if a in ['inform', 'recommend', 'select', 'book', 'offerbook']:
-                #ds_d = d.lower()
-                #if s in REF_SYS_DA[d]:
-                #    ds_s = REF_SYS_DA[d][s]
-                #elif s in REF_SYS_DA['Booking']:
-                #    ds_s = "book_" + REF_SYS_DA['Booking'][s]
-                #else:
-                #    ds_s = s.lower()
-                #    #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d))
-                slot = "%s-%s" % (d, s)
-                if slot in inform_aux:
-                    inform_aux[slot][0] = 1
-                    inform_mem[slot] = v
-        return inform_aux, inform_mem
-
-    # TODO: fix, still a mess...
-    def get_acts(self, user_act):
-        context = self.state['history']
-        if context:
-            if context[-1][0] != 'sys':
-                system_act = ''
-                context = [t for s,t in context]
-            else:
-                system_act = context[-1][-1]
-                context = [t for s,t in context[:-1]]
-        else:
-            system_act = ''
-            context = ['']
-
-        #print("  SYS:", system_act, context)
-        system_acts = [] # self.nlu.predict(system_act, context=context)
-        context = ['']
-
-        context.append(system_act)
-        #print("  USR:", user_act, context)
-        user_acts = self.nlu.predict(user_act, context=context)
-        
-        return user_acts, system_acts
-
-    def get_text(self, user_act, is_user=False):
-        if user_act == 'null':
-            return 'null'
-        if not isinstance(user_act, list):
-            return user_act
-        if is_user:
-            return self.nlg_usr.generate(user_act)
-        else:
-            return self.nlg_sys.generate(user_act)
-    
-    
-# if __name__ == "__main__":
-#     tracker = TRIPPY(model_type='roberta', model_path='/path/to/model',
-#                         nlu_path='/path/to/nlu')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     state = tracker.update('If you have something Asian that would be great.')
-#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     state = tracker.update('Great. Where are they located?')
-#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/dst/trippy/multiwoz/trippypublic b/convlab/dst/trippy/multiwoz/trippypublic
deleted file mode 160000
index a4e3f675b1f075a697e774833aec72f8148000b0..0000000000000000000000000000000000000000
--- a/convlab/dst/trippy/multiwoz/trippypublic
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a4e3f675b1f075a697e774833aec72f8148000b0
diff --git a/convlab/dst/trippyr/__init__.py b/convlab/dst/trippyr/__init__.py
deleted file mode 100755
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/convlab/dst/trippyr/multiwoz/__init__.py b/convlab/dst/trippyr/multiwoz/__init__.py
deleted file mode 100644
index b2b9444a46ae879cbd3670b067e3712a4e34c43b..0000000000000000000000000000000000000000
--- a/convlab/dst/trippyr/multiwoz/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from convlab.dst.trippyr.multiwoz.trippyr import TRIPPYR
diff --git a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py b/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py
deleted file mode 100644
index 36e1a1970b0a46ffd26cf5d2a4752f3178d0c748..0000000000000000000000000000000000000000
--- a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py
+++ /dev/null
@@ -1,825 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-from torch.nn import TripletMarginLoss
-from torch.nn import PairwiseDistance
-from torch.nn import MultiheadAttention
-import torch.nn.functional as F
-
-from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel)
-
-import time
-
-
-class RobertaForDST(RobertaPreTrainedModel):
-    def __init__(self, config):
-        super(RobertaForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.noncategorical = config.dst_noncategorical
-        self.categorical = config.dst_categorical
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.value_loss_for_nonpointable = config.dst_value_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.stack_token_logits = False # config.dst_stack_token_logits # TODO
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-        self.slot_attention_heads = config.dst_slot_attention_heads
-        self.dropout_rate = config.dst_dropout_rate
-        self.heads_dropout_rate = config.dst_heads_dropout_rate
-        
-        self.debug_fix_slot_embs = config.debug_fix_slot_embs
-        self.debug_joint_slot_gate = config.debug_joint_slot_gate
-        self.debug_joint_refer_gate = config.debug_joint_refer_gate
-        self.debug_simple_joint_slot_gate = config.debug_simple_joint_slot_gate
-        self.debug_separate_seq_tagging = config.debug_separate_seq_tagging
-        self.debug_sigmoid_sequence_tagging = config.debug_sigmoid_sequence_tagging
-        self.debug_att_output = config.debug_att_output
-        self.debug_tanh_for_att_output = config.debug_tanh_for_att_output
-        self.debug_stack_att = config.debug_stack_att
-        self.debug_stack_rep = config.debug_stack_rep
-        self.debug_sigmoid_slot_gates = config.debug_sigmoid_slot_gates
-        self.debug_att_slot_gates = config.debug_att_slot_gates
-        self.debug_use_triplet_loss = config.debug_use_triplet_loss
-        self.debug_use_tlf = config.debug_use_tlf
-        self.debug_value_att_none_class = config.debug_value_att_none_class
-        self.debug_tag_none_target = config.debug_tag_none_target
-        self.triplet_loss_weight = config.triplet_loss_weight
-        self.none_weight = config.none_weight
-        self.pretrain_loss_function = config.pretrain_loss_function
-        self.token_loss_function = config.token_loss_function
-        self.value_loss_function = config.value_loss_function
-        self.class_loss_function = config.class_loss_function
-        self.sequence_tagging_dropout = -1
-        if config.sequence_tagging_dropout > 0.0:
-            self.sequence_tagging_dropout = int(1 / config.sequence_tagging_dropout)
-        self.ignore_o_tags = config.ignore_o_tags
-        self.debug_slot_embs_per_step_nograd = config.debug_slot_embs_per_step_nograd
-        try:
-            self.debug_use_cross_attention = config.debug_use_cross_attention
-        except:
-            self.debug_use_cross_attention = False
-
-        self.val_rep_mode = config.val_rep_mode # TODO: for debugging at the moment.
- 
-        config.output_hidden_states = True # TODO
-
-        #config.dst_dropout_rate = 0.0 # TODO: for debugging
-        #config.dst_heads_dropout_rate = 0.0 # TODO: for debugging
-        
-        self.roberta = RobertaModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        # -- Dialogue state tracking functionality --
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        # Slot specific gates
-        for slot in self.slot_list:
-            if not self.debug_joint_slot_gate:
-                if self.debug_stack_att:
-                    if self.debug_sigmoid_slot_gates:
-                        for cl in range(self.class_labels):
-                            self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size * 2 + aux_dims, 1))
-                    elif self.debug_att_slot_gates:
-                        self.add_module("class_" + slot, MultiheadAttention(config.hidden_size * 2 + aux_dims, self.slot_attention_heads))
-                    else:
-                        self.add_module("class_" + slot, nn.Linear(config.hidden_size * 2 + aux_dims, self.class_labels))
-                else:
-                    if self.debug_sigmoid_slot_gates:
-                        for cl in range(self.class_labels):
-                            self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size + aux_dims, 1))
-                    elif self.debug_att_slot_gates:
-                        self.add_module("class_" + slot, MultiheadAttention(config.hidden_size + aux_dims, self.slot_attention_heads))
-                    else:
-                        self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            #self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1))
-            if not self.debug_joint_refer_gate:
-                self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-            if self.debug_separate_seq_tagging:
-                if self.debug_sigmoid_sequence_tagging:
-                    self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1))
-                else:
-                    self.add_module("token_" + slot, MultiheadAttention(config.hidden_size, self.slot_attention_heads))
-
-        if self.debug_att_output:
-            self.class_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        # Conditioned sequence tagging
-        if not self.debug_separate_seq_tagging:
-            if self.debug_sigmoid_sequence_tagging:
-                self.h1t = nn.Linear(config.hidden_size, config.hidden_size)
-                self.h2t = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
-                self.llt = nn.Linear(config.hidden_size * 2, 1)
-            else:
-                self.token_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        if self.debug_joint_refer_gate:
-            self.refer_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        self.value_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        #self.slot_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-
-        self.token_layer_norm = nn.LayerNorm(config.hidden_size)
-        self.token_layer_norm2 = nn.LayerNorm(config.hidden_size)
-        self.class_layer_norm = nn.LayerNorm(config.hidden_size)
-        self.class_layer_norm2 = nn.LayerNorm(config.hidden_size)
-
-        self.tlinear = nn.Linear(config.hidden_size, config.hidden_size)
-        self.clinear = nn.Linear(config.hidden_size, config.hidden_size)
-
-        # Conditioned slot gate
-        self.h1c = nn.Linear(config.hidden_size + aux_dims, config.hidden_size)
-        if self.debug_stack_att:
-            self.h0c = nn.Linear(config.hidden_size, config.hidden_size)
-            self.h2c = nn.Linear(config.hidden_size * 3, config.hidden_size * 3)
-            if self.debug_sigmoid_slot_gates:
-                for cl in range(self.class_labels):
-                    self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 3, 1))
-            elif self.debug_att_slot_gates:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads)
-                else:
-                    self.llc = MultiheadAttention(config.hidden_size * 3, self.slot_attention_heads)
-            else:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = nn.Linear(config.hidden_size * 2, self.class_labels)
-                else:
-                    self.llc = nn.Linear(config.hidden_size * 3, self.class_labels)
-        else:
-            if self.debug_att_slot_gates:
-                self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 1)
-            else:
-                self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
-            if self.debug_sigmoid_slot_gates:
-                for cl in range(self.class_labels):
-                    self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 2, 1))
-            elif self.debug_att_slot_gates:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = MultiheadAttention(config.hidden_size * 1, self.slot_attention_heads)
-                else:
-                    self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads)
-            else:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = nn.Linear(config.hidden_size * 1, self.class_labels)
-                else:
-                    self.llc = nn.Linear(config.hidden_size * 2, self.class_labels)
-
-        # Conditioned refer gate
-        self.h2r = nn.Linear(config.hidden_size * 2, config.hidden_size * 1)
-
-        # -- Spanless sequence tagging functionality --
-
-        self.dis = PairwiseDistance(p=2)
-        self.sigmoid = nn.Sigmoid()
-        self.tanh = nn.Tanh()
-        self.relu = nn.ReLU()
-        self.gelu = nn.GELU()
-        self.binary_cross_entropy = F.binary_cross_entropy
-        self.triplet_loss = TripletMarginLoss(margin=1.0, reduction="none")
-        self.mse = nn.MSELoss(reduction="none")
-        self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target
-        self.value_loss_fct = CrossEntropyLoss(reduction='none')
-
-        if self.none_weight != 1.0:
-            none_weight = self.none_weight
-            weight_mass = none_weight + (self.class_labels - 1)
-            none_weight /= weight_mass
-            other_weights = 1 / weight_mass
-            #self.clweights = torch.tensor([none_weight] + [other_weights] * (self.class_labels - 1)) # .to(outputs[0].device)
-            self.clweights = torch.tensor([other_weights] * self.class_labels) # .to(outputs[0].device)
-            self.clweights[self.class_types.index('none')] = none_weight
-            if self.debug_sigmoid_slot_gates:
-                self.class_loss_fct = F.binary_cross_entropy # (reduction="none") # weights for classes are controlled when adding losses
-            else:
-                self.class_loss_fct = CrossEntropyLoss(weight=self.clweights, reduction='none')
-        else:
-            if self.debug_sigmoid_slot_gates:
-                self.class_loss_fct = F.binary_cross_entropy # (reduction="none")
-            else:
-                self.class_loss_fct = CrossEntropyLoss(reduction='none')
-
-        self.init_weights()
-
-    def compute_triplet_loss_att(self, att_output, pos_sampling_input, neg_sampling_input, slot):
-        loss = self.triplet_loss(att_output, pos_sampling_input[slot].squeeze(), neg_sampling_input[slot].squeeze())
-        sample_mask = neg_sampling_input[slot].squeeze(1).sum(1) != 0
-        loss *= sample_mask
-        return loss
-
-    def forward(self, batch, step=None, epoch=None, t_slot=None, mode=None):
-        assert(mode in [None, "pretrain", "tag", "encode", "encode_vals", "represent"]) # TODO
-
-        input_ids = batch['input_ids']
-        input_mask = batch['input_mask']
-        #segment_ids = batch['segment_ids']
-        usr_mask = batch['usr_mask']
-        start_pos = batch['start_pos']
-        end_pos = batch['end_pos']
-        refer_id = batch['refer_id']
-        class_label_id = batch['class_label_id']
-        inform_slot_id = batch['inform_slot_id']
-        diag_state = batch['diag_state']
-        slot_ids = batch['slot_ids'] if 'slot_ids' in batch else None # TODO: fix?
-        slot_mask = batch['slot_mask'] if 'slot_mask' in batch else None # TODO: fix?
-        cl_ids = batch['cl_ids'] if 'cl_ids' in batch else None # TODO: fix?
-        cl_mask = batch['cl_mask'] if 'cl_mask' in batch else None # TODO: fix?
-        pos_sampling_input = batch['pos_sampling_input']
-        neg_sampling_input = batch['neg_sampling_input']
-        value_labels = batch['value_labels'] if 'value_labels' in batch else None # TODO: fix?
-
-        batch_input_mask = input_mask
-        if slot_ids is not None and slot_mask is not None:
-            if self.debug_slot_embs_per_step_nograd:
-                with torch.no_grad():
-                    outputs_slot = self.roberta(
-                        slot_ids,
-                        attention_mask=slot_mask,
-                        token_type_ids=None, # segment_ids,
-                        position_ids=None,
-                        head_mask=None
-                    )
-            else:
-                input_ids = torch.cat((input_ids, slot_ids))
-                input_mask = torch.cat((input_mask, slot_mask))
-        if cl_ids is not None and cl_mask is not None:
-            input_ids = torch.cat((input_ids, cl_ids))
-            input_mask = torch.cat((input_mask, cl_mask))
-
-        outputs = self.roberta(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=None, # segment_ids,
-            position_ids=None,
-            head_mask=None
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        if cl_ids is not None and cl_mask is not None:
-            sequence_output = sequence_output[:-1 * len(cl_ids), :, :]
-            encoded_classes = pooled_output[-1 * len(cl_ids):, :]
-            pooled_output = pooled_output[:-1 * len(cl_ids), :]
-        if slot_ids is not None and slot_mask is not None:
-            if self.debug_slot_embs_per_step_nograd:
-                encoded_slots_seq = outputs_slot[0]
-                encoded_slots_pooled = outputs_slot[1]
-            else:
-                encoded_slots_seq = sequence_output[-1 * len(slot_ids):, :, :]
-                sequence_output = sequence_output[:-1 * len(slot_ids), :, :]
-                encoded_slots_pooled = pooled_output[-1 * len(slot_ids):, :]
-                pooled_output = pooled_output[:-1 * len(slot_ids), :]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        #sequence_output11 = outputs[2][-2]
-
-        inverted_input_mask = ~(batch_input_mask.bool())
-        if usr_mask is None:
-            usr_mask = input_mask
-        inverted_usr_mask = ~(usr_mask.bool())
-
-        if mode == "encode": # Create vector representations only
-            return pooled_output, sequence_output, None
-
-        if mode == "encode_vals": # Create vector representations only
-            no_seq_w = 1 / batch_input_mask.sum(1)
-            uniform_weights = batch_input_mask * no_seq_w.unsqueeze(1)
-            pooled_output_vals = torch.matmul(uniform_weights, sequence_output).squeeze(1)
-            pooled_output_vals = self.token_layer_norm(pooled_output_vals)
-            return pooled_output_vals, None, None
- 
-        if mode == "pretrain":
-            pos_vectors = {}
-            pos_weights = {}
-            pos_vectors, pos_weights = self.token_att(
-                query=encoded_slots_pooled.squeeze(1).unsqueeze(0),
-                key=sequence_output.transpose(0, 1),
-                value=sequence_output.transpose(0, 1),
-                key_padding_mask=inverted_input_mask,
-                need_weights=True)
-            pos_vectors = pos_vectors.squeeze(0)
-            pos_vectors = self.token_layer_norm(pos_vectors)
-            pos_weights = pos_weights.squeeze(1)
-
-            #neg_vectors = {}
-            neg_weights = {}
-            #neg_vectors[slot], neg_weights[slot] = self.token_att(
-            #    query=neg_sampling_input[slot].squeeze(1).unsqueeze(0),
-            #    key=sequence_output.transpose(0, 1),
-            #    value=sequence_output.transpose(0, 1),
-            #    key_padding_mask=inverted_input_mask,
-            #    need_weights=True)
-            #neg_vectors[slot] = neg_vectors[slot].squeeze(0)
-            #neg_vectors[slot] = self.token_layer_norm(neg_vectors[slot])
-            #neg_weights[slot] = neg_weights[slot].squeeze(1)
-
-            pos_labels_clipped = torch.clamp(start_pos.float(), min=0, max=1)
-            pos_labels_clipped_scaled = pos_labels_clipped / torch.clamp(pos_labels_clipped.sum(1).unsqueeze(1), min=1) # scaled
-            #no_seq_w = 1 / batch_input_mask.sum(1)
-            #neg_labels_clipped = batch_input_mask * no_seq_w.unsqueeze(1)
-            if self.pretrain_loss_function == "mse":
-                pos_token_loss = self.mse(pos_weights, pos_labels_clipped_scaled) # TODO: MSE might be better for scaled targets
-            else:
-                pos_token_loss = self.binary_cross_entropy(pos_weights, pos_labels_clipped_scaled, reduction="none")
-            if self.ignore_o_tags:
-                pos_token_loss *= pos_labels_clipped # TODO: good idea?
-            #mm = torch.clamp(start_pos[slot].float(), min=0, max=1)
-            #mm = torch.clamp(mm * 10, min=1)
-            #pos_token_loss *= mm
-            pos_token_loss = pos_token_loss.sum(1)
-            #neg_token_loss = self.mse(neg_weights[slot], neg_labels_clipped)
-            #neg_token_loss = neg_token_loss.sum(1)
-
-            #triplet_loss = self.compute_triplet_loss_att(pos_vectors[slot], pos_sampling_input, neg_sampling_input, slot)
-
-            per_example_loss = pos_token_loss # + triplet_loss # + neg_token_loss
-            total_loss = per_example_loss.sum()
-            
-            return (total_loss, pos_weights, neg_weights,)
-
-        if mode == "tag":
-            query = torch.stack(list(batch['value_reps'].values())).transpose(1, 2).reshape(-1, pooled_output.size()[0], pooled_output.size()[1])
-            _, weights = self.token_att(query=query,
-                                        key=sequence_output.transpose(0, 1),
-                                        value=sequence_output.transpose(0, 1),
-                                        key_padding_mask=inverted_input_mask + inverted_usr_mask,
-                                        need_weights=True)
-            return (weights,)
-
-        aaa_time = time.time()
-        # Attention for sequence tagging
-        vectors = {}
-        weights = {}
-        for s_itr, slot in enumerate(self.slot_list):
-            if slot_ids is not None and slot_mask is not None:
-                encoded_slot_seq = encoded_slots_seq[s_itr]
-                encoded_slot_pooled = encoded_slots_pooled[s_itr]
-            else:
-                encoded_slot_seq = batch['encoded_slots_seq'][slot]
-                encoded_slot_pooled = batch['encoded_slots_pooled'][slot]
-            if self.debug_separate_seq_tagging:
-                if self.debug_sigmoid_sequence_tagging:
-                    weights[slot] = self.sigmoid(getattr(self, "token_" + slot)(sequence_output)).squeeze(2)
-                    vectors[slot] = torch.matmul(weights[slot], sequence_output)
-                    vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1)
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                else:
-                    if self.debug_use_cross_attention:
-                        query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1)
-                    else:
-                        query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0)
-                    vectors[slot], weights[slot] = getattr(self, "token_" + slot)(
-                        query=query,
-                        key=sequence_output.transpose(0, 1),
-                        value=sequence_output.transpose(0, 1),
-                        key_padding_mask=inverted_input_mask + inverted_usr_mask,
-                        need_weights=True) # TODO: use usr_mask better or worse?
-                    vectors[slot] = vectors[slot].squeeze(0)
-                    #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize?
-                    if self.debug_tanh_for_att_output:
-                        #vectors[slot] = self.tanh(vectors[slot])
-                        vectors[slot] = self.token_layer_norm2(vectors[slot])
-                        vectors[slot] = self.tanh(self.tlinear(vectors[slot]))
-                    #vectors[slot] += pooled_output # Residual
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                    #vectors[slot] = self.dropout_heads(vectors[slot])
-                    weights[slot] = weights[slot].squeeze(1)
-                    if self.debug_stack_rep:
-                        vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1)
-            else:
-                if self.debug_sigmoid_sequence_tagging:
-                    # Conditioned sequence tagging
-                    sequence_output_add = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(1).expand(sequence_output.size())
-                    xxt = self.gelu(self.h1t(sequence_output))
-                    yyt = self.gelu(self.h2t(torch.cat((sequence_output_add, xxt), 2)))
-                    weights[slot] = self.sigmoid(self.llt(yyt)).squeeze(2)
-                    vectors[slot] = torch.matmul(weights[slot], sequence_output)
-                    vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1)
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                else:
-                    if self.debug_use_cross_attention:
-                        query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1)
-                    else:
-                        query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0)
-                    vectors[slot], weights[slot] = self.token_att(
-                        query=query,
-                        key=sequence_output.transpose(0, 1),
-                        value=sequence_output.transpose(0, 1),
-                        key_padding_mask=inverted_input_mask + inverted_usr_mask,
-                        need_weights=True) # TODO: use usr_mask better or worse?
-                    vectors[slot] = vectors[slot].squeeze(0)
-                    if self.debug_use_cross_attention:
-                        vectors[slot] = torch.mean(vectors[slot] * (batch_input_mask + usr_mask).transpose(0, 1).unsqueeze(-1), dim=0)
-                    #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize?
-                    if self.debug_tanh_for_att_output:
-                        #vectors[slot] = self.tanh(vectors[slot])
-                        vectors[slot] = self.token_layer_norm2(vectors[slot])
-                        vectors[slot] = self.tanh(self.tlinear(vectors[slot]))
-                    #vectors[slot] += pooled_output # Residual
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                    #vectors[slot] = self.dropout_heads(vectors[slot])
-                    if self.debug_use_cross_attention:
-                        weights[slot] = torch.mean(weights[slot] * (batch_input_mask + usr_mask).unsqueeze(-1), dim=1)
-                    weights[slot] = weights[slot].squeeze(1)
-                    if self.debug_stack_rep:
-                        vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1)
-
-        if mode == "represent": # Create vector representations only
-            return vectors, None, weights
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-
-        bbb_time = time.time()
-        total_loss = 0
-        total_cl_loss = 0
-        total_tk_loss = 0
-        total_tp_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_per_example_cl_loss = {}
-        per_slot_per_example_tk_loss = {}
-        per_slot_per_example_tp_loss = {}
-        per_slot_att_weights = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_value_logits = {}
-        per_slot_refer_logits = {}
-        for s_itr, slot in enumerate(self.slot_list):
-            #if t_slot is not None and slot != t_slot:
-            #    continue
-            if slot_ids is not None and slot_mask is not None:
-                encoded_slot_seq = encoded_slots_seq[s_itr]
-                encoded_slot_pooled = encoded_slots_pooled[s_itr]
-            else:
-                encoded_slot_seq = batch['encoded_slots_seq'][slot]
-                encoded_slot_pooled = batch['encoded_slots_pooled'][slot]
-
-            # Attention for slot gates
-            if self.debug_att_output:
-                if self.debug_use_cross_attention:
-                    query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1)
-                else:
-                    query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0)
-                att_output, c_weights = self.class_att(
-                    query=query,
-                    key=sequence_output.transpose(0, 1),
-                    value=sequence_output.transpose(0, 1),
-                    key_padding_mask=inverted_input_mask,
-                    need_weights=True)
-                if self.debug_use_cross_attention:
-                    att_output = torch.mean(att_output, dim=0)
-                    c_weights = torch.mean(c_weights, dim=1)
-                if self.debug_tanh_for_att_output:
-                    #att_output = self.tanh(att_output)
-                    att_output = self.class_layer_norm2(att_output)
-                    att_output = self.tanh(self.clinear(att_output))
-                att_output = self.class_layer_norm(att_output)
-                att_output = self.dropout_heads(att_output)
-                per_slot_att_weights[slot] = c_weights.squeeze(1)
-            else:
-                per_slot_att_weights[slot] = None
-
-            # Conditioned slot gate, or separate slot gates
-            if self.debug_joint_slot_gate:
-                if self.debug_att_output:
-                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)))
-                    elif self.class_aux_feats_inform:
-                        xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels)), 1)))
-                    elif self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1)))
-                    else:
-                        xx = self.gelu(self.h1c(att_output.squeeze(0)))
-                    if self.debug_stack_att:
-                        x0 = self.gelu(self.h0c(pooled_output))
-                else:
-                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)))
-                    elif self.class_aux_feats_inform:
-                        xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)))
-                    elif self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)))
-                    else:
-                        xx = self.gelu(self.h1c(pooled_output))
-                if self.debug_att_output and not self.debug_simple_joint_slot_gate and self.debug_stack_att:
-                    yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), x0, xx), 1)))
-                elif self.debug_att_output and self.debug_simple_joint_slot_gate and self.debug_stack_att:
-                    yy = torch.cat((pooled_output, att_output.squeeze(0)), 1)
-                elif self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    yy = att_output.squeeze(0)
-                else:
-                    yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1)))
-                slot_gate_input = yy
-                slot_gate_layer = "llc"
-            else:
-                if self.debug_att_output:
-                    if self.debug_stack_att:
-                        if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                        elif self.class_aux_feats_inform:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels)), 1)
-                        elif self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1)
-                        else:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0)), 1)
-                    else:
-                        if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                        elif self.class_aux_feats_inform:
-                            slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels),), 1)
-                        elif self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1)
-                        else:
-                            slot_gate_input = att_output.squeeze(0)
-                else:
-                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                        slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                    elif self.class_aux_feats_inform:
-                        slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-                    elif self.class_aux_feats_ds:
-                        slot_gate_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-                    else:
-                        slot_gate_input = pooled_output
-                slot_gate_layer = "class_" + slot
-
-            # Conditioned refer gate, or separate refer gates
-            if self.debug_joint_slot_gate and self.debug_joint_refer_gate:
-                slot_refer_input = self.gelu(self.h2r(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1)))
-            else:
-                if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                    slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                elif self.class_aux_feats_inform:
-                    slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-                elif self.class_aux_feats_ds:
-                    slot_refer_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-                else:
-                    slot_refer_input = pooled_output
-
-            # Slot gate classification
-            if self.debug_sigmoid_slot_gates:
-                class_logits = []
-                for cl in range(self.class_labels):
-                    class_logits.append(self.sigmoid(getattr(self, slot_gate_layer + "_" + str(cl))(slot_gate_input)))
-                class_logits = torch.stack(class_logits, 1).squeeze(-1)
-            elif self.debug_att_slot_gates:
-                # TODO: implement separate gates as well
-                if cl_ids is not None and cl_mask is not None:
-                    bla = encoded_classes.unsqueeze(1).expand(-1, pooled_output.size()[0], -1)
-                else:
-                    bla = torch.stack(list(batch['encoded_classes'].values())).expand(-1, pooled_output.size()[0], -1) # TODO
-                #_, class_logits = self.slot_att(
-                _, class_logits = getattr(self, slot_gate_layer)(
-                    query=slot_gate_input.unsqueeze(0),
-                    key=bla,
-                    value=bla,
-                    need_weights=True)
-                class_logits = class_logits.squeeze(1)
-            else:
-                class_logits = getattr(self, slot_gate_layer)(slot_gate_input)
-
-            class_logits = self.dropout_heads(class_logits)
-
-            #token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output).squeeze(-1))
-            token_weights = weights[slot]
-
-            # ---
-
-            if self.triplet_loss_weight > 0.0 and not self.debug_use_triplet_loss:
-                slot_values = torch.stack(list(batch['encoded_slot_values'][slot].values())) # TODO: filter?
-                slot_values = slot_values.expand((-1, pooled_output.size(0), -1))
-                #yy = (batch['pos_sampling_input'][slot].sum(2) > 0.0)
-                #if self.debug_value_att_none_class:
-                #    ww = batch['value_labels'][slot][:, 1:] * yy
-                #else:
-                #    ww = batch['value_labels'][slot] * yy
-                #wwx = ww == 0
-                #bbb = slot_values * wwx.transpose(0, 1).unsqueeze(2)
-                #wwy = ww == 1
-                #yyy = batch['pos_sampling_input'][slot].expand(-1, slot_values.size(0), -1) * wwy.unsqueeze(2)
-                #slot_values = bbb + yyy.transpose(0, 1)
-                if self.debug_value_att_none_class:
-                    slot_values = torch.cat((vectors[slot].unsqueeze(0), slot_values))
-                #slot_values = slot_values.expand(-1, pooled_output.size()[0], -1)
-                #slot_values = torch.stack((batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter?
-                #slot_values = torch.stack((vectors[slot].unsqueeze(1), batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter?
-                #slot_values = slot_values.squeeze(2)
-                _, value_weights = self.value_att(
-                    query=vectors[slot].unsqueeze(0),
-                    key=slot_values,
-                    value=slot_values,
-                    need_weights=True)
-                ##vectors[slot] = vectors[slot].squeeze(0)
-                #vectors[slot] = torch.matmul(weights[slot], sequence_output).squeeze(1)
-                #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize?
-                #if self.debug_tanh_for_att_output:
-                    #vectors[slot] = self.tanh(vectors[slot])
-                    #vectors[slot] = self.token_layer_norm2(vectors[slot])
-                    #vectors[slot] = self.tanh(self.tlinear(vectors[slot]))
-                #vectors[slot] += pooled_output # Residual
-                #vectors[slot] = self.token_layer_norm(vectors[slot])
-                #vectors[slot] = self.dropout_heads(vectors[slot])
-                value_weights = value_weights.squeeze(1)
-            else:
-                value_weights = None
-
-            # ---
-
-            # TODO: implement choice between joint_refer_gate and individual refer gates, analogous to
-            # slot gates. Use same input as slot gate, i.e., for joint case, use yy, for individual
-            # use pooled. This is stored in slot_gate_input.
-            
-            # ---
-
-            if self.debug_joint_refer_gate:
-                if slot_ids is not None and slot_mask is not None:
-                    refer_slots = encoded_slots_pooled.unsqueeze(1).expand(-1, pooled_output.size()[0], -1)
-                else:
-                    #refer_slots = torch.stack((list(self.encoded_slots.values()))).expand(-1, pooled_output.size()[0], -1)
-                    refer_slots = torch.stack(list(batch['encoded_slots_pooled'].values())).expand(-1, pooled_output.size()[0], -1)
-                _, refer_weights = self.refer_att(
-                    query=slot_refer_input.unsqueeze(0),
-                    key=refer_slots,
-                    value=refer_slots,
-                    need_weights=True)
-                refer_weights = refer_weights.squeeze(1)
-                refer_logits = refer_weights
-            else:
-                refer_logits = getattr(self, "refer_" + slot)(slot_refer_input)
-
-            refer_logits = self.dropout_heads(refer_logits)
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = token_weights # TODO
-            per_slot_value_logits[slot] = value_weights
-            #per_slot_start_logits[slot] = self.sigmoid(token_logits) # TODO # this is for sigmoid approach
-            per_slot_refer_logits[slot] = refer_logits
-
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-
-                # TODO: solve this using the sequence_tagging def?
-                labels_clipped = torch.clamp(start_pos[slot].float(), min=0, max=1)
-                labels_clipped_scaled = labels_clipped / torch.clamp(labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets?
-                no_seq_mask = labels_clipped_scaled.sum(1) == 0
-                no_seq_w = 1 / batch_input_mask.sum(1)
-                labels_clipped_scaled += batch_input_mask * (no_seq_mask * no_seq_w).unsqueeze(1)
-                #token_weights = self.sigmoid(token_logits) # TODO # this is for sigmoid approach
-                if self.token_loss_function == "mse":
-                    token_loss = self.mse(token_weights, labels_clipped_scaled) # TODO: MSE might be better for scaled targets
-                else:
-                    token_loss = self.binary_cross_entropy(token_weights, labels_clipped_scaled, reduction="none")
-                if self.ignore_o_tags:
-                    token_loss *= labels_clipped # TODO: good idea?
-
-                # TODO: do negative examples have to be balanced due to their large number?
-                token_loss = token_loss.sum(1)
-                token_is_pointable = (start_pos[slot].sum(1) > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                value_loss = torch.zeros(token_is_pointable.size(), device=token_is_pointable.device)
-                if self.triplet_loss_weight > 0.0:
-                    if self.debug_use_triplet_loss:
-                        value_loss = self.compute_triplet_loss_att(vectors[slot], pos_sampling_input, neg_sampling_input, slot)
-                        #triplet_loss = torch.clamp(triplet_loss, max=1) # TODO: parameterize # Not the best idea I think...
-                    else:
-                        value_labels_clipped = torch.clamp(value_labels[slot].float(), min=0, max=1)
-                        value_labels_clipped /= torch.clamp(value_labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets?
-                        value_no_seq_mask = value_labels_clipped.sum(1) == 0
-                        value_no_seq_w = 1 / value_labels_clipped.size(1)
-                        value_labels_clipped += (value_no_seq_mask * value_no_seq_w).unsqueeze(1)
-                        if self.value_loss_function == "mse":
-                            value_loss = self.mse(value_weights, value_labels_clipped) # TODO: scale value_labels to also cover nonpointable cases and multitarget cases
-                        else:
-                            value_loss = self.binary_cross_entropy(value_weights, value_labels_clipped, reduction="none")
-                            #print(slot)
-                            #print(value_labels_clipped)
-                            #print(value_weights)
-                            #print(value_loss)
-                        value_loss = value_loss.sum(1)
-                    token_is_matchable = token_is_pointable
-                    if self.debug_tag_none_target:
-                        token_is_matchable *= (start_pos[slot][:, 1] == 0).float()
-                    if not self.value_loss_for_nonpointable:
-                        value_loss *= token_is_matchable
-
-                # Re-definition necessary to make slot-independent prediction possible
-                self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target
-                refer_loss = self.refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                if self.debug_sigmoid_slot_gates:
-                    class_loss = []
-                    for cl in range(self.class_labels):
-                        #class_loss.append(self.binary_cross_entropy(class_logits[:,cl], labels, reduction="none"))
-                        class_loss.append(self.class_loss_fct(class_logits[:, cl], (class_label_id[slot] == cl).float(), reduction="none"))
-                    class_loss = torch.stack(class_loss, 1)
-                    if self.none_weight != 1.0:
-                        class_loss *= self.clweights.to(outputs[0].device)
-                    class_loss = class_loss.sum(1)
-                elif self.debug_att_slot_gates:
-                    if self.class_loss_function == "mse":
-                        class_loss = self.mse(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float())
-                    else:
-                        class_loss = self.binary_cross_entropy(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float(), reduction="none")
-                    if self.none_weight != 1.0:
-                        class_loss *= self.clweights.to(outputs[0].device)
-                    class_loss = class_loss.sum(1)
-                else:
-                    class_loss = self.class_loss_fct(class_logits, class_label_id[slot])
-
-                #print("%15s, class loss: %.3f, token loss: %.3f, triplet loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (triplet_loss.sum()).item()))
-                #print("%15s, class loss: %.3f, token loss: %.3f, value loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (value_loss.sum()).item()))
-
-                st_switch = int(not (self.sequence_tagging_dropout >= 1 and step is not None and step % self.sequence_tagging_dropout == 0))
-
-                if self.refer_index > -1:
-                    #per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * triplet_loss
-                    per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * value_loss
-                    #per_example_loss = class_loss
-                else:
-                    #per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss
-                    per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * value_loss
-                    #if epoch is not None and epoch > 20:
-                    #    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss
-                    #else:
-                    #    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-                    #per_example_loss = class_loss
-                if self.debug_use_tlf:
-                    per_example_loss *= 1.0 + ((batch['diag_len'] - batch['turn_id']) * 0.05)
-
-                total_loss += per_example_loss.sum()
-                total_cl_loss += class_loss.sum()
-                total_tk_loss += token_loss.sum()
-                #total_tp_loss += triplet_loss.sum()
-                total_tp_loss += value_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-                per_slot_per_example_cl_loss[slot] = class_loss
-                per_slot_per_example_tk_loss[slot] = token_loss
-                #per_slot_per_example_tp_loss[slot] = triplet_loss
-                per_slot_per_example_tp_loss[slot] = value_loss
-        ccc_time = time.time()
-        #print("TIME:", bbb_time - aaa_time, ccc_time - bbb_time) # 0.028620243072509766
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,
-                   total_cl_loss,
-                   total_tk_loss,
-                   total_tp_loss,
-                   per_slot_per_example_loss,
-                   per_slot_per_example_cl_loss,
-                   per_slot_per_example_tk_loss,
-                   per_slot_per_example_tp_loss,
-                   per_slot_class_logits,
-                   per_slot_start_logits,
-                   per_slot_end_logits,
-                   per_slot_value_logits,
-                   per_slot_refer_logits,
-                   per_slot_att_weights,) + (vectors, weights,) + (pooled_output,) # + outputs[2:]
-        #outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, per_slot_att_weights,) + (vectors, weights,) # + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py~ b/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py~
deleted file mode 100644
index 2e4dec8bfea07922e25261bdb07798df559c5c4e..0000000000000000000000000000000000000000
--- a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py~
+++ /dev/null
@@ -1,854 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-from torch.nn import TripletMarginLoss
-from torch.nn import PairwiseDistance
-from torch.nn import MultiheadAttention
-import torch.nn.functional as F
-
-from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
-from transformers.modeling_utils import (PreTrainedModel)
-from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
-                                           ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm)
-
-import time
-
-
-class RobertaPreTrainedModel(PreTrainedModel):
-    """ An abstract class to handle weights initialization and
-        a simple interface for dowloading and loading pretrained models.
-    """
-    config_class = RobertaConfig
-    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
-    base_model_prefix = "roberta"
-    
-    def _init_weights(self, module):
-        """ Initialize the weights """
-        if isinstance(module, (nn.Linear, nn.Embedding)):
-            # Slightly different from the TF version which uses truncated_normal for initialization
-            # cf https://github.com/pytorch/pytorch/pull/5617
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-        elif isinstance(module, BertLayerNorm):
-            module.bias.data.zero_()
-            module.weight.data.fill_(1.0)
-        if isinstance(module, nn.Linear) and module.bias is not None:
-            module.bias.data.zero_()
-
-
-@add_start_docstrings(
-    """RoBERTa Model with classification heads for the DST task. """,
-    ROBERTA_START_DOCSTRING,
-)
-class RobertaForDST(RobertaPreTrainedModel):
-    def __init__(self, config):
-        super(RobertaForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.noncategorical = config.dst_noncategorical
-        self.categorical = config.dst_categorical
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.value_loss_for_nonpointable = config.dst_value_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.stack_token_logits = False # config.dst_stack_token_logits # TODO
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-        self.slot_attention_heads = config.dst_slot_attention_heads
-        self.dropout_rate = config.dst_dropout_rate
-        self.heads_dropout_rate = config.dst_heads_dropout_rate
-        
-        self.debug_fix_slot_embs = config.debug_fix_slot_embs
-        self.debug_joint_slot_gate = config.debug_joint_slot_gate
-        self.debug_joint_refer_gate = config.debug_joint_refer_gate
-        self.debug_simple_joint_slot_gate = config.debug_simple_joint_slot_gate
-        self.debug_separate_seq_tagging = config.debug_separate_seq_tagging
-        self.debug_sigmoid_sequence_tagging = config.debug_sigmoid_sequence_tagging
-        self.debug_att_output = config.debug_att_output
-        self.debug_tanh_for_att_output = config.debug_tanh_for_att_output
-        self.debug_stack_att = config.debug_stack_att
-        self.debug_stack_rep = config.debug_stack_rep
-        self.debug_sigmoid_slot_gates = config.debug_sigmoid_slot_gates
-        self.debug_att_slot_gates = config.debug_att_slot_gates
-        self.debug_use_triplet_loss = config.debug_use_triplet_loss
-        self.debug_use_tlf = config.debug_use_tlf
-        self.debug_value_att_none_class = config.debug_value_att_none_class
-        self.debug_tag_none_target = config.debug_tag_none_target
-        self.triplet_loss_weight = config.triplet_loss_weight
-        self.none_weight = config.none_weight
-        self.pretrain_loss_function = config.pretrain_loss_function
-        self.token_loss_function = config.token_loss_function
-        self.value_loss_function = config.value_loss_function
-        self.class_loss_function = config.class_loss_function
-        self.sequence_tagging_dropout = -1
-        if config.sequence_tagging_dropout > 0.0:
-            self.sequence_tagging_dropout = int(1 / config.sequence_tagging_dropout)
-        self.ignore_o_tags = config.ignore_o_tags
-        self.debug_slot_embs_per_step_nograd = config.debug_slot_embs_per_step_nograd
-        try:
-            self.debug_use_cross_attention = config.debug_use_cross_attention
-        except:
-            self.debug_use_cross_attention = False
-
-        self.val_rep_mode = config.val_rep_mode # TODO: for debugging at the moment.
- 
-        config.output_hidden_states = True # TODO
-
-        #config.dst_dropout_rate = 0.0 # TODO: for debugging
-        #config.dst_heads_dropout_rate = 0.0 # TODO: for debugging
-        
-        self.roberta = RobertaModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        # -- Dialogue state tracking functionality --
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        # Slot specific gates
-        for slot in self.slot_list:
-            if not self.debug_joint_slot_gate:
-                if self.debug_stack_att:
-                    if self.debug_sigmoid_slot_gates:
-                        for cl in range(self.class_labels):
-                            self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size * 2 + aux_dims, 1))
-                    elif self.debug_att_slot_gates:
-                        self.add_module("class_" + slot, MultiheadAttention(config.hidden_size * 2 + aux_dims, self.slot_attention_heads))
-                    else:
-                        self.add_module("class_" + slot, nn.Linear(config.hidden_size * 2 + aux_dims, self.class_labels))
-                else:
-                    if self.debug_sigmoid_slot_gates:
-                        for cl in range(self.class_labels):
-                            self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size + aux_dims, 1))
-                    elif self.debug_att_slot_gates:
-                        self.add_module("class_" + slot, MultiheadAttention(config.hidden_size + aux_dims, self.slot_attention_heads))
-                    else:
-                        self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            #self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1))
-            if not self.debug_joint_refer_gate:
-                self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-            if self.debug_separate_seq_tagging:
-                if self.debug_sigmoid_sequence_tagging:
-                    self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1))
-                else:
-                    self.add_module("token_" + slot, MultiheadAttention(config.hidden_size, self.slot_attention_heads))
-
-        if self.debug_att_output:
-            self.class_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        # Conditioned sequence tagging
-        if not self.debug_separate_seq_tagging:
-            if self.debug_sigmoid_sequence_tagging:
-                self.h1t = nn.Linear(config.hidden_size, config.hidden_size)
-                self.h2t = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
-                self.llt = nn.Linear(config.hidden_size * 2, 1)
-            else:
-                self.token_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        if self.debug_joint_refer_gate:
-            self.refer_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        self.value_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-        #self.slot_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads)
-
-        self.token_layer_norm = nn.LayerNorm(config.hidden_size)
-        self.token_layer_norm2 = nn.LayerNorm(config.hidden_size)
-        self.class_layer_norm = nn.LayerNorm(config.hidden_size)
-        self.class_layer_norm2 = nn.LayerNorm(config.hidden_size)
-
-        self.tlinear = nn.Linear(config.hidden_size, config.hidden_size)
-        self.clinear = nn.Linear(config.hidden_size, config.hidden_size)
-
-        # Conditioned slot gate
-        self.h1c = nn.Linear(config.hidden_size + aux_dims, config.hidden_size)
-        if self.debug_stack_att:
-            self.h0c = nn.Linear(config.hidden_size, config.hidden_size)
-            self.h2c = nn.Linear(config.hidden_size * 3, config.hidden_size * 3)
-            if self.debug_sigmoid_slot_gates:
-                for cl in range(self.class_labels):
-                    self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 3, 1))
-            elif self.debug_att_slot_gates:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads)
-                else:
-                    self.llc = MultiheadAttention(config.hidden_size * 3, self.slot_attention_heads)
-            else:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = nn.Linear(config.hidden_size * 2, self.class_labels)
-                else:
-                    self.llc = nn.Linear(config.hidden_size * 3, self.class_labels)
-        else:
-            if self.debug_att_slot_gates:
-                self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 1)
-            else:
-                self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
-            if self.debug_sigmoid_slot_gates:
-                for cl in range(self.class_labels):
-                    self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 2, 1))
-            elif self.debug_att_slot_gates:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = MultiheadAttention(config.hidden_size * 1, self.slot_attention_heads)
-                else:
-                    self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads)
-            else:
-                if self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    self.llc = nn.Linear(config.hidden_size * 1, self.class_labels)
-                else:
-                    self.llc = nn.Linear(config.hidden_size * 2, self.class_labels)
-
-        # Conditioned refer gate
-        self.h2r = nn.Linear(config.hidden_size * 2, config.hidden_size * 1)
-
-        # -- Spanless sequence tagging functionality --
-
-        self.dis = PairwiseDistance(p=2)
-        self.sigmoid = nn.Sigmoid()
-        self.tanh = nn.Tanh()
-        self.relu = nn.ReLU()
-        self.gelu = nn.GELU()
-        self.binary_cross_entropy = F.binary_cross_entropy
-        self.triplet_loss = TripletMarginLoss(margin=1.0, reduction="none")
-        self.mse = nn.MSELoss(reduction="none")
-        self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target
-        self.value_loss_fct = CrossEntropyLoss(reduction='none')
-
-        if self.none_weight != 1.0:
-            none_weight = self.none_weight
-            weight_mass = none_weight + (self.class_labels - 1)
-            none_weight /= weight_mass
-            other_weights = 1 / weight_mass
-            #self.clweights = torch.tensor([none_weight] + [other_weights] * (self.class_labels - 1)) # .to(outputs[0].device)
-            self.clweights = torch.tensor([other_weights] * self.class_labels) # .to(outputs[0].device)
-            self.clweights[self.class_types.index('none')] = none_weight
-            if self.debug_sigmoid_slot_gates:
-                self.class_loss_fct = F.binary_cross_entropy # (reduction="none") # weights for classes are controlled when adding losses
-            else:
-                self.class_loss_fct = CrossEntropyLoss(weight=self.clweights, reduction='none')
-        else:
-            if self.debug_sigmoid_slot_gates:
-                self.class_loss_fct = F.binary_cross_entropy # (reduction="none")
-            else:
-                self.class_loss_fct = CrossEntropyLoss(reduction='none')
-
-        self.init_weights()
-
-    def compute_triplet_loss_att(self, att_output, pos_sampling_input, neg_sampling_input, slot):
-        loss = self.triplet_loss(att_output, pos_sampling_input[slot].squeeze(), neg_sampling_input[slot].squeeze())
-        sample_mask = neg_sampling_input[slot].squeeze(1).sum(1) != 0
-        loss *= sample_mask
-        return loss
-
-    @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
-    def forward(self, batch, step=None, epoch=None, t_slot=None, mode=None):
-        assert(mode in [None, "pretrain", "tag", "encode", "encode_vals", "represent"]) # TODO
-
-        input_ids = batch['input_ids']
-        input_mask = batch['input_mask']
-        #segment_ids = batch['segment_ids']
-        usr_mask = batch['usr_mask']
-        start_pos = batch['start_pos']
-        end_pos = batch['end_pos']
-        refer_id = batch['refer_id']
-        class_label_id = batch['class_label_id']
-        inform_slot_id = batch['inform_slot_id']
-        diag_state = batch['diag_state']
-        slot_ids = batch['slot_ids'] if 'slot_ids' in batch else None # TODO: fix?
-        slot_mask = batch['slot_mask'] if 'slot_mask' in batch else None # TODO: fix?
-        cl_ids = batch['cl_ids'] if 'cl_ids' in batch else None # TODO: fix?
-        cl_mask = batch['cl_mask'] if 'cl_mask' in batch else None # TODO: fix?
-        pos_sampling_input = batch['pos_sampling_input']
-        neg_sampling_input = batch['neg_sampling_input']
-        value_labels = batch['value_labels'] if 'value_labels' in batch else None # TODO: fix?
-
-        batch_input_mask = input_mask
-        if slot_ids is not None and slot_mask is not None:
-            if self.debug_slot_embs_per_step_nograd:
-                with torch.no_grad():
-                    outputs_slot = self.roberta(
-                        slot_ids,
-                        attention_mask=slot_mask,
-                        token_type_ids=None, # segment_ids,
-                        position_ids=None,
-                        head_mask=None
-                    )
-            else:
-                input_ids = torch.cat((input_ids, slot_ids))
-                input_mask = torch.cat((input_mask, slot_mask))
-        if cl_ids is not None and cl_mask is not None:
-            input_ids = torch.cat((input_ids, cl_ids))
-            input_mask = torch.cat((input_mask, cl_mask))
-
-        outputs = self.roberta(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=None, # segment_ids,
-            position_ids=None,
-            head_mask=None
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        if cl_ids is not None and cl_mask is not None:
-            sequence_output = sequence_output[:-1 * len(cl_ids), :, :]
-            encoded_classes = pooled_output[-1 * len(cl_ids):, :]
-            pooled_output = pooled_output[:-1 * len(cl_ids), :]
-        if slot_ids is not None and slot_mask is not None:
-            if self.debug_slot_embs_per_step_nograd:
-                encoded_slots_seq = outputs_slot[0]
-                encoded_slots_pooled = outputs_slot[1]
-            else:
-                encoded_slots_seq = sequence_output[-1 * len(slot_ids):, :, :]
-                sequence_output = sequence_output[:-1 * len(slot_ids), :, :]
-                encoded_slots_pooled = pooled_output[-1 * len(slot_ids):, :]
-                pooled_output = pooled_output[:-1 * len(slot_ids), :]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        #sequence_output11 = outputs[2][-2]
-
-        inverted_input_mask = ~(batch_input_mask.bool())
-        if usr_mask is None:
-            usr_mask = input_mask
-        inverted_usr_mask = ~(usr_mask.bool())
-
-        if mode == "encode": # Create vector representations only
-            return pooled_output, sequence_output, None
-
-        if mode == "encode_vals": # Create vector representations only
-            no_seq_w = 1 / batch_input_mask.sum(1)
-            uniform_weights = batch_input_mask * no_seq_w.unsqueeze(1)
-            pooled_output_vals = torch.matmul(uniform_weights, sequence_output).squeeze(1)
-            pooled_output_vals = self.token_layer_norm(pooled_output_vals)
-            return pooled_output_vals, None, None
- 
-        if mode == "pretrain":
-            pos_vectors = {}
-            pos_weights = {}
-            pos_vectors, pos_weights = self.token_att(
-                query=encoded_slots_pooled.squeeze(1).unsqueeze(0),
-                key=sequence_output.transpose(0, 1),
-                value=sequence_output.transpose(0, 1),
-                key_padding_mask=inverted_input_mask,
-                need_weights=True)
-            pos_vectors = pos_vectors.squeeze(0)
-            pos_vectors = self.token_layer_norm(pos_vectors)
-            pos_weights = pos_weights.squeeze(1)
-
-            #neg_vectors = {}
-            neg_weights = {}
-            #neg_vectors[slot], neg_weights[slot] = self.token_att(
-            #    query=neg_sampling_input[slot].squeeze(1).unsqueeze(0),
-            #    key=sequence_output.transpose(0, 1),
-            #    value=sequence_output.transpose(0, 1),
-            #    key_padding_mask=inverted_input_mask,
-            #    need_weights=True)
-            #neg_vectors[slot] = neg_vectors[slot].squeeze(0)
-            #neg_vectors[slot] = self.token_layer_norm(neg_vectors[slot])
-            #neg_weights[slot] = neg_weights[slot].squeeze(1)
-
-            pos_labels_clipped = torch.clamp(start_pos.float(), min=0, max=1)
-            pos_labels_clipped_scaled = pos_labels_clipped / torch.clamp(pos_labels_clipped.sum(1).unsqueeze(1), min=1) # scaled
-            #no_seq_w = 1 / batch_input_mask.sum(1)
-            #neg_labels_clipped = batch_input_mask * no_seq_w.unsqueeze(1)
-            if self.pretrain_loss_function == "mse":
-                pos_token_loss = self.mse(pos_weights, pos_labels_clipped_scaled) # TODO: MSE might be better for scaled targets
-            else:
-                pos_token_loss = self.binary_cross_entropy(pos_weights, pos_labels_clipped_scaled, reduction="none")
-            if self.ignore_o_tags:
-                pos_token_loss *= pos_labels_clipped # TODO: good idea?
-            #mm = torch.clamp(start_pos[slot].float(), min=0, max=1)
-            #mm = torch.clamp(mm * 10, min=1)
-            #pos_token_loss *= mm
-            pos_token_loss = pos_token_loss.sum(1)
-            #neg_token_loss = self.mse(neg_weights[slot], neg_labels_clipped)
-            #neg_token_loss = neg_token_loss.sum(1)
-
-            #triplet_loss = self.compute_triplet_loss_att(pos_vectors[slot], pos_sampling_input, neg_sampling_input, slot)
-
-            per_example_loss = pos_token_loss # + triplet_loss # + neg_token_loss
-            total_loss = per_example_loss.sum()
-            
-            return (total_loss, pos_weights, neg_weights,)
-
-        if mode == "tag":
-            query = torch.stack(list(batch['value_reps'].values())).transpose(1, 2).reshape(-1, pooled_output.size()[0], pooled_output.size()[1])
-            _, weights = self.token_att(query=query,
-                                        key=sequence_output.transpose(0, 1),
-                                        value=sequence_output.transpose(0, 1),
-                                        key_padding_mask=inverted_input_mask + inverted_usr_mask,
-                                        need_weights=True)
-            return (weights,)
-
-        aaa_time = time.time()
-        # Attention for sequence tagging
-        vectors = {}
-        weights = {}
-        for s_itr, slot in enumerate(self.slot_list):
-            if slot_ids is not None and slot_mask is not None:
-                encoded_slot_seq = encoded_slots_seq[s_itr]
-                encoded_slot_pooled = encoded_slots_pooled[s_itr]
-            else:
-                encoded_slot_seq = batch['encoded_slots_seq'][slot]
-                encoded_slot_pooled = batch['encoded_slots_pooled'][slot]
-            if self.debug_separate_seq_tagging:
-                if self.debug_sigmoid_sequence_tagging:
-                    weights[slot] = self.sigmoid(getattr(self, "token_" + slot)(sequence_output)).squeeze(2)
-                    vectors[slot] = torch.matmul(weights[slot], sequence_output)
-                    vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1)
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                else:
-                    if self.debug_use_cross_attention:
-                        query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1)
-                    else:
-                        query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0)
-                    vectors[slot], weights[slot] = getattr(self, "token_" + slot)(
-                        query=query,
-                        key=sequence_output.transpose(0, 1),
-                        value=sequence_output.transpose(0, 1),
-                        key_padding_mask=inverted_input_mask + inverted_usr_mask,
-                        need_weights=True) # TODO: use usr_mask better or worse?
-                    vectors[slot] = vectors[slot].squeeze(0)
-                    #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize?
-                    if self.debug_tanh_for_att_output:
-                        #vectors[slot] = self.tanh(vectors[slot])
-                        vectors[slot] = self.token_layer_norm2(vectors[slot])
-                        vectors[slot] = self.tanh(self.tlinear(vectors[slot]))
-                    #vectors[slot] += pooled_output # Residual
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                    #vectors[slot] = self.dropout_heads(vectors[slot])
-                    weights[slot] = weights[slot].squeeze(1)
-                    if self.debug_stack_rep:
-                        vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1)
-            else:
-                if self.debug_sigmoid_sequence_tagging:
-                    # Conditioned sequence tagging
-                    sequence_output_add = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(1).expand(sequence_output.size())
-                    xxt = self.gelu(self.h1t(sequence_output))
-                    yyt = self.gelu(self.h2t(torch.cat((sequence_output_add, xxt), 2)))
-                    weights[slot] = self.sigmoid(self.llt(yyt)).squeeze(2)
-                    vectors[slot] = torch.matmul(weights[slot], sequence_output)
-                    vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1)
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                else:
-                    if self.debug_use_cross_attention:
-                        query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1)
-                    else:
-                        query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0)
-                    vectors[slot], weights[slot] = self.token_att(
-                        query=query,
-                        key=sequence_output.transpose(0, 1),
-                        value=sequence_output.transpose(0, 1),
-                        key_padding_mask=inverted_input_mask + inverted_usr_mask,
-                        need_weights=True) # TODO: use usr_mask better or worse?
-                    vectors[slot] = vectors[slot].squeeze(0)
-                    if self.debug_use_cross_attention:
-                        vectors[slot] = torch.mean(vectors[slot] * (batch_input_mask + usr_mask).transpose(0, 1).unsqueeze(-1), dim=0)
-                    #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize?
-                    if self.debug_tanh_for_att_output:
-                        #vectors[slot] = self.tanh(vectors[slot])
-                        vectors[slot] = self.token_layer_norm2(vectors[slot])
-                        vectors[slot] = self.tanh(self.tlinear(vectors[slot]))
-                    #vectors[slot] += pooled_output # Residual
-                    vectors[slot] = self.token_layer_norm(vectors[slot])
-                    #vectors[slot] = self.dropout_heads(vectors[slot])
-                    if self.debug_use_cross_attention:
-                        weights[slot] = torch.mean(weights[slot] * (batch_input_mask + usr_mask).unsqueeze(-1), dim=1)
-                    weights[slot] = weights[slot].squeeze(1)
-                    if self.debug_stack_rep:
-                        vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1)
-
-        if mode == "represent": # Create vector representations only
-            return vectors, None, weights
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-
-        bbb_time = time.time()
-        total_loss = 0
-        total_cl_loss = 0
-        total_tk_loss = 0
-        total_tp_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_per_example_cl_loss = {}
-        per_slot_per_example_tk_loss = {}
-        per_slot_per_example_tp_loss = {}
-        per_slot_att_weights = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_value_logits = {}
-        per_slot_refer_logits = {}
-        for s_itr, slot in enumerate(self.slot_list):
-            #if t_slot is not None and slot != t_slot:
-            #    continue
-            if slot_ids is not None and slot_mask is not None:
-                encoded_slot_seq = encoded_slots_seq[s_itr]
-                encoded_slot_pooled = encoded_slots_pooled[s_itr]
-            else:
-                encoded_slot_seq = batch['encoded_slots_seq'][slot]
-                encoded_slot_pooled = batch['encoded_slots_pooled'][slot]
-
-            # Attention for slot gates
-            if self.debug_att_output:
-                if self.debug_use_cross_attention:
-                    query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1)
-                else:
-                    query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0)
-                att_output, c_weights = self.class_att(
-                    query=query,
-                    key=sequence_output.transpose(0, 1),
-                    value=sequence_output.transpose(0, 1),
-                    key_padding_mask=inverted_input_mask,
-                    need_weights=True)
-                if self.debug_use_cross_attention:
-                    att_output = torch.mean(att_output, dim=0)
-                    c_weights = torch.mean(c_weights, dim=1)
-                if self.debug_tanh_for_att_output:
-                    #att_output = self.tanh(att_output)
-                    att_output = self.class_layer_norm2(att_output)
-                    att_output = self.tanh(self.clinear(att_output))
-                att_output = self.class_layer_norm(att_output)
-                att_output = self.dropout_heads(att_output)
-                per_slot_att_weights[slot] = c_weights.squeeze(1)
-            else:
-                per_slot_att_weights[slot] = None
-
-            # Conditioned slot gate, or separate slot gates
-            if self.debug_joint_slot_gate:
-                if self.debug_att_output:
-                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)))
-                    elif self.class_aux_feats_inform:
-                        xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels)), 1)))
-                    elif self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1)))
-                    else:
-                        xx = self.gelu(self.h1c(att_output.squeeze(0)))
-                    if self.debug_stack_att:
-                        x0 = self.gelu(self.h0c(pooled_output))
-                else:
-                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)))
-                    elif self.class_aux_feats_inform:
-                        xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)))
-                    elif self.class_aux_feats_ds:
-                        xx = self.gelu(self.h1c(torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)))
-                    else:
-                        xx = self.gelu(self.h1c(pooled_output))
-                if self.debug_att_output and not self.debug_simple_joint_slot_gate and self.debug_stack_att:
-                    yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), x0, xx), 1)))
-                elif self.debug_att_output and self.debug_simple_joint_slot_gate and self.debug_stack_att:
-                    yy = torch.cat((pooled_output, att_output.squeeze(0)), 1)
-                elif self.debug_att_output and self.debug_simple_joint_slot_gate:
-                    yy = att_output.squeeze(0)
-                else:
-                    yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1)))
-                slot_gate_input = yy
-                slot_gate_layer = "llc"
-            else:
-                if self.debug_att_output:
-                    if self.debug_stack_att:
-                        if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                        elif self.class_aux_feats_inform:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels)), 1)
-                        elif self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1)
-                        else:
-                            slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0)), 1)
-                    else:
-                        if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                        elif self.class_aux_feats_inform:
-                            slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels),), 1)
-                        elif self.class_aux_feats_ds:
-                            slot_gate_input = torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1)
-                        else:
-                            slot_gate_input = att_output.squeeze(0)
-                else:
-                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                        slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                    elif self.class_aux_feats_inform:
-                        slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-                    elif self.class_aux_feats_ds:
-                        slot_gate_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-                    else:
-                        slot_gate_input = pooled_output
-                slot_gate_layer = "class_" + slot
-
-            # Conditioned refer gate, or separate refer gates
-            if self.debug_joint_slot_gate and self.debug_joint_refer_gate:
-                slot_refer_input = self.gelu(self.h2r(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1)))
-            else:
-                if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                    slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-                elif self.class_aux_feats_inform:
-                    slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-                elif self.class_aux_feats_ds:
-                    slot_refer_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-                else:
-                    slot_refer_input = pooled_output
-
-            # Slot gate classification
-            if self.debug_sigmoid_slot_gates:
-                class_logits = []
-                for cl in range(self.class_labels):
-                    class_logits.append(self.sigmoid(getattr(self, slot_gate_layer + "_" + str(cl))(slot_gate_input)))
-                class_logits = torch.stack(class_logits, 1).squeeze(-1)
-            elif self.debug_att_slot_gates:
-                # TODO: implement separate gates as well
-                if cl_ids is not None and cl_mask is not None:
-                    bla = encoded_classes.unsqueeze(1).expand(-1, pooled_output.size()[0], -1)
-                else:
-                    bla = torch.stack(list(batch['encoded_classes'].values())).expand(-1, pooled_output.size()[0], -1) # TODO
-                #_, class_logits = self.slot_att(
-                _, class_logits = getattr(self, slot_gate_layer)(
-                    query=slot_gate_input.unsqueeze(0),
-                    key=bla,
-                    value=bla,
-                    need_weights=True)
-                class_logits = class_logits.squeeze(1)
-            else:
-                class_logits = getattr(self, slot_gate_layer)(slot_gate_input)
-
-            class_logits = self.dropout_heads(class_logits)
-
-            #token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output).squeeze(-1))
-            token_weights = weights[slot]
-
-            # ---
-
-            if self.triplet_loss_weight > 0.0 and not self.debug_use_triplet_loss:
-                slot_values = torch.stack(list(batch['encoded_slot_values'][slot].values())) # TODO: filter?
-                slot_values = slot_values.expand((-1, pooled_output.size(0), -1))
-                #yy = (batch['pos_sampling_input'][slot].sum(2) > 0.0)
-                #if self.debug_value_att_none_class:
-                #    ww = batch['value_labels'][slot][:, 1:] * yy
-                #else:
-                #    ww = batch['value_labels'][slot] * yy
-                #wwx = ww == 0
-                #bbb = slot_values * wwx.transpose(0, 1).unsqueeze(2)
-                #wwy = ww == 1
-                #yyy = batch['pos_sampling_input'][slot].expand(-1, slot_values.size(0), -1) * wwy.unsqueeze(2)
-                #slot_values = bbb + yyy.transpose(0, 1)
-                if self.debug_value_att_none_class:
-                    slot_values = torch.cat((vectors[slot].unsqueeze(0), slot_values))
-                #slot_values = slot_values.expand(-1, pooled_output.size()[0], -1)
-                #slot_values = torch.stack((batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter?
-                #slot_values = torch.stack((vectors[slot].unsqueeze(1), batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter?
-                #slot_values = slot_values.squeeze(2)
-                _, value_weights = self.value_att(
-                    query=vectors[slot].unsqueeze(0),
-                    key=slot_values,
-                    value=slot_values,
-                    need_weights=True)
-                ##vectors[slot] = vectors[slot].squeeze(0)
-                #vectors[slot] = torch.matmul(weights[slot], sequence_output).squeeze(1)
-                #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize?
-                #if self.debug_tanh_for_att_output:
-                    #vectors[slot] = self.tanh(vectors[slot])
-                    #vectors[slot] = self.token_layer_norm2(vectors[slot])
-                    #vectors[slot] = self.tanh(self.tlinear(vectors[slot]))
-                #vectors[slot] += pooled_output # Residual
-                #vectors[slot] = self.token_layer_norm(vectors[slot])
-                #vectors[slot] = self.dropout_heads(vectors[slot])
-                value_weights = value_weights.squeeze(1)
-            else:
-                value_weights = None
-
-            # ---
-
-            # TODO: implement choice between joint_refer_gate and individual refer gates, analogous to
-            # slot gates. Use same input as slot gate, i.e., for joint case, use yy, for individual
-            # use pooled. This is stored in slot_gate_input.
-            
-            # ---
-
-            if self.debug_joint_refer_gate:
-                if slot_ids is not None and slot_mask is not None:
-                    refer_slots = encoded_slots_pooled.unsqueeze(1).expand(-1, pooled_output.size()[0], -1)
-                else:
-                    #refer_slots = torch.stack((list(self.encoded_slots.values()))).expand(-1, pooled_output.size()[0], -1)
-                    refer_slots = torch.stack(list(batch['encoded_slots_pooled'].values())).expand(-1, pooled_output.size()[0], -1)
-                _, refer_weights = self.refer_att(
-                    query=slot_refer_input.unsqueeze(0),
-                    key=refer_slots,
-                    value=refer_slots,
-                    need_weights=True)
-                refer_weights = refer_weights.squeeze(1)
-                refer_logits = refer_weights
-            else:
-                refer_logits = getattr(self, "refer_" + slot)(slot_refer_input)
-
-            refer_logits = self.dropout_heads(refer_logits)
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = token_weights # TODO
-            per_slot_value_logits[slot] = value_weights
-            #per_slot_start_logits[slot] = self.sigmoid(token_logits) # TODO # this is for sigmoid approach
-            per_slot_refer_logits[slot] = refer_logits
-
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-
-                # TODO: solve this using the sequence_tagging def?
-                labels_clipped = torch.clamp(start_pos[slot].float(), min=0, max=1)
-                labels_clipped_scaled = labels_clipped / torch.clamp(labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets?
-                no_seq_mask = labels_clipped_scaled.sum(1) == 0
-                no_seq_w = 1 / batch_input_mask.sum(1)
-                labels_clipped_scaled += batch_input_mask * (no_seq_mask * no_seq_w).unsqueeze(1)
-                #token_weights = self.sigmoid(token_logits) # TODO # this is for sigmoid approach
-                if self.token_loss_function == "mse":
-                    token_loss = self.mse(token_weights, labels_clipped_scaled) # TODO: MSE might be better for scaled targets
-                else:
-                    token_loss = self.binary_cross_entropy(token_weights, labels_clipped_scaled, reduction="none")
-                if self.ignore_o_tags:
-                    token_loss *= labels_clipped # TODO: good idea?
-
-                # TODO: do negative examples have to be balanced due to their large number?
-                token_loss = token_loss.sum(1)
-                token_is_pointable = (start_pos[slot].sum(1) > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                value_loss = torch.zeros(token_is_pointable.size(), device=token_is_pointable.device)
-                if self.triplet_loss_weight > 0.0:
-                    if self.debug_use_triplet_loss:
-                        value_loss = self.compute_triplet_loss_att(vectors[slot], pos_sampling_input, neg_sampling_input, slot)
-                        #triplet_loss = torch.clamp(triplet_loss, max=1) # TODO: parameterize # Not the best idea I think...
-                    else:
-                        value_labels_clipped = torch.clamp(value_labels[slot].float(), min=0, max=1)
-                        value_labels_clipped /= torch.clamp(value_labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets?
-                        value_no_seq_mask = value_labels_clipped.sum(1) == 0
-                        value_no_seq_w = 1 / value_labels_clipped.size(1)
-                        value_labels_clipped += (value_no_seq_mask * value_no_seq_w).unsqueeze(1)
-                        if self.value_loss_function == "mse":
-                            value_loss = self.mse(value_weights, value_labels_clipped) # TODO: scale value_labels to also cover nonpointable cases and multitarget cases
-                        else:
-                            value_loss = self.binary_cross_entropy(value_weights, value_labels_clipped, reduction="none")
-                            #print(slot)
-                            #print(value_labels_clipped)
-                            #print(value_weights)
-                            #print(value_loss)
-                        value_loss = value_loss.sum(1)
-                    token_is_matchable = token_is_pointable
-                    if self.debug_tag_none_target:
-                        token_is_matchable *= (start_pos[slot][:, 1] == 0).float()
-                    if not self.value_loss_for_nonpointable:
-                        value_loss *= token_is_matchable
-
-                # Re-definition necessary to make slot-independent prediction possible
-                self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target
-                refer_loss = self.refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                if self.debug_sigmoid_slot_gates:
-                    class_loss = []
-                    for cl in range(self.class_labels):
-                        #class_loss.append(self.binary_cross_entropy(class_logits[:,cl], labels, reduction="none"))
-                        class_loss.append(self.class_loss_fct(class_logits[:, cl], (class_label_id[slot] == cl).float(), reduction="none"))
-                    class_loss = torch.stack(class_loss, 1)
-                    if self.none_weight != 1.0:
-                        class_loss *= self.clweights.to(outputs[0].device)
-                    class_loss = class_loss.sum(1)
-                elif self.debug_att_slot_gates:
-                    if self.class_loss_function == "mse":
-                        class_loss = self.mse(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float())
-                    else:
-                        class_loss = self.binary_cross_entropy(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float(), reduction="none")
-                    if self.none_weight != 1.0:
-                        class_loss *= self.clweights.to(outputs[0].device)
-                    class_loss = class_loss.sum(1)
-                else:
-                    class_loss = self.class_loss_fct(class_logits, class_label_id[slot])
-
-                #print("%15s, class loss: %.3f, token loss: %.3f, triplet loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (triplet_loss.sum()).item()))
-                #print("%15s, class loss: %.3f, token loss: %.3f, value loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (value_loss.sum()).item()))
-
-                st_switch = int(not (self.sequence_tagging_dropout >= 1 and step is not None and step % self.sequence_tagging_dropout == 0))
-
-                if self.refer_index > -1:
-                    #per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * triplet_loss
-                    per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * value_loss
-                    #per_example_loss = class_loss
-                else:
-                    #per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss
-                    per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * value_loss
-                    #if epoch is not None and epoch > 20:
-                    #    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss
-                    #else:
-                    #    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-                    #per_example_loss = class_loss
-                if self.debug_use_tlf:
-                    per_example_loss *= 1.0 + ((batch['diag_len'] - batch['turn_id']) * 0.05)
-
-                total_loss += per_example_loss.sum()
-                total_cl_loss += class_loss.sum()
-                total_tk_loss += token_loss.sum()
-                #total_tp_loss += triplet_loss.sum()
-                total_tp_loss += value_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-                per_slot_per_example_cl_loss[slot] = class_loss
-                per_slot_per_example_tk_loss[slot] = token_loss
-                #per_slot_per_example_tp_loss[slot] = triplet_loss
-                per_slot_per_example_tp_loss[slot] = value_loss
-        ccc_time = time.time()
-        #print(bbb_time - aaa_time, ccc_time - bbb_time) # 0.028620243072509766
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,
-                   total_cl_loss,
-                   total_tk_loss,
-                   total_tp_loss,
-                   per_slot_per_example_loss,
-                   per_slot_per_example_cl_loss,
-                   per_slot_per_example_tk_loss,
-                   per_slot_per_example_tp_loss,
-                   per_slot_class_logits,
-                   per_slot_start_logits,
-                   per_slot_end_logits,
-                   per_slot_value_logits,
-                   per_slot_refer_logits,
-                   per_slot_att_weights,) + (vectors, weights,) + (pooled_output,) # + outputs[2:]
-        #outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, per_slot_att_weights,) + (vectors, weights,) # + outputs[2:]
-
-        return outputs
diff --git a/convlab/dst/trippyr/multiwoz/trippyr.py b/convlab/dst/trippyr/multiwoz/trippyr.py
deleted file mode 100644
index f89de78cc1024d5fce60767addc3ce803f3e23ad..0000000000000000000000000000000000000000
--- a/convlab/dst/trippyr/multiwoz/trippyr.py
+++ /dev/null
@@ -1,793 +0,0 @@
-# Copyright 2021 Heinrich Heine University Duesseldorf
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import json
-import copy
-import pickle
-
-import torch
-from transformers import (RobertaConfig, RobertaTokenizer)
-
-from convlab.dst.trippyr.multiwoz.modeling_roberta_dst import (RobertaForDST)
-
-from convlab.dst.dst import DST
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab.nlu.jointBERT.multiwoz import BERTNLU
-from convlab.nlg.template.multiwoz import TemplateNLG
-from convlab.dst.rule.multiwoz import normalize_value
-
-from convlab.util import relative_import_module_from_unified_datasets
-ONTOLOGY = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'ontology')
-TEMPLATE_STATE = ONTOLOGY['state']
-
-import pdb
-import time
-
-
-MODEL_CLASSES = {
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
-}
-
-
-SLOT_MAP_TRIPPY_TO_UDF = {
-    'hotel': {
-        'pricerange': 'price range',
-        'book_stay': 'book stay',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode',
-        'price': 'price range',
-        'people': 'book people'
-    },
-    'restaurant': {
-        'pricerange': 'price range',
-        'book_time': 'book time',
-        'book_day': 'book day',
-        'book_people': 'book people',
-        'addr': 'address',
-        'post': 'postcode',
-        'price': 'price range',
-        'people': 'book people'
-    },
-    'taxi': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'car': 'type',
-        'car type': 'type',
-        'depart': 'departure',
-        'dest': 'destination'
-    },
-    'train': {
-        'arriveBy': 'arrive by',
-        'leaveAt': 'leave at',
-        'book_people': 'book people',
-        'arrive': 'arrive by',
-        'leave': 'leave at',
-        'depart': 'departure',
-        'dest': 'destination',
-        'id': 'train id',
-        'people': 'book people',
-        'time': 'duration',
-        'ticket': 'price',
-        'trainid': 'train id'
-    },
-    'attraction': {
-        'post': 'postcode',
-        'addr': 'address',
-        'fee': 'entrance fee',
-        'price': 'entrance fee'
-    },
-    'general': {},
-    'hospital': {
-        'post': 'postcode',
-        'addr': 'address'
-    },
-    'police': {
-        'post': 'postcode',
-        'addr': 'address'
-    }
-}
-
-
-class TRIPPYR(DST):
-    def print_header(self):
-        print(" _________  ________  ___  ________  ________  ___    ___        ________     ")
-        print("|\___   ___\\\   __  \|\  \|\   __  \|\   __  \|\  \  /  /|      |\   __  \    ")
-        print("\|___ \  \_\ \  \|\  \ \  \ \  \|\  \ \  \|\  \ \  \/  / /______\ \  \|\  \   ")
-        print("     \ \  \ \ \   _  _\ \  \ \   ____\ \   ____\ \    / /\_______\ \   _  _\  ")
-        print("      \ \  \ \ \  \\\  \\\ \  \ \  \___|\ \  \___|\/  /  /\|_______|\ \  \\\  \| ")
-        print("       \ \__\ \ \__\\\ _\\\ \__\ \__\    \ \__\ __/  / /             \ \__\\\ _\ ")
-        print("        \|__|  \|__|\|__|\|__|\|__|     \|__||\___/ /               \|__|\|__|")
-        print("          (c) 2022 Heinrich Heine University \|___|/                          ")
-        print()
-
-    def print_dialog(self, hst):
-        #print("Dialogue %s, turn %s:" % (self.global_diag_cnt, int(len(hst) / 2) - 1))
-        print("Dialogue %s, turn %s:" % (self.global_diag_cnt, self.global_turn_cnt))
-        for utt in hst[:-2]:
-            print("  \033[92m%s\033[0m" % (utt))
-        if len(hst) > 1:
-            print(" ", hst[-2])
-            print(" ", hst[-1])
-
-    def print_inform_memory(self, inform_mem):
-        print("Inform memory:")
-        is_all_none = True
-        for s in inform_mem:
-            if inform_mem[s] != 'none':
-                print("  %s = %s" % (s, inform_mem[s]))
-                is_all_none = False
-        if is_all_none:
-            print("  -")
-
-    def eval_user_acts(self, user_act, user_acts):
-        print("User acts:")
-        for ua in user_acts:
-            if ua not in user_act:
-                print("  \033[33m%s\033[0m" % (ua))
-            else:
-                print("  \033[92m%s\033[0m" % (ua))
-        for ua in user_act:
-            if ua not in user_acts:
-                print("  \033[91m%s\033[0m" % (ua))
-
-    def eval_dialog_state(self, state_updates, new_belief_state):
-        print("Dialogue state:")
-        for d in self.gt_belief_state:
-            print("  %s:" % (d))
-            for s in new_belief_state[d]:
-                is_printed = False
-                is_updated = False
-                if state_updates[d][s] > 0:
-                    is_updated = True
-                if is_updated:
-                    print("\033[3m", end='')
-                if new_belief_state[d][s] != self.gt_belief_state[d][s]:
-                    self.global_eval_stats[d][s]['FP'] += 1
-                    if self.gt_belief_state[d][s] == '':
-                        print("    \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='')
-                    else:
-                        print("    \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]), end='')
-                        self.global_eval_stats[d][s]['FN'] += 1
-                    is_printed = True
-                elif new_belief_state[d][s] != '':
-                    print("    \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='')
-                    self.global_eval_stats[d][s]['TP'] += 1
-                    is_printed = True
-                if is_updated:
-                    print(" (%s)" % (self.config.dst_class_types[state_updates[d][s]]))
-                elif is_printed:
-                    print()
-
-    def eval_print_stats(self):
-        print("Statistics:")
-        for d in self.global_eval_stats:
-            for s in self.global_eval_stats[d]:
-                TP = self.global_eval_stats[d][s]['TP']
-                FP = self.global_eval_stats[d][s]['FP']
-                FN = self.global_eval_stats[d][s]['FN']
-                prec = TP / ( TP + FP + 1e-8)
-                rec = TP / ( TP + FN + 1e-8)
-                f1 = 2 * ((prec * rec) / (prec + rec + 1e-8))
-                print("  %s %s Recall: %.2f, Precision: %.2f, F1: %.2f" % (d, s, rec, prec, f1))
-
-    def __init__(self, model_type="roberta",
-                 model_name="roberta-base",
-                 model_path="",
-                 nlu_path="",
-                 emb_path="",
-                 no_eval=False,
-                 no_history=False,
-                 no_normalize_value=False,
-                 no_smoothing=False,
-                 gt_user_acts=False,
-                 gt_ds=False,
-                 gt_request_acts=False,
-                 fp16=False):
-        super(TRIPPYR, self).__init__()
-
-        self.print_header()
-
-        self.model_type = model_type.lower()
-        self.model_name = model_name.lower()
-        self.model_path = model_path
-        self.nlu_path = nlu_path
-        self.emb_path = emb_path
-        self.no_eval = no_eval
-        self.no_history = no_history
-        self.no_normalize_value = no_normalize_value
-        self.no_smoothing = no_smoothing
-        self.gt_user_acts = gt_user_acts
-        self.gt_ds = gt_ds
-        self.gt_request_acts = gt_request_acts
-        self.fp16 = fp16
-
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-        self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type]
-        self.config = self.config_class.from_pretrained(self.model_path, local_files_only=True) # TODO: parameterize
-        # TODO: update config (parameters)
-
-        path = os.path.dirname(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        path = os.path.join(path, 'data/multiwoz/value_dict.json')
-        self.value_dict = json.load(open(path))
-
-        self.global_eval_stats = copy.deepcopy(TEMPLATE_STATE)
-        for d in self.global_eval_stats:
-            for s in self.global_eval_stats[d]:
-                self.global_eval_stats[d][s] = {'TP': 0, 'FP': 0, 'FN': 0}
-        self.global_diag_cnt = -3
-        self.global_turn_cnt = -1
-
-        self.load_weights()
-        self.load_embeddings(self.emb_path, self.fp16)
-    
-    def load_weights(self):
-        self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name, local_files_only=True) # TODO: do_lower_case=args.do_lower_case ? # TODO: parameterize
-        self.model = self.model_class.from_pretrained(self.model_path, config=self.config, local_files_only=True) # TODO: parameterize
-        self.model.to(self.device)
-        self.model.eval()
-        self.nlu = BERTNLU(model_file=self.nlu_path) # This is used for internal evaluation
-        self.nlg_usr = TemplateNLG(is_user=True)
-        self.nlg_sys = TemplateNLG(is_user=False)
-
-    def load_embeddings(self, emb_path, fp16=False):
-        self.encoded_slots_pooled = pickle.load(open(os.path.join(emb_path, "encoded_slots_pooled.pickle"), "rb"))
-        self.encoded_slots_seq = pickle.load(open(os.path.join(emb_path, "encoded_slots_seq.pickle"), "rb"))
-        self.encoded_slot_values = pickle.load(open(os.path.join(emb_path, "encoded_slot_values_test.pickle"), "rb")) # TODO: maybe load a clean list in accordance with ontology?
-        if fp16:
-            for e in self.encoded_slots_pooled:
-                self.encoded_slots_pooled[e] = self.encoded_slots_pooled[e].type(torch.float32)
-            for e in self.encoded_slots_seq:
-                self.encoded_slots_seq[e] = self.encoded_slots_seq[e].type(torch.float32)
-            for e in self.encoded_slot_values:
-                for f in self.encoded_slot_values[e]:
-                    self.encoded_slot_values[e][f] = self.encoded_slot_values[e][f].type(torch.float32)
-
-    def init_session(self):
-        self.state = default_state() # Initialise as empty state
-        self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE)
-        self.nlg_history = []
-        self.gt_belief_state = copy.deepcopy(TEMPLATE_STATE)
-        self.global_diag_cnt += 1
-        self.global_turn_cnt = -1
-
-    def update_gt_belief_state(self, user_act):
-        for intent, domain, slot, value in user_act:
-            if domain == 'police':
-                continue
-            if intent == 'inform':
-                if slot == 'none' or slot == '':
-                    continue
-                domain_dic = self.gt_belief_state[domain]
-                if slot in domain_dic:
-                    #nvalue = normalize_value(self.value_dict, domain, slot, value)
-                    self.gt_belief_state[domain][slot] = value # nvalue
-                #elif slot != 'none' or slot != '':
-                #    raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-    # TODO: receive semantic, convert semantic -> text -> semantic for sanity check
-    # For TripPy: receive semantic, convert semantic -> text (with context) as input to DST
-    # - allows for accuracy estimates
-    # - allows isolating inform prediction from request prediction (as can be taken from input for sanity check)
-    def update(self, user_act=''):
-        def normalize_values(text):
-            text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"}
-            #text = re.sub("^(\d{2}) : (\d{2})$", r"\1:\2", text) # Times
-            #text = re.sub(" ?' ?s", "s", text) # Genitive
-            text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters
-            text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe
-            if text in text_to_num:
-                text = text_to_num[text]
-            return text
-
-        def filter_sequences(seqs, mode="first"):
-            if mode == "first":
-                return seqs[0][0][0]
-            elif mode == "max_first":
-                max_conf = 0
-                max_idx = 0
-                for e_itr, e in enumerate(seqs[0]):
-                    if e[1] > max_conf:
-                        max_conf = e[1]
-                        max_idx = e_itr
-                return seqs[0][max_idx][0]
-            elif mode == "max":
-                max_conf = 0
-                max_t_idx = 0
-                for t_itr, t in enumerate(seqs):
-                    for e_itr, e in enumerate(t):
-                        if e[1] > max_conf:
-                            max_conf = e[1]
-                            max_t_idx = t_itr
-                            max_idx = e_itr
-                return seqs[max_t_idx][max_idx][0]
-            else:
-                print("WARN: mode %s unknown. Aborting." % mode)
-                exit()
-
-        prev_state = self.state
-
-        if not self.no_eval:
-            print("-" * 40)
-
-        #nlg_history = []
-        ##for h in prev_state['history'][-2:]: # TODO: make this an option?
-        #for h in prev_state['history']:
-        #    nlg_history.append([h[0], self.get_text(h[1], is_user=(h[0]=='user'), normalize=True)])
-        ## Special case: at the beginning of the dialog the history might be empty (depending on policy)
-        #if len(nlg_history) == 0:
-        #    nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False, normalize=True)])
-        #    nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True, normalize=True)])
-        if self.no_history:
-            self.nlg_history = []
-        self.nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False, normalize=True)])
-        self.nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True, normalize=True)])
-        self.global_turn_cnt += 1
-        if not self.no_eval:
-            self.print_dialog(self.nlg_history)
-
-        # --- Get inform memory and auxiliary features ---
-
-        # If system_action is plain text, get acts using NLU
-        if isinstance(prev_state['user_action'], str):
-            u_acts, s_acts = self.get_acts()
-        elif isinstance(prev_state['user_action'], list):
-            u_acts = user_act # same as prev_state['user_action']
-            s_acts = prev_state['system_action']
-        else:
-            raise Exception('Unknown format for user action:', prev_state['user_action'])
-        inform_mem = self.get_inform_mem(s_acts)
-        if not self.no_eval:
-            self.print_inform_memory(inform_mem)
-        #printed_inform_mem = False
-        #for s in inform_mem:
-        #    if inform_mem[s] != 'none':
-        #        if not printed_inform_mem:
-        #            print("DST: inform_mem:")
-        #        print(s, ':', inform_mem[s])
-        #        printed_inform_mem = True
-
-        # --- Tokenize dialogue context and feed DST model ---
-
-        features = self.get_features(self.nlg_history)
-        pred_states, pred_classes, cls_representation = self.predict(features, inform_mem)
-
-        # --- Update ConvLab-style dialogue state ---
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        user_acts = []
-        for state, value in pred_states.items():
-            if isinstance(value, list):
-                value = filter_sequences(value, mode="max")
-            value = normalize_values(value)
-            if value == 'none':
-                continue
-            domain, slot = state.split('-', 1)
-            # Value normalization # TODO: according to trippy rules?
-            if domain == 'hotel' and slot == 'type':
-                value = "hotel" if value == "yes" else "guesthouse"
-            if not self.no_normalize_value:
-                value = normalize_value(self.value_dict, domain, slot, value)
-            slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot)
-            if slot in new_belief_state[domain]:
-                new_belief_state[domain][slot] = value
-                user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value])
-            else:
-                raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-        self.update_gt_belief_state(u_acts) # For evaluation
-
-        # BELIEF STATE UPDATE
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state # TripPy-R
-        if self.gt_ds:
-            new_state['belief_state'] = self.gt_belief_state # Rule
-
-        state_updates = {}
-        for cl in pred_classes:
-            cl_d, cl_s = cl.split('-')
-            # Some reformatting for the evaluation further down
-            if cl_d not in state_updates:
-                state_updates[cl_d] = {}
-            state_updates[cl_d][SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s)] = pred_classes[cl]
-            # We care only about the requestable slots here
-            if self.config.dst_class_types[pred_classes[cl]] != 'request':
-                continue
-            if cl_d != 'general' and cl_s == 'none':
-                user_acts.append(['inform', cl_d, '', ''])
-            elif cl_d == 'general':
-                user_acts.append([SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), 'general', '', ''])
-                #user_acts.append(['bye', 'general', '', '']) # Map "thank" to "bye"? Mind "hello" as well!
-            elif not self.gt_request_acts:
-                user_acts.append(['request', cl_d, SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), ''])
-
-        # USER ACTS UPDATE
-        new_state['user_action'] = user_acts # TripPy-R
-        # ONLY FOR DEBUGGING
-        if self.gt_user_acts:
-            new_state['user_action'] = u_acts # Rule
-        elif self.gt_request_acts:
-            for e in u_acts:
-                ea, _, _, _ = e
-                if ea == 'request':
-                    user_acts.append(e)
-
-        if not self.no_eval:
-            self.eval_user_acts(u_acts, user_acts)
-            self.eval_dialog_state(state_updates, new_belief_state)
-
-        #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu?
-
-        self.state = new_state
-
-        # Print eval statistics
-        if self.state['terminated'] and not self.no_eval:
-            print("Booked:", self.state['booked'])
-            self.eval_print_stats()
-            print("=" * 10, "End of the dialogue", "=" * 10)
-
-        return self.state
-    
-    def predict(self, features, inform_mem):
-        def _tokenize(text):
-            if "\u0120" in text:
-                text = re.sub(" ", "", text)
-                text = re.sub("\u0120", " ", text)
-                text = text.strip()
-            return text
-            #return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
-
-        def get_spans(pred, norm_logits, input_tokens, usr_utt_spans):
-            span_indices = [i for i in range(len(pred)) if pred[i]]
-            prev_si = None
-            spans = []
-            #confs = []
-            for si in span_indices:
-                if prev_si is None or si - prev_si > 1:
-                    spans.append(([], [], []))
-                    #confs.append([])
-                #spans[-1].append(input_tokens[si])
-                spans[-1][0].append(si)
-                spans[-1][1].append(input_tokens[si])
-                spans[-1][2].append(norm_logits[si])
-                #confs[-1].append(norm_logits[si])
-                prev_si = si
-            #spans = [' '.join(t for t in s) for s in spans]
-            spans = [(min(i), max(i), ' '.join(t for t in s), (sum(c) / len(c)).item()) for (i, s, c) in spans]
-            #confs = [(sum(c) / len(c)).item() for c in confs]
-            final_spans = {}
-            for s in spans:
-                for us_itr, us in enumerate(usr_utt_spans):
-                    if s[0] >= us[0] and s[1] <= us[1]:
-                        if us_itr not in final_spans:
-                            final_spans[us_itr] = []
-                        final_spans[us_itr].append(s[2:])
-                        break
-            final_spans = list(final_spans.values())
-            return final_spans # , confs
-        
-        def get_usr_utt_spans(usr_mask):
-            span_indices = [i for i in range(len(usr_mask)) if usr_mask[i]]
-            prev_si = None
-            spans = []
-            for si in span_indices:
-                if prev_si is None or si - prev_si > 1:
-                    spans.append([])
-                spans[-1].append(si)
-                prev_si = si
-            spans = [[min(s), max(s)] for s in spans]
-            return spans
-
-        def smooth_roberta_predictions(pred, input_tokens):
-            smoothed_pred = pred.detach().clone()
-            # Forward
-            span = False
-            for i in range(len(pred)):
-                if pred[i] > 0:
-                    span = True
-                elif span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<":
-                    smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens
-                elif span and (input_tokens[i][0] == "\u0120" or input_tokens[i][0] == "<"):
-                    span = False
-            # Backward
-            span = False
-            for i in range(len(pred) - 1, 0, -1):
-                if pred[i] > 0:
-                    span = True
-                if span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<":
-                    smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens
-                elif span and input_tokens[i][0] == "\u0120":
-                    smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens
-                    span = False
-            #if pred != smoothed_pred:
-            #    print(get_spans(pred, input_tokens))
-            #    print(get_spans(smoothed_pred, input_tokens))
-            return smoothed_pred
-
-        #aaa_time = time.time()
-        with torch.no_grad():
-            outputs = self.model(features) # TODO: mode etc
-        #bbb_time = time.time()
-
-        input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked!
-
-        # assign identified spans to their respective usr turns (simply append spans as list of lists)
-        usr_utt_spans = get_usr_utt_spans(features['usr_mask'][0][1:])
-
-        per_slot_class_logits = outputs[8] # [2]
-        per_slot_start_logits = outputs[9] # [3]
-        per_slot_end_logits = outputs[10] # [4]
-        per_slot_value_logits = outputs[11] # [4]
-        per_slot_refer_logits = outputs[12] # [5]
-
-        cls_representation = outputs[16]
-
-        # TODO: maybe add assert to check that batch=1
-        
-        predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-        class_predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            start_logits = per_slot_start_logits[slot][0]
-            end_logits = per_slot_end_logits[slot][0] if slot in per_slot_end_logits else None
-            value_logits = per_slot_value_logits[slot][0] if per_slot_value_logits[slot] is not None else None
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            weights = start_logits[:len(features['input_ids'][0])]
-            norm_logits = torch.div(torch.clamp(weights - torch.mean(weights), min=0), torch.max(weights))
-
-            class_prediction = int(class_logits.argmax())
-            start_prediction = norm_logits > 0.0
-            if not self.no_smoothing:
-                start_prediction = smooth_roberta_predictions(start_prediction, input_tokens) # Slow!
-            start_prediction[0] = False # Ignore <s>
-            end_prediction = int(end_logits.argmax()) if end_logits is not None else None
-            refer_prediction = int(refer_logits.argmax())
-
-            if class_prediction == self.config.dst_class_types.index('dontcare'):
-                predictions[slot] = 'dontcare'
-            elif class_prediction == self.config.dst_class_types.index('copy_value'):
-                spans = get_spans(start_prediction[1:], norm_logits[1:], input_tokens[1:], usr_utt_spans)
-                if len(spans) > 0:
-                    for e_itr in range(len(spans)):
-                        for ee_itr in range(len(spans[e_itr])):
-                            tmp = list(spans[e_itr][ee_itr])
-                            tmp[0] = _tokenize(tmp[0])
-                            spans[e_itr][ee_itr] = tuple(tmp)
-                    predictions[slot] = spans
-                else:
-                    predictions[slot] = "none"
-            elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'):
-                predictions[slot] = "yes" # 'true'
-            elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'):
-                predictions[slot] = "no" # 'false'
-            elif class_prediction == self.config.dst_class_types.index('inform'):
-                #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot])
-                predictions[slot] = inform_mem[slot]
-            #elif class_prediction == self.config.dst_class_types.index('request'):
-            #    if slot in ["hotel-internet", "hotel-parking"]:
-            #        predictions[slot] = "yes" # 'true'
-            # Referral case is handled below
-
-        # Referral case. All other slot values need to be seen first in order
-        # to be able to do this correctly.
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'):
-                # Only slots that have been mentioned before can be referred to.
-                # First try to resolve a reference within the same turn. (One can think of a situation
-                # where one slot is referred to in the same utterance. This phenomenon is however
-                # currently not properly covered in the training data label generation process)
-                # Then try to resolve a reference given the current dialogue state.
-                predictions[slot] = predictions[list(self.config.dst_slot_list.keys())[refer_prediction]]
-                if predictions[slot] == 'none':
-                    referred_slot = list(self.config.dst_slot_list.keys())[refer_prediction]
-                    referred_slot_d, referred_slot_s = referred_slot.split('-')
-                    referred_slot_s = SLOT_MAP_TRIPPY_TO_UDF[referred_slot_d].get(referred_slot_s, referred_slot_s)
-                    #pdb.set_trace()
-                    if self.state['belief_state'][referred_slot_d][referred_slot_s] != '':
-                        predictions[slot] = self.state['belief_state'][referred_slot_d][referred_slot_s]
-                if predictions[slot] == 'none':
-                    ref_slot = list(self.config.dst_slot_list.keys())[refer_prediction]
-                    if ref_slot == 'hotel-name':
-                        predictions[slot] = 'the hotel'
-                    elif ref_slot == 'restaurant-name':
-                        predictions[slot] = 'the restaurant'
-                    elif ref_slot == 'attraction-name':
-                        predictions[slot] = 'the attraction'
-                    elif ref_slot == 'hotel-area':
-                        predictions[slot] = 'same area as the hotel'
-                    elif ref_slot == 'restaurant-area':
-                        predictions[slot] = 'same area as the restaurant'
-                    elif ref_slot == 'attraction-area':
-                        predictions[slot] = 'same area as the attraction'
-                    elif ref_slot == 'hotel-pricerange':
-                        predictions[slot] = 'in the same price range as the hotel'
-                    elif ref_slot == 'restaurant-pricerange':
-                        predictions[slot] = 'in the same price range as the restaurant'
-
-            class_predictions[slot] = class_prediction
-
-        # TODO: value normalization
-        # TODO: value matching
-        #ccc_time = time.time()
-        #print("TIME:", bbb_time - aaa_time, ccc_time - bbb_time)
-
-        return predictions, class_predictions, cls_representation
-
-    def get_features(self, context):
-        def to_device(batch, device):
-            if isinstance(batch, tuple):
-                batch_on_device = tuple([to_device(element, device) for element in batch])
-            if isinstance(batch, dict):
-                batch_on_device = {k: to_device(v, device) for k, v in batch.items()}
-            else:
-                batch_on_device = batch.to(device) if batch is not None else batch
-            return batch_on_device
-
-        assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models
-        input_tokens = ['<s>']
-        e_itr = 0
-        for e_itr, e in enumerate(reversed(context)):
-            input_tokens.append(e[1] if e[1] != 'null' else ' ')
-            if e_itr < 2:
-                input_tokens.append('</s> </s>')
-            else:
-                input_tokens.append('</s>')
-            # Ignore history for now
-            #if e_itr == 1:
-            #    break
-        if e_itr == 0:
-            input_tokens.append('</s> </s>')
-        input_tokens.append('</s>')
-        input_tokens = ' '.join(input_tokens)
-
-        # TODO: delex sys utt somehow, or refrain from using delex for sys utts?
-        features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length)
-
-        input_ids = torch.tensor(features['input_ids']).reshape(1,-1)
-        input_mask = torch.tensor(features['attention_mask']).reshape(1,-1)
-        usr_mask = torch.zeros(input_ids.size())
-        usr_seen = False
-        sys_seen = False
-        usr_sep = 0
-        sys_sep = 0
-        hst_cnt = 0
-        for i_itr, i in enumerate(input_ids[0,:]):
-            if i_itr == 0:
-                continue
-            is_usr = True
-            if i == 1:
-                is_usr = False
-            if i == 2:
-                is_usr = False
-                if not usr_seen:
-                    usr_sep += 1
-                    if usr_sep == 2:
-                        usr_seen = True
-                elif not sys_seen:
-                    sys_sep += 1
-                    if sys_sep == 2:
-                        sys_seen = True
-                else:
-                    hst_cnt += 1
-            if usr_seen and not sys_seen:
-                is_usr = False
-            elif usr_seen and sys_seen and hst_cnt % 2 == 1:
-                is_usr = False
-            if is_usr:
-                usr_mask[0,i_itr] = 1
-        #usr_mask = torch.tensor(features['attention_mask']).reshape(1,-1) # TODO
-        features = {'input_ids': input_ids,
-                    'input_mask': input_mask,
-                    'usr_mask': usr_mask,
-                    'start_pos': None,
-                    'end_pos': None,
-                    'refer_id': None,
-                    'class_label_id': None,
-                    'inform_slot_id': None,
-                    'diag_state': None,
-                    'pos_sampling_input': None,
-                    'neg_sampling_input': None,
-                    'encoded_slots_pooled': self.encoded_slots_pooled,
-                    'encoded_slots_seq': self.encoded_slots_seq,
-                    'encoded_slot_values': self.encoded_slot_values}
-
-        return to_device(features, self.device)
-
-    # TODO: consider "booked" values?
-    def get_inform_mem(self, state):
-        inform_mem = {slot: 'none' for slot in self.config.dst_slot_list}
-        for e in state:
-            a, d, s, v = e
-            if a in ['inform', 'recommend', 'select', 'book', 'offerbook']:
-                #ds_d = d.lower()
-                #if s in REF_SYS_DA[d]:
-                #    ds_s = REF_SYS_DA[d][s]
-                #elif s in REF_SYS_DA['Booking']:
-                #    ds_s = "book_" + REF_SYS_DA['Booking'][s]
-                #else:
-                #    ds_s = s.lower()
-                #    #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d))
-                slot = "%s-%s" % (d, s)
-                if slot in inform_mem:
-                    inform_mem[slot] = v
-        return inform_mem
-
-    def get_acts(self):
-        context = self.state['history']
-        if context[-1][0] != 'user':
-            raise Exception("Wrong order of utterances, check your input.")
-        system_act = context[-2][-1]
-        user_act = context[-1][-1]
-        system_context = [t for s,t in context[:-2]]
-        user_context = [t for s,t in context[:-1]]
-
-        #print("  SYS:", system_act, system_context)
-        system_acts = self.nlu.predict(system_act, context=system_context)
-
-        #print("  USR:", user_act, user_context)
-        user_acts = self.nlu.predict(user_act, context=user_context)
-        
-        return user_acts, system_acts
-
-    def get_text(self, act, is_user=False, normalize=False):
-        if act == 'null':
-            return 'null'
-        if not isinstance(act, list):
-            result = act
-        elif is_user:
-            result = self.nlg_usr.generate(act)
-        else:
-            result = self.nlg_sys.generate(act)
-        if normalize:
-            return self.normalize_text(result)
-        else:
-            return result
-
-    def normalize_text(self, text):
-        norm_text = text.lower()
-        #norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference
-        #norm_text = re.sub("ca not", "cannot", norm_text)
-        norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0])
-        return norm_text
-
-
-# if __name__ == "__main__":
-#     tracker = TRIPPY(model_type='roberta', model_path='/path/to/model',
-#                         nlu_path='/path/to/nlu')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     state = tracker.update('If you have something Asian that would be great.')
-#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     state = tracker.update('Great. Where are they located?')
-#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/dst/trippyr/multiwoz/trippyr.py~ b/convlab/dst/trippyr/multiwoz/trippyr.py~
deleted file mode 100644
index fc5696227ffbd6de639af9eda92983b77c41059e..0000000000000000000000000000000000000000
--- a/convlab/dst/trippyr/multiwoz/trippyr.py~
+++ /dev/null
@@ -1,521 +0,0 @@
-# Copyright 2021 Heinrich Heine University Duesseldorf
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import json
-import copy
-import pickle
-
-import torch
-from transformers import (RobertaConfig, RobertaTokenizer)
-
-from convlab2.dst.trippyr.multiwoz.modeling_roberta_dst import (RobertaForDST)
-
-from convlab2.dst.dst import DST
-from convlab2.util.multiwoz.state import default_state
-from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab2.nlu.jointBERT.multiwoz import BERTNLU
-from convlab2.dst.rule.multiwoz import normalize_value
-
-MODEL_CLASSES = {
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
-}
-
-
-class TRIPPYR(DST):
-    def __init__(self, model_type="roberta", model_name="roberta-base", model_path="", nlu_path="", emb_path="", fp16=False):
-        super(TRIPPYR, self).__init__()
-
-        self.model_type = model_type.lower()
-        self.model_name = model_name.lower()
-        self.model_path = model_path
-        self.nlu_path = nlu_path
-
-        path = os.path.dirname(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        path = os.path.join(path, 'data/multiwoz/value_dict.json')
-        self.value_dict = json.load(open(path))
-
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-        self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type]
-        self.config = self.config_class.from_pretrained(self.model_path)
-        # TODO: update config (parameters)
-
-        self.load_weights()
-        self.load_embeddings(emb_path, fp16)
-    
-    def load_weights(self):
-        self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name) # TODO: do_lower_case=args.do_lower_case ?
-        self.model = self.model_class.from_pretrained(self.model_path, config=self.config)
-        self.model.to(self.device)
-        self.model.eval()
-        self.nlu = BERTNLU(model_file=self.nlu_path) # TODO: remove, once TripPy takes over its task
-
-    def load_embeddings(self, emb_path, fp16=False):
-        self.encoded_slots_pooled = pickle.load(open(os.path.join(emb_path, "encoded_slots_pooled.pickle"), "rb"))
-        self.encoded_slots_seq = pickle.load(open(os.path.join(emb_path, "encoded_slots_seq.pickle"), "rb"))
-        self.encoded_slot_values = pickle.load(open(os.path.join(emb_path, "encoded_slot_values_test.pickle"), "rb"))
-        if fp16:
-            for e in self.encoded_slots_pooled:
-                self.encoded_slots_pooled[e] = self.encoded_slots_pooled[e].type(torch.float32)
-            for e in self.encoded_slots_seq:
-                self.encoded_slots_seq[e] = self.encoded_slots_seq[e].type(torch.float32)
-            for e in self.encoded_slot_values:
-                for f in self.encoded_slot_values[e]:
-                    self.encoded_slot_values[e][f] = self.encoded_slot_values[e][f].type(torch.float32)
-
-    def init_session(self):
-        self.state = default_state()
-        self.hidden_states = None
-
-    def update(self, user_act=''):
-        def filter_sequences(seqs, mode="first"):
-            if mode == "first":
-                return tokenize(seqs[0][0][0])
-            elif mode == "max_first":
-                max_conf = 0
-                max_idx = 0
-                for e_itr, e in enumerate(seqs[0]):
-                    if e[1] > max_conf:
-                        max_conf = e[1]
-                        max_idx = e_itr
-                return tokenize(seqs[0][max_idx][0])
-            elif mode == "max":
-                max_conf = 0
-                max_t_idx = 0
-                for t_itr, t in enumerate(seqs):
-                    for e_itr, e in enumerate(t):
-                        if e[1] > max_conf:
-                            max_conf = e[1]
-                            max_t_idx = t_itr
-                            max_idx = e_itr
-                return tokenize(seqs[max_t_idx][max_idx][0])
-            else:
-                print("WARN: mode %s unknown. Aborting." % mode)
-                exit()
-
-        def tokenize(text):
-            if "\u0120" in text:
-                text = re.sub(" ", "", text)
-                text = re.sub("\u0120", " ", text)
-                text = text.strip()
-            return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
-
-        prev_state = self.state
-        #print("--")
-
-        # --- Get inform memory  ---
-
-        # If system_action is plain text, get acts using NLU
-        if isinstance(prev_state['system_action'], str):
-            acts, _ = self.get_acts(prev_state['system_action'])
-        elif isinstance(prev_state['system_action'], list):
-            acts = prev_state['system_action']
-        else:
-            raise Exception('Unknown format for system action:', prev_state['system_action'])
-        inform_mem = self.get_inform_mem(acts)
-        #print("inform_mem:")
-        #for s in inform_mem:
-        #    if inform_mem[s] != 'none':
-        #        print(s, ':', inform_mem[s])
-
-        # --- Tokenize dialogue context and feed DST model ---
-
-        features = self.get_features(self.state['history'])
-        pred_states, pred_classes, cls_representation = self.predict(features, inform_mem)
-        #print(pred_states)
-        #print(pred_classes)
-        #import pdb
-        #pdb.set_trace()
-
-        # --- Update ConvLab-style dialogue state ---
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        user_acts = []
-        for state, value in pred_states.items():
-            if isinstance(value, list):
-                value = filter_sequences(value, mode="max")
-            else:
-                value = tokenize(value)
-            #if pred_classes[state] > 0:
-            #    print(pred_classes[state], state, value)
-
-            domain, slot = state.split('-', 1)
-
-            # 1.) Domain prediction
-            if slot == "none": # and pred_classes[state] == 3:
-                continue # for now, continue
-
-            # 2.) Requests and greetings
-            if pred_classes[state] == 7:
-                if domain == "general":
-                    user_acts.append([slot, 'general', 'none', 'none'])
-                else:
-                    user_acts.append(['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(slot, slot.capitalize()), '?'])
-
-            # 3.) Informable slots
-            if value == 'none':
-                continue
-            # Value normalization # TODO: according to trippy rules?
-            if domain == 'hotel' and slot == 'type':
-                value = "hotel" if value == "yes" else "guesthouse"
-            value = normalize_value(self.value_dict, domain, slot, value)
-            if slot not in ['name', 'book']:
-                if domain not in new_belief_state:
-                    if domain == 'bus':
-                        continue
-                    else:
-                        raise Exception('Domain <{}> not in belief state'.format(domain))
-            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) if domain.capitalize() in REF_SYS_DA else slot
-            assert 'semi' in new_belief_state[domain]
-            assert 'book' in new_belief_state[domain]
-            if 'book' in slot:
-                assert slot.startswith('book_')
-                s = slot.split('_')[1]
-                if s in new_belief_state[domain]['book']:
-                    new_belief_state[domain]['book'][s] = value
-                    user_acts.append(['Inform', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(s, s.capitalize()), value])
-            elif slot in new_belief_state[domain]['semi']:
-                new_belief_state[domain]['semi'][slot] = value
-                user_acts.append(['Inform', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(slot, slot.capitalize()), value])
-            else:
-                raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain))
-
-        # Update request_state
-        #new_request_state = copy.deepcopy(prev_state['request_state'])
-
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state
-        #new_state['request_state'] = new_request_state
-
-        nlu_user_acts, nlu_system_acts = self.get_acts(user_act)
-        user_acts = []
-        for e in nlu_user_acts:
-            if e[0] != 'Inform':
-                user_acts.append(e)
-        #new_state['system_action'] = nlu_system_acts # Empty when DST for user -> needed?
-        new_state['user_action'] = user_acts
-
-        new_state['cls_representation'] = cls_representation
-
-        self.state = new_state
-
-        #print("--")
-        return self.state
-    
-    def predict(self, features, inform_mem):
-        def _tokenize(text):
-            if "\u0120" in text:
-                text = re.sub(" ", "", text)
-                text = re.sub("\u0120", " ", text)
-                text = text.strip()
-            return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
-
-        def get_spans(pred, norm_logits, input_tokens, usr_utt_spans):
-            span_indices = [i for i in range(len(pred)) if pred[i]]
-            prev_si = None
-            spans = []
-            #confs = []
-            for si in span_indices:
-                if prev_si is None or si - prev_si > 1:
-                    spans.append(([], [], []))
-                    #confs.append([])
-                #spans[-1].append(input_tokens[si])
-                spans[-1][0].append(si)
-                spans[-1][1].append(input_tokens[si])
-                spans[-1][2].append(norm_logits[si])
-                #confs[-1].append(norm_logits[si])
-                prev_si = si
-            #spans = [' '.join(t for t in s) for s in spans]
-            spans = [(min(i), max(i), ' '.join(t for t in s), (sum(c) / len(c)).item()) for (i, s, c) in spans]
-            #confs = [(sum(c) / len(c)).item() for c in confs]
-            final_spans = {}
-            for s in spans:
-                for us_itr, us in enumerate(usr_utt_spans):
-                    if s[0] >= us[0] and s[1] <= us[1]:
-                        if us_itr not in final_spans:
-                            final_spans[us_itr] = []
-                        final_spans[us_itr].append(s[2:])
-                        break
-            final_spans = list(final_spans.values())
-            return final_spans # , confs
-        
-        def get_usr_utt_spans(usr_mask):
-            span_indices = [i for i in range(len(usr_mask)) if usr_mask[i]]
-            prev_si = None
-            spans = []
-            for si in span_indices:
-                if prev_si is None or si - prev_si > 1:
-                    spans.append([])
-                spans[-1].append(si)
-                prev_si = si
-            spans = [[min(s), max(s)] for s in spans]
-            return spans
-
-        def smooth_roberta_predictions(pred, input_tokens):
-            smoothed_pred = pred.detach().clone()
-            # Forward
-            span = False
-            i = 0
-            while i < len(pred):
-                if pred[i] > 0:
-                    span = True
-                elif span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<":
-                    smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens
-                elif span and (input_tokens[i][0] == "\u0120" or input_tokens[i][0] == "<"):
-                    span = False
-                i += 1
-            # Backward
-            span = False
-            i = len(pred) - 1
-            while i >= 0:
-                if pred[i] > 0:
-                    span = True
-                if span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<":
-                    smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens
-                elif span and input_tokens[i][0] == "\u0120":
-                    smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens
-                    span = False
-                i -= 1
-            #if pred != smoothed_pred:
-            #    print(get_spans(pred, input_tokens))
-            #    print(get_spans(smoothed_pred, input_tokens))
-            return smoothed_pred
-
-        with torch.no_grad():
-            outputs = self.model(features) # TODO: mode etc
-
-        input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked!
-
-        # assign identified spans to their respective usr turns (simply append spans as list of lists)
-        usr_utt_spans = get_usr_utt_spans(features['usr_mask'][0][1:])
-
-        per_slot_class_logits = outputs[8] # [2]
-        per_slot_start_logits = outputs[9] # [3]
-        per_slot_end_logits = outputs[10] # [4]
-        per_slot_value_logits = outputs[11] # [4]
-        per_slot_refer_logits = outputs[12] # [5]
-
-        cls_representation = outputs[16]
-
-        # TODO: maybe add assert to check that batch=1
-        
-        predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-        class_predictions = {slot: 'none' for slot in self.config.dst_slot_list}
-
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            start_logits = per_slot_start_logits[slot][0]
-            end_logits = per_slot_end_logits[slot][0] if slot in per_slot_end_logits else None
-            value_logits = per_slot_value_logits[slot][0] if per_slot_value_logits[slot] is not None else None
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            weights = start_logits[:len(features['input_ids'][0])]
-            norm_logits = torch.clamp(weights - torch.mean(weights), min=0) / max(weights)
-
-            class_prediction = int(class_logits.argmax())
-            start_prediction = norm_logits > 0.0
-            start_prediction = smooth_roberta_predictions(start_prediction, input_tokens)
-            start_prediction[0] = False # Ignore <s>
-            end_prediction = int(end_logits.argmax()) if end_logits is not None else None
-            refer_prediction = int(refer_logits.argmax())
-
-            if class_prediction == self.config.dst_class_types.index('dontcare'):
-                predictions[slot] = 'dontcare'
-            elif class_prediction == self.config.dst_class_types.index('copy_value'):
-                spans = get_spans(start_prediction[1:], norm_logits[1:], input_tokens[1:], usr_utt_spans)
-                if len(spans) > 0:
-                    for e_itr in range(len(spans)):
-                        for ee_itr in range(len(spans[e_itr])):
-                            tmp = list(spans[e_itr][ee_itr])
-                            tmp[0] = _tokenize(tmp[0])
-                            spans[e_itr][ee_itr] = tuple(tmp)
-                    predictions[slot] = spans
-                else:
-                    predictions[slot] = "none"
-            elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'):
-                predictions[slot] = "yes" # 'true'
-            elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'):
-                predictions[slot] = "no" # 'false'
-            elif class_prediction == self.config.dst_class_types.index('inform'):
-                #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot])
-                predictions[slot] = inform_mem[slot]
-            #elif class_prediction == self.config.dst_class_types.index('request'):
-            #    if slot in ["hotel-internet", "hotel-parking"]:
-            #        predictions[slot] = "yes" # 'true'
-            # Referral case is handled below
-
-            class_predictions[slot] = class_prediction
-
-        # Referral case. All other slot values need to be seen first in order
-        # to be able to do this correctly.
-        for slot in self.config.dst_slot_list:
-            class_logits = per_slot_class_logits[slot][0]
-            refer_logits = per_slot_refer_logits[slot][0]
-
-            class_prediction = int(class_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'):
-                # Only slots that have been mentioned before can be referred to.
-                # One can think of a situation where one slot is referred to in the same utterance.
-                # This phenomenon is however currently not properly covered in the training data
-                # label generation process.
-                predictions[slot] = predictions[list(self.config.dst_slot_list.keys())[refer_prediction]]
-
-        # TODO: value normalization
-        # TODO: value matching
-
-            #if class_prediction > 0:
-            #    print("  ", slot, "->", class_prediction, ",", predictions[slot])
-
-        return predictions, class_predictions, cls_representation
-
-    def get_features(self, context):
-        def to_device(batch, device):
-            if isinstance(batch, tuple):
-                batch_on_device = tuple([to_device(element, device) for element in batch])
-            if isinstance(batch, dict):
-                batch_on_device = {k: to_device(v, device) for k, v in batch.items()}
-            else:
-                batch_on_device = batch.to(device) if batch is not None else batch
-            return batch_on_device
-
-        assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models
-        input_tokens = [] # ['<s>']
-        for e_itr, e in enumerate(reversed(context)):
-            input_tokens.append(e[1] if e[1] != 'null' else ' ')
-            if e_itr < 2:
-                input_tokens.append('</s> </s>')
-            else:
-                input_tokens.append('</s>')
-            # Ignore history for now
-            if e_itr == 1:
-                break
-        if e_itr == 0:
-            input_tokens.append('</s> </s>')
-        #input_tokens.append('</s>')
-        input_tokens = ' '.join(input_tokens)
-
-        # TODO: delex sys utt somehow, or refrain from using delex for sys utts?
-        features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=True, max_length=self.config.dst_max_seq_length)
-
-        input_ids = torch.tensor(features['input_ids']).reshape(1,-1)
-        input_mask = torch.tensor(features['attention_mask']).reshape(1,-1)
-        usr_mask = torch.zeros(input_ids.size())
-        usr_seen = False
-        sys_seen = False
-        usr_sep = 0
-        sys_sep = 0
-        hst_cnt = 0
-        for i_itr, i in enumerate(input_ids[0,:]):
-            if i_itr == 0:
-                continue
-            is_usr = True
-            if i == 1:
-                is_usr = False
-            if i == 2:
-                is_usr = False
-                if not usr_seen:
-                    usr_sep += 1
-                    if usr_sep == 2:
-                        usr_seen = True
-                elif not sys_seen:
-                    sys_sep += 1
-                    if sys_sep == 2:
-                        sys_seen = True
-                else:
-                    hst_cnt += 1
-            if usr_seen and not sys_seen:
-                is_usr = False
-            elif usr_seen and sys_seen and hst_cnt % 2 == 1:
-                is_usr = False
-            if is_usr:
-                usr_mask[0,i_itr] = 1
-        #usr_mask = torch.tensor(features['attention_mask']).reshape(1,-1) # TODO
-        features = {'input_ids': input_ids,
-                    'input_mask': input_mask,
-                    'usr_mask': usr_mask,
-                    'start_pos': None,
-                    'end_pos': None,
-                    'refer_id': None,
-                    'class_label_id': None,
-                    'inform_slot_id': None,
-                    'diag_state': None,
-                    'pos_sampling_input': None,
-                    'neg_sampling_input': None,
-                    'encoded_slots_pooled': self.encoded_slots_pooled,
-                    'encoded_slots_seq': self.encoded_slots_seq,
-                    'encoded_slot_values': self.encoded_slot_values}
-
-        return to_device(features, self.device)
-
-    # TODO: consider "booked" values?
-    def get_inform_mem(self, state):
-        inform_mem = {slot: 'none' for slot in self.config.dst_slot_list}
-        for e in state:
-            a, d, s, v = e
-            if a in ['Inform', 'Recommend', 'Select', 'Book', 'OfferBook']:
-                ds_d = d.lower()
-                if s in REF_SYS_DA[d]:
-                    ds_s = REF_SYS_DA[d][s]
-                elif s in REF_SYS_DA['Booking']:
-                    ds_s = "book_" + REF_SYS_DA['Booking'][s]
-                else:
-                    ds_s = s.lower()
-                    #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d))
-                slot = "%s-%s" % (ds_d, ds_s)
-                if slot in inform_mem:
-                    inform_mem[slot] = v
-        return inform_mem
-
-    # TODO: fix, still a mess...
-    def get_acts(self, user_act):
-        context = self.state['history']
-        if context:
-            if context[-1][0] != 'sys':
-                system_act = ''
-                context = [t for s,t in context]
-            else:
-                system_act = context[-1][-1]
-                context = [t for s,t in context[:-1]]
-        else:
-            system_act = ''
-            context = ['']
-
-        #print("  SYS:", system_act, context)
-        system_acts = self.nlu.predict(system_act, context=context)
-
-        context.append(system_act)
-        #print("  USR:", user_act, context)
-        user_acts = self.nlu.predict(user_act, context=context)
-        
-        return user_acts, system_acts
-
-
-# if __name__ == "__main__":
-#     tracker = TRIPPY(model_type='roberta', model_path='/path/to/model',
-#                         nlu_path='/path/to/nlu')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     state = tracker.update('If you have something Asian that would be great.')
-#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     state = tracker.update('Great. Where are they located?')
-#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py
index a29af17b6b1da3be4e4347326f4d72e4187fd4c0..c89361b2a84824198e516941b71806440a9ba3a5 100755
--- a/convlab/evaluator/multiwoz_eval.py
+++ b/convlab/evaluator/multiwoz_eval.py
@@ -5,16 +5,29 @@ import re
 import numpy as np
 
 from copy import deepcopy
+from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map
+from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
 from convlab.evaluator.evaluator import Evaluator
 from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map
 from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple
 from convlab.util.multiwoz.dbquery import Database
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
 from convlab.util import relative_import_module_from_unified_datasets
 
-DEBUG = False
-reverse_da = relative_import_module_from_unified_datasets(
-    'multiwoz21', 'preprocess.py', 'reverse_da')
+# import reflect table
+REF_SYS_DA_M = {}
+for dom, ref_slots in REF_SYS_DA.items():
+    dom = dom.lower()
+    REF_SYS_DA_M[dom] = {}
+    for slot_a, slot_b in ref_slots.items():
+        if slot_a == 'Ref':
+            slot_b = 'ref'
+        REF_SYS_DA_M[dom][slot_a.lower()] = slot_b
+    REF_SYS_DA_M[dom]['none'] = 'none'
+REF_SYS_DA_M['taxi']['phone'] = 'phone'
+REF_SYS_DA_M['taxi']['car'] = 'car type'
+
+reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da')
+
 
 requestable = \
     {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'],
@@ -28,20 +41,15 @@ requestable = \
 belief_domains = requestable.keys()
 
 mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone',
-                          'post': 'postcode', 'price': 'pricerange', 'ref': 'ref',
-                          'price range': 'pricerange', 'address': 'address', 'postcode': 'postcode'},
+                          'post': 'postcode', 'price': 'pricerange', 'ref': 'ref'},
            'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name',
-                     'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type', 'ref': 'ref',
-                     'price range': 'pricerange', 'address': 'address', 'postcode': 'postcode'},
+                     'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type', 'ref': 'ref'},
            'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone',
-                          'post': 'postcode', 'type': 'type', 'entrance fee': 'entrance fee'},
-           'train': {'id': 'trainID', 'arrive': 'arrive by', 'day': 'day', 'depart': 'departure', 'dest': 'destination',
-                     'time': 'duration', 'leave': 'leave at', 'ticket': 'price', 'ref': 'ref',
-                     'arrive by': 'arrive by', 'leave at': 'leave at', 'departure': 'departure', 'destination': "destination", "duration": "duration", "price": "price"},
-           'taxi': {'car': 'car type', 'phone': 'phone',
-                    'car type': 'car type'},
-           'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department',
-                        'postcode': 'postcode', 'address': 'address'},
+                          'post': 'postcode', 'type': 'type'},
+           'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination',
+                     'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price', 'ref': 'ref'},
+           'taxi': {'car': 'car type', 'phone': 'phone'},
+           'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
            'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
 
 
@@ -59,7 +67,6 @@ for dom, ref_slots in REF_SYS_DA.items():
     REF_SYS_DA_M[dom]['none'] = 'none'
 REF_SYS_DA_M['taxi']['phone'] = 'phone'
 REF_SYS_DA_M['taxi']['car'] = 'car type'
-REF_SYS_DA_M['train']['id'] = 'train id'
 DEF_VAL_UNK = '?'  # Unknown
 DEF_VAL_DNC = 'dontcare'  # Do not care
 DEF_VAL_NUL = 'none'  # for none
@@ -68,6 +75,14 @@ DEF_VAL_NOBOOK = 'no'  # for booked
 
 NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK]
 
+# Not sure values in inform
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK]
+
 
 class MultiWozEvaluator(Evaluator):
     def __init__(self, check_book_constraints=True, check_domain_success=False):
@@ -277,8 +292,7 @@ class MultiWozEvaluator(Evaluator):
                         if value == value_predicted:
                             match += 1
                     except Exception as e:
-                        if DEBUG:
-                            print("Tracker probably does not track that slot.", e)
+                        print("Tracker probably does not track that slot.", e)
                         # if tracker does not track it, it trivially matches since policy has no chance otherwise
                         match += 1
 
@@ -319,7 +333,6 @@ class MultiWozEvaluator(Evaluator):
                 else:
                     # print('FN + 1')
                     reqt_not_inform.add(('request', domain, k))
-                    print("reqt_not_inform.add(('request', domain, k))", domain, k)
                     FN += 1
             for k in inform_slot[domain]:
                 # exclude slots that are informed by users
@@ -329,13 +342,12 @@ class MultiWozEvaluator(Evaluator):
                     # print('FP + 1 @2', k)
                     inform_not_reqt.add(('inform', domain, k,))
                     FP += 1
-
         return TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt
 
     def _check_value(self, domain, key, value):
         if key == "area":
             return value.lower() in ["centre", "east", "south", "west", "north"]
-        elif key == "arriveBy" or key == "leaveAt" or key == "arrive by" or key == "leave at":
+        elif key == "arriveBy" or key == "leaveAt":
             return time_re.match(value)
         elif key == "day":
             return value.lower() in ["monday", "tuesday", "wednesday", "thursday", "friday",
@@ -348,13 +360,13 @@ class MultiWozEvaluator(Evaluator):
             return re.match(r'^\d{11}$', value) or domain == "restaurant"
         elif key == "price":
             return 'pound' in value
-        elif key == "pricerange" or key == "price range":
+        elif key == "pricerange":
             return value in ["cheap", "expensive", "moderate", "free"] or domain == "attraction"
         elif key == "postcode":
             return re.match(r'^cb\d{1,3}[a-z]{2,3}$', value) or value == 'pe296fl'
         elif key == "stars":
             return re.match(r'^\d$', value)
-        elif key == "trainID" or key == "train id":
+        elif key == "trainID":
             return re.match(r'^tr\d{4}$', value.lower())
         else:
             return True
@@ -426,9 +438,6 @@ class MultiWozEvaluator(Evaluator):
 
         TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt = self._inform_F1_goal(
             goal, self.sys_da_array)
-        if len(reqt_not_inform) > 0:
-            print("bad_inform", bad_inform)
-            print("reqt_not_inform", reqt_not_inform)
         if aggregate:
             try:
                 rec = TP / (TP + FN)
@@ -463,14 +472,6 @@ class MultiWozEvaluator(Evaluator):
                 book_constraint_sess == 1 or book_constraint_sess is None) else 0
             return self.success if not self.check_book_constraints else self.success_strict
         else:
-            print("===== fail reason =====")
-            if goal_sess != 1:
-                print("goal_sess", goal_sess)
-            if book_sess is not None and book_sess < 1:
-                print("book_sess", book_sess)
-            if inform_sess[1] is not None and inform_sess[1] < 1:
-                print("inform_sess", inform_sess)
-
             self.complete = 1 if booking_done and (
                 inform_sess[1] == 1 or inform_sess[1] is None) else 0
             self.success = 0
@@ -668,12 +669,11 @@ class MultiWozEvaluator(Evaluator):
                 slot = reverse_da_slot_name_map[domain][slot]
             else:
                 slot = slot.capitalize()
-
             if intent.lower() in ['inform', 'recommend']:
                 if domain.lower() in goal:
                     if 'reqt' in goal[domain.lower()]:
-                        if REF_SYS_DA_M.get(domain.lower(), {}).get(slot.lower(), slot.lower()) in goal[domain.lower()][
-                                'reqt']:
+                        if REF_SYS_DA_M.get(domain.lower(), {}).get(slot.lower(), slot.lower()) \
+                                in goal[domain.lower()]['reqt']:
                             if val in NOT_SURE_VALS:
                                 val = '\"' + val + '\"'
                             goal[domain.lower()]['reqt'][
diff --git a/convlab/nlg/template/multiwoz/manual_system_template_nlg.bck.json b/convlab/nlg/template/multiwoz/manual_system_template_nlg.bck.json
deleted file mode 100755
index aa5f6b69615092d5c2bb64e645e048ef081c20fe..0000000000000000000000000000000000000000
--- a/convlab/nlg/template/multiwoz/manual_system_template_nlg.bck.json
+++ /dev/null
@@ -1,1452 +0,0 @@
-{
-    "Attraction-Inform": {
-        "Addr": [
-            "it is located in #ATTRACTION-INFORM-ADDR#",
-            "adress is #ATTRACTION-INFORM-ADDR#",
-            "It is on #ATTRACTION-INFORM-ADDR# .",
-            "their address in our system is listed as #ATTRACTION-INFORM-ADDR# .",
-            "The address is #ATTRACTION-INFORM-ADDR# .",
-            "it 's located at #ATTRACTION-INFORM-ADDR# .",
-            "#ATTRACTION-INFORM-ADDR# is the address",
-            "They are located at #ATTRACTION-INFORM-ADDR# ."
-        ],
-        "none": [
-            "what information would you like about it today ?",
-            "i have their info , what would you like to know ?"
-        ],
-        "Choice": [
-            "I ' ve found #ATTRACTION-INFORM-CHOICE# places for you to go . Do you have any specific ideas in mind ?",
-            "sorry about that , there are actually #ATTRACTION-INFORM-CHOICE# .",
-            "sure , there are #ATTRACTION-INFORM-CHOICE# for you to choose from .",
-            "There are #ATTRACTION-INFORM-CHOICE# . Would you like me to recommend one for you ?",
-            "We have #ATTRACTION-INFORM-CHOICE# of those ! Anything specific you need or just a recommendation ?",
-            "sure , there are #ATTRACTION-INFORM-CHOICE# options for you",
-            "sure , there are #ATTRACTION-INFORM-CHOICE# in that area .",
-            "we have #ATTRACTION-INFORM-CHOICE# options , can i reccomend for you ?",
-            "there are #ATTRACTION-INFORM-CHOICE# , anything in particular you are looking for ?",
-            "We have #ATTRACTION-INFORM-CHOICE# such location ."
-        ],
-        "Post": [
-            "The postcode of the attraction is #ATTRACTION-INFORM-POST# .",
-            "The post code is #ATTRACTION-INFORM-POST# .",
-            "Its postcode is #ATTRACTION-INFORM-POST# .",
-            "Their postcode is #ATTRACTION-INFORM-POST# ."
-        ],
-        "Fee": [
-            "Its entrance fee is #ATTRACTION-INFORM-FEE# .",
-            "The entry fee is #ATTRACTION-INFORM-FEE# .",
-            "their entrance fee is #ATTRACTION-INFORM-FEE# by our system currently ."
-        ],
-        "Name": [
-            "I think a fun place to visit is #ATTRACTION-INFORM-NAME# .",
-            "#ATTRACTION-INFORM-NAME# looks good .",
-            "#ATTRACTION-INFORM-NAME# is available , would that work for you ?",
-            "we have #ATTRACTION-INFORM-NAME# .",
-            "#ATTRACTION-INFORM-NAME# is popular among visitors .",
-            "How about #ATTRACTION-INFORM-NAME# ?",
-            "What about #ATTRACTION-INFORM-NAME# ?",
-            "you might want to try the #ATTRACTION-INFORM-NAME# ."
-        ],
-        "Area": [
-            "That one is located in the #ATTRACTION-INFORM-AREA# .",
-            "it is located in the #ATTRACTION-INFORM-AREA# .",
-            "They are located within the #ATTRACTION-INFORM-AREA# .",
-            "it 's located in the #ATTRACTION-INFORM-AREA# .",
-            "That is in the #ATTRACTION-INFORM-AREA# .",
-            "It will be located in the #ATTRACTION-INFORM-AREA# .",
-            "it is in the #ATTRACTION-INFORM-AREA# of the city",
-            "It is in the #ATTRACTION-INFORM-AREA# ."
-        ],
-        "Phone": [
-            "The attraction phone number is #ATTRACTION-INFORM-PHONE# .",
-            "Here is the attraction phone number , #ATTRACTION-INFORM-PHONE# ."
-        ],
-        "Type": [
-            "It is listed as a #ATTRACTION-INFORM-TYPE# attraction .",
-            "it is a #ATTRACTION-INFORM-TYPE# .",
-            "There are some wonderful #ATTRACTION-INFORM-TYPE# in that area .",
-            "it 's considered a #ATTRACTION-INFORM-TYPE# .",
-            "Would you be interested in visiting a #ATTRACTION-INFORM-TYPE# ?",
-            "It 's a #ATTRACTION-INFORM-TYPE# attraction .",
-            "It is listed as #ATTRACTION-INFORM-TYPE# ."
-        ],
-        "Open": [
-            "#ATTRACTION-INFORM-OPEN# in our database .",
-            "#ATTRACTION-INFORM-OPEN# ."
-        ],
-        "Price": [
-            "it is in the #ATTRACTION-INFORM-PRICE# price range",
-            "The fee is #ATTRACTION-INFORM-PRICE# ."
-        ]
-    },
-    "Attraction-NoOffer": {
-        "none": [
-            "There are no attractions matching that description .",
-            "I ' m sorry but I have not found any matches .",
-            "I ' m sorry there are no matches .",
-            "There are none available at this time .",
-            "I ' m sorry . I ' m not finding any attractions that meet your criteria .",
-            "we do nt have any in that area .",
-            "I do n't have anything meeting that criteria ."
-        ],
-        "Type": [
-            "There are no #ATTRACTION-NOOFFER-TYPE# close to the area you are requesting",
-            "No , I ' m sorry , I am not finding anything with #ATTRACTION-NOOFFER-TYPE# .",
-            "I ' m sorry , but it does n't look like we have a #ATTRACTION-NOOFFER-TYPE# that matches your criteria .",
-            "Unfortunately there are no #ATTRACTION-NOOFFER-TYPE# venues in that location .",
-            "I ' m sorry , I do n't see any #ATTRACTION-NOOFFER-TYPE# attractions in that area of town . Is there anything else you 'd be interested in seeing ?",
-            "There are no #ATTRACTION-NOOFFER-TYPE# in this area .",
-            "I ' m sorry . There are no #ATTRACTION-NOOFFER-TYPE# listed in that area .",
-            "Unfortunately I can not find anything strictly categorized as #ATTRACTION-NOOFFER-TYPE# in that area can you provide more specifications ?",
-            "There are no #ATTRACTION-NOOFFER-TYPE# in that area ."
-        ],
-        "Name": [
-            "There is no listing for #ATTRACTION-NOOFFER-NAME#",
-            "i am sorry but i actually am not finding any information for #ATTRACTION-NOOFFER-NAME# ."
-        ],
-        "Area": [
-            "no such attractions in #ATTRACTION-NOOFFER-AREA#",
-            "I ' m sorry , I could n't find anything like that in the #ATTRACTION-NOOFFER-AREA# .",
-            "sorry , i could n't find anything in the #ATTRACTION-NOOFFER-AREA# .",
-            "I am sorry , I am unable to locate an attraction in #ATTRACTION-NOOFFER-AREA# ?"
-        ],
-        "Fee": [
-            "sorry , i could n't find anything with #ATTRACTION-NOOFFER-Fee# .",
-            "There are no attractions with #ATTRACTION-NOOFFER-Fee# ."
-        ],
-        "Addr": [
-            "I ' m sorry , but I do n't see any attractions at #ATTRACTION-NOOFFER-ADDR# ."
-        ]
-    },
-    "Attraction-Recommend": {
-        "Name": [
-            "we have #ATTRACTION-RECOMMEND-NAME#",
-            "#ATTRACTION-RECOMMEND-NAME# looks good , would you like to head there ?",
-            "Would you like #ATTRACTION-RECOMMEND-NAME# ?",
-            "you would love #ATTRACTION-RECOMMEND-NAME#",
-            "I recommend #ATTRACTION-RECOMMEND-NAME# ",
-            "I would suggest #ATTRACTION-RECOMMEND-NAME# .",
-            "I 'd recommend #ATTRACTION-RECOMMEND-NAME# . Would you like some information on it ?",
-            "You would love #ATTRACTION-RECOMMEND-NAME#",
-            "#ATTRACTION-RECOMMEND-NAME# meets your requirements .",
-            "how about #ATTRACTION-RECOMMEND-NAME# ? they 're pretty fun ."
-        ],
-        "Type": [
-            "I would suggest visiting one of the famous #ATTRACTION-RECOMMEND-TYPE# .",
-            "How about a #ATTRACTION-RECOMMEND-TYPE# ?",
-            "Would a #ATTRACTION-RECOMMEND-TYPE# work for you ?",
-            "It 's an #ATTRACTION-RECOMMEND-TYPE# . Great for the whole family , especially the younger ones !"
-        ],
-        "none": [
-            "It 's a really fun attraction with lots of interesting things to do in it ."
-        ],
-        "Fee": [
-            "Its entrance fee is #ATTRACTION-RECOMMEND-FEE# .",
-            "The entry fee is #ATTRACTION-RECOMMEND-FEE# .",
-            "their entrance fee is #ATTRACTION-RECOMMEND-FEE# by our system currently ."
-        ],
-        "Addr": [
-            "it is located in #ATTRACTION-RECOMMEND-ADDR#",
-            "adress is #ATTRACTION-RECOMMEND-ADDR#",
-            "It is on #ATTRACTION-RECOMMEND-ADDR# .",
-            "their address in our system is listed as #ATTRACTION-RECOMMEND-ADDR# .",
-            "The address is #ATTRACTION-RECOMMEND-ADDR# .",
-            "it 's located at #ATTRACTION-RECOMMEND-ADDR# .",
-            "#ATTRACTION-RECOMMEND-ADDR# is the address",
-            "They are located at #ATTRACTION-RECOMMEND-ADDR# ."
-        ],
-        "Post": [
-            "The postcode is #ATTRACTION-RECOMMEND-POST# .",
-            "The post code is #ATTRACTION-RECOMMEND-POST# .",
-            "Its postcode is #ATTRACTION-RECOMMEND-POST# .",
-            "Their postcode is #ATTRACTION-RECOMMEND-POST# ."
-        ],
-        "Phone": [
-            "The attraction phone number is #ATTRACTION-RECOMMEND-PHONE# .",
-            "Here is the attraction phone number , #ATTRACTION-RECOMMEND-PHONE# ."
-        ],
-        "Area": [
-            "That one is located in the #ATTRACTION-RECOMMEND-AREA# .",
-            "it is located in the #ATTRACTION-RECOMMEND-AREA# .",
-            "They are located within the #ATTRACTION-RECOMMEND-AREA# .",
-            "it 's located in the #ATTRACTION-RECOMMEND-AREA# .",
-            "That is in the #ATTRACTION-RECOMMEND-AREA# .",
-            "It will be located in the #ATTRACTION-RECOMMEND-AREA# .",
-            "it is in the #ATTRACTION-RECOMMEND-AREA# of the city",
-            "It is in the #ATTRACTION-RECOMMEND-AREA# ."
-        ],
-        "Price": [
-            "it is in the #ATTRACTION-RECOMMEND-PRICE# price range",
-            "The fee is #ATTRACTION-RECOMMEND-PRICE# ."
-        ],
-        "Choice": [
-            "I ' ve found #ATTRACTION-RECOMMEND-CHOICE# places for you to go . Do you have any specific ideas in mind ?",
-            "sure , there are #ATTRACTION-RECOMMEND-CHOICE# for you to choose from .",
-            "There are #ATTRACTION-RECOMMEND-CHOICE# . Would you like me to recommend one for you ?",
-            "We have #ATTRACTION-RECOMMEND-CHOICE# of those ! Anything specific you need or just a recommendation ?",
-            "sure , there are #ATTRACTION-RECOMMEND-CHOICE# options for you",
-            "sure , there are #ATTRACTION-RECOMMEND-CHOICE# in that area .",
-            "we have #ATTRACTION-RECOMMEND-CHOICE# options , can i reccomend for you ?",
-            "there are #ATTRACTION-RECOMMEND-CHOICE# , anything in particular you are looking for ?",
-            "We have #ATTRACTION-RECOMMEND-CHOICE# such location ."
-        ],
-        "Open": [
-            "#ATTRACTION-RECOMMEND-OPEN# in our database .",
-            "#ATTRACTION-RECOMMEND-OPEN# ."
-        ]
-    },
-    "Attraction-Request": {
-        "Type": [
-            "What type of attraction are you looking for ?",
-            "Please specify the type of attraction you 're interested in .",
-            "What type of attraction are you interested in ?",
-            "What kind of attraction are you interested in ?",
-            "What kind of attraction were you looking for ?",
-            "you have any particular attraction in mind ?",
-            "what type of attractions are you interested in ?",
-            "What sort of attraction would you like it to be ?"
-        ],
-        "Area": [
-            "Any particular area ?",
-            "is there a certain area of town you would prefer ?",
-            "I have various attractions all over town . Is there a specific area you are wanting to find something to do ?",
-            "Do you have a part of town you prefer ?",
-            "What part of town would you like it",
-            "Do you have a preference for the area of town you wish to visit ?",
-            "Where in town would you like to go ?",
-            "Which part of town would you prefer ?",
-            "Is there a specific area you are looking for ?"
-        ],
-        "Name": [
-            "What is the name of the attraction ?",
-            "What attraction are you thinking about ?",
-            "I ' m sorry for the confusion , what attraction are you interested in ?",
-            "What attraction were you thinking of ?",
-            "Do you know the name of it ?",
-            "can you give me the name of it ?"
-        ],
-        "Price": [
-            "any specific price range to help narrow down available options ?",
-            "What price range would you like ?",
-            "what is your price range for that ?",
-            "What price range are you looking for ?",
-            "What price point is good for you ?",
-            "Does a entrance fee make any difference ?",
-            "Would you like a free entrance fee or paid ?",
-            "Do you need free admission or pay to get in ?",
-            "What price point would you like ?"
-        ]
-    },
-    "Booking-Book": {
-        "Ref": [
-            "Booking was successful . Reference number is : #BOOKING-BOOK-REF# .",
-            "your reference number is #BOOKING-BOOK-REF# .",
-            "Here is the booking information : Booking was successful . Reference number is : #BOOKING-BOOK-REF#",
-            "Reference number is : #BOOKING-BOOK-REF# .",
-            "All set . Your reference number is #BOOKING-BOOK-REF# ."
-        ],
-        "Name": [
-            "the #BOOKING-BOOK-NAME# seems appropriate . i have booked it for you .",
-            "i have booked you #BOOKING-BOOK-NAME#",
-            "I booked you at #BOOKING-BOOK-NAME#",
-            "Your reservation is at #BOOKING-BOOK-NAME# ."
-        ],
-        "none": [
-            "Thanks , booking has been completed .",
-            "Booking was successful .",
-            "Your booking was successful .",
-            "You 're all booked ."
-        ],
-        "Day": [
-            "I was able to book you for #BOOKING-BOOK-DAY# .",
-            "I have reserved for #BOOKING-BOOK-DAY#",
-            "I was able to get you that reservation on #BOOKING-BOOK-DAY# ."
-        ],
-        "Time": [
-            "Would #BOOKING-BOOK-TIME# be a convenient time for you ?"
-        ],
-        "Stay": [
-            "I ' ve booked you for #BOOKING-BOOK-STAY# night .",
-            "I was able to book your reservation for #BOOKING-BOOK-STAY# days ."
-        ],
-        "People": [
-            "I was able to book it for you for #BOOKING-BOOK-PEOPLE# people .",
-            "I have booked for #BOOKING-BOOK-PEOPLE# people .",
-            "I was able to book a reservation for #BOOKING-BOOK-PEOPLE# people ."
-        ]
-    },
-    "Booking-Inform": {
-        "none": [
-            "Shall I try to start and book you into one ?",
-            "I will book it for you and get a reference number ?",
-            "Would you like for me to try and make a reservation ?",
-            "I will go ahead and book that now .",
-            "Can I make a reservation for you ?",
-            "Would you like me to book it ?"
-        ],
-        "Ref": [
-            "Booking was successful . Reference number is : #BOOKING-INFORM-REF#.",
-            "I was able to book it , reference number is #BOOKING-INFORM-REF# ."
-        ],
-        "Name": [
-            "Did you need to book the #BOOKING-INFORM-NAME# ?",
-            "It is #BOOKING-INFORM-NAME# . Do you want a reservation ?",
-            "#BOOKING-INFORM-NAME# . Would you like to book a reservation ?",
-            "Would you like to book the #BOOKING-INFORM-NAME# ?",
-            "Want me to book #BOOKING-INFORM-NAME# ?",
-            "Would you like me to book the #BOOKING-INFORM-NAME# for you ?",
-            "I ' ve located #BOOKING-INFORM-NAME# , would you like me to assist you with booking ?"
-        ],
-        "People": [
-            "Will you be booking for #BOOKING-INFORM-PEOPLE# people ?",
-            "Would you like to book for #BOOKING-INFORM-PEOPLE# people ?",
-            "that was #BOOKING-INFORM-PEOPLE# , correct ?",
-            "i want to confirm this , do i book for #BOOKING-INFORM-PEOPLE# person ?",
-            "So for #BOOKING-INFORM-PEOPLE# people in total ?",
-            "Ok I will book it for you for #BOOKING-INFORM-PEOPLE# people",
-            "I will book that for #BOOKING-INFORM-PEOPLE# people .",
-            "Do you want reservations for #BOOKING-INFORM-PEOPLE# people ?"
-        ],
-        "Day": [
-            "Okay , so you would like the reservation for #BOOKING-INFORM-DAY# ?",
-            "Will you be coming in on #BOOKING-INFORM-DAY# ?",
-            "Would you like this reservation be for #BOOKING-INFORM-DAY# ?",
-            "Do you want the reservations to begin on #BOOKING-INFORM-DAY# ?"
-        ],
-        "Time": [
-            "#BOOKING-INFORM-TIME# is available , would you like me to book that for you ?",
-            "I try to make the reservation for #BOOKING-INFORM-TIME# ."
-        ],
-        "Stay": [
-            "#BOOKING-INFORM-STAY# . Would you like me to book that ?",
-            "For #BOOKING-INFORM-STAY# day ?"
-        ]
-    },
-    "Booking-NoBook": {
-        "none": [
-            "I am unable to book this for you . Do you have any other preferences ?",
-            "Booking was unsuccessful do you have any other preference ?",
-            "I ' m sorry , I was unable to reserve rooms . Would you like to try anything else ?",
-            "I ' m sorry those are not available .",
-            "Unfortunately the booking was not successful ."
-        ],
-        "Day": [
-            "I apologize , but it looks like #BOOKING-NOBOOK-DAY# is not working .",
-            "I ' m sorry #BOOKING-NOBOOK-DAY# is n't working either .",
-            "I ' m sorry but i ' m unable to make the reservation on #BOOKING-NOBOOK-DAY# .",
-            "sorry , but #BOOKING-NOBOOK-DAY# is all booked",
-            "we are currently full on #BOOKING-NOBOOK-DAY# would you like to book at another hotel ?",
-            "I ' m sorry , but there 's nothing available starting on #BOOKING-NOBOOK-DAY# .",
-            "I am unable to book for #BOOKING-NOBOOK-DAY# .",
-            "#BOOKING-NOBOOK-DAY# is not available ."
-        ],
-        "Stay": [
-            "Sorry , the hotel ca n't accommodate you for #BOOKING-NOBOOK-STAY# . want to change dates ?",
-            "Neither is available for #BOOKING-NOBOOK-STAY# nights .",
-            "They do n't have a room available for #BOOKING-NOBOOK-STAY# nights . Anything else you 'd like me to try ?",
-            "Unfortunately it can not be booked for #BOOKING-NOBOOK-STAY# days . Did you want to get information about a different hotel instead ?"
-        ],
-        "Ref": [
-            "Great the booking was successful , your reference number is #BOOKING-NOBOOK-REF#.",
-            "Booking was successful . Reference number is : #BOOKING-NOBOOK-REF# .",
-            "Okay that booking was successful and your reference number is #BOOKING-NOBOOK-REF#."
-        ],
-        "Name": [
-            "Let 's decide on #BOOKING-NOBOOK-NAME# . Unfortunately , that appears to already be booked . Do you want to try one of the others ?"
-        ],
-        "Time": [
-            "I am sorry they do not have a table at #BOOKING-NOBOOK-TIME# , perhaps another restaurant ?"
-        ],
-        "People": [
-            "I am sorry , I am unable to make a reservation for #BOOKING-NOBOOK-PEOPLE# people",
-            "I ' m unable to book for #BOOKING-NOBOOK-PEOPLE# people .",
-            "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ."
-        ]
-    },
-    "Booking-Request": {
-        "Day": [
-            "What day would you like your booking for ?",
-            "What day would you like that reservation ?",
-            "what day would you like the booking to be made for ?",
-            "What day would you like to book ?",
-            "Ok , what day would you like to make the reservation on ?"
-        ],
-        "Stay": [
-            "How many nights will you be staying ?",
-            "And how many nights ?",
-            "for how many days ?",
-            "And for how many days ?",
-            "how many days would you like to stay ?",
-            "How many nights would you like to book it for ?",
-            "And what nights would you like me to reserve for you ?",
-            "How many nights are you wanting to stay ?",
-            "How many days will you be staying ?"
-        ],
-        "People": [
-            "For how many people ?",
-            "How many people will be ?",
-            "How many people will be with you ?",
-            "How many people is the reservation for ?"
-        ],
-        "Time": [
-            "Do you have a time preference ?",
-            "what time are you looking for a reservation at ?",
-            "For what time ?",
-            "What time would you like me to make your reservation ?",
-            "What time would you like the reservation for ?",
-            "what time should I make the reservation for ?",
-            "What time would you prefer ?",
-            "What time would you like the reservation for ?"
-        ]
-    },
-    "Hotel-Inform": {
-        "Internet": [
-            "it has wifi .",
-            "the place provides free wifi .",
-            "the wifi is included .",
-            "There is wifi available at the hotel .",
-            "internet is available .",
-            "it has free wifi ."
-        ],
-        "Stars": [
-            "It is rated #HOTEL-INFORM-STARS# stars ,",
-            "It is rated #HOTEL-INFORM-STARS# stars , is that okay ?",
-            "It has #HOTEL-INFORM-STARS# stars .",
-            "It has a rating of #HOTEL-INFORM-STARS# stars .",
-            "They have a #HOTEL-INFORM-STARS# Star rating",
-            "The hotel is #HOTEL-INFORM-STARS# stars .",
-            "It is a #HOTEL-INFORM-STARS#-star rating .",
-            "it does have #HOTEL-INFORM-STARS# stars ."
-        ],
-        "Name": [
-            "#HOTEL-INFORM-NAME# is available would you like to try that ?",
-            "Does the #HOTEL-INFORM-NAME# work ?",
-            "You can try #HOTEL-INFORM-NAME#",
-            "#HOTEL-INFORM-NAME# is a great place",
-            "how about #HOTEL-INFORM-NAME# ?",
-            "How about #HOTEL-INFORM-NAME# ?",
-            "Okay , how about #HOTEL-INFORM-NAME# ?",
-            "what about #HOTEL-INFORM-NAME# ?",
-            "How about #HOTEL-INFORM-NAME# ?"
-        ],
-        "Area": [
-            "It is in the #HOTEL-INFORM-AREA# area .",
-            "They are located in the #HOTEL-INFORM-AREA# .",
-            "it is in the #HOTEL-INFORM-AREA# .",
-            "It 's located in the #HOTEL-INFORM-AREA# .",
-            "It is in the #HOTEL-INFORM-AREA# part of town .",
-            "It 's in #HOTEL-INFORM-AREA# .",
-            "It is indeed in the #HOTEL-INFORM-AREA# ."
-        ],
-        "Parking": [
-            "It does include free parking .",
-            "they have free parking .",
-            "it offers free parking .",
-            "the parking is free ."
-        ],
-        "Phone": [
-            "The hotel phone number is #HOTEL-INFORM-PHONE# .",
-            "The phone number of the hotel is #HOTEL-INFORM-PHONE# ."
-        ],
-        "Choice": [
-            "i have #HOTEL-INFORM-CHOICE# options for you",
-            "There are #HOTEL-INFORM-CHOICE# of those .",
-            "We have #HOTEL-INFORM-CHOICE# such places .",
-            "I have #HOTEL-INFORM-CHOICE# different options for you !"
-        ],
-        "Addr": [
-            "The hotel address is #HOTEL-INFORM-ADDR# .",
-            "They are located at #HOTEL-INFORM-ADDR#",
-            "It is located at #HOTEL-INFORM-ADDR#"
-        ],
-        "Post": [
-            "The postal code for that hotel is #HOTEL-INFORM-POST# .",
-            "The postcode is #HOTEL-INFORM-POST# .",
-            "their postcode is #HOTEL-INFORM-POST#"
-        ],
-        "none": [
-            "Yes , it fits all those needs .",
-            "Yes it does"
-        ],
-        "Type": [
-            "It is a #HOTEL-INFORM-TYPE# ."
-        ],
-        "Price": [
-            "Its listed as #HOTEL-INFORM-PRICE# .",
-            "It is in the #HOTEL-INFORM-PRICE# price range .",
-            "It is a #HOTEL-INFORM-PRICE# place .",
-            "It is #HOTEL-INFORM-PRICE# .",
-            "This is an #HOTEL-INFORM-PRICE# hotel .",
-            "It is #HOTEL-INFORM-PRICE# priced .",
-            "The price range is #HOTEL-INFORM-PRICE# ."
-        ],
-        "Ref": [
-            "the reference number is #HOTEL-INFORM-REF# .",
-            "Your reference number is #HOTEL-INFORM-REF#.",
-            "You 're all set ! Your reference number is #HOTEL-INFORM-REF# .",
-            "Your reference number is #HOTEL-INFORM-REF#",
-            "The reference number is #HOTEL-INFORM-REF#",
-            "Reference number is : #HOTEL-INFORM-REF# .",
-            "The Reference number is : #HOTEL-INFORM-REF# ."
-        ]
-    },
-    "Hotel-NoOffer": {
-        "Type": [
-            "Sorry there is no #HOTEL-NOOFFER-TYPE# fitting the description you asked for",
-            "I was not able to find any #HOTEL-NOOFFER-TYPE# that met those requirements .",
-            "There are no #HOTEL-NOOFFER-TYPE# that meet that criteria , would you like information about the hotel options ?",
-            "I ' m sorry , I ' m afraid I do n't see any #HOTEL-NOOFFER-TYPE# matching that description . Do you want to try a different price range or star rating ?",
-            "i ca n't find any #HOTEL-NOOFFER-TYPE# that fit your criteria , i ' m sorry .",
-            "It is n't , and unfortunately I do n't have a #HOTEL-NOOFFER-TYPE# that matches that criteria .",
-            "I ' m sorry , there are no #HOTEL-NOOFFER-TYPE# that match your preferences .",
-            "no #HOTEL-NOOFFER-TYPE# meet your criteria ."
-        ],
-        "none": [
-            "Sorry , my search did n't bring back any results .",
-            "I ' m sorry , I can not help you with hotels . Are you sure that 's what you 're looking for ?",
-            "I was unable to find any matching places for that .",
-            "I ' m sorry there are no matches .",
-            "There were no matches found .",
-            "There are no hotels meeting these requirements .",
-            "Nothing fits all of that criteria .",
-            "Sorry , I ' m not finding anything ."
-        ],
-        "Stars": [
-            "I am sorry , but that hotel does not have a #HOTEL-NOOFFER-STARS# star rating , would you like another option ?",
-            "Unfortunately , I could n't find anything with #HOTEL-NOOFFER-STARS# stars .",
-            "I am sorry I have no listings for any with #HOTEL-NOOFFER-STARS# stars .",
-            "I am sorry , there are not #HOTEL-NOOFFER-STARS# stars available .",
-            "I have not found anything with a star of #HOTEL-NOOFFER-STARS# ."
-        ],
-        "Parking": [
-            "I ' m not showing anything in that area of town with no parking ."
-        ],
-        "Area": [
-            "Sorry there are none in the #HOTEL-NOOFFER-AREA# .",
-            "There are none in the #HOTEL-NOOFFER-AREA# . Perhaps another criteria change ?",
-            "I ' m sorry , no , none in the #HOTEL-NOOFFER-AREA# .",
-            "There are n't any that match your criteria in the #HOTEL-NOOFFER-AREA# . Any other suggestions ?",
-            "I have nothing in the #HOTEL-NOOFFER-AREA# . Can I try something else ?"
-        ],
-        "Name": [
-            "I am not finding anything for #HOTEL-NOOFFER-NAME# that suit your needs",
-            "There is no #HOTEL-NOOFFER-NAME# in our system .",
-            "#HOTEL-NOOFFER-NAME# is not available ."
-        ],
-        "Price": [
-            "I ' m sorry , I do n't have anything in the #HOTEL-NOOFFER-PRICE# price range , would you like to search for something else ?",
-            "There is none that is #HOTEL-NOOFFER-PRICE# . Would you like to change your criteria ?",
-            "I ' m sorry , there are no #HOTEL-NOOFFER-PRICE# hotels . Would you like to try searching for something else ?"
-        ],
-        "Internet": [
-            "There does n't appear to be a hotel with wifi .",
-            "I ' m sorry there are no results for hotels with free internet .",
-            "There are not any with wifi ."
-        ]
-    },
-    "Hotel-Recommend": {
-        "Name": [
-            "How about #HOTEL-RECOMMEND-NAME# ? It has all the attributes you requested and a great name !",
-            "How about #HOTEL-RECOMMEND-NAME# ? Fits your request perfectly .",
-            "Would #HOTEL-RECOMMEND-NAME# work for you ?",
-            "Everyone seems to enjoy the #HOTEL-RECOMMEND-NAME# .",
-            "How about #HOTEL-RECOMMEND-NAME# ?",
-            "#HOTEL-RECOMMEND-NAME# looks like it would be a good choice .",
-            "Will #HOTEL-RECOMMEND-NAME# be alright ?",
-            "I would suggest #HOTEL-RECOMMEND-NAME#"
-        ],
-        "Type": [
-            "would a #HOTEL-RECOMMEND-TYPE# be OK ?"
-        ],
-        "Price": [
-            "Its listed as #HOTEL-RECOMMEND-PRICE# .",
-            "It is in the #HOTEL-RECOMMEND-PRICE# price range .",
-            "It is a #HOTEL-RECOMMEND-PRICE# place .",
-            "It is #HOTEL-RECOMMEND-PRICE# .",
-            "This is an #HOTEL-RECOMMEND-PRICE# hotel .",
-            "The price range is #HOTEL-RECOMMEND-PRICE# ."
-        ],
-        "Area": [
-            "It is in the #HOTEL-RECOMMEND-AREA# area .",
-            "They are located in the #HOTEL-RECOMMEND-AREA# .",
-            "it is in the #HOTEL-RECOMMEND-AREA# .",
-            "It 's located in the #HOTEL-RECOMMEND-AREA# .",
-            "It is in the #HOTEL-RECOMMEND-AREA# part of town .",
-            "It 's in #HOTEL-RECOMMEND-AREA# .",
-            "It is indeed in the #HOTEL-RECOMMEND-AREA# ."
-        ],
-        "Addr": [
-            "The address is #HOTEL-RECOMMEND-ADDR# .",
-            "They are located at #HOTEL-RECOMMEND-ADDR#",
-            "The address is #HOTEL-RECOMMEND-ADDR# .",
-            "It is located at #HOTEL-RECOMMEND-ADDR#"
-        ],
-        "Post": [
-            "The postal code for that hotel is #HOTEL-RECOMMEND-POST# .",
-            "The postcode is #HOTEL-RECOMMEND-POST# .",
-            "their postcode is #HOTEL-RECOMMEND-POST#"
-        ],
-        "Internet": [
-            "it has wifi .",
-            "the place provides free wifi .",
-            "the wifi is included .",
-            "There is wifi available at the hotel .",
-            "internet is available .",
-            "it has free wifi ."
-        ],
-        "Parking": [
-            "It does include free parking .",
-            "they have free parking .",
-            "it offers free parking .",
-            "the parking is free ."
-        ],
-        "Stars": [
-            "It is rated #HOTEL-RECOMMEND-STARS# stars ,",
-            "It is rated #HOTEL-RECOMMEND-STARS# stars , is that okay ?",
-            "It has #HOTEL-RECOMMEND-STARS# stars .",
-            "It has a rating of #HOTEL-RECOMMEND-STARS# stars .",
-            "it has #HOTEL-RECOMMEND-STARS# star rating .",
-            "They have a #HOTEL-RECOMMEND-STARS# Star rating",
-            "The hotel is #HOTEL-RECOMMEND-STARS# stars .",
-            "It is a #HOTEL-RECOMMEND-STARS#-star rating .",
-            "it does have #HOTEL-RECOMMEND-STARS# stars ."
-        ],
-        "none": [
-            "I would suggest that place .",
-            "May I recommend something to you ?",
-            "would you like me to make a recommendation ?",
-            "would you like a recommendation ?"
-        ],
-        "Phone": [
-            "The hotel phone number is #HOTEL-RECOMMEND-PHONE# .",
-            "The phone number of the hotel is #HOTEL-RECOMMEND-PHONE# ."
-        ],
-        "Choice": [
-            "i have #HOTEL-RECOMMEND-CHOICE# options for you",
-            "There are #HOTEL-RECOMMEND-CHOICE# of those .",
-            "We have #HOTEL-RECOMMEND-CHOICE# such places .",
-            "I have #HOTEL-RECOMMEND-CHOICE# different options for you !"
-        ]
-    },
-    "Hotel-Request": {
-        "Area": [
-            "Okay , do you have a specific area you want to stay in ?",
-            "What area are you looking to stay in ?",
-            "Remind me of the area you need that in .",
-            "Do you have an idea on the location ?",
-            "What area would you like the hotel in ?",
-            "Is there a specific area of town you 're interested in ?",
-            "What area of town would you like to be in ?",
-            "What area would you like to stay in ?",
-            "Let me get some additional information . What area of town would you like to stay in ?"
-        ],
-        "Price": [
-            "What is your price range ?",
-            "What price range were you thinking ?",
-            "What price range do you prefer ?",
-            "Okay , do you have any price range you 're looking for ?",
-            "Is there a price you are looking for ?",
-            "Do you have a price range preference ?",
-            "Is there a price range you prefer ?",
-            "do you have a price range preference ?",
-            "what price range are you looking for ?",
-            "What is the price range for you ?"
-        ],
-        "Type": [
-            "Would you like a guesthouse or a hotel ?",
-            "Would you like to stay in a guesthouse , or in a hotel ?",
-            "would you like a guesthouse or hotel ?",
-            "Okay , were you looking for a hotel or a guesthouse ?",
-            "Do you have a preference for a hotel versus a guesthouse ?",
-            "Would you like to try a hotel ?",
-            "Were you looking for a hotel with a guesthouse ?",
-            "What type of hotel are you looking for ?"
-        ],
-        "Internet": [
-            "Would you prefer one with free internet ?",
-            "do you need free internet ?",
-            "Would you prefer one with internet ?"
-        ],
-        "Name": [
-            "do you have the name of the hotel ?",
-            "Which hotel is it ?",
-            "do you know what you 're looking for ?",
-            "Do you have the hotel name ?",
-            "What is the name of the hotel you are looking for ?",
-            "could you please give me the name of the Hotel you are looking for ?",
-            "Can you give me the name of the place ?",
-            "What is the name of the hotel you 'd like to book ?",
-            "Would you like to tell me the name of that hotel ?"
-        ],
-        "Stars": [
-            "How many stars would you like ?",
-            "Is there a number of stars you prefer ?",
-            "Is there a certain star rating you would like it to have ?",
-            "Do you have a preference of star rating ?",
-            "Do you have a preference of number of stars ?",
-            "How many stars should the hotel be rated for ?",
-            "What star rating do you prefer ?",
-            "How many stars are you looking for ?"
-        ],
-        "Parking": [
-            "Will you need parking ?",
-            "Do you need it to have free parking ?",
-            "Do you have a parking preference ?",
-            "Does the hotel needs to have free parking ?",
-            "Will you need free parking ?",
-            "Do you need free parking ?",
-            "Will you need parking while you 're there ?",
-            "Will you be needing free parking ?"
-        ]
-    },
-    "Restaurant-Inform": {
-        "Addr": [
-            "they are located at #RESTAURANT-INFORM-ADDR#",
-            "It is at #RESTAURANT-INFORM-ADDR#",
-            "The restaurant address is #RESTAURANT-INFORM-ADDR# .",
-            "Their address is #RESTAURANT-INFORM-ADDR# ."
-        ],
-        "Price": [
-            "They are in the #RESTAURANT-INFORM-PRICE# price range .",
-            "It is a #RESTAURANT-INFORM-PRICE# restaurant .",
-            "This is a #RESTAURANT-INFORM-PRICE# one .",
-            "They are in the #RESTAURANT-INFORM-PRICE# price range .",
-            "They are #RESTAURANT-INFORM-PRICE#",
-            "It 's in the #RESTAURANT-INFORM-PRICE# price range .",
-            "This restaurant is in the #RESTAURANT-INFORM-PRICE# price range ."
-        ],
-        "Food": [
-            "They serve #RESTAURANT-INFORM-FOOD# food .",
-            "They serve #RESTAURANT-INFORM-FOOD# .",
-            "It is #RESTAURANT-INFORM-FOOD# food .",
-            "That is a #RESTAURANT-INFORM-FOOD# restaurant ."
-        ],
-        "Name": [
-            "How about #RESTAURANT-INFORM-NAME# ?",
-            "#RESTAURANT-INFORM-NAME# looks like a good place .",
-            "How does the #RESTAURANT-INFORM-NAME# sound ?",
-            "Okay , how about #RESTAURANT-INFORM-NAME# ?",
-            "Would you like to try #RESTAURANT-INFORM-NAME# ?",
-            "There is a restaurant called #RESTAURANT-INFORM-NAME# that meets your criteria .",
-            "How about #RESTAURANT-INFORM-NAME# ?",
-            "there 's a place called #RESTAURANT-INFORM-NAME#",
-            "Would you like to try #RESTAURANT-INFORM-NAME# ?"
-        ],
-        "Choice": [
-            "I have #RESTAURANT-INFORM-CHOICE# different restaurants I can give you some information for . They are all pretty good .",
-            "there are #RESTAURANT-INFORM-CHOICE# different places that match your description .",
-            "There are #RESTAURANT-INFORM-CHOICE# restaurants in that area that fit that criteria .",
-            "i have #RESTAURANT-INFORM-CHOICE# options for you",
-            "We have #RESTAURANT-INFORM-CHOICE# such places .",
-            "I have #RESTAURANT-INFORM-CHOICE# options for you !",
-            "there are #RESTAURANT-INFORM-CHOICE# available restaurants ."
-        ],
-        "Post": [
-            "The restaurant postcode is #RESTAURANT-INFORM-POST# .",
-            "Their postcode is #RESTAURANT-INFORM-POST#",
-            "The post code is #RESTAURANT-INFORM-POST# ."
-        ],
-        "Phone": [
-            "The number of the restaurant is #RESTAURANT-INFORM-PHONE# .",
-            "The restaurant's phone number is #RESTAURANT-INFORM-PHONE# .",
-            "The phone number of the restaurant is #RESTAURANT-INFORM-PHONE# .",
-            "#RESTAURANT-INFORM-PHONE# is the restaurant phone number"
-        ],
-        "Area": [
-            "it is in the #RESTAURANT-INFORM-AREA# area .",
-            "It is located in the #RESTAURANT-INFORM-AREA# ."
-        ],
-        "Ref": [
-            "I was able to get you into that restaurant and your reference number is #RESTAURANT-INFORM-REF#.",
-            "The reference number is #RESTAURANT-INFORM-REF# .",
-            "I have a reference number for you . It is #RESTAURANT-INFORM-REF# .",
-            "the reference number is #RESTAURANT-INFORM-REF#.",
-            "The reference number is #RESTAURANT-INFORM-REF#",
-            "Reference number is : #RESTAURANT-INFORM-REF#.",
-            "Your reference number is #RESTAURANT-INFORM-REF# . You will likely need it .",
-            "Your reference code is #RESTAURANT-INFORM-REF# .",
-            "Your reference number is #RESTAURANT-INFORM-REF# ."
-        ],
-        "none": [
-            "Is there anything else I can help you with ?",
-            "I have found that restaurant for you",
-            "It 's a perfect fit .",
-            "That 's a great place ."
-        ]
-    },
-    "Restaurant-NoOffer": {
-        "Food": [
-            "There are no #RESTAURANT-NOOFFER-FOOD# food places , shall I run another search ?",
-            "I do not have anything in that price range for #RESTAURANT-NOOFFER-FOOD# . Another criteria perhaps ?",
-            "I am sorry there are not #RESTAURANT-NOOFFER-FOOD# restaurants . Can I check for a Chinese restaurant ?",
-            "I am unable to find any #RESTAURANT-NOOFFER-FOOD# restaurants in town .",
-            "I have nothing with #RESTAURANT-NOOFFER-FOOD# . Do you have another preference ?",
-            "There are no #RESTAURANT-NOOFFER-FOOD# restaurants .",
-            "I did not find any #RESTAURANT-NOOFFER-FOOD# restaurants .",
-            "There no #RESTAURANT-NOOFFER-FOOD# restaurants that I can find right now . Would something else work ?",
-            "I ' m sorry I have no restaurants serving #RESTAURANT-NOOFFER-FOOD# food .",
-            "There are no #RESTAURANT-NOOFFER-FOOD# restaurants unfortunately ."
-        ],
-        "none": [
-            "I do n't have anything meeting that criteria . Can I look for something else ?",
-            "we do nt have a place that matches those qualities . can you try something else ?",
-            "I am afraid there is none .",
-            "We do n't have any of those , sad to say . Want to broaden the search ?",
-            "There are no matching records found for that request .",
-            "No , I ' m sorry . The search did n't pull up any matches .",
-            "There is no listing for this restaurant"
-        ],
-        "Area": [
-            "Sorry , there are no restaurants like that in the #RESTAURANT-NOOFFER-AREA# .",
-            "I did not find any restaurants in #RESTAURANT-NOOFFER-AREA# .",
-            "I am sorry there is none even in the #RESTAURANT-NOOFFER-AREA#",
-            "I am sorry there are no restaurants in #RESTAURANT-NOOFFER-AREA# that match that description .",
-            "There are none in #RESTAURANT-NOOFFER-AREA# of town .",
-            "I am sorry but there are no restaurants that fit that criteria in the #RESTAURANT-NOOFFER-AREA# .",
-            "i have n't found any in the #RESTAURANT-NOOFFER-AREA#",
-            "there no such restraunts in #RESTAURANT-NOOFFER-AREA#"
-        ],
-        "Price": [
-            "I do n't have anything in the #RESTAURANT-NOOFFER-PRICE# range that fits that criteria .",
-            "There are none in #RESTAURANT-NOOFFER-PRICE# , perhaps something else ?",
-            "There are no #RESTAURANT-NOOFFER-PRICE# ones .",
-            "No #RESTAURANT-NOOFFER-PRICE# restaurant"
-        ],
-        "Name": [
-            "i ' m sorry . i can not find details for #RESTAURANT-NOOFFER-NAME# ."
-        ]
-    },
-    "Restaurant-Recommend": {
-        "Name": [
-            "How about #RESTAURANT-RECOMMEND-NAME# ?",
-            "#RESTAURANT-RECOMMEND-NAME# has some great reviews .",
-            "I would suggest #RESTAURANT-RECOMMEND-NAME# .",
-            "The #RESTAURANT-RECOMMEND-NAME# is a nice place would you like to try that one ?",
-            "Excellent . #RESTAURANT-RECOMMEND-NAME# is just your thing .",
-            "I would suggest #RESTAURANT-RECOMMEND-NAME# .",
-            "#RESTAURANT-RECOMMEND-NAME# sounds like it might be what you are looking for .",
-            "#RESTAURANT-RECOMMEND-NAME# matches your description .",
-            "I have a place called #RESTAURANT-RECOMMEND-NAME# , does that sound like something you would enjoy ?",
-            "I recommend #RESTAURANT-RECOMMEND-NAME#"
-        ],
-        "Food": [
-            "Would you like #RESTAURANT-RECOMMEND-FOOD# food ?",
-            "Would you like to try an #RESTAURANT-RECOMMEND-FOOD# restaurant ?",
-            "How about #RESTAURANT-RECOMMEND-FOOD# ?",
-            "Okay , may I suggest #RESTAURANT-RECOMMEND-FOOD# food ?"
-        ],
-        "Price": [
-            "They are in the #RESTAURANT-RECOMMEND-PRICE# price range .",
-            "It is a #RESTAURANT-RECOMMEND-PRICE# restaurant .",
-            "This is a #RESTAURANT-RECOMMEND-PRICE# one .",
-            "They are in the #RESTAURANT-RECOMMEND-PRICE# price range .",
-            "They are #RESTAURANT-RECOMMEND-PRICE#",
-            "It 's in the #RESTAURANT-RECOMMEND-PRICE# price range .",
-            "This restaurant is in the #RESTAURANT-RECOMMEND-PRICE# price range ."
-        ],
-        "Area": [
-            "it is in the #RESTAURANT-RECOMMEND-AREA# area .",
-            "It is located in the #RESTAURANT-RECOMMEND-AREA# ."
-        ],
-        "Addr": [
-            "they are located at #RESTAURANT-RECOMMEND-ADDR#",
-            "It is at #RESTAURANT-RECOMMEND-ADDR#",
-            "The address is #RESTAURANT-RECOMMEND-ADDR# .",
-            "Their address is #RESTAURANT-RECOMMEND-ADDR# ."
-        ],
-        "Post": [
-            "The postcode is #RESTAURANT-RECOMMEND-POST# .",
-            "Their postcode is #RESTAURANT-RECOMMEND-POST#",
-            "The post code is #RESTAURANT-RECOMMEND-POST# ."
-        ],
-        "Phone": [
-            "The number of the restaurant is #RESTAURANT-RECOMMEND-PHONE# .",
-            "The restaurant's phone number is #RESTAURANT-RECOMMEND-PHONE#",
-            "The phone number of the restaurant is #RESTAURANT-RECOMMEND-PHONE# .",
-            "#RESTAURANT-RECOMMEND-PHONE# is the restaurant phone number"
-        ],
-        "none": [
-            "Is there anything else I can help you with ?",
-            "I have found that restaurant for you",
-            "It 's a perfect fit .",
-            "That 's a great place ."
-        ]
-    },
-    "Restaurant-Request": {
-        "Area": [
-            "what location ?",
-            "What area of town would you prefer ?",
-            "Do you have a specific area in mind ?",
-            "We need some more information . Where are looking to eat ?",
-            "Do you have an area of town you prefer ?",
-            "Which side of town would you prefer ?",
-            "What area should the restaurant be in ?",
-            "Do you have an area preference ?",
-            "Do you have a location preference ?"
-        ],
-        "Food": [
-            "Do you have any specific type of food you would like ?",
-            "What type of food are you looking for ?",
-            "Do you have a preference in food type ?",
-            "What cuisine are you interested in ?",
-            "What type of food would you like ?",
-            "What type of food do you want to eat ?",
-            "Is there a certain kind of food you would like ?",
-            "what type of food would you like ?",
-            "what type of food would you like to eat ?",
-            "did you have a specific kind of cuisine in mind ?"
-        ],
-        "Price": [
-            "is there a price range that you prefer ?",
-            "Do you have a price range ?",
-            "what price range would you like to stay within ?",
-            "Do you have a preference for the price range ?",
-            "Do you have a certain price range you would like ?",
-            "Did you have a price range in mind ?",
-            "what is the price range you are looking for ?",
-            "what price range are you looking for ?",
-            "Do you have a price range in mind ?",
-            "Is there a price range you would prefer to stay within ?"
-        ],
-        "Name": [
-            "Do you know the name ?",
-            "what is the name of the restaurant ?",
-            "What 's the name of the restaurant you 're looking for ?",
-            "what is the name of the restaurant ?",
-            "What is the name of the restaurant you have in mind ?",
-            "Are you looking for something in particular ?",
-            "what is the name of the restaurant you are needing information on ?",
-            "Do you know the name of the location ?",
-            "Is there a certain restaurant you 're looking for ?"
-        ]
-    },
-    "Taxi-Inform": {
-        "Arrive": [
-            "it will arrive by #TAXI-INFORM-ARRIVE#",
-            "the taxi is due to arrive at #TAXI-INFORM-ARRIVE# .",
-            "I can book one for the closest time to #TAXI-INFORM-ARRIVE# .",
-            "you will arrive by #TAXI-INFORM-ARRIVE# ."
-        ],
-        "none": [
-            "Your taxi will be available and has been booked .",
-            "I have booked you a Taxi that fits your needs .",
-            "The Taxi has been booked as requested .",
-            "you have been assigned a specific car .",
-            "I have booked your taxi",
-            "Okay I completed a booking for you"
-        ],
-        "Car": [
-            "The model of the car was #TAXI-INFORM-CAR# .",
-            "A #TAXI-INFORM-CAR# is booked for you .",
-            "The taxi is a #TAXI-INFORM-CAR#",
-            "a #TAXI-INFORM-CAR# is booked .",
-            "I was able to book a #TAXI-INFORM-CAR# for you !",
-            "it will be a #TAXI-INFORM-CAR# .",
-            "Your booking is complete , a #TAXI-INFORM-CAR# will be picking you up .",
-            "The car arriving will be a #TAXI-INFORM-CAR# .",
-            "You got the #TAXI-INFORM-CAR# enjoy the ride .",
-            "I have successfully booked you a #TAXI-INFORM-CAR# ."
-        ],
-        "Phone": [
-            "The contact number is #TAXI-INFORM-PHONE# .",
-            "Contact number for the taxi is #TAXI-INFORM-PHONE# .",
-            "their contact number is #TAXI-INFORM-PHONE#",
-            "the contact is #TAXI-INFORM-PHONE# .",
-            "You can give them a call at #TAXI-INFORM-PHONE#",
-            "The contact number is #TAXI-INFORM-PHONE# .",
-            "you can reach them on #TAXI-INFORM-PHONE#"
-        ],
-        "Leave": [
-            "It will leave at #TAXI-INFORM-LEAVE# ."
-        ],
-        "Dest": [
-            "it 's taking you to the #TAXI-INFORM-DEST# in time for your reservation .",
-            "I have booked a taxi to take you to #TAXI-INFORM-DEST# for your reservation ."
-        ],
-        "Depart": [
-            "It will pick you up at #TAXI-INFORM-DEPART# .",
-            "Okay , I ' ve booked you a taxi leaving #TAXI-INFORM-DEPART# .",
-            "I have booked you a taxi from #TAXI-INFORM-DEPART# .",
-            "I have booked a taxi to pick you up at #TAXI-INFORM-DEPART# ."
-        ]
-    },
-    "Taxi-Request": {
-        "Depart": [
-            "Where will you be departing from ?",
-            "Where will you leave from ?",
-            "I need to know where you 'd like picked up at .",
-            "Where are you departing from ?",
-            "Where are you departing from , please ?",
-            "where are you leaving from ?",
-            "Where will you be leaving from ?"
-        ],
-        "Dest": [
-            "what will your destination be ?",
-            "Where will the taxi be taking you ?",
-            "I need to know where you are traveling to",
-            "what is your destination please",
-            "What is the destination ?",
-            "Where are you going ?",
-            "And where will you be going ?",
-            "Where will you be going to ?"
-        ],
-        "Leave": [
-            "What time would you like the taxi to pick you up ?",
-            "What time would you like to leave ?",
-            "What time would you like to be picked up ?",
-            "what time do you want to leave by ?",
-            "What time would you like to be picked up ?",
-            "What time do you need to book a taxi for ?",
-            "When would you like to leave ?",
-            "what time will you be leaving .",
-            "Can you tell me what time you would like the taxi to pick you up ?",
-            "I need to know what time you need to leave ."
-        ],
-        "Arrive": [
-            "at what time would you like to arrive by ?",
-            "when would you like to arrive ?",
-            "What time would you like to arrive ?",
-            "Do you have a arrival time in mind ?",
-            "When is your arrival time ?",
-            "when would you like to arrive ?",
-            "when would you like to arrive by ?",
-            "What time would you like to arrive there by ?",
-            "What time would you like to arrive at your destination ?"
-        ]
-    },
-    "Train-Inform": {
-        "Ticket": [
-            "The fare is #TRAIN-INFORM-TICKET# per ticket .",
-            "The price of those tickets are #TRAIN-INFORM-TICKET# .",
-            "It is #TRAIN-INFORM-TICKET#",
-            "The price is #TRAIN-INFORM-TICKET# per ticket .",
-            "The price is #TRAIN-INFORM-TICKET# .",
-            "The trip will cost #TRAIN-INFORM-TICKET# .",
-            "The fare is #TRAIN-INFORM-TICKET#",
-            "It would cost #TRAIN-INFORM-TICKET# .",
-            "The cost of the one way journey is #TRAIN-INFORM-TICKET# .",
-            "The price is #TRAIN-INFORM-TICKET# per ticket ."
-        ],
-        "Leave": [
-            "it leaves at #TRAIN-INFORM-LEAVE#",
-            "I have a train leaving at #TRAIN-INFORM-LEAVE# would that be okay ?",
-            "There is a train that leaves at #TRAIN-INFORM-LEAVE# .",
-            "How about #TRAIN-INFORM-LEAVE# will that work for you ?",
-            "There is a train meeting your criteria and is leaving at #TRAIN-INFORM-LEAVE# .",
-            "I would say you should leave by #TRAIN-INFORM-LEAVE#"
-        ],
-        "Id": [
-            "The train ID is #TRAIN-INFORM-ID# .",
-            "Their ID is #TRAIN-INFORM-ID# .",
-            "#TRAIN-INFORM-ID# would be your perfect fit ."
-        ],
-        "Ref": [
-            "The reference number is #TRAIN-INFORM-REF# .",
-            "The reference number of your trip is #TRAIN-INFORM-REF# .",
-            "Here is the reference number #TRAIN-INFORM-REF#",
-            "your reference number is #TRAIN-INFORM-REF#"
-        ],
-        "Time": [
-            "The travel time is #TRAIN-INFORM-TIME# .",
-            "That would be #TRAIN-INFORM-TIME# .",
-            "#TRAIN-INFORM-TIME# would be the total duration .",
-            "The trip will last #TRAIN-INFORM-TIME#",
-            "the travel will take #TRAIN-INFORM-TIME#",
-            "The trip is #TRAIN-INFORM-TIME# .",
-            "The travel time for the trip is #TRAIN-INFORM-TIME# one way ."
-        ],
-        "Arrive": [
-            "It arrives at #TRAIN-INFORM-ARRIVE# .",
-            "it should arrive by #TRAIN-INFORM-ARRIVE#",
-            "The arrival time is #TRAIN-INFORM-ARRIVE# ."
-        ],
-        "Depart": [
-            "that train is departing from #TRAIN-INFORM-DEPART# .",
-            "it departs from #TRAIN-INFORM-DEPART# .",
-            "the train will be departing from #TRAIN-INFORM-DEPART# ."
-        ],
-        "Choice": [
-            "I have #TRAIN-INFORM-CHOICE# trains that meet your criteria .",
-            "There are #TRAIN-INFORM-CHOICE# .",
-            "There are #TRAIN-INFORM-CHOICE# options",
-            "There are #TRAIN-INFORM-CHOICE# trains available .",
-            "There are #TRAIN-INFORM-CHOICE# total trips available to you"
-        ],
-        "none": [
-            "Yes it is .",
-            "Yes it does ."
-        ],
-        "Day": [
-            "The train is for #TRAIN-INFORM-DAY# you are all set",
-            "that train leaves on #TRAIN-INFORM-DAY# ."
-        ],
-        "Dest": [
-            "the booking is for arriving in #TRAIN-INFORM-DEST# .",
-            "the train stop is #TRAIN-INFORM-DEST# ."
-        ],
-        "People": [
-            "I booked #TRAIN-INFORM-PEOPLE# tickets ."
-        ]
-    },
-    "Train-NoOffer": {
-        "Depart": [
-            "There is no train leaving #TRAIN-NOOFFER-DEPART# .",
-            "There are no trains leaving from #TRAIN-NOOFFER-DEPART# ."
-        ],
-        "none": [
-            "I ' m sorry , unfortunately there are no trains available at that time .",
-            "I do not have any trains that match your request .",
-            "There is no train leaving at that time .",
-            "I ' m sorry , there are no trains that meet your criteria ."
-        ],
-        "Leave": [
-            "There is not one that leaves at #TRAIN-NOOFFER-LEAVE# .",
-            "There is no train leaving at #TRAIN-NOOFFER-LEAVE# .",
-            "There are no trains that leave after #TRAIN-NOOFFER-LEAVE# .",
-            "There are no trains leaving at #TRAIN-NOOFFER-LEAVE# .",
-            "Unfortunately , there are no trains that leave after #TRAIN-NOOFFER-LEAVE# ."
-        ],
-        "Day": [
-            "No I ' m sorry there are not on a #TRAIN-NOOFFER-DAY# .",
-            "There are no trains on #TRAIN-NOOFFER-DAY# .",
-            "Unfortunately , there is not a train on #TRAIN-NOOFFER-DAY# matching your criteria .",
-            "There are no trains on #TRAIN-NOOFFER-DAY# ."
-        ],
-        "Dest": [
-            "I ' m sorry there are no trains to #TRAIN-NOOFFER-Dest# .",
-            "There are no trains going to #TRAIN-NOOFFER-Dest# . Do you have another destination in mind ?",
-            "There do n't seem to be any trains going to #TRAIN-NOOFFER-Dest# ."
-        ],
-        "Arrive": [
-            "I am sorry there are not trains to arrive at #TRAIN-NOOFFER-ARRIVE# .",
-            "There are no trains arriving by #TRAIN-NOOFFER-ARRIVE# .",
-            "I ' m sorry , we do n't have any trains arriving by #TRAIN-NOOFFER-ARRIVE# .",
-            "no trains arrives by #TRAIN-NOOFFER-ARRIVE#",
-            "I do not have a train arriving at #TRAIN-NOOFFER-ARRIVE# ."
-        ],
-        "Id": [
-            "Sorry it looks like #TRAIN-NOOFFER-ID# is no longer available .",
-            "I ' m sorry , train #TRAIN-NOOFFER-ARRIVE# is not available ."
-        ]
-    },
-    "Train-OfferBook": {
-        "none": [
-            "Ok I will book that for you and get you a confirmation number",
-            "Ok great . Would you like for me to go ahead and book the train for you ?",
-            "Would you like me to reserve your train tickets ?",
-            "I will book that for you now .",
-            "Great , I have a train that meets your criteria . Would you like me to book it for you ?",
-            "Would you like me to book this train for you ?",
-            "Well , I can book it for YOU if you would like . Would you like me to ?",
-            "Did you want reservations ?"
-        ],
-        "Depart": [
-            "Would you like me to book a train from #TRAIN-OFFERBOOK-DEPART# for you ?"
-        ],
-        "Id": [
-            "Train #TRAIN-OFFERBOOK-ID# would work for you .",
-            "Okay the trainID is #TRAIN-OFFERBOOK-ID# ",
-            "The #TRAIN-OFFERBOOK-ID# meets your criteria .",
-            "Shall I go ahead and book you for train #TRAIN-OFFERBOOK-ID# ?",
-            "Okay , would you like me to book #TRAIN-OFFERBOOK-ID# ?",
-            "#TRAIN-OFFERBOOK-ID# looks like it would work for you",
-            "I would suggest booking #TRAIN-OFFERBOOK-ID#",
-            "Train #TRAIN-OFFERBOOK-ID# meets your criteria"
-        ],
-        "Arrive": [
-            "Will you train arriving at #TRAIN-OFFERBOOK-ARRIVE# work ok for you ?",
-            "I have a train arriving at #TRAIN-OFFERBOOK-ARRIVE# . Would that do ?",
-            "It arrives by #TRAIN-OFFERBOOK-ARRIVE# would that be okay to book for you ?",
-            "I can get you tickets for an arrival time at #TRAIN-OFFERBOOK-ARRIVE# , is that okay ?",
-            "Will an arrival time of #TRAIN-OFFERBOOK-ARRIVE# work for you ?",
-            "Would you like me to book the #TRAIN-OFFERBOOK-ARRIVE# train ?"
-        ],
-        "People": [
-            "I will book it for you for #TRAIN-OFFERBOOK-PEOPLE# people",
-            "I will go ahead and book #TRAIN-OFFERBOOK-PEOPLE# tickets .",
-            "i just want to confirm if i am booking #TRAIN-OFFERBOOK-PEOPLE# ticket",
-            "I will book #TRAIN-OFFERBOOK-PEOPLE# tickets for you .",
-            "I will book it for #TRAIN-OFFERBOOK-PEOPLE# people ."
-        ],
-        "Ticket": [
-            "The price is #TRAIN-OFFERBOOK-TICKET# per ticket ",
-            "The cost per seat is #TRAIN-OFFERBOOK-TICKET# .",
-            "The price of the ticket is #TRAIN-OFFERBOOK-TICKET# "
-        ],
-        "Leave": [
-            "Would you like me to book you on the #TRAIN-OFFERBOOK-LEAVE# train ?",
-            "There is a train that leaves at #TRAIN-OFFERBOOK-LEAVE# would you like me to book that train for you ?",
-            "Okay ! How about the train that leaves at #TRAIN-OFFERBOOK-LEAVE# ?",
-            "I would recommend the train that leaves at #TRAIN-OFFERBOOK-LEAVE# . Would you like me to book that ?",
-            "There is a train arriving at #TRAIN-OFFERBOOK-LEAVE# would you like me to book tickets for that one ?",
-            "The earliest train is at #TRAIN-OFFERBOOK-LEAVE# , do you want me to book it ?",
-            "There is a train leaving at #TRAIN-OFFERBOOK-LEAVE# would you like me to book this ?",
-            "Great , I have a train leaving there at #TRAIN-OFFERBOOK-LEAVE# . Would you like to book that ?",
-            "I can book a #TRAIN-OFFERBOOK-LEAVE# for you .",
-            "We can book you for the train leaving at #TRAIN-OFFERBOOK-LEAVE# ."
-        ],
-        "Choice": [
-            "I have #TRAIN-OFFERBOOK-CHOICE# trains available that meet all of your requirements , would you like me to book a ticket for you ?",
-            "There are #TRAIN-OFFERBOOK-CHOICE# trains available . Should I book a train for you ?"
-        ],
-        "Dest": [
-            "Would you like me to book a train to #TRAIN-OFFERBOOK-DEST# for you ?"
-        ],
-        "Time": [
-            "The travel time is #TRAIN-OFFERBOOK-TIME# , would you like me to book it for you ?"
-        ],
-        "Ref": [
-            "Your reference number is #TRAIN-OFFERBOOK-REF# ."
-        ],
-        "Day": [
-            "Would you like to take the train on #TRAIN-OFFERBOOK-DAY# ?",
-            "I can book you on #TRAIN-OFFERBOOK-DAY#",
-            "I can book your tickets for #TRAIN-OFFERBOOK-DAY# ."
-        ]
-    },
-    "Train-OfferBooked": {
-        "Ref": [
-            "I ' ve booked your train tickets , and your reference number is #TRAIN-OFFERBOOKED-REF#.",
-            "Your booking was successful . Your reference number is #TRAIN-OFFERBOOKED-REF# .",
-            "I have made those reservations and your reference number is #TRAIN-OFFERBOOKED-REF# ."
-        ],
-        "none": [
-            "Wonderful , your train has been booked !",
-            "your reservation is booked",
-            "ok , i got that fixed for you .",
-            "Booking was successful",
-            "Great your booking is all set !",
-            "You have been booked !"
-        ],
-        "Ticket": [
-            "The cost of your ticket will be #TRAIN-OFFERBOOKED-TICKET#",
-            "the total fee is #TRAIN-OFFERBOOKED-TICKET# payable at the station",
-            "It is #TRAIN-OFFERBOOKED-TICKET# .",
-            "Booking was successful , the total fee is #TRAIN-OFFERBOOKED-TICKET# payable at the station ."
-        ],
-        "Id": [
-            "the train ID is #TRAIN-OFFERBOOKED-ID# .",
-            "Ok . You should be set . The booking was successful . The train number is #TRAIN-OFFERBOOKED-ID# ."
-        ],
-        "People": [
-            "I have successfully make a booking for #TRAIN-OFFERBOOKED-PEOPLE# on that train .",
-            "I have booked you #TRAIN-OFFERBOOKED-PEOPLE# tickets ."
-        ],
-        "Leave": [
-            "i have booked you one leaving at #TRAIN-OFFERBOOKED-LEAVE# ."
-        ],
-        "Time": [
-            "The travel time is #TRAIN-OFFERBOOKED-TIME# .",
-            "That would be #TRAIN-OFFERBOOKED-TIME# .",
-            "#TRAIN-OFFERBOOKED-TIME# would be the total duration .",
-            "The trip will last #TRAIN-OFFERBOOKED-TIME#",
-            "the travel will take #TRAIN-OFFERBOOKED-TIME#",
-            "The trip is #TRAIN-OFFERBOOKED-TIME# .",
-            "The travel time for the trip is #TRAIN-OFFERBOOKED-TIME# one way ."
-        ],
-        "Day": [
-            "The train is for #TRAIN-OFFERBOOKED-DAY# you are all set",
-            "that train leaves on #TRAIN-OFFERBOOKED-DAY# ."
-        ],
-        "Depart": [
-            "that train is departing from #TRAIN-OFFERBOOKED-DEPART# .",
-            "it departs from #TRAIN-OFFERBOOKED-DEPART# .",
-            "the train will be departing from #TRAIN-OFFERBOOKED-DEPART# ."
-        ],
-        "Dest": [
-            "the booking is for arriving in #TRAIN-OFFERBOOKED-DEST# .",
-            "the train stop is #TRAIN-OFFERBOOKED-DEST# ."
-        ],
-        "Arrive": [
-            "It arrives at #TRAIN-OFFERBOOKED-ARRIVE# .",
-            "it should arrive by #TRAIN-OFFERBOOKED-ARRIVE#",
-            "The arrival time is #TRAIN-OFFERBOOKED-ARRIVE# ."
-        ],
-        "Choice": [
-            "I have #TRAIN-OFFERBOOKED-CHOICE# trains that meet your criteria .",
-            "There are #TRAIN-OFFERBOOKED-CHOICE# .",
-            "There are #TRAIN-OFFERBOOKED-CHOICE# options",
-            "There are #TRAIN-OFFERBOOKED-CHOICE# trains available .",
-            "There are #TRAIN-OFFERBOOKED-CHOICE# total trips available to you"
-        ]
-    },
-    "Train-Request": {
-        "Day": [
-            "Can you confirm your desired travel day ?",
-            "Can you tell me what day you would like to travel , please ?",
-            "Can you tell me what day you would like to travel ?",
-            "can you tell me which day you 'd like to travel on ?",
-            "What day would you like ?",
-            "What day will you travel on ?",
-            "On what day will you be traveling ?",
-            "What day will you be traveling ?",
-            "what day did you have in mind ?",
-            "what day would you like to travel ?"
-        ],
-        "Dest": [
-            "Where would you like to go to ?",
-            "Where is your destination ?",
-            "where would you like your train to take you ?",
-            "Where are you heading to ?",
-            "Where are you headed ?",
-            "Where will you be arriving at ?",
-            "What is your destination ?",
-            "What station would you like to arrive at ?"
-        ],
-        "People": [
-            "How many tickets would you like ?",
-            "how many tickets would you like me to book ?",
-            "for how many tickets ?",
-            "For how many people ?",
-            "how many tickets do you need ?",
-            "No problem . How many seats would you like to book ?"
-        ],
-        "Depart": [
-            "where will you be departing from ?",
-            "Where are you departing from ?",
-            "Where will you be leaving from ?",
-            "where did you want to depart from ?",
-            "Where are you departing from ?",
-            "Where will you be traveling from ?"
-        ],
-        "Leave": [
-            "what time would you like to depart ?",
-            "when would you like to travel ?",
-            "Is there a certain time you are wanting to leave ?",
-            "When would you like to leave by ?",
-            "departure time in mind ?",
-            "When would you like the train to depart ?",
-            "what time do you want to depart ?",
-            "when would you like to leave by ?",
-            "what time would you like to leave ?"
-        ],
-        "Arrive": [
-            "Is there a time you would prefer to arrive ?",
-            "What time would you like to arrive by ?",
-            "What time do you need to arrive ?",
-            "What time do you want to arrive by ?",
-            "Do you have an arrival time in mind ?",
-            "Is there a time you would like to arrive by ?",
-            "Is there a time you would like to get there by ?",
-            "Is there a time you need to arrive by ?"
-        ]
-    },
-    "Police-Inform": {
-        "Addr": [
-            "it is located in #POLICE-INFORM-ADDR#",
-            "adress is #POLICE-INFORM-ADDR#",
-            "It is on #POLICE-INFORM-ADDR# .",
-            "their address in our system is listed as #POLICE-INFORM-ADDR# .",
-            "The address is #POLICE-INFORM-ADDR# .",
-            "it 's located at #POLICE-INFORM-ADDR# .",
-            "#POLICE-INFORM-ADDR# is the address",
-            "They are located at #POLICE-INFORM-ADDR# ."
-        ],
-        "Post": [
-            "The postcode of the police is #POLICE-INFORM-POST# .",
-            "The post code is #POLICE-INFORM-POST# .",
-            "Its postcode is #POLICE-INFORM-POST# .",
-            "Their postcode is #POLICE-INFORM-POST# ."
-        ],
-        "Name": [
-            "I think a fun place to visit is #POLICE-INFORM-NAME# .",
-            "#POLICE-INFORM-NAME# looks good .",
-            "#POLICE-INFORM-NAME# is available , would that work for you ?",
-            "we have #POLICE-INFORM-NAME# .",
-            "#POLICE-INFORM-NAME# is popular among visitors .",
-            "How about #POLICE-INFORM-NAME# ?",
-            "What about #POLICE-INFORM-NAME# ?",
-            "you might want to try the #POLICE-INFORM-NAME# ."
-        ],
-        "Phone": [
-            "The police phone number is #POLICE-INFORM-PHONE# .",
-            "Here is the police phone number , #POLICE-INFORM-PHONE# ."
-        ]
-    },
-    "Hospital-Inform": {
-        "Addr": [
-            "it is located in #HOSPITAL-INFORM-ADDR#",
-            "adress is #HOSPITAL-INFORM-ADDR#",
-            "It is on #HOSPITAL-INFORM-ADDR# .",
-            "their address in our system is listed as #HOSPITAL-INFORM-ADDR# .",
-            "The address is #HOSPITAL-INFORM-ADDR# .",
-            "it 's located at #HOSPITAL-INFORM-ADDR# .",
-            "#HOSPITAL-INFORM-ADDR# is the address",
-            "They are located at #HOSPITAL-INFORM-ADDR# ."
-        ],
-        "Post": [
-            "The postcode of the hospital is #HOSPITAL-INFORM-POST# .",
-            "The post code is #HOSPITAL-INFORM-POST# .",
-            "Its postcode is #HOSPITAL-INFORM-POST# .",
-            "Their postcode is #HOSPITAL-INFORM-POST# ."
-        ],
-        "Department": [
-            "The department of the hospital is #HOSPITAL-INFORM-POST# .",
-            "The department is #HOSPITAL-INFORM-POST# .",
-            "Its department is #HOSPITAL-INFORM-POST# .",
-            "Their department is #HOSPITAL-INFORM-POST# ."
-	    
-        ],
-        "Phone": [
-            "The hospital phone number is #HOSPITAL-INFORM-PHONE# .",
-            "Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ."
-        ]
-    },
-    "Hospital-Request": {
-	"Department": [
-            "What is the name of the hospital department ?",
-            "What hospital department are you thinking about ?",
-            "I ' m sorry for the confusion , what hospital department are you interested in ?",
-            "What hospital department were you thinking of ?",
-            "Do you know the department of it ?",
-            "can you give me the department of it ?"
-	]
-    },
-    "general-bye": {
-        "none": [
-            "Thank you for using our services .",
-            "Goodbye . If you think of anything else you need do n't hesitate to contact us .",
-            "You are very welcome . Goodbye .",
-            "Thank you and enjoy your visit . Have a great day .",
-            "I ' m happy to help , and I hope you enjoy your stay !",
-            "Thank you and goodbye .",
-            "Thank you for using our system !"
-        ]
-    },
-    "general-greet": {
-        "none": [
-            "You are more than welcome !",
-            "Have a good day .",
-            "I ' m happy to have been able to help you today .",
-            "Glad to have been of help . Thank you for using the Cambridge TownInfo centre . Enjoy the rest of your day !",
-            "Thank you for contacting the help desk . Have a great day .",
-            "Thank you for contacting us and have a nice day .",
-            "Ok , thank you . Have a good day .",
-            "Thank you for using our services ."
-        ]
-    },
-    "general-reqmore": {
-        "none": [
-            "You are welcome . Is there anything else I can help you with today ?",
-            "is there anything else I can help you with ?",
-            "Is there anything else I can help you with today ?",
-            "Did you need any further assistance today ?"
-        ]
-    },
-    "general-reqinfo": {
-        "none": [
-            "Could you please provide me with more information ."
-        ]
-    },
-    "general-welcome": {
-        "none": [
-            "you are welcome",
-            "Welcome , it was a pleasure serving you .",
-            "You 're welcome ! I hope you have a wonderful trip !",
-            "Okay ! Glad I could help . Enjoy your stay .",
-            "You are welcome . Have a good day !",
-            "you are welcome",
-            "You 're welcome . Have a good day !"
-        ]
-    }
-}
diff --git a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json
index 300369b525d3d5c4781ea00587f198e56657a37b..414f0b80b740d0c33568ba21bab5e13e822d95df 100755
--- a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json
+++ b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json
@@ -225,7 +225,7 @@
             "Do you know the name of it ?",
             "can you give me the name of it ?"
         ],
-        "Fee": [
+        "Price": [
             "any specific price range to help narrow down available options ?",
             "What price range would you like ?",
             "what is your price range for that ?",
@@ -363,6 +363,42 @@
             "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ."
         ]
     },
+    "Booking-Request": {
+        "Day": [
+            "What day would you like your booking for ?",
+            "What day would you like that reservation ?",
+            "what day would you like the booking to be made for ?",
+            "What day would you like to book ?",
+            "Ok , what day would you like to make the reservation on ?"
+        ],
+        "Stay": [
+            "How many nights will you be staying ?",
+            "And how many nights ?",
+            "for how many days ?",
+            "And for how many days ?",
+            "how many days would you like to stay ?",
+            "How many nights would you like to book it for ?",
+            "And what nights would you like me to reserve for you ?",
+            "How many nights are you wanting to stay ?",
+            "How many days will you be staying ?"
+        ],
+        "People": [
+            "For how many people ?",
+            "How many people will be ?",
+            "How many people will be with you ?",
+            "How many people is the reservation for ?"
+        ],
+        "Time": [
+            "Do you have a time preference ?",
+            "what time are you looking for a reservation at ?",
+            "For what time ?",
+            "What time would you like me to make your reservation ?",
+            "What time would you like the reservation for ?",
+            "what time should I make the reservation for ?",
+            "What time would you prefer ?",
+            "What time would you like the reservation for ?"
+        ]
+    },
     "Hotel-Inform": {
         "Internet": [
             "it has wifi .",
@@ -661,30 +697,6 @@
             "Do you need free parking ?",
             "Will you need parking while you 're there ?",
             "Will you be needing free parking ?"
-        ],
-        "Day": [
-            "What day would you like your booking for ?",
-            "What day would you like that reservation ?",
-            "what day would you like the booking to be made for ?",
-            "What day would you like to book ?",
-            "Ok , what day would you like to make the reservation on ?"
-        ],
-        "Stay": [
-            "How many nights will you be staying ?",
-            "And how many nights ?",
-            "for how many days ?",
-            "And for how many days ?",
-            "how many days would you like to stay ?",
-            "How many nights would you like to book it for ?",
-            "And what nights would you like me to reserve for you ?",
-            "How many nights are you wanting to stay ?",
-            "How many days will you be staying ?"
-        ],
-        "People": [
-            "For how many people ?",
-            "How many people will be ?",
-            "How many people will be with you ?",
-            "How many people is the reservation for ?"
         ]
     },
     "Restaurant-Inform": {
@@ -906,29 +918,6 @@
             "what is the name of the restaurant you are needing information on ?",
             "Do you know the name of the location ?",
             "Is there a certain restaurant you 're looking for ?"
-        ],
-        "Day": [
-            "What day would you like your booking for ?",
-            "What day would you like that reservation ?",
-            "what day would you like the booking to be made for ?",
-            "What day would you like to book ?",
-            "Ok , what day would you like to make the reservation on ?"
-        ],
-        "People": [
-            "For how many people ?",
-            "How many people will be ?",
-            "How many people will be with you ?",
-            "How many people is the reservation for ?"
-        ],
-        "Time": [
-            "Do you have a time preference ?",
-            "what time are you looking for a reservation at ?",
-            "For what time ?",
-            "What time would you like me to make your reservation ?",
-            "What time would you like the reservation for ?",
-            "what time should I make the reservation for ?",
-            "What time would you prefer ?",
-            "What time would you like the reservation for ?"
         ]
     },
     "Taxi-Inform": {
@@ -1342,77 +1331,6 @@
             "Is there a time you need to arrive by ?"
         ]
     },
-    "Police-Inform": {
-        "Addr": [
-            "it is located in #POLICE-INFORM-ADDR#",
-            "adress is #POLICE-INFORM-ADDR#",
-            "It is on #POLICE-INFORM-ADDR# .",
-            "their address in our system is listed as #POLICE-INFORM-ADDR# .",
-            "The address is #POLICE-INFORM-ADDR# .",
-            "it 's located at #POLICE-INFORM-ADDR# .",
-            "#POLICE-INFORM-ADDR# is the address",
-            "They are located at #POLICE-INFORM-ADDR# ."
-        ],
-        "Post": [
-            "The postcode of the police is #POLICE-INFORM-POST# .",
-            "The post code is #POLICE-INFORM-POST# .",
-            "Its postcode is #POLICE-INFORM-POST# .",
-            "Their postcode is #POLICE-INFORM-POST# ."
-        ],
-        "Name": [
-            "I think a fun place to visit is #POLICE-INFORM-NAME# .",
-            "#POLICE-INFORM-NAME# looks good .",
-            "#POLICE-INFORM-NAME# is available , would that work for you ?",
-            "we have #POLICE-INFORM-NAME# .",
-            "#POLICE-INFORM-NAME# is popular among visitors .",
-            "How about #POLICE-INFORM-NAME# ?",
-            "What about #POLICE-INFORM-NAME# ?",
-            "you might want to try the #POLICE-INFORM-NAME# ."
-        ],
-        "Phone": [
-            "The police phone number is #POLICE-INFORM-PHONE# .",
-            "Here is the police phone number , #POLICE-INFORM-PHONE# ."
-        ]
-    },
-    "Hospital-Inform": {
-        "Addr": [
-            "it is located in #HOSPITAL-INFORM-ADDR#",
-            "adress is #HOSPITAL-INFORM-ADDR#",
-            "It is on #HOSPITAL-INFORM-ADDR# .",
-            "their address in our system is listed as #HOSPITAL-INFORM-ADDR# .",
-            "The address is #HOSPITAL-INFORM-ADDR# .",
-            "it 's located at #HOSPITAL-INFORM-ADDR# .",
-            "#HOSPITAL-INFORM-ADDR# is the address",
-            "They are located at #HOSPITAL-INFORM-ADDR# ."
-        ],
-        "Post": [
-            "The postcode of the hospital is #HOSPITAL-INFORM-POST# .",
-            "The post code is #HOSPITAL-INFORM-POST# .",
-            "Its postcode is #HOSPITAL-INFORM-POST# .",
-            "Their postcode is #HOSPITAL-INFORM-POST# ."
-        ],
-        "Department": [
-            "The department of the hospital is #HOSPITAL-INFORM-POST# .",
-            "The department is #HOSPITAL-INFORM-POST# .",
-            "Its department is #HOSPITAL-INFORM-POST# .",
-            "Their department is #HOSPITAL-INFORM-POST# ."
-	    
-        ],
-        "Phone": [
-            "The hospital phone number is #HOSPITAL-INFORM-PHONE# .",
-            "Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ."
-        ]
-    },
-    "Hospital-Request": {
-	"Department": [
-            "What is the name of the hospital department ?",
-            "What hospital department are you thinking about ?",
-            "I ' m sorry for the confusion , what hospital department are you interested in ?",
-            "What hospital department were you thinking of ?",
-            "Do you know the department of it ?",
-            "can you give me the department of it ?"
-	]
-    },
     "general-bye": {
         "none": [
             "Thank you for using our services .",
@@ -1460,4 +1378,4 @@
             "You 're welcome . Have a good day !"
         ]
     }
-}
+}
\ No newline at end of file
diff --git a/convlab/nlg/template/multiwoz/nlg.bck.py b/convlab/nlg/template/multiwoz/nlg.bck.py
deleted file mode 100755
index c657f13058e5aa273a7898de113b611be628cf8a..0000000000000000000000000000000000000000
--- a/convlab/nlg/template/multiwoz/nlg.bck.py
+++ /dev/null
@@ -1,502 +0,0 @@
-import json
-import random
-import os
-from pprint import pprint
-import collections
-import logging
-
-import numpy as np
-import re
-
-from convlab.nlg import NLG
-from convlab.nlg.template.multiwoz.noise_functions import delete_random_token, random_token_permutation, spelling_noise
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
-from convlab.util import relative_import_module_from_unified_datasets
-from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple
-
-reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da')
-
-def lower_keys(x, d=None):
-    if isinstance(x, list):
-        return [lower_keys(v) for v in x]
-    elif isinstance(x, dict):
-        #return {k.lower(): lower_keys(v) for k, v in x.items()}
-        return {SLOT_MAP_TO_UDF.get(k.title(), k.lower()): lower_keys(v) for k, v in x.items()}
-    else:
-        #return x
-        y = re.sub("#([^\-]+)-([^\-]+)-([^\-]+)#", lambda x: "#" + x.group(1) + "-" + x.group(2) + "-" + SLOT_MAP_TO_UDF[d].get(x.group(3).title(), x.group(3)).upper() + "#", x)
-        return y
-
-
-def read_json(filename):
-    raw_data = None
-    with open(filename, 'r') as f:
-        raw_data = json.load(f)
-        #return lower_keys(json.load(f), None)
-    data = {}
-    for di in raw_data:
-        new_di = di.lower()
-        data[new_di] = {}
-        d, _ = di.split('-')
-        for s in raw_data[di]:
-            new_s = SLOT_MAP_TO_UDF[d].get(s.title(), s.lower()) if d in SLOT_MAP_TO_UDF else s.lower()
-            data[new_di][new_s] = []
-            for t in raw_data[di][s]:
-                new_t = re.sub("#([^\-]+)-([^\-]+)-([^\-]+)#", lambda x: "#" + x.group(1) + "-" + x.group(2) + "-" + SLOT_MAP_TO_UDF[d].get(x.group(3).title(), x.group(3)).upper() + "#" if d in SLOT_MAP_TO_UDF else x.group(3).upper() + "#", t)
-                data[new_di][new_s].append(new_t)
-    return data
-
-
-# supported slot
-Slot2word = {
-    'Fee': 'fee',
-    'Addr': 'address',
-    'Area': 'area',
-    'Stars': 'stars',
-    'Internet': 'Internet',
-    'Department': 'department',
-    'Choice': 'choice',
-    'Ref': 'reference number',
-    'Food': 'food',
-    'Type': 'type',
-    'Price': 'price range',
-    'Stay': 'stay',
-    'Phone': 'phone number',
-    'Post': 'postcode',
-    'Day': 'day',
-    'Name': 'name',
-    'Car': 'car type',
-    'Leave': 'leave',
-    'Time': 'time',
-    'Arrive': 'arrive',
-    'Ticket': 'ticket',
-    'Depart': 'departure',
-    'People': 'people',
-    'Dest': 'destination',
-    'Parking': 'parking',
-    'Open': 'open',
-    'Id': 'Id',
-    # 'TrainID': 'TrainID'
-}
-
-
-SLOT_MAP_TO_UDF = {
-    'Attraction': {
-        'Addr': 'address',
-        'Post': 'postcode',
-        'Fee': 'entrance fee'
-    },
-    'Hospital': {
-        'Post': 'postcode',
-        'Addr': 'address'
-    },
-    'Hotel': {
-        'Addr': 'address',
-        'Post': 'postcode',
-        'Price': 'price range',
-        'Stay': 'book stay',
-        'Day': 'book day',
-        'People': 'book people'
-    },
-    'Police': {
-        'Post': 'postcode',
-        'Addr': 'address'
-    },
-    'Restaurant': {
-        'Price': 'price range',
-        'Time': 'book time',
-        'People': 'book people',
-        'Day': 'book day',
-        'Addr': 'address',
-        'Post': 'postcode'
-    },
-    'Taxi': {
-        'Leave': 'leave at',
-        'Arrive': 'arrive by',
-        'Dest': 'destination',
-        'Depart': 'departure',
-        'Car': 'type'
-    },
-    'Train': {
-        'Time': 'duration',
-        'Ticket': 'price',
-        'Leave': 'leave at',
-        'Id': 'train id',
-        'Arrive': 'arrive by',
-        'Depart': 'departure',
-        'People': 'book people',
-        'Dest': 'destination'
-    }
-}
-
-
-SLOT_MAP_TO_UDF2 = {
-    'Addr': 'address',
-    'Area': 'area',
-    'Arrive': 'arrive by',
-    'Car': 'type',
-    'Choice': 'choice',
-    'Day': 'day',
-    'Depart': 'departure',
-    'Department': 'department',
-    'Dest': 'destination',
-    'Fee': 'entrance fee',
-    'Food': 'food',
-    'Id': 'train id',
-    'Internet': 'internet',
-    'Leave': 'leave at',
-    'Name': 'name',
-    'Open': 'open',
-    'Parking': 'parking',
-    'People': 'book people',
-    'Phone': 'phone',
-    'Post': 'postcode',
-    'Price': 'price range',
-    'Ref': 'ref',
-    'Stars': 'stars',
-    'Stay': 'book stay',
-    'Ticket': 'price',
-    'Time': 'time',
-    'TrainID': 'train id',
-    'Type': 'type'
-}
-
-slot2word = dict((k.lower(), v.lower()) for k, v in Slot2word.items())
-
-
-class TemplateNLG(NLG):
-
-    def __init__(self, is_user, mode="manual", label_noise=0.0, text_noise=0.0, seed=0):
-        """
-        Args:
-            is_user:
-                if dialog_act from user or system
-            mode:
-                - `auto`: templates extracted from data without manual modification, may have no match;
-
-                - `manual`: templates with manual modification, sometimes verbose;
-
-                - `auto_manual`: use auto templates first. When fails, use manual templates.
-
-                both template are dict, *_template[dialog_act][slot] is a list of templates.
-        """
-        super().__init__()
-        self.is_user = is_user
-        self.mode = mode
-        self.label_noise = label_noise
-        self.text_noise = text_noise
-        if not is_user:
-            self.label_noise, self.text_noise = 0.0, 0.0
-
-        print("NLG seed " + str(seed))
-        logging.info(f'Building {"user" if is_user else "system"} template NLG module using {mode} templates.')
-        if self.label_noise > 0.0 or self.text_noise > 0.0:
-            self.seed = seed
-            np.random.seed(seed)
-            random.seed(seed)
-            logging.info(f'Template NLG will generate {self.label_noise * 100}% noise in values and {self.text_noise * 100}% random text noise.')
-
-        self.load_templates()
-
-    def load_templates(self):
-        template_dir = os.path.dirname(os.path.abspath(__file__))
-        if self.is_user:
-            if 'manual' in self.mode:
-                self.manual_user_template = read_json(os.path.join(
-                    template_dir, 'manual_user_template_nlg.json'))
-            if 'auto' in self.mode:
-                self.auto_user_template = read_json(os.path.join(
-                    template_dir, 'auto_user_template_nlg.json'))
-        else:
-            if 'manual' in self.mode:
-                self.manual_system_template = read_json(os.path.join(
-                    template_dir, 'manual_system_template_nlg.json'))
-            if 'auto' in self.mode:
-                self.auto_system_template = read_json(os.path.join(template_dir, 'auto_system_template_nlg.json'))
-        logging.info('NLG templates loaded.')
-        
-        if self.label_noise > 0.0 and self.is_user:
-            self.label_map = read_json(os.path.join(template_dir, 'label_maps.json'))
-            logging.info('NLG value noise label map loaded.')
-
-    def sorted_dialog_act(self, dialog_acts):
-        new_action_group = {}
-        for item in dialog_acts:
-            intent, domain, slot, value = item
-            if domain not in new_action_group:
-                new_action_group[domain] = {
-                    'nooffer': [], 'inform-name': [], 'inform-other': [], 'request': [], 'other': []}
-            if intent == 'NoOffer':
-                new_action_group[domain]['nooffer'].append(item)
-            elif intent == 'Inform' and slot == 'Name':
-                new_action_group[domain]['inform-name'].append(item)
-            elif intent == 'Inform':
-                new_action_group[domain]['inform-other'].append(item)
-            elif intent == 'request':
-                new_action_group[domain]['request'].append(item)
-            else:
-                new_action_group[domain]['other'].append(item)
-
-        new_action = []
-        if 'general' in new_action_group:
-            new_action += new_action_group['general']['other']
-            del new_action_group['general']
-        for domain in new_action_group:
-            for k in ['other', 'request', 'inform-other', 'inform-name', 'nooffer']:
-                new_action = new_action_group[domain][k] + new_action
-        return new_action
-
-    def noisy_dialog_acts(self, dialog_acts):
-        if self.label_noise > 0.0:
-            noisy_acts = []
-            for intent, domain, slot, value in dialog_acts:
-                if intent == 'Inform':
-                    if value in self.label_map:
-                        if np.random.uniform() < self.label_noise:
-                            value = self.label_map[value]
-                            value = np.random.choice(value)
-                noisy_acts.append([intent, domain, slot, value])
-            return noisy_acts
-        return dialog_acts
-
-    def generate(self, dialog_acts):
-        """NLG for Multiwoz dataset
-
-        Args:
-            dialog_acts
-        Returns:
-            generated sentence
-        """
-        print("dialog_acts0:", dialog_acts)
-        dialog_acts = unified_format(dialog_acts)
-        print("dialog_acts1:", dialog_acts)
-        dialog_acts = reverse_da(dialog_acts)
-        print("dialog_acts2:", dialog_acts)
-        dialog_acts = act_dict_to_flat_tuple(dialog_acts)
-        print("dialog_acts3:", dialog_acts)
-        dialog_acts = self.noisy_dialog_acts(
-            dialog_acts) if self.is_user else dialog_acts
-        dialog_acts = self.sorted_dialog_act(dialog_acts)
-        action = collections.OrderedDict()
-        for intent, domain, slot, value in dialog_acts:
-            k = '-'.join([domain.lower(), intent.lower()])
-            action.setdefault(k, [])
-            action[k].append([slot.lower(), value])
-        dialog_acts = action
-        print("dialog_acts4:", dialog_acts)
-        mode = self.mode
-        try:
-            is_user = self.is_user
-            if mode == 'manual':
-                if is_user:
-                    template = self.manual_user_template
-                else:
-                    template = self.manual_system_template
-
-                return self._manual_generate(dialog_acts, template)
-
-            elif mode == 'auto':
-                if is_user:
-                    template = self.auto_user_template
-                else:
-                    template = self.auto_system_template
-
-                return self._auto_generate(dialog_acts, template)
-
-            elif mode == 'auto_manual':
-                if is_user:
-                    template1 = self.auto_user_template
-                    template2 = self.manual_user_template
-                else:
-                    template1 = self.auto_system_template
-                    template2 = self.manual_system_template
-
-                res = self._auto_generate(dialog_acts, template1)
-                if res == 'None':
-                    res = self._manual_generate(dialog_acts, template2)
-                return res
-
-            else:
-                raise Exception(
-                    "Invalid mode! available mode: auto, manual, auto_manual")
-        except Exception as e:
-            print('Error in processing:')
-            pprint(dialog_acts)
-            raise e
-
-    def _postprocess(self, sen):
-        sen_strip = sen.strip()
-        sen = ''.join([val.capitalize() if i == 0 else val for i,
-                       val in enumerate(sen_strip)])
-        if len(sen) > 0 and sen[-1] != '?' and sen[-1] != '.':
-            sen += '.'
-        sen += ' '
-        return sen
-
-    def _add_random_noise(self, sen):
-        if self.text_noise > 0.0:
-            end = sen[-3:]
-            sen = sen[:-3]
-            sen = random_token_permutation(
-                sen, probability=self.text_noise / 2)
-            sen = delete_random_token(sen, probability=self.text_noise)
-            sen += end
-        return sen
-
-    def _add_random_noise(self, sen):
-        if self.text_noise > 0.0:
-            end = sen[-3:]
-            sen = sen[:-3]
-            sen = spelling_noise(sen, prob=self.text_noise / 2)
-            # sen = random_token_permutation(sen, probability=self.text_noise / 4)
-            sen = delete_random_token(sen, probability=self.text_noise / 2)
-            sen += end
-        return sen
-
-    def _manual_generate(self, dialog_acts, template):
-        sentences = ''
-        for dialog_act, slot_value_pairs in dialog_acts.items():
-            intent = dialog_act.split('-')
-            if 'select' == intent[1]:
-                slot2values = {}
-                for slot, value in slot_value_pairs:
-                    slot2values.setdefault(slot, [])
-                    slot2values[slot].append(value)
-                for slot, values in slot2values.items():
-                    if slot == 'none':
-                        continue
-                    sentence = 'Do you prefer ' + values[0]
-                    for i, value in enumerate(values[1:]):
-                        if i == (len(values) - 2):
-                            sentence += ' or ' + value
-                        else:
-                            sentence += ' , ' + value
-                    sentence += ' {} ? '.format(slot2word[slot])
-                    sentences += self._add_random_noise(sentence)
-            elif 'request' == intent[1]:
-                for slot, value in slot_value_pairs:
-                    if dialog_act not in template or slot not in template[dialog_act]:
-                        if dialog_act not in template:
-                            print("dialog_act not in template:", dialog_act)
-                        else:
-                            print("slot not in template:", dialog_act, slot)
-                            import pdb
-                            pdb.set_trace()
-                        sentence = 'What is the {} of {} ? '.format(
-                            slot.lower(), dialog_act.split('-')[0].lower())
-                        sentences += self._add_random_noise(sentence)
-                    else:
-                        sentence = random.choice(template[dialog_act][slot])
-                        sentence = self._postprocess(sentence)
-                        sentences += self._add_random_noise(sentence)
-            elif 'general' == intent[0] and dialog_act in template:
-                sentence = random.choice(template[dialog_act]['none'])
-                sentence = self._postprocess(sentence)
-                sentences += self._add_random_noise(sentence)
-            else:
-                for slot, value in slot_value_pairs:
-                    if isinstance(value, str):
-                        value_lower = value.lower()
-                    if value in ["do nt care", "do n't care", "dontcare"]:
-                        sentence = 'I don\'t care about the {} of the {}'.format(
-                            slot, dialog_act.split('-')[0])
-                    elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any':
-                        # user have no preference, any choice is ok
-                        sentence = random.choice([
-                            "Please pick one for me. ",
-                            "Anyone would be ok. ",
-                            "Just select one for me. "
-                        ])
-                    elif slot == 'price' and 'same price range' in value_lower:
-                        sentence = random.choice([
-                            "it just needs to be {} .".format(value),
-                            "Oh , I really need something {} .".format(value),
-                            "I would prefer something that is {} .".format(
-                                value),
-                            "it needs to be {} .".format(value)
-                        ])
-                    elif slot in ['internet', 'parking'] and value_lower == 'no':
-                        sentence = random.choice([
-                            "It does n't need to have {} .".format(slot),
-                            "I do n't need free {} .".format(slot),
-                        ])
-                    elif dialog_act in template and slot in template[dialog_act]:
-                        sentence = random.choice(template[dialog_act][slot])
-                        if 'not available' in value.lower():
-                            domain_ = dialog_act.split('-', 1)[0]
-                            slot_ = slot2word.get(slot, None)
-                            slot_ = REF_SYS_DA.get(domain_, {}).get(
-                                slot.title(), None) if not slot_ else slot_
-                            if slot_:
-                                sentence = f"Sorry, I do not know the {slot_.lower()} of the {domain_.lower()}."
-                            else:
-                                sentence = "Sorry, I do not have that information."
-                        else:
-                            sentence = sentence.replace(
-                                '#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value))
-                    elif slot == 'notbook':
-                        sentence = random.choice([
-                            "I do not need to book. ",
-                            "I 'm not looking to make a booking at the moment."
-                        ])
-                    else:
-                        if slot in slot2word:
-                            if 'not available' in value.lower():
-                                domain_ = dialog_act.split('-', 1)[0]
-                                slot_ = slot2word[slot]
-                                if slot_:
-                                    sentence = f"Sorry, I do not know the {slot_.lower()} of the {domain_.lower()}."
-                                else:
-                                    sentence = "Sorry, I do not have that information."
-                            else:
-                                sentence = 'The {} is {} . '.format(
-                                    slot2word[slot], str(value))
-                        else:
-                            sentence = ''
-                    sentence = self._postprocess(sentence)
-                    sentences += self._add_random_noise(sentence)
-        return sentences.strip()
-
-    def _auto_generate(self, dialog_acts, template):
-        sentences = ''
-        for dialog_act, slot_value_pairs in dialog_acts.items():
-            key = ''
-            for s, v in sorted(slot_value_pairs, key=lambda x: x[0]):
-                key += s + ';'
-            if dialog_act in template and key in template[dialog_act]:
-                sentence = random.choice(template[dialog_act][key])
-                if 'request' in dialog_act or 'general' in dialog_act:
-                    sentence = self._postprocess(sentence)
-                    sentences += sentence
-                else:
-                    for s, v in sorted(slot_value_pairs, key=lambda x: x[0]):
-                        if v != 'none':
-                            sentence = sentence.replace(
-                                '#{}-{}#'.format(dialog_act.upper(), s.upper()), v, 1)
-                    sentence = self._postprocess(sentence)
-                    sentences += sentence
-            else:
-                return 'None'
-        return sentences.strip()
-
-
-def example():
-    # dialog act
-    dialog_acts = [['Inform', 'Hotel', 'Area', 'east'], [
-        'Inform', 'Hotel', 'Internet', 'no'], ['welcome', 'general', 'none', 'none']]
-    #dialog_acts = [['Inform', 'Restaurant', 'NotBook', 'none']]
-    print(dialog_acts)
-
-    # system model for manual, auto, auto_manual
-    nlg_sys_manual = TemplateNLG(is_user=False, mode='manual')
-    nlg_sys_auto = TemplateNLG(is_user=False, mode='auto')
-    nlg_sys_auto_manual = TemplateNLG(is_user=False, mode='auto_manual')
-
-    # generate
-    print('manual      : ', nlg_sys_manual.generate(dialog_acts))
-    print('auto        : ', nlg_sys_auto.generate(dialog_acts))
-    print('auto_manual : ', nlg_sys_auto_manual.generate(dialog_acts))
-
-
-if __name__ == '__main__':
-    example()
diff --git a/convlab/nlg/template/multiwoz/nlg.py b/convlab/nlg/template/multiwoz/nlg.py
index 5f362ebbabd68ee0b6fab62c692caf0e8da436e9..f83a6db4e2ad6f77f9bdc154cfda7bf2db0ff2c5 100755
--- a/convlab/nlg/template/multiwoz/nlg.py
+++ b/convlab/nlg/template/multiwoz/nlg.py
@@ -31,33 +31,33 @@ def read_json(filename):
 
 # supported slot
 Slot2word = {
-    'Fee': 'entrance fee',
+    'Fee': 'fee',
     'Addr': 'address',
     'Area': 'area',
-    'Stars': 'number of stars',
-    'Internet': 'internet',
+    'Stars': 'stars',
+    'Internet': 'Internet',
     'Department': 'department',
     'Choice': 'choice',
     'Ref': 'reference number',
     'Food': 'food',
     'Type': 'type',
     'Price': 'price range',
-    'Stay': 'length of the stay',
+    'Stay': 'stay',
     'Phone': 'phone number',
     'Post': 'postcode',
     'Day': 'day',
     'Name': 'name',
     'Car': 'car type',
-    'Leave': 'departure time',
+    'Leave': 'leave',
     'Time': 'time',
-    'Arrive': 'arrival time',
-    'Ticket': 'ticket price',
+    'Arrive': 'arrive',
+    'Ticket': 'ticket',
     'Depart': 'departure',
-    'People': 'number of people',
+    'People': 'people',
     'Dest': 'destination',
     'Parking': 'parking',
-    'Open': 'opening hours',
-    'Id': 'id',
+    'Open': 'open',
+    'Id': 'Id',
     # 'TrainID': 'TrainID'
 }
 
@@ -271,10 +271,6 @@ class TemplateNLG(NLG):
             elif 'request' == intent[1]:
                 for slot, value in slot_value_pairs:
                     if dialog_act not in template or slot not in template[dialog_act]:
-                        if dialog_act not in template:
-                            print("WARNING (nlg.py): (User?: %s) dialog_act '%s' not in template!" % (self.is_user, dialog_act))
-                        else:
-                            print("WARNING (nlg.py): (User?: %s) slot '%s' of dialog_act '%s' not in template!" % (self.is_user, slot, dialog_act))
                         sentence = 'What is the {} of {} ? '.format(
                             slot.lower(), dialog_act.split('-')[0].lower())
                         sentences += self._add_random_noise(sentence)
@@ -292,7 +288,7 @@ class TemplateNLG(NLG):
                         value_lower = value.lower()
                     if value in ["do nt care", "do n't care", "dontcare"]:
                         sentence = 'I don\'t care about the {} of the {}'.format(
-                            slot2word.get(slot, slot), dialog_act.split('-')[0])
+                            slot, dialog_act.split('-')[0])
                     elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any':
                         # user have no preference, any choice is ok
                         sentence = random.choice([
diff --git a/convlab/nlu/jointBERT/dataloader.py b/convlab/nlu/jointBERT/dataloader.py
index 2aa04ae3cb2735e55717037524543b9f1b7a8039..d1fcbc7a4864211a9956cedacc7c3479c195733a 100755
--- a/convlab/nlu/jointBERT/dataloader.py
+++ b/convlab/nlu/jointBERT/dataloader.py
@@ -21,7 +21,7 @@ class Dataloader:
         self.intent2id = dict([(x, i) for i, x in enumerate(intent_vocab)])
         self.id2tag = dict([(i, x) for i, x in enumerate(tag_vocab)])
         self.tag2id = dict([(x, i) for i, x in enumerate(tag_vocab)])
-        self.tokenizer = BertTokenizer.from_pretrained(pretrained_weights, local_files_only=True)
+        self.tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
         self.data = {}
         self.intent_weight = [1] * len(self.intent2id)
 
diff --git a/convlab/nlu/jointBERT/jointBERT.py b/convlab/nlu/jointBERT/jointBERT.py
index 9550c50b51707e29d8cadd0d67d125ee7716f92a..5f73c9aba6808ff92e15d7967bf6a2aa43419c75 100755
--- a/convlab/nlu/jointBERT/jointBERT.py
+++ b/convlab/nlu/jointBERT/jointBERT.py
@@ -12,7 +12,7 @@ class JointBERT(nn.Module):
         self.intent_weight = intent_weight if intent_weight is not None else torch.tensor([1.]*intent_dim)
 
         print(model_config['pretrained_weights'])
-        self.bert = BertModel.from_pretrained(model_config['pretrained_weights'], local_files_only=True)
+        self.bert = BertModel.from_pretrained(model_config['pretrained_weights'])
         self.dropout = nn.Dropout(model_config['dropout'])
         self.context = model_config['context']
         self.finetune = model_config['finetune']
diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py
index 1373919e5861156c87a2ba14d6506d13e0842204..e25fbad1227c4b1f85ae2ae42a8ac899fa61d7b8 100755
--- a/convlab/nlu/jointBERT/multiwoz/nlu.py
+++ b/convlab/nlu/jointBERT/multiwoz/nlu.py
@@ -74,8 +74,7 @@ class BERTNLU(NLU):
         for token in token_list:
             token = token.strip()
             self.nlp.tokenizer.add_special_case(
-                #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
-                token, [{ORTH: token}])
+                token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
         logging.info("BERTNLU loaded")
 
     def predict(self, utterance, context=list()):
diff --git a/convlab/policy/evaluate.py b/convlab/policy/evaluate.py
index 4d89faff2a9dd133dffec76e8964d7e427dc27f2..7a692261869f35e587c34a26e425d1489abdcf56 100755
--- a/convlab/policy/evaluate.py
+++ b/convlab/policy/evaluate.py
@@ -14,7 +14,6 @@ from convlab.policy.rule.multiwoz import RulePolicy
 from convlab.task.multiwoz.goal_generator import GoalGenerator
 from convlab.util.custom_util import set_seed, get_config, env_config, create_goals, data_goals
 from tqdm import tqdm
-from pprint import pprint
 
 
 def init_logging(log_dir_path, path_suffix=None):
@@ -60,10 +59,7 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d
         policy_sys = GDPL(vectorizer=conf['vectorizer_sys_activated'])
     elif model_name == "DDPT":
         from convlab.policy.vtrace_DPT import VTRACE
-        policy_sys = VTRACE(
-            is_train=False, vectorizer=conf['vectorizer_sys_activated'])
-    else:
-        print("Unknown model name", model_name)
+        policy_sys = VTRACE(is_train=False, vectorizer=conf['vectorizer_sys_activated'])
 
     try:
         if model_path:
@@ -83,7 +79,6 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d
     if goals_from_data:
         logging.info("read goals from dataset...")
         goals = data_goals(dialogues, dataset="multiwoz21", dial_ids_order=0)
-
     else:
         logging.info("create goals from goal_generator...")
         goals = create_goals(goal_generator, num_goals=dialogues,
@@ -100,21 +95,18 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d
         task_succ_strict = 0
         complete = 0
 
-        # if verbose:
-        #     logging.info("NEW EPISODE!!!!" + "-" * 80)
-        #     logging.info(f"\n Seed: {seed}")
-        #     logging.info(f"GOAL: {sess.evaluator.goal}")
-        #     logging.info("\n")
-        dialog = []
+        if verbose:
+            logging.info("NEW EPISODE!!!!" + "-" * 80)
+            logging.info(f"\n Seed: {seed}")
+            logging.info(f"GOAL: {sess.evaluator.goal}")
+            logging.info("\n")
         for i in range(40):
             sys_response, user_response, session_over, reward = sess.next_turn(
                 sys_response)
 
             if verbose:
-                dialog.append({"usr": user_response})
-                dialog.append({"sys": sys_response})
-                logging.info(f"usr {user_response}")
-                logging.info(f"sys {sys_response}")
+                logging.info(f"USER RESPONSE: {user_response}")
+                logging.info(f"SYS RESPONSE: {sys_response}")
 
             actions += len(sys_response)
             length = len(sys_response)
@@ -133,19 +125,18 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d
                 task_succ = sess.evaluator.task_success()
                 task_succ = sess.evaluator.success
                 task_succ_strict = sess.evaluator.success_strict
-                complete = sess.evaluator.complete
-                # TODO check the definision of complete rate
-                # complete = sess.user_agent.policy.policy.goal.task_complete()
-
-                # if goals_from_data:
-                #     complete = sess.user_agent.policy.policy.goal.task_complete()
-                # else:
-                #     complete = sess.evaluator.complete
+                if goals_from_data:
+                    complete = sess.user_agent.policy.policy.goal.task_complete()
+                else:
+                    complete = sess.evaluator.complete
                 break
 
         if verbose:
             logging.info(f"Complete: {complete}")
             logging.info(f"Success: {task_succ}")
+            logging.info(f"Success strict: {task_succ_strict}")
+            logging.info(f"Return: {total_return}")
+            logging.info(f"Average actions: {actions / turns}")
 
         task_success['Complete'].append(complete)
         task_success['Success'].append(task_succ)
diff --git a/convlab/policy/evaluate_distributed.py b/convlab/policy/evaluate_distributed.py
index d58b7dee7ae242292f303d6442fad74e40566e97..1f7b3ffe93c040e6e18aa0ccd88b8c21fbcd178c 100644
--- a/convlab/policy/evaluate_distributed.py
+++ b/convlab/policy/evaluate_distributed.py
@@ -73,9 +73,7 @@ def sampler(pid, queue, evt, sess, seed_range, goals):
 
             if session_over is True:
                 success = sess.evaluator.task_success()
-                # TODO check the differenct between complete and success
-                # complete = sess.evaluator.complete
-                complete = sess.user_agent.policy.policy.goal.task_complete()
+                complete = sess.evaluator.complete
                 success = sess.evaluator.success
                 success_strict = sess.evaluator.success_strict
                 break
diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py
deleted file mode 100644
index 87de854970d2701900ba180d2bf15736071e0c1a..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/evaluate.py
+++ /dev/null
@@ -1,257 +0,0 @@
-import json
-import os
-import sys
-from argparse import ArgumentParser
-from pprint import pprint
-
-import torch
-from convlab.nlg.evaluate import fine_SER
-from datasets import load_metric
-
-# from convlab.policy.genTUS.pg.stepGenTUSagent import \
-#     stepGenTUSPG as UserPolicy
-from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
-from tqdm import tqdm
-
-sys.path.append(os.path.dirname(os.path.dirname(
-    os.path.dirname(os.path.abspath(__file__)))))
-
-
-def arg_parser():
-    parser = ArgumentParser()
-    parser.add_argument("--model-checkpoint", type=str, help="the model path")
-    parser.add_argument("--model-weight", type=str,
-                        help="the model weight", default="")
-    parser.add_argument("--input-file", type=str, help="the testing input file",
-                        default="")
-    parser.add_argument("--generated-file", type=str, help="the generated results",
-                        default="")
-    parser.add_argument("--only-action", action="store_true")
-    parser.add_argument("--dataset", default="multiwoz")
-    parser.add_argument("--do-semantic", action="store_true",
-                        help="do semantic evaluation")
-    parser.add_argument("--do-nlg", action="store_true",
-                        help="do nlg generation")
-    parser.add_argument("--do-golden-nlg", action="store_true",
-                        help="do golden nlg generation")
-    return parser.parse_args()
-
-
-class Evaluator:
-    def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False):
-        self.dataset = dataset
-        self.model_checkpoint = model_checkpoint
-        self.model_weight = model_weight
-        # if model_weight:
-        #     self.usr_policy = UserPolicy(
-        #         self.model_checkpoint, only_action=only_action)
-        #     self.usr_policy.load(model_weight)
-        #     self.usr = self.usr_policy.usr
-        # else:
-        self.usr = UserActionPolicy(
-            model_checkpoint, only_action=only_action, dataset=self.dataset)
-        self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
-
-    def generate_results(self, f_eval, golden=False):
-        in_file = json.load(open(f_eval))
-        r = {
-            "input": [],
-            "golden_acts": [],
-            "golden_utts": [],
-            "gen_acts": [],
-            "gen_utts": []
-        }
-        for dialog in tqdm(in_file['dialog']):
-            inputs = dialog["in"]
-            labels = self.usr._parse_output(dialog["out"])
-            if golden:
-                usr_act = labels["action"]
-                usr_utt = self.usr.generate_text_from_give_semantic(
-                    inputs, usr_act)
-
-            else:
-                output = self.usr._parse_output(
-                    self.usr._generate_action(inputs))
-                usr_act = self.usr._remove_illegal_action(output["action"])
-                usr_utt = output["text"]
-            r["input"].append(inputs)
-            r["golden_acts"].append(labels["action"])
-            r["golden_utts"].append(labels["text"])
-            r["gen_acts"].append(usr_act)
-            r["gen_utts"].append(usr_utt)
-
-        return r
-
-    def read_generated_result(self, f_eval):
-        in_file = json.load(open(f_eval))
-        r = {
-            "input": [],
-            "golden_acts": [],
-            "golden_utts": [],
-            "gen_acts": [],
-            "gen_utts": []
-        }
-        for dialog in tqdm(in_file['dialog']):
-            for x in dialog:
-                r[x].append(dialog[x])
-
-        return r
-
-    def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
-        if input_file:
-            print("Force generation")
-            gen_r = self.generate_results(input_file, golden)
-
-        elif generated_file:
-            gen_r = self.read_generated_result(generated_file)
-        else:
-            print("You must specify the input_file or the generated_file")
-
-        nlg_eval = {
-            "golden": golden,
-            "metrics": {},
-            "dialog": []
-        }
-        for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]):
-            nlg_eval["dialog"].append({
-                "input": input,
-                "golden_acts": golden_act,
-                "golden_utts": golden_utt,
-                "gen_acts": gen_act,
-                "gen_utts": gen_utt
-            })
-
-        if golden:
-            print("Calculate BLEU")
-            bleu_metric = load_metric("sacrebleu")
-            labels = [[utt] for utt in gen_r["golden_utts"]]
-
-            bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
-                                             references=labels,
-                                             force=True)
-            print("bleu_metric", bleu_score)
-            nlg_eval["metrics"]["bleu"] = bleu_score
-
-        else:
-            print("Calculate SER")
-            missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
-                gen_r["gen_acts"], gen_r["gen_utts"])
-
-            print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
-                "genTUSNLG", missing, total, hallucinate, missing/total))
-            nlg_eval["metrics"]["SER"] = missing/total
-
-        dir_name = self.model_checkpoint
-        json.dump(nlg_eval,
-                  open(os.path.join(dir_name, "nlg_eval.json"), 'w'),
-                  indent=2)
-        return os.path.join(dir_name, "nlg_eval.json")
-
-    def evaluation(self, input_file=None, generated_file=None):
-        force_prediction = True
-        if generated_file:
-            gen_file = json.load(open(generated_file))
-            force_prediction = False
-            if gen_file["golden"]:
-                force_prediction = True
-
-        if force_prediction:
-            in_file = json.load(open(input_file))
-            dialog_result = []
-            gen_acts, golden_acts = [], []
-            # scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
-            for dialog in tqdm(in_file['dialog']):
-                inputs = dialog["in"]
-                labels = self.usr._parse_output(dialog["out"])
-                ans_action = self.usr._remove_illegal_action(labels["action"])
-                preds = self.usr._generate_action(inputs)
-                preds = self.usr._parse_output(preds)
-                usr_action = self.usr._remove_illegal_action(preds["action"])
-
-                gen_acts.append(usr_action)
-                golden_acts.append(ans_action)
-
-                d = {"input": inputs,
-                     "golden_acts": ans_action,
-                     "gen_acts": usr_action}
-                if "text" in preds:
-                    d["golden_utts"] = labels["text"]
-                    d["gen_utts"] = preds["text"]
-                    # print("pred text", preds["text"])
-
-                dialog_result.append(d)
-        else:
-            gen_acts, golden_acts = [], []
-            for dialog in gen_file['dialog']:
-                gen_acts.append(dialog["gen_acts"])
-                golden_acts.append(dialog["golden_acts"])
-            dialog_result = gen_file['dialog']
-
-        scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
-
-        for gen_act, golden_act in zip(gen_acts, golden_acts):
-            s = f1_measure(preds=gen_act, labels=golden_act)
-            for metric in scores:
-                scores[metric].append(s[metric])
-
-        result = {}
-        for metric in scores:
-            result[metric] = sum(scores[metric])/len(scores[metric])
-            print(f"{metric}: {result[metric]}")
-
-        result["dialog"] = dialog_result
-        basename = "semantic_evaluation_result"
-        json.dump(result, open(os.path.join(
-            self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
-        # if self.model_weight:
-        #     json.dump(result, open(os.path.join(
-        #         'results', f"{basename}.json"), 'w'))
-        # else:
-        #     json.dump(result, open(os.path.join(
-        #         self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
-
-
-def f1_measure(preds, labels):
-    tp = 0
-    score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
-    for p in preds:
-        if p in labels:
-            tp += 1.0
-    if preds:
-        score["precision"] = tp/len(preds)
-    if labels:
-        score["recall"] = tp/len(labels)
-    if (score["precision"] + score["recall"]) > 0:
-        score["f1"] = 2*(score["precision"]*score["recall"]) / \
-            (score["precision"]+score["recall"])
-    if tp == len(preds) and tp == len(labels):
-        score["turn_acc"] = 1
-    return score
-
-
-def main():
-    args = arg_parser()
-    eval = Evaluator(args.model_checkpoint,
-                     args.dataset,
-                     args.model_weight,
-                     args.only_action)
-    print("model checkpoint", args.model_checkpoint)
-    print("generated_file", args.generated_file)
-    print("input_file", args.input_file)
-    with torch.no_grad():
-        if args.do_semantic:
-            eval.evaluation(args.input_file)
-        if args.do_nlg:
-            nlg_result = eval.nlg_evaluation(input_file=args.input_file,
-                                             generated_file=args.generated_file,
-                                             golden=args.do_golden_nlg)
-            if args.generated_file:
-                generated_file = args.generated_file
-            else:
-                generated_file = nlg_result
-            eval.evaluation(args.input_file,
-                            generated_file)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/convlab/policy/genTUS/ppo/vector.py b/convlab/policy/genTUS/ppo/vector.py
deleted file mode 100644
index 4c502a46f87582008ff49219f8a14844378b9ed2..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/ppo/vector.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import json
-
-import torch
-from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
-from convlab.policy.genTUS.token_map import tokenMap
-from convlab.policy.tus.unify.Goal import Goal
-from transformers import BartTokenizer
-
-
-class stepGenTUSVector:
-    def __init__(self, model_checkpoint, max_in_len=400, max_out_len=80, allow_general_intent=True):
-        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
-        self.vocab = len(self.tokenizer)
-        self.max_in_len = max_in_len
-        self.max_out_len = max_out_len
-        self.token_map = tokenMap(tokenizer=self.tokenizer)
-        self.token_map.default(only_action=True)
-        self.kg = KnowledgeGraph(self.tokenizer)
-        self.mentioned_domain = []
-        self.allow_general_intent = allow_general_intent
-        self.candidate_num = 5
-        if self.allow_general_intent:
-            print("---> allow_general_intent")
-
-    def init_session(self, goal: Goal):
-        self.goal = goal
-        self.mentioned_domain = []
-
-    def encode(self, raw_inputs, max_length, return_tensors="pt", truncation=True):
-        model_input = self.tokenizer(raw_inputs,
-                                     max_length=max_length,
-                                     return_tensors=return_tensors,
-                                     truncation=truncation,
-                                     padding="max_length")
-        return model_input
-
-    def decode(self, generated_so_far, skip_special_tokens=True):
-        output = self.tokenizer.decode(
-            generated_so_far, skip_special_tokens=skip_special_tokens)
-        return output
-
-    def state_vectorize(self, action, history, turn):
-        self.goal.update_user_goal(action=action)
-        inputs = json.dumps({"system": action,
-                             "goal": self.goal.get_goal_list(),
-                             "history": history,
-                             "turn": str(turn)})
-        inputs = self.encode(inputs, self.max_in_len)
-        s_vec, action_mask = inputs["input_ids"][0], inputs["attention_mask"][0]
-
-        return s_vec, action_mask
-
-    def action_vectorize(self, action, s=None):
-        # action:  [[intent, domain, slot, value], ...]
-        vec = {"vector": torch.tensor([]), "mask": torch.tensor([])}
-        if s is not None:
-            raw_inputs = self.decode(s[0])
-            self.kg.parse_input(raw_inputs)
-
-        self._append(vec, self._get_id("<s>"))
-        self._append(vec, self.token_map.get_id('start_json'))
-        self._append(vec, self.token_map.get_id('start_act'))
-
-        act_len = len(action)
-        for i, (intent, domain, slot, value) in enumerate(action):
-            if value == '?':
-                value = '<?>'
-            c_idx = {x: None for x in ["intent", "domain", "slot", "value"]}
-
-            if s is not None:
-                c_idx["intent"] = self._candidate_id(self.kg.candidate(
-                    "intent", allow_general_intent=self.allow_general_intent))
-                c_idx["domain"] = self._candidate_id(self.kg.candidate(
-                    "domain", intent=intent))
-                c_idx["slot"] = self._candidate_id(self.kg.candidate(
-                    "slot", intent=intent, domain=domain, is_mentioned=self.is_mentioned(domain)))
-                c_idx["value"] = self._candidate_id(self.kg.candidate(
-                    "value", intent=intent, domain=domain, slot=slot))
-
-            self._append(vec, self._get_id(intent), c_idx["intent"])
-            self._append(vec, self.token_map.get_id('sep_token'))
-            self._append(vec, self._get_id(domain), c_idx["domain"])
-            self._append(vec, self.token_map.get_id('sep_token'))
-            self._append(vec, self._get_id(slot), c_idx["slot"])
-            self._append(vec, self.token_map.get_id('sep_token'))
-            self._append(vec, self._get_id(value), c_idx["value"])
-
-            c_idx = [0]*self.candidate_num
-            c_idx[0] = self.token_map.get_id('end_act')[0]
-            c_idx[1] = self.token_map.get_id('sep_act')[0]
-            if i == act_len - 1:
-                x = self.token_map.get_id('end_act')
-            else:
-                x = self.token_map.get_id('sep_act')
-
-            self._append(vec, x, c_idx)
-
-        self._append(vec, self._get_id("</s>"))
-
-        # pad
-        if len(vec["vector"]) < self.max_out_len:
-            pad_len = self.max_out_len-len(vec["vector"])
-            self._append(vec, x=torch.tensor([1]*pad_len))
-        for vec_type in vec:
-            vec[vec_type] = vec[vec_type].to(torch.int64)
-
-        return vec
-
-    def _append(self, vec, x, candidate=None):
-        if type(x) is list:
-            x = torch.tensor(x)
-        mask = self._mask(x, candidate)
-        vec["vector"] = torch.cat((vec["vector"], x), dim=-1)
-        vec["mask"] = torch.cat((vec["mask"], mask), dim=0)
-
-    def _mask(self, idx, c_idx=None):
-        mask = torch.zeros(len(idx), self.candidate_num)
-        mask[:, 0] = idx
-        if c_idx is not None and len(c_idx) > 1:
-            mask[0, :] = torch.tensor(c_idx)
-
-        return mask
-
-    def _candidate_id(self, candidate):
-        if len(candidate) > self.candidate_num:
-            print(f"too many candidates. Max = {self.candidate_num}")
-        c_idx = [0]*self.candidate_num
-        for i, idx in enumerate([self._get_id(c)[0] for c in candidate[:self.candidate_num]]):
-            c_idx[i] = idx
-        return c_idx
-
-    def _get_id(self, value):
-        token_id = self.tokenizer(value, add_special_tokens=False)
-        return token_id["input_ids"]
-
-    def action_devectorize(self, action_id):
-        return self.decode(action_id)
-
-    def update_mentioned_domain(self, semantic_act):
-        for act in semantic_act:
-            domain = act[1]
-            if domain not in self.mentioned_domain:
-                self.mentioned_domain.append(domain)
-
-    def is_mentioned(self, domain):
-        if domain in self.mentioned_domain:
-            return True
-        return False
diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py
deleted file mode 100644
index 902a54068446741b39fc93f264d9a4b06245afe0..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/stepGenTUS.py
+++ /dev/null
@@ -1,656 +0,0 @@
-import json
-import os
-
-import torch
-from transformers import BartTokenizer
-
-from convlab.policy.genTUS.ppo.vector import stepGenTUSVector
-from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel
-from convlab.policy.genTUS.token_map import tokenMap
-from convlab.policy.genTUS.unify.Goal import Goal
-from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
-from convlab.policy.policy import Policy
-from convlab.task.multiwoz.goal_generator import GoalGenerator
-
-DEBUG = False
-
-
-class UserActionPolicy(Policy):
-    def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
-        self.mode = mode
-        # if mode == "semantic" and only_action:
-        #     # only generate semantic action in prediction
-        print("model_checkpoint", model_checkpoint)
-        self.only_action = only_action
-        if self.only_action:
-            print("change mode to semantic because only_action=True")
-            self.mode = "semantic"
-        self.max_in_len = 500
-        self.max_out_len = 50 if only_action else 200
-        max_act_len = kwargs.get("max_act_len", 2)
-        print("max_act_len", max_act_len)
-        self.max_action_len = max_act_len
-        if "max_act_len" in kwargs:
-            self.max_out_len = 30 * self.max_action_len
-            print("max_act_len", self.max_out_len)
-        self.max_turn = max_turn
-        if mode not in ["semantic", "language"]:
-            print("Unknown user mode")
-
-        self.reward = {"success":  self.max_turn*2,
-                       "fail": self.max_turn*-1}
-        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
-        self.device = "cuda" if torch.cuda.is_available() else "cpu"
-        train_whole_model = kwargs.get("whole_model", True)
-        self.model = stepGenTUSmodel(
-            model_checkpoint, train_whole_model=train_whole_model)
-        self.model.eval()
-        self.model.to(self.device)
-        self.model.share_memory()
-
-        self.turn_level_reward = kwargs.get("turn_level_reward", True)
-        self.cooperative = kwargs.get("cooperative", True)
-
-        dataset = kwargs.get("dataset", "")
-        self.kg = KnowledgeGraph(
-            tokenizer=self.tokenizer,
-            dataset=dataset)
-
-        self.goal_gen = GoalGenerator()
-
-        self.vector = stepGenTUSVector(
-            model_checkpoint, self.max_in_len, self.max_out_len)
-        self.norm_reward = False
-
-        self.action_penalty = kwargs.get("action_penalty", False)
-        self.usr_act_penalize = kwargs.get("usr_act_penalize", 0)
-        self.goal_list_type = kwargs.get("goal_list_type", "normal")
-        self.update_mode = kwargs.get("update_mode", "normal")
-        self.max_history = kwargs.get("max_history", 3)
-        self.init_session()
-
-    def _update_seq(self, sub_seq: list, pos: int):
-        for x in sub_seq:
-            self.seq[0, pos] = x
-            pos += 1
-
-        return pos
-
-    def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True):
-        # TODO no duplicate
-        self.kg.parse_input(raw_inputs)
-        model_input = self.vector.encode(raw_inputs, self.max_in_len)
-        # start token
-        self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
-        pos = self._update_seq([0], 0)
-        pos = self._update_seq(self.token_map.get_id('start_json'), pos)
-        pos = self._update_seq(self.token_map.get_id('start_act'), pos)
-
-        # get semantic actions
-        for act_len in range(self.max_action_len):
-            pos = self._get_semantic_action(
-                model_input, pos, mode, allow_general_intent)
-
-            terminate, token_name = self._stop_semantic(
-                model_input, pos, act_len)
-            pos = self._update_seq(self.token_map.get_id(token_name), pos)
-
-            if terminate:
-                break
-
-        if self.only_action:
-            # return semantic action. Don't need to generate text
-            return self.vector.decode(self.seq[0, :pos])
-
-        # TODO remove illegal action here?
-
-        # get text output
-        pos = self._update_seq(self.token_map.get_id("start_text"), pos)
-
-        text = self._get_text(model_input, pos)
-
-        return text
-
-    def generate_text_from_give_semantic(self, raw_inputs, semantic_action):
-        self.kg.parse_input(raw_inputs)
-        model_input = self.vector.encode(raw_inputs, self.max_in_len)
-        self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
-        pos = self._update_seq([0], 0)
-        pos = self._update_seq(self.token_map.get_id('start_json'), pos)
-        pos = self._update_seq(self.token_map.get_id('start_act'), pos)
-
-        if len(semantic_action) == 0:
-            pos = self._update_seq(self.token_map.get_id("end_act"), pos)
-
-        for act_id, (intent, domain, slot, value) in enumerate(semantic_action):
-            pos = self._update_seq(self.kg._get_token_id(intent), pos)
-            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
-            pos = self._update_seq(self.kg._get_token_id(domain), pos)
-            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
-            pos = self._update_seq(self.kg._get_token_id(slot), pos)
-            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
-            pos = self._update_seq(self.kg._get_token_id(value), pos)
-
-            if act_id == len(semantic_action) - 1:
-                token_name = "end_act"
-            else:
-                token_name = "sep_act"
-            pos = self._update_seq(self.token_map.get_id(token_name), pos)
-        pos = self._update_seq(self.token_map.get_id("start_text"), pos)
-
-        raw_output = self._get_text(model_input, pos)
-        return self._parse_output(raw_output)["text"]
-
-    def _get_text(self, model_input, pos):
-        s_pos = pos
-        for i in range(s_pos, self.max_out_len):
-            next_token_logits = self.model.get_next_token_logits(
-                model_input, self.seq[:1, :pos])
-            next_token = torch.argmax(next_token_logits, dim=-1)
-
-            if self._stop_text(next_token):
-                # text = self.vector.decode(self.seq[0, s_pos:pos])
-                # text = self._norm_str(text)
-                # return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
-                break
-
-            pos = self._update_seq([next_token], pos)
-        text = self.vector.decode(self.seq[0, s_pos:pos])
-        text = self._norm_str(text)
-        return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
-        # TODO return None
-
-    def _stop_text(self, next_token):
-        if next_token == self.token_map.get_id("end_json")[0]:
-            return True
-        elif next_token == self.token_map.get_id("end_json_2")[0]:
-            return True
-
-        return False
-
-    @staticmethod
-    def _norm_str(text: str):
-        text = text.strip('"')
-        text = text.replace('"', "'")
-        text = text.replace('\\', "")
-        return text
-
-    def _stop_semantic(self, model_input, pos, act_length=0):
-
-        outputs = self.model.get_next_token_logits(
-            model_input, self.seq[:1, :pos])
-        tokens = {}
-        for token_name in ['sep_act', 'end_act']:
-            tokens[token_name] = {
-                "token_id": self.token_map.get_id(token_name)}
-            hash_id = tokens[token_name]["token_id"][0]
-            tokens[token_name]["score"] = outputs[:, hash_id].item()
-
-        if tokens['end_act']["score"] > tokens['sep_act']["score"]:
-            terminate = True
-        else:
-            terminate = False
-
-        if act_length >= self.max_action_len - 1:
-            terminate = True
-
-        token_name = "end_act" if terminate else "sep_act"
-
-        return terminate, token_name
-
-    def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True):
-
-        intent = self._get_intent(
-            model_input, self.seq[:1, :pos], mode, allow_general_intent)
-        pos = self._update_seq(intent["token_id"], pos)
-        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
-
-        # get domain
-        domain = self._get_domain(
-            model_input, self.seq[:1, :pos], intent["token_name"], mode)
-        pos = self._update_seq(domain["token_id"], pos)
-        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
-
-        # get slot
-        slot = self._get_slot(
-            model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode)
-        pos = self._update_seq(slot["token_id"], pos)
-        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
-
-        # get value
-
-        value = self._get_value(
-            model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode)
-        pos = self._update_seq(value["token_id"], pos)
-
-        return pos
-
-    def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
-        next_token_logits = self.model.get_next_token_logits(
-            model_input, generated_so_far)
-
-        return self.kg.get_intent(next_token_logits, mode, allow_general_intent)
-
-    def _get_domain(self, model_input, generated_so_far, intent, mode="max"):
-        next_token_logits = self.model.get_next_token_logits(
-            model_input, generated_so_far)
-
-        return self.kg.get_domain(next_token_logits, intent, mode)
-
-    def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"):
-        next_token_logits = self.model.get_next_token_logits(
-            model_input, generated_so_far)
-        is_mentioned = self.vector.is_mentioned(domain)
-        return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned)
-
-    def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"):
-        next_token_logits = self.model.get_next_token_logits(
-            model_input, generated_so_far)
-
-        return self.kg.get_value(next_token_logits, intent, domain, slot, mode)
-
-    def _remove_illegal_action(self, action):
-        # Transform illegal action to legal action
-        new_action = []
-        for act in action:
-            if len(act) == 4:
-                if "<?>" in act[-1]:
-                    act = [act[0], act[1], act[2], "?"]
-                if act not in new_action:
-                    new_action.append(act)
-            else:
-                print("illegal action:", action)
-        return new_action
-
-    def _parse_output(self, in_str):
-        in_str = str(in_str)
-        in_str = in_str.replace('<s>', '').replace(
-            '<\\s>', '').replace('o"clock', "o'clock")
-        action = {"action": [], "text": ""}
-        try:
-            action = json.loads(in_str)
-        except:
-            print("invalid action:", in_str)
-            print("-"*20)
-        return action
-
-    def predict(self, sys_act, mode="max", allow_general_intent=True):
-        # raw_sys_act = sys_act
-        # sys_act = sys_act[:5]
-        # update goal
-        # TODO
-        allow_general_intent = False
-        self.model.eval()
-
-        if not self.add_sys_from_reward:
-            self.goal.update_user_goal(action=sys_act, char="sys")
-            self.sys_acts.append(sys_act)  # for terminate conversation
-
-        # update constraint
-        self.time_step += 2
-
-        history = []
-        if self.usr_acts:
-            if self.max_history == 1:
-                history = self.usr_acts[-1]
-            else:
-                history = self.usr_acts[-1*self.max_history:]
-        inputs = json.dumps({"system": sys_act,
-                             "goal": self.goal.get_goal_list(),
-                             "history": history,
-                             "turn": str(int(self.time_step/2))})
-        with torch.no_grad():
-            raw_output = self._generate_action(
-                raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
-        output = self._parse_output(raw_output)
-        self.semantic_action = self._remove_illegal_action(output["action"])
-        if not self.only_action:
-            self.utterance = output["text"]
-
-        # TODO
-        if self.is_finish():
-            self.semantic_action, self.utterance = self._good_bye()
-
-        # if self.is_finish():
-        #     print("terminated")
-
-        # if self.is_finish():
-        #     good_bye = self._good_bye()
-        #     self.goal.add_usr_da(good_bye)
-        #     return good_bye
-
-        self.goal.update_user_goal(action=self.semantic_action, char="usr")
-        self.vector.update_mentioned_domain(self.semantic_action)
-        self.usr_acts.append(self.semantic_action)
-
-        # if self._usr_terminate(usr_action):
-        #     print("terminated by user")
-        #     self.terminated = True
-
-        del inputs
-
-        if self.mode == "language":
-            # print("in", sys_act)
-            # print("out", self.utterance)
-            return self.utterance
-        else:
-            return self.semantic_action
-
-    def init_session(self, goal=None):
-        self.token_map = tokenMap(tokenizer=self.tokenizer)
-        self.token_map.default(only_action=self.only_action)
-        self.time_step = 0
-        remove_domain = "police"  # remove police domain in inference
-
-        if not goal:
-            self._new_goal(remove_domain=remove_domain)
-        else:
-            self._read_goal(goal)
-
-        self.vector.init_session(goal=self.goal)
-        print("="*20)
-        print("goal for GenTUS", self.goal)
-
-        self.terminated = False
-        self.add_sys_from_reward = False
-        self.sys_acts = []
-        self.usr_acts = []
-        self.semantic_action = []
-        self.utterance = ""
-
-    def _read_goal(self, data_goal):
-        self.goal = Goal(goal=data_goal)
-
-    def _new_goal(self, remove_domain="police", domain_len=None):
-        goal = self.goal_gen.get_user_goal()
-        self.goal = Goal(goal)
-        # keep_generate_goal = True
-        # # domain_len = 1
-        # while keep_generate_goal:
-        #     self.goal = Goal(goal_generator=self.goal_gen,
-        #                      goal_list_type=self.goal_list_type,
-        #                      update_mode=self.update_mode)
-        #     if (domain_len and len(self.goal.domains) != domain_len) or \
-        #             (remove_domain and remove_domain in self.goal.domains):
-        #         keep_generate_goal = True
-        #     else:
-        #         keep_generate_goal = False
-
-    def load(self, model_path):
-        self.model.load_state_dict(torch.load(
-            model_path, map_location=self.device))
-        # self.model = BartForConditionalGeneration.from_pretrained(
-        #     model_checkpoint)
-
-    def get_goal(self):
-        if self.goal.raw_goal is not None:
-            return self.goal.raw_goal
-        goal = {}
-        for domain in self.goal.domain_goals:
-            if domain not in goal:
-                goal[domain] = {}
-            for intent in self.goal.domain_goals[domain]:
-                if intent == "inform":
-                    slot_type = "info"
-                elif intent == "request":
-                    slot_type = "reqt"
-                elif intent == "book":
-                    slot_type = "book"
-                else:
-                    print("unknown slot type")
-                if slot_type not in goal[domain]:
-                    goal[domain][slot_type] = {}
-                for slot, value in self.goal.domain_goals[domain][intent].items():
-                    goal[domain][slot_type][slot] = value
-        return goal
-
-    def get_reward(self, sys_response=None):
-        self.add_sys_from_reward = False if sys_response is None else True
-
-        if self.add_sys_from_reward:
-            self.goal.update_user_goal(action=sys_response, char="sys")
-            self.goal.add_sys_da(sys_response)  # for evaluation
-            self.sys_acts.append(sys_response)  # for terminate conversation
-
-        if self.is_finish():
-            if self.is_success():
-                reward = self.reward["success"]
-                self.success = True
-            else:
-                reward = self.reward["fail"]
-                self.success = False
-
-        else:
-            reward = -1
-            if self.turn_level_reward:
-                reward += self.turn_reward()
-
-            self.success = None
-            # if self.action_penalty:
-            #     reward += self._system_action_penalty()
-
-        if self.norm_reward:
-            reward = (reward - 20)/60
-        return reward
-
-    def _system_action_penalty(self):
-        free_action_len = 3
-        if len(self.sys_acts) < 1:
-            return 0
-        # TODO only penalize the slots not in user goal
-        # else:
-        #     penlaty = 0
-        #     for i in range(len(self.sys_acts[-1])):
-        #         penlaty += -1*i
-        #     return penlaty
-        if len(self.sys_acts[-1]) > 3:
-            return -1*(len(self.sys_acts[-1])-free_action_len)
-        return 0
-
-    def turn_reward(self):
-        r = 0
-        r += self._new_act_reward()
-        r += self._reply_reward()
-        r += self._usr_act_len()
-        return r
-
-    def _usr_act_len(self):
-        last_act = self.usr_acts[-1]
-        penalty = 0
-        if len(last_act) > 2:
-            penalty = (2-len(last_act))*self.usr_act_penalize
-        return penalty
-
-    def _new_act_reward(self):
-        last_act = self.usr_acts[-1]
-        if last_act != self.semantic_action:
-            print(f"---> why? last {last_act} usr {self.semantic_action}")
-        new_act = []
-        for act in last_act:
-            if len(self.usr_acts) < 2:
-                break
-            if act[1].lower() == "general":
-                new_act.append(0)
-            elif act in self.usr_acts[-2]:
-                new_act.append(-1)
-            elif act not in self.usr_acts[-2]:
-                new_act.append(1)
-
-        return sum(new_act)
-
-    def _reply_reward(self):
-        if self.cooperative:
-            return self._cooperative_reply_reward()
-        else:
-            return self._non_cooperative_reply_reward()
-
-    def _non_cooperative_reply_reward(self):
-        r = []
-        reqts = []
-        infos = []
-        reply_len = 0
-        max_len = 1
-        for act in self.sys_acts[-1]:
-            if act[0] == "request":
-                reqts.append([act[1], act[2]])
-        for act in self.usr_acts[-1]:
-            if act[0] == "inform":
-                infos.append([act[1], act[2]])
-        for req in reqts:
-            if req in infos:
-                if reply_len < max_len:
-                    r.append(1)
-                elif reply_len == max_len:
-                    r.append(0)
-                else:
-                    r.append(-5)
-
-        if r:
-            return sum(r)
-        return 0
-
-    def _cooperative_reply_reward(self):
-        r = []
-        reqts = []
-        infos = []
-        for act in self.sys_acts[-1]:
-            if act[0] == "request":
-                reqts.append([act[1], act[2]])
-        for act in self.usr_acts[-1]:
-            if act[0] == "inform":
-                infos.append([act[1], act[2]])
-        for req in reqts:
-            if req in infos:
-                r.append(1)
-            else:
-                r.append(-1)
-        if r:
-            return sum(r)
-        return 0
-
-    def _usr_terminate(self):
-        for act in self.semantic_action:
-            if act[0] in ['thank', 'bye']:
-                return True
-        return False
-
-    def is_finish(self):
-        # stop by model generation?
-        if self._finish_conversation_rule():
-            self.terminated = True
-            return True
-        elif self._usr_terminate():
-            self.terminated = True
-            return True
-        self.terminated = False
-        return False
-
-    def is_success(self):
-        task_complete = self.goal.task_complete()
-        # goal_status = self.goal.all_mentioned()
-        # should mentioned all slots
-        if task_complete:  # and goal_status["complete"] > 0.6:
-            return True
-        return False
-
-    def _good_bye(self):
-        if self.is_success():
-            return [['thank', 'general', 'none', 'none']], "thank you. bye"
-            # if self.mode == "semantic":
-            #     return [['thank', 'general', 'none', 'none']]
-            # else:
-            #     return "bye"
-        else:
-            return [["bye", "general", "None", "None"]], "bye"
-            if self.mode == "semantic":
-                return [["bye", "general", "None", "None"]]
-            return "bye"
-
-    def _finish_conversation_rule(self):
-        if self.is_success():
-            return True
-
-        if self.time_step > self.max_turn:
-            return True
-
-        if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]):
-            return True
-        return False
-
-    def is_terminated(self):
-        # Is there any action to say?
-        self.is_finish()
-        return self.terminated
-
-
-class UserPolicy(Policy):
-    def __init__(self,
-                 model_checkpoint,
-                 mode="semantic",
-                 only_action=True,
-                 sample=False,
-                 action_penalty=False,
-                 **kwargs):
-        # self.config = config
-        # if not os.path.exists(self.config["model_dir"]):
-        #     os.mkdir(self.config["model_dir"])
-        #     model_downloader(self.config["model_dir"],
-        #                      "https://zenodo.org/record/5779832/files/default.zip")
-
-        self.policy = UserActionPolicy(
-            model_checkpoint,
-            mode=mode,
-            only_action=only_action,
-            action_penalty=action_penalty,
-            **kwargs)
-        self.policy.load(os.path.join(
-            model_checkpoint, "pytorch_model.bin"))
-        self.sample = sample
-
-    def predict(self, sys_act, mode="max"):
-        if self.sample:
-            mode = "sample"
-        else:
-            mode = "max"
-        response = self.policy.predict(sys_act, mode)
-        return response
-
-    def init_session(self, goal=None):
-        self.policy.init_session(goal)
-
-    def is_terminated(self):
-        return self.policy.is_terminated()
-
-    def get_reward(self, sys_response=None):
-        return self.policy.get_reward(sys_response)
-
-    def get_goal(self):
-        if hasattr(self.policy, 'get_goal'):
-            return self.policy.get_goal()
-        return None
-
-
-if __name__ == "__main__":
-    import os
-
-    from convlab.dialog_agent import PipelineAgent
-    # from convlab.nlu.jointBERT.multiwoz import BERTNLU
-    from convlab.util.custom_util import set_seed
-
-    set_seed(20220220)
-    # Test semantic level behaviour
-    model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0'
-    usr_policy = UserPolicy(
-        model_checkpoint,
-        mode="semantic")
-    usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
-    usr_nlu = None  # BERTNLU()
-    usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
-    print(usr.policy.get_goal())
-
-    print(usr.response([]))
-    print(usr.policy.policy.goal.status)
-    print(usr.response([["request", "attraction", "area", "?"]]))
-    print(usr.policy.policy.goal.status)
-    print(usr.response([["request", "attraction", "area", "?"]]))
-    print(usr.policy.policy.goal.status)
diff --git a/convlab/policy/genTUS/stepGenTUSmodel.py b/convlab/policy/genTUS/stepGenTUSmodel.py
deleted file mode 100644
index e2eaf7bc5064262808120aac4a9cbe2eb007d863..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/stepGenTUSmodel.py
+++ /dev/null
@@ -1,114 +0,0 @@
-
-import json
-
-import torch
-from torch.nn.functional import softmax, one_hot, cross_entropy
-
-from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
-from convlab.policy.genTUS.token_map import tokenMap
-from convlab.policy.genTUS.utils import append_tokens
-from transformers import (BartConfig, BartForConditionalGeneration,
-                          BartTokenizer)
-
-
-class stepGenTUSmodel(BartForConditionalGeneration):
-    def __init__(self, model_checkpoint, train_whole_model=True, **kwargs):
-        config = BartConfig.from_pretrained(model_checkpoint)
-        super().__init__(config, **kwargs)
-
-        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
-        self.vocab = len(self.tokenizer)
-        self.kg = KnowledgeGraph(self.tokenizer)
-        self.action_kg = KnowledgeGraph(self.tokenizer)
-        self.token_map = tokenMap(self.tokenizer)
-        # only_action doesn't matter. it is only used for get_log_prob
-        self.token_map.default(only_action=True)
-
-        if not train_whole_model:
-            for param in self.parameters():
-                param.requires_grad = False
-
-            for param in self.model.decoder.layers[-1].fc1.parameters():
-                param.requires_grad = True
-            for param in self.model.decoder.layers[-1].fc2.parameters():
-                param.requires_grad = True
-
-    def get_trainable_param(self):
-
-        return filter(
-            lambda p: p.requires_grad, self.parameters())
-
-    def get_next_token_logits(self, model_input, generated_so_far):
-        input_ids = model_input["input_ids"].to(self.device)
-        attention_mask = model_input["attention_mask"].to(self.device)
-        outputs = self.forward(
-            input_ids=input_ids,
-            attention_mask=attention_mask,
-            decoder_input_ids=generated_so_far,
-            return_dict=True)
-        return outputs.logits[:, -1, :]
-
-    def get_log_prob(self, s, a, action_mask, prob_mask):
-        output = self.forward(input_ids=s,
-                              attention_mask=action_mask,
-                              decoder_input_ids=a)
-        prob = self._norm_prob(a[:, 1:].long(),
-                               output.logits[:, :-1, :],
-                               prob_mask[:, 1:, :].long())
-        return prob
-
-    def _norm_prob(self, a, prob, mask):
-        prob = softmax(prob, -1)
-        base = self._base(prob, mask).to(self.device)  # [b, seq_len]
-        prob = (prob*one_hot(a, num_classes=self.vocab)).sum(-1)
-        prob = torch.log(prob / base)
-        pad_mask = a != 1
-        prob = prob*pad_mask.float()
-        return prob.sum(-1)
-
-    @staticmethod
-    def _base(prob, mask):
-        batch_size, seq_len, dim = prob.shape
-        base = torch.zeros(batch_size, seq_len)
-        for b in range(batch_size):
-            for s in range(seq_len):
-                temp = [prob[b, s, c] for c in mask[b, s, :] if c > 0]
-                base[b, s] = torch.sum(torch.tensor(temp))
-        return base
-
-
-if __name__ == "__main__":
-    import os
-    from convlab.util.custom_util import set_seed
-    from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
-    set_seed(0)
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    model_checkpoint = 'results/genTUS-22-01-31-09-21/'
-    usr = UserActionPolicy(model_checkpoint=model_checkpoint)
-    usr.model.load_state_dict(torch.load(
-        os.path.join(model_checkpoint, "pytorch_model.bin"), map_location=device))
-    usr.model.eval()
-
-    test_file = "convlab/policy/genTUS/data/goal_status_validation_v1.json"
-    data = json.load(open(test_file))
-    test_id = 20
-    inputs = usr.tokenizer(data["dialog"][test_id]["in"],
-                           max_length=400,
-                           return_tensors="pt",
-                           truncation=True)
-
-    actions = [data["dialog"][test_id]["out"],
-               data["dialog"][test_id+100]["out"]]
-
-    for action in actions:
-        action = json.loads(action)
-        vec = usr.vector.action_vectorize(
-            action["action"], s=inputs["input_ids"])
-
-        print({"action": action["action"]})
-        print("get_log_prob", usr.model.get_log_prob(
-            inputs["input_ids"],
-            torch.unsqueeze(vec["vector"], 0),
-            inputs["attention_mask"],
-            torch.unsqueeze(vec["mask"], 0)))
diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py
deleted file mode 100644
index 7825c2880928c40f68284b0c3199932cd1cfc477..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/token_map.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import json
-
-
-class tokenMap:
-    def __init__(self, tokenizer):
-        self.tokenizer = tokenizer
-        self.token_name = {}
-        self.hash_map = {}
-        self.debug = False
-        self.default()
-
-    def default(self, only_action=False):
-        self.format_tokens = {
-            'start_json': '{"action": [',   # 49643, 10845, 7862, 646
-            'start_act': '["',              # 49329
-            'sep_token': '", "',            # 1297('",'), 22
-            'sep_act': '"], ["',               # 49177
-            'end_act': '"]], "',            # 42248, 7479, 22
-            'start_text': 'text": "',       # 29015, 7862, 22
-            'end_json': '}',                 # 24303
-            'end_json_2': '"}'                 # 48805
-        }
-        if only_action:
-            self.format_tokens['end_act'] = '"]]}'
-        for token_name in self.format_tokens:
-            self.add_token(
-                token_name, self.format_tokens[token_name])
-
-    def add_token(self, token_name, value):
-        if token_name in self.token_name and self.debug:
-            print(f"---> duplicate token: {token_name}({value})!!!!!!!")
-
-        token_id = self.tokenizer(str(value), add_special_tokens=False)[
-            "input_ids"]
-        self.token_name[token_name] = {"value": value, "token_id": token_id}
-        # print(token_id)
-        hash_id = token_id[0]
-        if hash_id in self.hash_map and self.debug:
-            print(
-                f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}")
-        self.hash_map[hash_id] = {
-            "name": token_name, "value": value, "token_id": token_id}
-
-    def get_info(self, hash_id):
-        return self.hash_map[hash_id]
-
-    def get_id(self, token_name):
-        # workaround
-        # if token_name not in self.token_name[token_name]:
-        #     self.add_token(token_name, token_name)
-        return self.token_name[token_name]["token_id"]
-
-    def get_token_value(self, token_name):
-        return self.token_name[token_name]["value"]
-
-    def token_name_is_in(self, token_name):
-        if token_name in self.token_name:
-            return True
-        return False
-
-    def hash_id_is_in(self, hash_id):
-        if hash_id in self.hash_map:
-            return True
-        return False
diff --git a/convlab/policy/genTUS/train_model.py b/convlab/policy/genTUS/train_model.py
deleted file mode 100644
index 2162417461d692514e3b27742dbfd477491fc24e..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/train_model.py
+++ /dev/null
@@ -1,258 +0,0 @@
-import json
-import os
-import sys
-from argparse import ArgumentParser
-from datetime import datetime
-from pprint import pprint
-import numpy as np
-import torch
-import transformers
-from datasets import Dataset, load_metric
-from tqdm import tqdm
-from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
-                          BartForConditionalGeneration, BartTokenizer,
-                          DataCollatorForSeq2Seq, Seq2SeqTrainer,
-                          Seq2SeqTrainingArguments)
-
-sys.path.append(os.path.dirname(os.path.dirname(
-    os.path.dirname(os.path.abspath(__file__)))))
-
-os.environ["WANDB_DISABLED"] = "true"
-
-METRIC = load_metric("sacrebleu")
-TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base")
-TOKENIZER.add_tokens(["<?>"])
-MAX_IN_LEN = 500
-MAX_OUT_LEN = 500
-
-
-def arg_parser():
-    parser = ArgumentParser()
-    # data_name, dial_ids_order, split2ratio
-    parser.add_argument("--model-type", type=str, default="unify",
-                        help="unify or multiwoz")
-    parser.add_argument("--data-name", type=str, default="multiwoz21",
-                        help="multiwoz21, sgd, tm1, tm2, tm3, sgd+tm, or all")
-    parser.add_argument("--dial-ids-order", type=int, default=0)
-    parser.add_argument("--split2ratio", type=float, default=1)
-    parser.add_argument("--batch-size", type=int, default=16)
-    parser.add_argument("--model-checkpoint", type=str,
-                        default="facebook/bart-base")
-    return parser.parse_args()
-
-
-def gentus_compute_metrics(eval_preds):
-    preds, labels = eval_preds
-    if isinstance(preds, tuple):
-        preds = preds[0]
-    decoded_preds = TOKENIZER.batch_decode(
-        preds, skip_special_tokens=True, max_length=MAX_OUT_LEN)
-
-    # Replace -100 in the labels as we can't decode them.
-    labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id)
-    decoded_labels = TOKENIZER.batch_decode(
-        labels, skip_special_tokens=True, max_length=MAX_OUT_LEN)
-
-    act, text = postprocess_text(decoded_preds, decoded_labels)
-
-    result = METRIC.compute(
-        # predictions=decoded_preds, references=decoded_labels)
-        predictions=text["preds"], references=text["labels"])
-    result = {"bleu": result["score"]}
-    f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"])
-    for s in f1_scores:
-        result[s] = f1_scores[s]
-
-    result = {k: round(v, 4) for k, v in result.items()}
-    return result
-
-
-def postprocess_text(preds, labels):
-    act = {"preds": [], "labels": []}
-    text = {"preds": [], "labels": []}
-
-    for pred, label in zip(preds, labels):
-        model_output = parse_output(pred.strip())
-        label_output = parse_output(label.strip())
-        if len(label_output["text"]) < 1:
-            continue
-        act["preds"].append(model_output.get("action", []))
-        text["preds"].append(model_output.get("text", pred.strip()))
-        act["labels"].append(label_output["action"])
-        text["labels"].append([label_output["text"]])
-
-    return act, text
-
-
-def parse_output(in_str):
-    in_str = in_str.replace('<s>', '').replace('<\\s>', '')
-    try:
-        output = json.loads(in_str)
-    except:
-        # print(f"invalid action {in_str}")
-        output = {"action": [], "text": ""}
-    return output
-
-
-def f1_measure(pred_acts, label_acts):
-    result = {"precision": [], "recall": [], "f1": []}
-    for pred, label in zip(pred_acts, label_acts):
-        r = tp_fn_fp(pred, label)
-        for m in result:
-            result[m].append(r[m])
-    for m in result:
-        result[m] = sum(result[m])/len(result[m])
-
-    return result
-
-
-def tp_fn_fp(pred, label):
-    tp, fn, fp = 0.0, 0.0, 0.0
-    precision, recall, f1 = 0, 0, 0
-    for p in pred:
-        if p in label:
-            tp += 1
-        else:
-            fp += 1
-    for l in label:
-        if l not in pred:
-            fn += 1
-    if (tp+fp) > 0:
-        precision = tp / (tp+fp)
-    if (tp+fn) > 0:
-        recall = tp/(tp+fn)
-    if (precision + recall) > 0:
-        f1 = (2*precision*recall)/(precision+recall)
-
-    return {"precision": precision, "recall": recall, "f1": f1}
-
-
-class TrainerHelper:
-    def __init__(self, tokenizer, max_input_length=500, max_target_length=500):
-        print("transformers version is: ", transformers.__version__)
-        self.tokenizer = tokenizer
-        self.max_input_length = max_input_length
-        self.max_target_length = max_target_length
-        self.base_name = "convlab/policy/genTUS"
-        self.dir_name = ""
-
-    def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
-        # base_name = "convlab/policy/genTUS/unify/data"
-        if model_type not in ["unify", "multiwoz"]:
-            print("Unknown model type. Currently only support unify and multiwoz")
-        self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}"
-        return os.path.join(self.base_name, model_type, 'data', self.dir_name)
-
-    def get_model_folder(self, model_type):
-        folder_name = os.path.join(
-            self.base_name, model_type, "experiments", self.dir_name)
-        if not os.path.exists(folder_name):
-            os.makedirs(folder_name)
-        return folder_name
-
-    def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
-        data_folder = self._get_data_folder(
-            model_type, data_name, dial_ids_order, split2ratio)
-
-        raw_data = {}
-        for d_type in ["train", "validation", "test"]:
-            f_name = os.path.join(data_folder, f"{d_type}.json")
-            raw_data[d_type] = json.load(open(f_name))
-
-        tokenized_datasets = {}
-        for data_type, data in raw_data.items():
-            tokenized_datasets[data_type] = Dataset.from_dict(
-                self._preprocess(data["dialog"]))
-
-        return tokenized_datasets
-
-    def _preprocess(self, examples):
-        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
-        if isinstance(examples, dict):
-            examples = [examples]
-        for example in tqdm(examples):
-            inputs = self.tokenizer(example["in"],
-                                    max_length=self.max_input_length,
-                                    truncation=True)
-
-            # Setup the tokenizer for targets
-            with self.tokenizer.as_target_tokenizer():
-                labels = self.tokenizer(example["out"],
-                                        max_length=self.max_target_length,
-                                        truncation=True)
-            for key in ["input_ids", "attention_mask"]:
-                model_inputs[key].append(inputs[key])
-            model_inputs["labels"].append(labels["input_ids"])
-
-        return model_inputs
-
-
-def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"):
-    tokenizer = TOKENIZER
-
-    train_helper = TrainerHelper(
-        tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length)
-    data = train_helper.parse_data(model_type=model_type,
-                                   data_name=data_name,
-                                   dial_ids_order=dial_ids_order,
-                                   split2ratio=split2ratio)
-
-    model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
-    model.resize_token_embeddings(len(tokenizer))
-    fp16 = False
-    if torch.cuda.is_available():
-        fp16 = True
-
-    model_dir = os.path.join(
-        train_helper.get_model_folder(model_type),
-        f"{datetime.now().strftime('%y-%m-%d-%H-%M')}")
-
-    args = Seq2SeqTrainingArguments(
-        model_dir,
-        evaluation_strategy="epoch",
-        learning_rate=2e-5,
-        per_device_train_batch_size=batch_size,
-        per_device_eval_batch_size=batch_size,
-        weight_decay=0.01,
-        save_total_limit=2,
-        num_train_epochs=5,
-        predict_with_generate=True,
-        fp16=fp16,
-        push_to_hub=False,
-        generation_max_length=max_target_length,
-        logging_dir=os.path.join(model_dir, 'log')
-    )
-    data_collator = DataCollatorForSeq2Seq(
-        tokenizer, model=model, padding=True)
-
-    # customize this trainer
-    trainer = Seq2SeqTrainer(
-        model=model,
-        args=args,
-        train_dataset=data["train"],
-        eval_dataset=data["test"],
-        data_collator=data_collator,
-        tokenizer=tokenizer,
-        compute_metrics=gentus_compute_metrics)
-    print("start training...")
-    trainer.train()
-    print("saving model...")
-    trainer.save_model()
-
-
-def main():
-    args = arg_parser()
-    print("---> data_name", args.data_name)
-    train(model_type=args.model_type,
-          data_name=args.data_name,
-          dial_ids_order=args.dial_ids_order,
-          split2ratio=args.split2ratio,
-          batch_size=args.batch_size,
-          max_input_length=MAX_IN_LEN,
-          max_target_length=MAX_OUT_LEN,
-          model_checkpoint=args.model_checkpoint)
-
-
-if __name__ == "__main__":
-    main()
-    # sgd+tm: 46000
diff --git a/convlab/policy/genTUS/unify/Goal.py b/convlab/policy/genTUS/unify/Goal.py
deleted file mode 100644
index f257fa19d2ee5f87b9fc8d16b806705516c864a4..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/unify/Goal.py
+++ /dev/null
@@ -1,231 +0,0 @@
-"""
-The user goal for unify data format
-"""
-import json
-from convlab.policy.tus.unify.Goal import old_goal2list
-from convlab.task.multiwoz.goal_generator import GoalGenerator
-from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
-from convlab.util.custom_util import slot_mapping
-DEF_VAL_UNK = '?'  # Unknown
-DEF_VAL_DNC = 'dontcare'  # Do not care
-DEF_VAL_NUL = 'none'  # for none
-NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, ""]
-
-NOT_MENTIONED = "not mentioned"
-FULFILLED = "fulfilled"
-REQUESTED = "requested"
-CONFLICT = "conflict"
-
-
-class Goal:
-    """ User Goal Model Class. """
-
-    def __init__(self, goal=None):
-        """
-        create new Goal from a dialog or from goal_generator
-        Args:
-            goal: can be a list (create from a dialog), an abus goal, or none
-        """
-        self.domains = []
-        self.domain_goals = {}
-        self.status = {}
-        self.invert_slot_mapping = {v: k for k, v in slot_mapping.items()}
-        self.raw_goal = None
-
-        self._init_goal_from_data(goal)
-        self._init_status()
-
-    def __str__(self):
-        return '-----Goal-----\n' + \
-               json.dumps(self.domain_goals, indent=4) + \
-               '\n-----Goal-----'
-
-    def _init_goal_from_data(self, goal=None):
-        if not goal:
-            goal_gen = GoalGenerator()
-            old_goal = goal_gen.get_user_goal()
-            self.raw_goal = old_goal
-            goal = old_goal2list(old_goal)
-
-        elif isinstance(goal, dict):
-            self.raw_goal = goal
-            goal = old_goal2list(goal)
-
-        elif isinstance(goal, ABUS_Goal):
-            self.raw_goal = goal.domain_goals
-            goal = old_goal2list(goal.domain_goals)
-
-        # be careful of this order
-        for domain, intent, slot, value in goal:
-            if domain == "none":
-                continue
-            if domain not in self.domains:
-                self.domains.append(domain)
-                self.domain_goals[domain] = {}
-            if intent not in self.domain_goals[domain]:
-                self.domain_goals[domain][intent] = {}
-
-            if not value:
-                if intent == "request":
-                    self.domain_goals[domain][intent][slot] = DEF_VAL_UNK
-                else:
-                    print(
-                        f"unknown no value intent {domain}, {intent}, {slot}")
-            else:
-                self.domain_goals[domain][intent][slot] = value
-
-    def _init_status(self):
-        for domain, domain_goal in self.domain_goals.items():
-            if domain not in self.status:
-                self.status[domain] = {}
-            for slot_type, sub_domain_goal in domain_goal.items():
-                if slot_type not in self.status[domain]:
-                    self.status[domain][slot_type] = {}
-                for slot in sub_domain_goal:
-                    if slot not in self.status[domain][slot_type]:
-                        self.status[domain][slot_type][slot] = {}
-                    self.status[domain][slot_type][slot] = {
-                        "value": str(sub_domain_goal[slot]),
-                        "status": NOT_MENTIONED}
-
-    def get_goal_list(self, data_goal=None):
-        goal_list = []
-        if data_goal:
-            # make sure the order!!!
-            for domain, intent, slot, _ in data_goal:
-                status = self._get_status(domain, intent, slot)
-                value = self.domain_goals[domain][intent][slot]
-                goal_list.append([intent, domain, slot, value, status])
-            return goal_list
-        else:
-            for domain, domain_goal in self.domain_goals.items():
-                for intent, sub_goal in domain_goal.items():
-                    for slot, value in sub_goal.items():
-                        status = self._get_status(domain, intent, slot)
-                        goal_list.append([intent, domain, slot, value, status])
-
-        return goal_list
-
-    def _get_status(self, domain, intent, slot):
-        if domain not in self.status:
-            return NOT_MENTIONED
-        if intent not in self.status[domain]:
-            return NOT_MENTIONED
-        if slot not in self.status[domain][intent]:
-            return NOT_MENTIONED
-        return self.status[domain][intent][slot]["status"]
-
-    def task_complete(self):
-        """
-        Check that all requests have been met
-        Returns:
-            (boolean): True to accomplish.
-        """
-        for domain, domain_goal in self.status.items():
-            if domain not in self.domain_goals:
-                continue
-            for slot_type, sub_domain_goal in domain_goal.items():
-                if slot_type not in self.domain_goals[domain]:
-                    continue
-                for slot, status in sub_domain_goal.items():
-                    if slot not in self.domain_goals[domain][slot_type]:
-                        continue
-                    # for strict success, turn this on
-                    if status["status"] in [NOT_MENTIONED, CONFLICT]:
-                        if status["status"] == CONFLICT and slot in ["arrive by", "leave at"]:
-                            continue
-                        return False
-                    if "?" in status["value"]:
-                        return False
-
-        return True
-
-    # TODO change to update()?
-    def update_user_goal(self, action, char="usr"):
-        # update request and booked
-        if char == "usr":
-            self._user_action_update(action)
-        elif char == "sys":
-            self._system_action_update(action)
-        else:
-            print("!!!UNKNOWN CHAR!!!")
-
-    def _user_action_update(self, action):
-        # no need to update user goal
-        for intent, domain, slot, _ in action:
-            goal_intent = self._check_slot_and_intent(domain, slot)
-            if not goal_intent:
-                continue
-            # fulfilled by user
-            if is_inform(intent):
-                self._set_status(goal_intent, domain, slot, FULFILLED)
-            # requested by user
-            if is_request(intent):
-                self._set_status(goal_intent, domain, slot, REQUESTED)
-
-    def _system_action_update(self, action):
-        for intent, domain, slot, value in action:
-            goal_intent = self._check_slot_and_intent(domain, slot)
-            if not goal_intent:
-                continue
-            # fulfill request by system
-            if is_inform(intent) and is_request(goal_intent):
-                self._set_status(goal_intent, domain, slot, FULFILLED)
-                self._set_goal(goal_intent, domain, slot, value)
-
-            if is_inform(intent) and is_inform(goal_intent):
-                # fulfill infrom by system
-                if value == self.domain_goals[domain][goal_intent][slot]:
-                    self._set_status(goal_intent, domain, slot, FULFILLED)
-                # conflict system inform
-                else:
-                    self._set_status(goal_intent, domain, slot, CONFLICT)
-            # requested by system
-            if is_request(intent) and is_inform(goal_intent):
-                self._set_status(goal_intent, domain, slot, REQUESTED)
-
-    def _set_status(self, intent, domain, slot, status):
-        self.status[domain][intent][slot]["status"] = status
-
-    def _set_goal(self, intent, domain, slot, value):
-        # old_value = self.domain_goals[domain][intent][slot]
-        self.domain_goals[domain][intent][slot] = value
-        self.status[domain][intent][slot]["value"] = value
-        # print(
-        #     f"updating user goal {intent}-{domain}-{slot} {old_value}-> {value}")
-
-    def _check_slot_and_intent(self, domain, slot):
-        not_found = ""
-        if domain not in self.domain_goals:
-            return not_found
-        for intent in self.domain_goals[domain]:
-            if slot in self.domain_goals[domain][intent]:
-                return intent
-        return not_found
-
-
-def is_inform(intent):
-    if "inform" in intent:
-        return True
-    return False
-
-
-def is_request(intent):
-    if "request" in intent:
-        return True
-    return False
-
-
-def transform_data_act(data_action):
-    action_list = []
-    for _, dialog_act in data_action.items():
-        for act in dialog_act:
-            value = act.get("value", "")
-            if not value:
-                if "request" in act["intent"]:
-                    value = "?"
-                else:
-                    value = "none"
-            action_list.append(
-                [act["intent"], act["domain"], act["slot"], value])
-    return action_list
diff --git a/convlab/policy/genTUS/unify/build_data.py b/convlab/policy/genTUS/unify/build_data.py
deleted file mode 100644
index 50873a1d4b6ffaa8ee49c84a4b088ae56ad13554..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/unify/build_data.py
+++ /dev/null
@@ -1,211 +0,0 @@
-import json
-import os
-import sys
-from argparse import ArgumentParser
-
-from tqdm import tqdm
-
-from convlab.policy.genTUS.unify.Goal import Goal, transform_data_act
-from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
-
-
-sys.path.append(os.path.dirname(os.path.dirname(
-    os.path.dirname(os.path.abspath(__file__)))))
-
-
-def arg_parser():
-    parser = ArgumentParser()
-    parser.add_argument("--dataset", type=str, default="multiwoz21",
-                        help="the dataset, such as multiwoz21, sgd, tm1, tm2, and tm3.")
-    parser.add_argument("--dial-ids-order", type=int, default=0)
-    parser.add_argument("--split2ratio", type=float, default=1)
-    parser.add_argument("--random-order", action="store_true")
-    parser.add_argument("--no-status", action="store_true")
-    parser.add_argument("--add-history",  action="store_true")
-    parser.add_argument("--remove-domain", type=str, default="")
-
-    return parser.parse_args()
-
-class DataBuilder:
-    def __init__(self, dataset='multiwoz21'):
-        self.dataset = dataset
-
-    def setup_data(self,
-                   raw_data,
-                   random_order=False,
-                   no_status=False,
-                   add_history=False,
-                   remove_domain=None):
-        examples = {data_split: {"dialog": []} for data_split in raw_data}
-
-        for data_split, dialogs in raw_data.items():
-            for dialog in tqdm(dialogs, ascii=True):
-                example = self._one_dialog(dialog=dialog,
-                                           add_history=add_history,
-                                           random_order=random_order,
-                                           no_status=no_status)
-                examples[data_split]["dialog"] += example
-
-        return examples
-
-    def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False):
-        example = []
-        history = []
-
-        data_goal = self.norm_domain_goal(create_goal(dialog))
-        if not data_goal:
-            return example
-        user_goal = Goal(goal=data_goal)
-
-        for turn_id in range(0, len(dialog["turns"]), 2):
-            sys_act = self._get_sys_act(dialog, turn_id)
-
-            user_goal.update_user_goal(action=sys_act, char="sys")
-            usr_goal_str = self._user_goal_str(user_goal, data_goal, random_order, no_status)
-
-            usr_act = self.norm_domain(transform_data_act(
-                dialog["turns"][turn_id]["dialogue_acts"]))
-            user_goal.update_user_goal(action=usr_act, char="usr")
-
-            # change value "?" to "<?>"
-            usr_act = self._modify_act(usr_act)
-
-            in_str = self._dump_in_str(sys_act, usr_goal_str, history, turn_id, add_history)
-            out_str = self._dump_out_str(usr_act, dialog["turns"][turn_id]["utterance"])
-
-            history.append(usr_act)
-            if usr_act:
-                example.append({"in": in_str, "out": out_str})
-
-        return example
-
-    def _get_sys_act(self, dialog, turn_id):
-        sys_act = []
-        if turn_id > 0:
-            sys_act = self.norm_domain(transform_data_act(
-                dialog["turns"][turn_id - 1]["dialogue_acts"]))
-        return sys_act
-
-    def _user_goal_str(self, user_goal, data_goal, random_order, no_status):
-        if random_order:
-            usr_goal_str = user_goal.get_goal_list()
-        else:
-            usr_goal_str = user_goal.get_goal_list(data_goal=data_goal)
-
-        if no_status:
-            usr_goal_str = self._remove_status(usr_goal_str)
-        return usr_goal_str
-
-    def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history):
-        in_str = {}
-        in_str["system"] = self._modify_act(sys_act)
-        in_str["goal"] = usr_goal_str
-        if add_history:
-            h = []
-            if history:
-                h = history[-3:]
-            in_str["history"] = h
-            in_str["turn"] = str(int(turn_id/2))
-
-        return json.dumps(in_str)
-
-    def _dump_out_str(self, usr_act, text):
-        out_str = {"action": usr_act, "text": text}
-        return json.dumps(out_str)
-
-    @staticmethod
-    def _norm_intent(intent):
-        if intent in ["inform_intent", "negate_intent", "affirm_intent", "request_alts"]:
-            return f"_{intent}"
-        return intent
-
-    def norm_domain(self, x):
-        if not x:
-            return x
-        norm_result = []
-        # print(x)
-        for intent, domain, slot, value in x:
-            if "_" in domain:
-                domain = domain.split('_')[0]
-            if not domain:
-                domain = "none"
-            if not slot:
-                slot = "none"
-            if not value:
-                if intent == "request":
-                    value = "<?>"
-                else:
-                    value = "none"
-            norm_result.append([self._norm_intent(intent), domain, slot, value])
-        return norm_result
-
-    def norm_domain_goal(self, x):
-        if not x:
-            return x
-        norm_result = []
-        # take care of the order!
-        for domain, intent, slot, value in x:
-            if "_" in domain:
-                domain = domain.split('_')[0]
-            if not domain:
-                domain = "none"
-            if not slot:
-                slot = "none"
-            if not value:
-                if intent == "request":
-                    value = "<?>"
-                else:
-                    value = "none"
-            norm_result.append([domain, self._norm_intent(intent), slot, value])
-        return norm_result
-
-    @staticmethod
-    def _remove_status(goal_list):
-        new_list = [[goal[0], goal[1], goal[2], goal[3]]
-                    for goal in goal_list]
-        return new_list
-
-    @staticmethod
-    def _modify_act(act):
-        new_act = []
-        for i, d, s, value in act:
-            if value == "?":
-                new_act.append([i, d, s, "<?>"])
-            else:
-                new_act.append([i, d, s, value])
-        return new_act
-
-
-if __name__ == "__main__":
-    args = arg_parser()
-
-    base_name = "convlab/policy/genTUS/unify/data"
-    dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}"
-    folder_name = os.path.join(base_name, dir_name)
-    remove_domain = args.remove_domain
-
-    if not os.path.exists(folder_name):
-        os.makedirs(folder_name)
-
-    dataset = load_experiment_dataset(
-        data_name=args.dataset,
-        dial_ids_order=args.dial_ids_order,
-        split2ratio=args.split2ratio)
-    data_builder = DataBuilder(dataset=args.dataset)
-    data = data_builder.setup_data(
-        raw_data=dataset,
-        random_order=args.random_order,
-        no_status=args.no_status,
-        add_history=args.add_history,
-        remove_domain=remove_domain)
-
-    for data_type in data:
-        if remove_domain:
-            file_name = os.path.join(
-                folder_name,
-                f"no{remove_domain}_{data_type}.json")
-        else:
-            file_name = os.path.join(
-                folder_name,
-                f"{data_type}.json")
-        json.dump(data[data_type], open(file_name, 'w'), indent=2)
diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py
deleted file mode 100644
index 68af13e481fe4799dfc2a6f3763b526611eabd9c..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/unify/knowledge_graph.py
+++ /dev/null
@@ -1,252 +0,0 @@
-import json
-from random import choices
-
-from convlab.policy.genTUS.token_map import tokenMap
-
-from transformers import BartTokenizer
-
-DEBUG = False
-DATASET = "unify"
-
-
-class KnowledgeGraph:
-    def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"):
-        print("dataset", dataset)
-        self.debug = DEBUG
-        self.tokenizer = tokenizer
-
-        if "multiwoz" in dataset:
-            self.domain_intent = ["inform", "request"]
-            self.general_intent = ["thank", "bye"]
-        # use sgd dataset intents as default
-        else:
-            self.domain_intent = ["_inform_intent",
-                                  "_negate_intent",
-                                  "_affirm_intent",
-                                  "inform",
-                                  "request",
-                                  "affirm",
-                                  "negate",
-                                  "select",
-                                  "_request_alts"]
-            self.general_intent = ["thank_you", "goodbye"]
-
-        self.general_domain = "none"
-        self.kg_map = {"intent": tokenMap(tokenizer=self.tokenizer)}
-
-        for intent in self.domain_intent + self.general_intent:
-            self.kg_map["intent"].add_token(intent, intent)
-
-        self.init()
-
-    def init(self):
-        for map_type in ["domain", "slot", "value"]:
-            self.kg_map[map_type] = tokenMap(tokenizer=self.tokenizer)
-        self.add_token("<?>", "value")
-
-    def parse_input(self, in_str):
-        self.init()
-        inputs = json.loads(in_str)
-        self.sys_act = inputs["system"]
-        self.user_goal = {}
-        self._add_none_domain()
-        for intent, domain, slot, value, _ in inputs["goal"]:
-            self._update_user_goal(intent, domain, slot, value, source="goal")
-
-        for intent, domain, slot, value in self.sys_act:
-            self._update_user_goal(intent, domain, slot, value, source="sys")
-
-    def _add_none_domain(self):
-        self.user_goal["none"] = {"none": "none"}
-        # add slot
-        self.add_token("none", "domain")
-        self.add_token("none", "slot")
-        self.add_token("none", "value")
-
-    def _update_user_goal(self, intent, domain, slot, value, source="goal"):
-
-        if value == "?":
-            value = "<?>"
-
-        if intent == "request" and source == "sys":
-            value = "dontcare"  # user can "dontcare" system request
-
-        if source == "sys" and intent != "request":
-            return
-
-        if domain not in self.user_goal:
-            self.user_goal[domain] = {}
-            self.user_goal[domain]["none"] = ["none"]
-            self.add_token(domain, "domain")
-            self.add_token("none", "slot")
-            self.add_token("none", "value")
-
-        if slot not in self.user_goal[domain]:
-            self.user_goal[domain][slot] = []
-            self.add_token(domain, "slot")
-
-        if value not in self.user_goal[domain][slot]:
-            value = json.dumps(str(value))[1:-1]
-            self.user_goal[domain][slot].append(value)
-            value = value.replace('"', "'")
-            self.add_token(value, "value")
-
-    def add_token(self, token_name, map_type):
-        if map_type == "value":
-            token_name = token_name.replace('"', "'")
-        if not self.kg_map[map_type].token_name_is_in(token_name):
-            self.kg_map[map_type].add_token(token_name, token_name)
-
-    def _get_max_score(self, outputs, candidate_list, map_type):
-        score = {}
-        if not candidate_list:
-            print(f"ERROR: empty candidate list for {map_type}")
-            score[1] = {"token_id": self._get_token_id(
-                "none"), "token_name": "none"}
-
-        for x in candidate_list:
-            hash_id = self._get_token_id(x)[0]
-            s = outputs[:, hash_id].item()
-            score[s] = {"token_id": self._get_token_id(x),
-                        "token_name": x}
-        return score
-
-    def _select(self, score, mode="max"):
-        probs = [s for s in score]
-        if mode == "max":
-            s = max(probs)
-        elif mode == "sample":
-            s = choices(probs, weights=probs, k=1)
-            s = s[0]
-
-        else:
-            print("unknown select mode")
-
-        return s
-
-    def _get_max_domain_token(self, outputs, candidates, map_type, mode="max"):
-        score = self._get_max_score(outputs, candidates, map_type)
-        s = self._select(score, mode)
-        token_id = score[s]["token_id"]
-        token_name = score[s]["token_name"]
-
-        return {"token_id": token_id, "token_name": token_name}
-
-    def candidate(self, candidate_type, **kwargs):
-        if "intent" in kwargs:
-            intent = kwargs["intent"]
-        if candidate_type == "intent":
-            allow_general_intent = kwargs.get("allow_general_intent", True)
-            if allow_general_intent:
-                return self.domain_intent + self.general_intent
-            else:
-                return self.domain_intent
-        elif candidate_type == "domain":
-            if intent in self.general_intent:
-                return [self.general_domain]
-            else:
-                return [d for d in self.user_goal]
-        elif candidate_type == "slot":
-            if intent in self.general_intent:
-                return ["none"]
-            else:
-                return self._filter_slot(intent, kwargs["domain"], kwargs["is_mentioned"])
-        else:
-            if intent in self.general_intent:
-                return ["none"]
-            elif intent.lower() == "request":
-                return ["<?>"]
-            else:
-                return self._filter_value(intent, kwargs["domain"], kwargs["slot"])
-
-    def get_intent(self, outputs, mode="max", allow_general_intent=True):
-        # return intent, token_id_list
-        # TODO request?
-        canidate_list = self.candidate(
-            "intent", allow_general_intent=allow_general_intent)
-        score = self._get_max_score(outputs, canidate_list, "intent")
-        s = self._select(score, mode)
-
-        return score[s]
-
-    def get_domain(self, outputs, intent, mode="max"):
-        if intent in self.general_intent:
-            token_name = self.general_domain
-            token_id = self.tokenizer(token_name, add_special_tokens=False)
-            token_map = {"token_id": token_id['input_ids'],
-                         "token_name": token_name}
-
-        elif intent in self.domain_intent:
-            # [d for d in self.user_goal]
-            domain_list = self.candidate("domain", intent=intent)
-            token_map = self._get_max_domain_token(
-                outputs=outputs, candidates=domain_list, map_type="domain", mode=mode)
-        else:
-            if self.debug:
-                print("unknown intent", intent)
-
-        return token_map
-
-    def get_slot(self, outputs, intent, domain, mode="max", is_mentioned=False):
-        if intent in self.general_intent:
-            token_name = "none"
-            token_id = self.tokenizer(token_name, add_special_tokens=False)
-            token_map = {"token_id": token_id['input_ids'],
-                         "token_name": token_name}
-
-        elif intent in self.domain_intent:
-            slot_list = self.candidate(
-                candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned)
-            token_map = self._get_max_domain_token(
-                outputs=outputs, candidates=slot_list, map_type="slot", mode=mode)
-
-        return token_map
-
-    def get_value(self, outputs, intent, domain, slot, mode="max"):
-        if intent in self.general_intent or slot.lower() == "none":
-            token_name = "none"
-            token_id = self.tokenizer(token_name, add_special_tokens=False)
-            token_map = {"token_id": token_id['input_ids'],
-                         "token_name": token_name}
-
-        elif intent.lower() == "request":
-            token_name = "<?>"
-            token_id = self.tokenizer(token_name, add_special_tokens=False)
-            token_map = {"token_id": token_id['input_ids'],
-                         "token_name": token_name}
-
-        elif intent in self.domain_intent:
-            # TODO should not none ?
-            # value_list = [v for v in self.user_goal[domain][slot]]
-            value_list = self.candidate(
-                candidate_type="value", intent=intent, domain=domain, slot=slot)
-
-            token_map = self._get_max_domain_token(
-                outputs=outputs, candidates=value_list, map_type="value", mode=mode)
-
-        return token_map
-
-    def _filter_slot(self, intent, domain, is_mentioned=True):
-        slot_list = []
-        for slot in self.user_goal[domain]:
-            value_list = self._filter_value(intent, domain, slot)
-            if len(value_list) > 0:
-                slot_list.append(slot)
-        if not is_mentioned and intent.lower() != "request":
-            slot_list.append("none")
-        return slot_list
-
-    def _filter_value(self, intent, domain, slot):
-        value_list = [v for v in self.user_goal[domain][slot]]
-        if "none" in value_list:
-            value_list.remove("none")
-        if intent.lower() != "request":
-            if "?" in value_list:
-                value_list.remove("?")
-            if "<?>" in value_list:
-                value_list.remove("<?>")
-        # print(f"{intent}-{domain}-{slot}= {value_list}")
-        return value_list
-
-    def _get_token_id(self, token):
-        return self.tokenizer(token, add_special_tokens=False)["input_ids"]
diff --git a/convlab/policy/genTUS/utils.py b/convlab/policy/genTUS/utils.py
deleted file mode 100644
index 39c822dd35f985790d53a2834e8b6fe437864f24..0000000000000000000000000000000000000000
--- a/convlab/policy/genTUS/utils.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import torch
-
-
-def append_tokens(tokens, new_token, device):
-    return torch.cat((tokens, torch.tensor([new_token]).to(device)), dim=1)
diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json
index 98aad8ddbcf50060c0d38a640f1ac1bb2d7a5c7e..a4a24598d7eac283bcc05d21b0f695cad90b6433 100644
--- a/convlab/policy/ppo/semantic_level_config.json
+++ b/convlab/policy/ppo/semantic_level_config.json
@@ -1,6 +1,6 @@
 {
 	"model": {
-		"load_path": "convlab/policy/ppo/pretrained_models/supervised",
+		"load_path": "",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
 		"batchsz": 1000,
@@ -40,4 +40,4 @@
 		}
 	},
 	"usr_nlg": {}
-}
+}
\ No newline at end of file
diff --git a/convlab/policy/ppo/semantic_level_config_eval.json b/convlab/policy/ppo/semantic_level_config_eval.json
deleted file mode 100644
index be03421a84f5cd48f04ef255901942225a84e6ac..0000000000000000000000000000000000000000
--- a/convlab/policy/ppo/semantic_level_config_eval.json
+++ /dev/null
@@ -1,43 +0,0 @@
-{
-	"model": {
-		"load_path": "",
-		"use_pretrained_initialisation": false,
-		"pretrained_load_path": "",
-		"batchsz": 1000,
-		"seed": 0,
-		"epoch": 0,
-		"eval_frequency": 5,
-		"process_num": 4,
-		"sys_semantic_to_usr": false,
-		"num_eval_dialogues": 500
-	},
-	"vectorizer_sys": {
-		"uncertainty_vector_mul": {
-			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
-			"ini_params": {
-				"use_masking": true,
-				"manually_add_entity_names": false,
-				"seed": 0
-			}
-		}
-	},
-	"nlu_sys": {},
-	"dst_sys": {
-		"RuleDST": {
-			"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
-			"ini_params": {}
-		}
-	},
-	"sys_nlg": {},
-	"nlu_usr": {},
-	"dst_usr": {},
-	"policy_usr": {
-		"RulePolicy": {
-			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
-			"ini_params": {
-				"character": "usr"
-			}
-		}
-	},
-	"usr_nlg": {}
-}
diff --git a/convlab/policy/ppo/setsumbt_end_baseline_config_eval.json b/convlab/policy/ppo/setsumbt_end_baseline_config_eval.json
deleted file mode 100644
index b4fd62aa810a67e1a2219f70ba630bedab249431..0000000000000000000000000000000000000000
--- a/convlab/policy/ppo/setsumbt_end_baseline_config_eval.json
+++ /dev/null
@@ -1,65 +0,0 @@
-{
-	"model": {
-		"load_path": "supervised",
-		"pretrained_load_path": "",
-		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
-		"seed": 0,
-		"epoch": 0,
-		"eval_frequency": 5,
-		"process_num": 2,
-		"num_eval_dialogues": 500,
-		"sys_semantic_to_usr": false
-	},
-	"vectorizer_sys": {
-		"uncertainty_vector_mul": {
-			"class_path": "convlab.policy.vector.vector_uncertainty.VectorUncertainty",
-			"ini_params": {
-				"use_masking": false,
-				"manually_add_entity_names": false,
-			        "seed": 0,
-			        "use_confidence_scores": true,
-			        "use_state_total_uncertainty": true
-			}
-		}
-	},
-	"nlu_sys": {},
-	"dst_sys": {
-		"setsumbt-mul": {
-			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
-			"ini_params": {
-			        "model_path": "https://zenodo.org/record/5497808/files/setsumbt_end.zip",
-			        "return_confidence_scores": true,
-			        "return_belief_state_entropy": true
-			}
-		}
-	},
-	"sys_nlg": {
-		"TemplateNLG": {
-			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
-			"ini_params": {
-				"is_user": false
-			}
-		}
-	},
-	"nlu_usr": {},
-	"dst_usr": {},
-	"policy_usr": {
-		"RulePolicy": {
-			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
-			"ini_params": {
-				"character": "usr"
-			}
-		}
-	},
-	"usr_nlg": {
-		"TemplateNLG": {
-			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
-			"ini_params": {
-				"is_user": true,
-				"label_noise": 0.0,
-				"text_noise": 0.0
-			}
-		}
-	}
-}
diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py
index 1d453c9ac395d981a0adebe912be5266ec6bd0e1..9f00d3ddfd635e52c0084b0989ae46763aec4de1 100755
--- a/convlab/policy/ppo/train.py
+++ b/convlab/policy/ppo/train.py
@@ -34,6 +34,7 @@ except RuntimeError:
 
 
 def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
+
     """
     This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
     processes.
@@ -108,6 +109,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
 
 
 def sample(env, policy, batchsz, process_num, seed):
+
     """
     Given batchsz number of task, the batchsz will be splited equally to each processes
     and when processes return, it merge all data and return
@@ -135,8 +137,7 @@ def sample(env, policy, batchsz, process_num, seed):
     evt = mp.Event()
     processes = []
     for i in range(process_num):
-        process_args = (i, queue, evt, env, policy,
-                        process_batchsz, train_seeds[i])
+        process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
         processes.append(mp.Process(target=sampler, args=process_args))
     for p in processes:
         # set the process as daemon, and it will be killed once the main process is stoped.
@@ -189,41 +190,31 @@ if __name__ == '__main__':
                         help="Set level for logger")
     parser.add_argument("--save_eval_dials", type=bool, default=False,
                         help="Flag for saving dialogue_info during evaluation")
-    parser.add_argument("--save_path", type=str, default=None,
-                        help="Custom save path other than the path of this script")
 
     path = parser.parse_args().path
     seed = parser.parse_args().seed
     mode = parser.parse_args().mode
     save_eval = parser.parse_args().save_eval_dials
-    custom_save_path = parser.parse_args().save_path
 
-    if custom_save_path:
-        logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
-            init_logging(custom_save_path, mode)
-    else:
-        logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
-            init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
+    logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
+        init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
 
     args = [('model', 'seed', seed)] if seed is not None else list()
 
     environment_config = load_config_file(path)
-    save_config(vars(parser.parse_args()),
-                environment_config, config_save_path)
+    save_config(vars(parser.parse_args()), environment_config, config_save_path)
 
     conf = get_config(path, args)
     seed = conf['model']['seed']
     logging.info('Train seed is ' + str(seed))
     set_seed(seed)
 
-    policy_sys = PPO(True, seed=conf['model']['seed'],
-                     vectorizer=conf['vectorizer_sys_activated'])
+    policy_sys = PPO(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated'])
 
     # Load model
     if conf['model']['use_pretrained_initialisation']:
         logging.info("Loading supervised model checkpoint.")
-        policy_sys.load_from_pretrained(
-            conf['model'].get('pretrained_load_path', ""))
+        policy_sys.load_from_pretrained(conf['model'].get('pretrained_load_path', ""))
     elif conf['model']['load_path']:
         try:
             policy_sys.load(conf['model']['load_path'])
@@ -251,8 +242,7 @@ if __name__ == '__main__':
 
     logging.info(f"Evaluating at start - {time_now}" + '-'*60)
     time_now = time.time()
-    eval_dict = eval_policy(conf, policy_sys, env, sess,
-                            save_eval, log_save_path)
+    eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
     logging.info(f"Finished evaluating, time spent: {time.time() - time_now}")
 
     for key in eval_dict:
@@ -267,15 +257,13 @@ if __name__ == '__main__':
     for i in range(conf['model']['epoch']):
         idx = i + 1
         # print("Epoch :{}".format(str(idx)))
-        update(env, policy_sys, conf['model']['batchsz'],
-               idx, conf['model']['process_num'], seed=seed)
+        update(env, policy_sys, conf['model']['batchsz'], idx, conf['model']['process_num'], seed=seed)
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
             logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60)
 
-            eval_dict = eval_policy(
-                conf, policy_sys, env, sess, save_eval, log_save_path)
+            eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
             best_complete_rate, best_success_rate, best_return = \
                 save_best(policy_sys, best_complete_rate, best_success_rate, best_return,
@@ -283,8 +271,7 @@ if __name__ == '__main__':
                           eval_dict["avg_return"], save_path)
             policy_sys.save(save_path, "last")
             for key in eval_dict:
-                tb_writer.add_scalar(
-                    key, eval_dict[key], idx * conf['model']['batchsz'])
+                tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz'])
 
     logging.info("End of Training: " +
                  time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
@@ -293,9 +280,5 @@ if __name__ == '__main__':
     f.write(str(datetime.now() - begin_time))
     f.close()
 
-    if custom_save_path:
-        move_finished_training(dir_path, os.path.join(
-            custom_save_path, "finished_experiments"))
-    else:
-        move_finished_training(dir_path, os.path.join(
-            os.path.dirname(os.path.abspath(__file__)), "finished_experiments"))
+    move_finished_training(dir_path, os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "finished_experiments"))
diff --git a/convlab/policy/ppo/trippy_eval.json b/convlab/policy/ppo/trippy_eval.json
deleted file mode 100644
index d7f6ff5dba432bf519fc1bc1e13c425d8f897dfb..0000000000000000000000000000000000000000
--- a/convlab/policy/ppo/trippy_eval.json
+++ /dev/null
@@ -1,47 +0,0 @@
-{
-	"model": {
-		"load_path": "finished_experiments/experiment_2022-09-08-14-53-25/save/best_ppo",
-		"pretrained_load_path": "",
-		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
-		"seed": 0,
-		"epoch": 0,
-		"eval_frequency": 5,
-		"process_num": 1,
-		"num_eval_dialogues": 500,
-		"sys_semantic_to_usr": false
-	},
-        "vectorizer_sys": {
-	        "uncertainty_vector_mul": {
-		        "class_path": "convlab.policy.vector.vector_binary.VectorBinary",
-		        "ini_params": {
-			        "use_masking": true,
-			        "manually_add_entity_names": true,
-			        "seed": 0
-			}
-		}
-	},
-	"nlu_sys": {},
-	"dst_sys": {
-		"TripPy": {
-			"class_path": "convlab.dst.trippy.multiwoz.TRIPPY",
-		        "ini_params": {
-			        "model_type": "roberta",
-			        "model_path": "/gpfs/project/heckmi/trippy/multiwoz21/trippy.convlab3/results.42/checkpoint-28392",
-			        "nlu_path": "/home/heckmi/models/bert_multiwoz_all_context.zip"
-			}
-		}
-	},
-        "sys_nlg": {},
-	"nlu_usr": {},
-	"dst_usr": {},
-	"policy_usr": {
-		"RulePolicy": {
-			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
-			"ini_params": {
-				"character": "usr"
-			}
-		}
-	},
-	"usr_nlg": {}
-}
diff --git a/convlab/policy/ppo/trippy_train.json b/convlab/policy/ppo/trippy_train.json
deleted file mode 100644
index 248811acb2c22c18dd66b54787740968067a703c..0000000000000000000000000000000000000000
--- a/convlab/policy/ppo/trippy_train.json
+++ /dev/null
@@ -1,47 +0,0 @@
-{
-	"model": {
-		"load_path": "/home/heckmi/models/supervised",
-		"pretrained_load_path": "",
-		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
-		"seed": 0,
-		"epoch": 50,
-		"eval_frequency": 5,
-		"process_num": 1,
-	        "num_eval_dialogues": 500,
-		"sys_semantic_to_usr": false
-	},
-        "vectorizer_sys": {
-	        "uncertainty_vector_mul": {
-		        "class_path": "convlab.policy.vector.vector_binary.VectorBinary",
-		        "ini_params": {
-			        "use_masking": true,
-			        "manually_add_entity_names": true,
-			        "seed": 0
-			}
-		}
-	},
-	"nlu_sys": {},
-	"dst_sys": {
-		"TripPy": {
-			"class_path": "convlab.dst.trippy.multiwoz.TRIPPY",
-		        "ini_params": {
-			        "model_type": "roberta",
-			        "model_path": "/gpfs/project/heckmi/trippy/multiwoz21/trippy.convlab3/results.42/checkpoint-28392",
-			        "nlu_path": "/home/heckmi/models/bert_multiwoz_all_context.zip"
-			}
-		}
-	},
-        "sys_nlg": {},
-	"nlu_usr": {},
-	"dst_usr": {},
-	"policy_usr": {
-		"RulePolicy": {
-			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
-			"ini_params": {
-				"character": "usr"
-			}
-		}
-	},
-	"usr_nlg": {}
-}
diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json
index 445fa9278552a9353529d8eca6dc8b2c6c169fd9..cc0216122dd473ab270fe52617621732d5b2034c 100644
--- a/convlab/policy/ppo/tus_semantic_level_config.json
+++ b/convlab/policy/ppo/tus_semantic_level_config.json
@@ -1,11 +1,11 @@
 {
 	"model": {
-		"load_path": "convlab/policy/ppo/pretrained_models/supervised",
+		"load_path": "convlab/policy/mle/experiments/experiment_2022-05-23-14-08-43/save/supervised",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
 		"batchsz": 1000,
 		"seed": 0,
-		"epoch": 200,
+		"epoch": 50,
 		"eval_frequency": 5,
 		"process_num": 4,
 		"sys_semantic_to_usr": false,
@@ -35,9 +35,9 @@
 		"TUSPolicy": {
 			"class_path": "convlab.policy.tus.unify.TUS.UserPolicy",
 			"ini_params": {
-				"config": "convlab/policy/tus/unify/exp/multiwoz.json"
+				"config": "convlab/policy/tus/unify/exp/all.json"
 			}
 		}
 	},
 	"usr_nlg": {}
-}
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/Goal.py b/convlab/policy/tus/unify/Goal.py
index 610469bb4246cf6c16ef4271c89827cab6421437..82a6e7a9799b4537495cf35e57f6f52711e78af8 100644
--- a/convlab/policy/tus/unify/Goal.py
+++ b/convlab/policy/tus/unify/Goal.py
@@ -1,9 +1,6 @@
 import time
 import json
-from convlab.policy.tus.unify.util import split_slot_name, slot_name_map
-from convlab.util.custom_util import slot_mapping
-
-from random import sample, shuffle
+from convlab.policy.tus.unify.util import split_slot_name
 from pprint import pprint
 DEF_VAL_UNK = '?'  # Unknown
 DEF_VAL_DNC = 'dontcare'  # Do not care
@@ -30,35 +27,6 @@ def isTimeFormat(input):
         return False
 
 
-def old_goal2list(goal: dict, reorder=False) -> list:
-    goal_list = []
-    for domain in goal:
-        for slot_type in ['info', 'book', 'reqt']:
-            if slot_type not in goal[domain]:
-                continue
-            temp = []
-            for slot in goal[domain][slot_type]:
-                s = slot
-                if slot in slot_name_map:
-                    s = slot_name_map[slot]
-                elif slot in slot_name_map[domain]:
-                    s = slot_name_map[domain][slot]
-                # domain, intent, slot, value
-                if slot_type in ['info', 'book']:
-                    i = "inform"
-                    v = goal[domain][slot_type][slot]
-                else:
-                    i = "request"
-                    v = DEF_VAL_UNK
-                s = slot_mapping.get(s, s)
-                temp.append([domain, i, s, v])
-            shuffle(temp)
-            goal_list = goal_list + temp
-    # shuffle_goal = goal_list[:1] + sample(goal_list[1:], len(goal_list)-1)
-    # return shuffle_goal
-    return goal_list
-
-
 class Goal(object):
     """ User Goal Model Class. """
 
@@ -133,7 +101,6 @@ class Goal(object):
                     if self.domain_goals[domain]["reqt"][slot] == DEF_VAL_UNK:
                         # print(f"not fulfilled request{domain}-{slot}")
                         return False
-
         return True
 
     def init_local_id(self):
@@ -202,10 +169,6 @@ class Goal(object):
 
     def _update_status(self, action: list, char: str):
         for intent, domain, slot, value in action:
-            if slot == "arrive by":
-                slot = "arriveBy"
-            elif slot == "leave at":
-                slot = "leaveAt"
             if domain not in self.status:
                 self.status[domain] = {}
             # update info
@@ -217,10 +180,6 @@ class Goal(object):
     def _update_goal(self, action: list, char: str):
         # update requt slots in goal
         for intent, domain, slot, value in action:
-            if slot == "arrive by":
-                slot = "arriveBy"
-            elif slot == "leave at":
-                slot = "leaveAt"
             if "info" not in intent:
                 continue
             if self._check_update_request(domain, slot) and value != "?":
diff --git a/convlab/policy/tus/unify/TUS.py b/convlab/policy/tus/unify/TUS.py
index de98e6d1dfea7d4486d56a60dee5662f745c2df9..09c50672fb889ffd3e965ec0e90d2dee30b2c360 100644
--- a/convlab/policy/tus/unify/TUS.py
+++ b/convlab/policy/tus/unify/TUS.py
@@ -5,18 +5,19 @@ from copy import deepcopy
 
 import torch
 from convlab.policy.policy import Policy
+from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import (
+    act_dict_to_flat_tuple, unified_format)
 from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction
 from convlab.policy.tus.unify.Goal import Goal
 from convlab.policy.tus.unify.usermanager import BinaryFeature
-from convlab.policy.tus.unify.util import create_goal, split_slot_name
+from convlab.policy.tus.unify.util import (create_goal, int2onehot,
+                                           metadata2state, parse_dialogue_act,
+                                           parse_user_goal, split_slot_name)
 from convlab.util import (load_dataset,
                           relative_import_module_from_unified_datasets)
 from convlab.util.custom_util import model_downloader
-from convlab.task.multiwoz.goal_generator import GoalGenerator
-from convlab.policy.tus.unify.Goal import old_goal2list
-from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
-
-
+from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA
+from pprint import pprint
 reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets(
     'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value'])
 
@@ -50,7 +51,7 @@ class UserActionPolicy(Policy):
         self.user = TransformerActionPrediction(self.config).to(device=DEVICE)
         if pretrain:
             model_path = os.path.join(
-                self.config["model_dir"], "model-non-zero")# self.config["model_name"])
+                self.config["model_dir"], self.config["model_name"])
             print(f"loading model from {model_path}...")
             self.load(model_path)
         self.user.eval()
@@ -60,7 +61,6 @@ class UserActionPolicy(Policy):
         self.reward = {"success": 40,
                        "fail": -20}
         self.sys_acts = []
-        self.goal_gen = GoalGenerator()
 
     def _no_offer(self, system_in):
         for intent, domain, slot, value in system_in:
@@ -127,14 +127,13 @@ class UserActionPolicy(Policy):
         self.topic = 'NONE'
         remove_domain = "police"  # remove police domain in inference
 
+        # if not goal:
+        #     self.new_goal(remove_domain=remove_domain)
+        # else:
+        #     self.read_goal(goal)
         if not goal:
-            old_goal = self.goal_gen.get_user_goal()
-            goal_list = old_goal2list(old_goal)
-            goal = Goal(goal_list)
-
-        elif type(goal) == ABUS_Goal:
-            goal_list = old_goal2list(goal.domain_goals)
-            goal = Goal(goal_list)
+            data = load_dataset(self.dataset, 0)
+            goal = Goal(create_goal(data["test"][0]))
 
         self.read_goal(goal)
         self.feat_handler.initFeatureHandeler(self.goal)
@@ -156,15 +155,15 @@ class UserActionPolicy(Policy):
         else:
             self.goal = Goal(goal=data_goal)
 
-    # def new_goal(self, remove_domain="police", domain_len=None):
-    #     keep_generate_goal = True
-    #     while keep_generate_goal:
-    #         self.goal = Goal(goal_generator=self.goal_gen)
-    #         if (domain_len and len(self.goal.domains) != domain_len) or \
-    #                 (remove_domain and remove_domain in self.goal.domains):
-    #             keep_generate_goal = True
-    #         else:
-    #             keep_generate_goal = False
+    def new_goal(self, remove_domain="police", domain_len=None):
+        keep_generate_goal = True
+        while keep_generate_goal:
+            self.goal = Goal(goal_generator=self.goal_gen)
+            if (domain_len and len(self.goal.domains) != domain_len) or \
+                    (remove_domain and remove_domain in self.goal.domains):
+                keep_generate_goal = True
+            else:
+                keep_generate_goal = False
 
     def load(self, model_path=None):
         self.user.load_state_dict(torch.load(model_path, map_location=DEVICE))
@@ -404,32 +403,16 @@ class UserPolicy(Policy):
             self.config = json.load(open(config))
         else:
             self.config = config
-        self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz'
+        self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}'
         if not os.path.exists(self.config["model_dir"]):
             # os.mkdir(self.config["model_dir"])
             model_downloader(os.path.dirname(self.config["model_dir"]),
                              "https://zenodo.org/record/5779832/files/default.zip")
-        self.slot2dbattr = {
-            'open hours': 'openhours',
-            'price range': 'pricerange',
-            'arrive by': 'arriveBy',
-            'leave at': 'leaveAt',
-            'train id': 'trainID'
-        }
-        self.dbattr2slot = {}
-        for k,v in self.slot2dbattr.items():
-            self.dbattr2slot[v] = k
 
         self.policy = UserActionPolicy(self.config)
 
     def predict(self, state):
-        raw_act = self.policy.predict(state)
-        act = []
-        for intent, domain, slot, value in raw_act:
-            if slot in self.dbattr2slot:
-                slot = self.dbattr2slot[slot]
-            act.append([intent, domain, slot, value])
-        return act
+        return self.policy.predict(state)
 
     def init_session(self, goal=None):
         self.policy.init_session(goal)
@@ -441,6 +424,13 @@ class UserPolicy(Policy):
         return self.policy.get_reward()
 
     def get_goal(self):
+        slot2dbattr = {
+            'open hours': 'openhours',
+            'price range': 'pricerange',
+            'arrive by': 'arriveBy',
+            'leave at': 'leaveAt',
+            'train id': 'trainID'
+        }
         if hasattr(self.policy, 'get_goal'):
             # workaround: convert goal to old format
             multiwoz_goal = {}
@@ -459,8 +449,8 @@ class UserPolicy(Policy):
                                 multiwoz_goal[domain]["book"] = {}
                             norm_slot = slot.split(' ')[-1]
                             multiwoz_goal[domain]["book"][norm_slot] = value
-                        elif slot in self.slot2dbattr:
-                            norm_slot = self.slot2dbattr[slot]
+                        elif slot in slot2dbattr:
+                            norm_slot = slot2dbattr[slot]
                             multiwoz_goal[domain][slot_type][norm_slot] = value
                         else:
                             multiwoz_goal[domain][slot_type][slot] = value
diff --git a/convlab/policy/tus/unify/usermanager.py b/convlab/policy/tus/unify/usermanager.py
index 3192d3d578ca3698424b16f47933d6863208f77c..640da7b9ee01414f04700e6aaa2976f9c916914f 100644
--- a/convlab/policy/tus/unify/usermanager.py
+++ b/convlab/policy/tus/unify/usermanager.py
@@ -97,7 +97,7 @@ class TUSDataManager(Dataset):
                     action_list, user_goal, cur_state, usr_act)
                 domain_label = self.feature_handler.domain_label(
                     user_goal, usr_act)
-                # pre_state = user_goal.update(action=usr_act, char="user") # trick?
+                pre_state = user_goal.update(action=usr_act, char="user")
                 feature["id"].append(dialog["dialogue_id"])
                 feature["input"].append(input_feature)
                 feature["mask"].append(mask)
diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py
index d65f72a06e181e66bfe0d7ac0c60f0c03a56ad43..1978f489534f74839b11dee333e232cbc601f270 100644
--- a/convlab/policy/tus/unify/util.py
+++ b/convlab/policy/tus/unify/util.py
@@ -1,49 +1,9 @@
 from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
-from convlab.util import load_dataset
-
 import json
 
 NOT_MENTIONED = "not mentioned"
 
 
-def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1):
-    ratio = {'train': split2ratio, 'validation': split2ratio}
-    if data_name == "all" or data_name == "sgd+tm" or data_name == "tm":
-        print("merge all datasets...")
-        if data_name == "all":
-            all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"]
-        if data_name == "sgd+tm":
-            all_dataset = ["sgd", "tm1", "tm2", "tm3"]
-        if data_name == "tm":
-            all_dataset = ["tm1", "tm2", "tm3"]
-
-        datasets = {}
-        for name in all_dataset:
-            datasets[name] = load_dataset(
-                name,
-                dial_ids_order=dial_ids_order,
-                split2ratio=ratio)
-        raw_data = merge_dataset(datasets, all_dataset[0])
-
-    else:
-        print(f"load single dataset {data_name}/{split2ratio}")
-        raw_data = load_dataset(data_name,
-                                dial_ids_order=dial_ids_order,
-                                split2ratio=ratio)
-    return raw_data
-
-
-def merge_dataset(datasets, data_name):
-    data_split = [x for x in datasets[data_name]]
-    raw_data = {}
-    for data_type in data_split:
-        raw_data[data_type] = []
-        for dataname, dataset in datasets.items():
-            print(f"merge {dataname}...")
-            raw_data[data_type] += dataset[data_type]
-    return raw_data
-
-
 def int2onehot(index, output_dim=6, remove_zero=False):
     one_hot = [0] * output_dim
     if remove_zero:
@@ -129,6 +89,50 @@ def get_booking_domain(slot, value, all_values, domain_list):
     return found
 
 
+def act2slot(intent, domain, slot, value, all_values):
+
+    if domain not in UsrDa2Goal:
+        # print(f"Not handle domain {domain}")
+        return ""
+
+    if domain == "booking":
+        slot = SysDa2Goal[domain][slot]
+        domain = get_booking_domain(slot, value, all_values)
+        return f"{domain}-{slot}"
+
+    elif domain in UsrDa2Goal:
+        if slot in SysDa2Goal[domain]:
+            slot = SysDa2Goal[domain][slot]
+        elif slot in UsrDa2Goal[domain]:
+            slot = UsrDa2Goal[domain][slot]
+        elif slot in SysDa2Goal["booking"]:
+            slot = SysDa2Goal["booking"][slot]
+        # else:
+        #     print(
+        #         f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}")
+
+        return f"{domain}-{slot}"
+
+    print("strange!!!")
+    print(intent, domain, slot, value)
+
+    return ""
+
+
+def get_user_history(dialog, all_values):
+    turn_num = len(dialog)
+    mentioned_slot = []
+    for turn_id in range(0, turn_num, 2):
+        usr_act = parse_dialogue_act(
+            dialog[turn_id]["dialog_act"])
+        for intent, domain, slot, value in usr_act:
+            slot_name = act2slot(
+                intent, domain.lower(), slot.lower(), value.lower(), all_values)
+            if slot_name not in mentioned_slot:
+                mentioned_slot.append(slot_name)
+    return mentioned_slot
+
+
 def update_config_file(file_name, attribute, value):
     with open(file_name, 'r') as config_file:
         config = json.load(config_file)
@@ -143,7 +147,7 @@ def update_config_file(file_name, attribute, value):
 def create_goal(dialog) -> list:
     # a list of {'intent': ..., 'domain': ..., 'slot': ..., 'value': ...}
     dicts = []
-    for turn in dialog['turns']:
+    for i, turn in enumerate(dialog['turns']):
         # print(turn['speaker'])
         # assert (i % 2 == 0) == (turn['speaker'] == 'user')
         # if i % 2 == 0:
@@ -201,45 +205,6 @@ def split_slot_name(slot_name):
         return tokens[0], '-'.join(tokens[1:])
 
 
-# copy from data.unified_datasets.multiwoz21
-slot_name_map = {
-    'addr': "address",
-    'post': "postcode",
-    'pricerange': "price range",
-    'arrive': "arrive by",
-    'arriveby': "arrive by",
-    'leave': "leave at",
-    'leaveat': "leave at",
-    'depart': "departure",
-    'dest': "destination",
-    'fee': "entrance fee",
-    'open': 'open hours',
-    'car': "type",
-    'car type': "type",
-    'ticket': 'price',
-    'trainid': 'train id',
-    'id': 'train id',
-    'people': 'book people',
-    'stay': 'book stay',
-    'none': '',
-    'attraction': {
-        'price': 'entrance fee'
-    },
-    'hospital': {},
-    'hotel': {
-        'day': 'book day', 'price': "price range"
-    },
-    'restaurant': {
-        'day': 'book day', 'time': 'book time', 'price': "price range"
-    },
-    'taxi': {},
-    'train': {
-        'day': 'day', 'time': "duration"
-    },
-    'police': {},
-    'booking': {}
-}
-
 if __name__ == "__main__":
     print(split_slot_name("restaurant-search-location"))
     print(split_slot_name("sports-day.match"))
diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py
deleted file mode 100644
index ee4582467925b76b2c1472f445e057808cfda9fc..0000000000000000000000000000000000000000
--- a/convlab/policy/vector/vector_uncertainty.py
+++ /dev/null
@@ -1,166 +0,0 @@
-# -*- coding: utf-8 -*-
-import sys
-import numpy as np
-import logging
-
-from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
-from convlab.policy.vector.vector_binary import VectorBinary
-
-
-class VectorUncertainty(VectorBinary):
-    """Vectorise state and state uncertainty predictions"""
-
-    def __init__(self,
-                 dataset_name: str = 'multiwoz21',
-                 character: str = 'sys',
-                 use_masking: bool = False,
-                 manually_add_entity_names: bool = True,
-                 seed: str = 0,
-                 use_confidence_scores: bool = True,
-                 confidence_thresholds: dict = None,
-                 use_state_total_uncertainty: bool = False,
-                 use_state_knowledge_uncertainty: bool = False):
-        """
-        Args:
-            dataset_name: Name of environment dataset
-            character: Character of the agent (sys/usr)
-            use_masking: If true certain actions are masked during devectorisation
-            manually_add_entity_names: If true inform entity name actions are manually added
-            seed: Seed
-            use_confidence_scores: If true confidence scores are used in state vectorisation
-            confidence_thresholds: If true confidence thresholds are used in database querying
-            use_state_total_uncertainty: If true state entropy is added to the state vector
-            use_state_knowledge_uncertainty: If true state mutual information is added to the state vector
-        """
-
-        self.use_confidence_scores = use_confidence_scores
-        self.use_state_total_uncertainty = use_state_total_uncertainty
-        self.use_state_knowledge_uncertainty = use_state_knowledge_uncertainty
-        if confidence_thresholds is not None:
-            self.setup_uncertain_query(confidence_thresholds)
-
-        super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
-
-    def get_state_dim(self):
-        self.belief_state_dim = 0
-
-        for domain in self.ontology['state']:
-            for slot in self.ontology['state'][domain]:
-                # Dim 1 - indicator/confidence score
-                # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc)
-                slot_dim = 1 if not self.use_state_total_uncertainty else 2
-                slot_dim += 1 if self.use_state_knowledge_uncertainty else 0
-                self.belief_state_dim += slot_dim
-
-        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
-            len(self.db_domains) + 6 * len(self.db_domains) + 1
-
-    # Add thresholds for db_queries
-    def setup_uncertain_query(self, confidence_thresholds):
-        self.use_confidence_scores = True
-        self.confidence_thresholds = confidence_thresholds
-        logging.info('DB Search uncertainty activated.')
-
-    def dbquery_domain(self, domain):
-        """
-        query entities of specified domain
-        Args:
-            domain string:
-                domain to query
-        Returns:
-            entities list:
-                list of entities of the specified domain
-        """
-        # Get all user constraints
-        constraints = {slot: value for slot, value in self.state[domain].items()
-                       if slot and value not in ['dontcare',
-                                                 "do n't care", "do not care"]} if domain in self.state else dict()
-
-        # Remove constraints for which the uncertainty is high
-        if self.confidence_scores is not None and self.use_confidence_scores and self.confidence_thresholds is not None:
-            # Collect threshold values for each domain-slot pair
-            threshold = self.confidence_thresholds.get(domain, dict())
-            threshold = {slot: threshold.get(slot, 0.05) for slot in constraints}
-            # Get confidence scores for each constraint
-            probs = self.confidence_scores.get(domain, dict())
-            probs = {slot: probs.get(slot, {}).get('inform', 1.0) for slot in constraints}
-
-            # Filter out constraints for which confidence is lower than threshold
-            constraints = {slot: value for slot, value in constraints.items() if probs[slot] >= threshold[slot]}
-
-        return self.db.query(domain, constraints.items(), topk=10)
-
-    def vectorize_user_act(self, state):
-        """Return confidence scores for the user actions"""
-        self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
-        action = state['user_action'] if self.character == 'sys' else state['system_action']
-        opp_action = delexicalize_da(action, self.requestable)
-        opp_action = flat_da(opp_action)
-        opp_act_vec = np.zeros(self.da_opp_dim)
-        for da in opp_action:
-            if da in self.opp2vec:
-                if 'belief_state_probs' in state and self.use_confidence_scores:
-                    domain, intent, slot, value = da.split('-')
-                    if domain in state['belief_state_probs']:
-                        slot = slot if slot else 'none'
-                        if slot in state['belief_state_probs'][domain]:
-                            prob = state['belief_state_probs'][domain][slot]
-                        elif slot.lower() in state['belief_state_probs'][domain]:
-                            prob = state['belief_state_probs'][domain][slot.lower()]
-                        else:
-                            prob = dict()
-
-                        if intent in prob:
-                            prob = float(prob[intent])
-                        else:
-                            prob = 1.0
-                    else:
-                        prob = 1.0
-                else:
-                    prob = 1.0
-                opp_act_vec[self.opp2vec[da]] = prob
-
-        return opp_act_vec
-
-    def vectorize_belief_state(self, state, domain_active_dict):
-        belief_state = np.zeros(self.belief_state_dim)
-        i = 0
-        for domain in self.belief_domains:
-            if self.use_confidence_scores and 'belief_state_probs' in state:
-                for slot in state['belief_state'][domain]:
-                    prob = None
-                    if slot in state['belief_state_probs'][domain]:
-                        prob = state['belief_state_probs'][domain][slot]
-                        prob = prob['inform'] if 'inform' in prob else None
-                    if prob:
-                        belief_state[i] = float(prob)
-                    i += 1
-            else:
-                for slot, value in state['belief_state'][domain].items():
-                    if value and value != 'not mentioned':
-                        belief_state[i] = 1.
-                    i += 1
-
-            if 'active_domains' in state:
-                domain_active = state['active_domains'][domain]
-                domain_active_dict[domain] = domain_active
-            else:
-                if [slot for slot, value in state['belief_state'][domain].items() if value]:
-                    domain_active_dict[domain] = True
-
-        # Add knowledge and/or total uncertainty to the belief state
-        if self.use_state_total_uncertainty and 'entropy' in state:
-            for domain in self.belief_domains:
-                for slot in state['belief_state'][domain]:
-                    if slot in state['entropy'][domain]:
-                        belief_state[i] = float(state['entropy'][domain][slot])
-                    i += 1
-
-        if self.use_state_knowledge_uncertainty and 'mutual_information' in state:
-            for domain in self.belief_domains:
-                for slot in state['belief_state'][domain]['semi']:
-                    if slot in state['mutual_information'][domain]:
-                        belief_state[i] = float(state['mutual_information'][domain][slot])
-                    i += 1
-
-        return belief_state, domain_active_dict
diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt
new file mode 100644
index 0000000000000000000000000000000000000000..2f7e5e9cdf949b082b9ae752c09d2eafeb322bfb
Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt differ
diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5c6193163b6a1fc788b959eff2d0153b3a4bf7cb
Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt differ
diff --git a/convlab/util/analysis_tool/analyzer.py b/convlab/util/analysis_tool/analyzer.py
index 61a863ca49601344a5b3652d670ba40d217268ce..5163bee1e151b3d8855a481222d753a862f1f783 100755
--- a/convlab/util/analysis_tool/analyzer.py
+++ b/convlab/util/analysis_tool/analyzer.py
@@ -109,7 +109,6 @@ class Analyzer:
 
             step = 0
 
-            print('Goal:', sess.evaluator.goal, file=flog)
             # print('init goal:',file=f)
             # # print(sess.evaluator.goal, file=f)
             # # pprint(sess.evaluator.goal)
@@ -156,7 +155,7 @@ class Analyzer:
                 if session_over:
                     break
 
-            task_success = sess.evaluator.task_success(f=flog)
+            task_success = sess.evaluator.task_success()
             task_complete = sess.user_agent.policy.policy.goal.task_complete()
             book_rate = sess.evaluator.book_rate()
             stats = sess.evaluator.inform_F1()
@@ -164,7 +163,6 @@ class Analyzer:
             if task_success:
                 print('Dialogue succesfully completed!', file=flog)
             else:
-                print("> final goal:", sess.evaluator.goal, file=flog)
                 print('Dialogue NOT completed succesfully!', file=flog)
 
             percentage = sess.evaluator.final_goal_analyze()
diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py
index 80fa115a2b1d6bb46f8366c9cbfc6d0911f75849..aad6c4cd9f2898a57ae59c47c9cf1bd5629c8fec 100644
--- a/convlab/util/custom_util.py
+++ b/convlab/util/custom_util.py
@@ -25,7 +25,8 @@ import signal
 
 
 slot_mapping = {"pricerange": "price range", "post": "postcode", "arriveBy": "arrive by", "leaveAt": "leave at",
-                "Id": "train id", "ref": "reference", "trainID": "train id"}
+                "Id": "trainid", "ref": "reference"}
+
 
 sys.path.append(os.path.dirname(os.path.dirname(
     os.path.dirname(os.path.abspath(__file__)))))
@@ -102,8 +103,7 @@ def load_config_file(filepath: str = None) -> dict:
 
 def save_config(terminal_args, config_file_args, config_save_path, policy_config=None):
     config_save_path = os.path.join(config_save_path, f'config_saved.json')
-    args_dict = {"args": terminal_args,
-                 "config": config_file_args, "policy_config": policy_config}
+    args_dict = {"args": terminal_args, "config": config_file_args, "policy_config": policy_config}
     json.dump(args_dict, open(config_save_path, 'w'))
 
 
@@ -165,29 +165,26 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
     goals = []
     for seed in range(1000, 1000 + conf['model']['num_eval_dialogues']):
         set_seed(seed)
-        goal = create_goals(goal_generator, 1,
-                            single_domain_goals, allowed_domains)
+        goal = create_goals(goal_generator, 1, single_domain_goals, allowed_domains)
         goals.append(goal[0])
 
     if conf['model']['process_num'] == 1:
         complete_rate, success_rate, success_rate_strict, avg_return, turns, \
             avg_actions, task_success, book_acts, inform_acts, request_acts, \
-            select_acts, offer_acts, recommend_acts = evaluate(sess,
-                                                               num_dialogues=conf['model']['num_eval_dialogues'],
-                                                               sys_semantic_to_usr=conf['model'][
-                                                                   'sys_semantic_to_usr'],
-                                                               save_flag=save_eval, save_path=log_save_path, goals=goals)
-
-        total_acts = book_acts + inform_acts + request_acts + \
-            select_acts + offer_acts + recommend_acts
+                select_acts, offer_acts, recommend_acts = evaluate(sess,
+                                                num_dialogues=conf['model']['num_eval_dialogues'],
+                                                sys_semantic_to_usr=conf['model'][
+                                                    'sys_semantic_to_usr'],
+                                                save_flag=save_eval, save_path=log_save_path, goals=goals)
+
+        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
     else:
         complete_rate, success_rate, success_rate_strict, avg_return, turns, \
             avg_actions, task_success, book_acts, inform_acts, request_acts, \
             select_acts, offer_acts, recommend_acts = \
             evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])),
                                  conf['model']['process_num'], goals)
-        total_acts = book_acts + inform_acts + request_acts + \
-            select_acts + offer_acts + recommend_acts
+        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
 
         task_success_gathered = {}
         for task_dict in task_success:
@@ -199,18 +196,12 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
 
     policy_sys.is_train = True
 
-    mean_complete, err_complete = np.average(complete_rate), np.std(
-        complete_rate) / np.sqrt(len(complete_rate))
-    mean_success, err_success = np.average(success_rate), np.std(
-        success_rate) / np.sqrt(len(success_rate))
-    mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std(
-        success_rate_strict) / np.sqrt(len(success_rate_strict))
-    mean_return, err_return = np.average(avg_return), np.std(
-        avg_return) / np.sqrt(len(avg_return))
-    mean_turns, err_turns = np.average(
-        turns), np.std(turns) / np.sqrt(len(turns))
-    mean_actions, err_actions = np.average(avg_actions), np.std(
-        avg_actions) / np.sqrt(len(avg_actions))
+    mean_complete, err_complete = np.average(complete_rate), np.std(complete_rate) / np.sqrt(len(complete_rate))
+    mean_success, err_success = np.average(success_rate), np.std(success_rate) / np.sqrt(len(success_rate))
+    mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std(success_rate_strict) / np.sqrt(len(success_rate_strict))
+    mean_return, err_return = np.average(avg_return), np.std(avg_return) / np.sqrt(len(avg_return))
+    mean_turns, err_turns = np.average(turns), np.std(turns) / np.sqrt(len(turns))
+    mean_actions, err_actions = np.average(avg_actions), np.std(avg_actions) / np.sqrt(len(avg_actions))
 
     logging.info(f"Complete: {mean_complete}+-{round(err_complete, 2)}, "
                  f"Success: {mean_success}+-{round(err_success, 2)}, "
@@ -381,9 +372,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
                 complete = sess.evaluator.complete
                 task_succ = sess.evaluator.success
                 task_succ_strict = sess.evaluator.success_strict
-                if not task_succ:
-                    print("%s | %s %s | GOAL: %s" %
-                          (seed - 1000, task_succ, complete, sess.evaluator.goal))
                 break
         else:
             complete = 0
@@ -428,11 +416,10 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
     # save dialogue_info and clear mem
 
     return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \
-        task_success['total_return'], task_success['turns'], task_success['avg_actions'], task_success, \
+           task_success['total_return'], task_success['turns'], task_success['avg_actions'], task_success, \
         np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \
         np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \
-        np.average(task_success['total_offer_acts']), np.average(
-            task_success['total_recommend_acts'])
+        np.average(task_success['total_offer_acts']), np.average(task_success['total_recommend_acts'])
 
 
 def model_downloader(download_dir, model_path):
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index aa71884703ffc51be528db177f02e3f22e2ee89d..726079d1c2b04c304bfd7055d48b1b4ae4905856 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
 from pprint import pprint
 from convlab.util.file_util import cached_path
 import shutil
-# from sentence_transformers import SentenceTransformer, util
+from sentence_transformers import SentenceTransformer, util
 import torch
 from tqdm import tqdm
 
diff --git a/examples/agent_examples/test_heckmi_PPO.py b/examples/agent_examples/test_heckmi_PPO.py
deleted file mode 100755
index 39b5a79f898a8be8178a9fbac504d674168a2cd9..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_PPO.py
+++ /dev/null
@@ -1,53 +0,0 @@
-
-from convlab.dst.rule.multiwoz import RuleDST
-from convlab.policy.ppo import PPO
-from convlab.policy.rule.multiwoz import RulePolicy
-
-from convlab.dialog_agent import PipelineAgent
-from convlab.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-from datetime import datetime
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System
-    sys_nlu = None
-    sys_dst = RuleDST()
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    sys_nlg = None
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
-
-    # User Simulator
-    user_nlu = None
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    user_nlg = None
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    now = datetime.now()
-    time = now.strftime("%Y%m%d%H%M%S")
-    name = f'PPO-Seed{seed}-{time}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_PPO_semantic.py b/examples/agent_examples/test_heckmi_PPO_semantic.py
deleted file mode 100755
index 3fe3f536673016d8383ac598ff35dbb9c97b219f..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_PPO_semantic.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# available NLU models
-
-from convlab2.dst.rule.multiwoz import RuleDST
-from convlab2.policy.ppo import PPO
-from convlab2.policy.rule.multiwoz import RulePolicy
-
-from convlab2.dialog_agent import PipelineAgent
-from convlab2.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    sys_dst = RuleDST()
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    sys_nlg = None
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    user_nlg = None
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    name=f'TripPy-PPO-Seed{seed}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_SetSUMBT_PPO.py b/examples/agent_examples/test_heckmi_SetSUMBT_PPO.py
deleted file mode 100755
index 0fd5232150e35833d26e780fdd0e3060f02bd0df..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_SetSUMBT_PPO.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# available NLU models
-
-from convlab2.nlu.jointBERT.multiwoz import BERTNLU
-from convlab2.dst.setsumbt.unified_format_data.Tracker import SetSUMBTTracker
-from convlab2.policy.ppo import PPO
-from convlab2.policy.rule.multiwoz import RulePolicy
-from convlab2.nlg.template.multiwoz import TemplateNLG
-
-from convlab2.dialog_agent import PipelineAgent
-from convlab2.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-from datetime import datetime
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    sys_dst = SetSUMBTTracker(model_type='roberta',
-                              model_path="https://cloud.cs.uni-duesseldorf.de/s/Yqkzz8NW3yoMWRk/download/setsumbt_end.zip")
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    sys_nlg = TemplateNLG(is_user=False)
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True)
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    #user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
-    #                   model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    #user_nlg = None
-    user_nlg = TemplateNLG(is_user=True)
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    now = datetime.now()
-    time = now.strftime("%Y%m%d%H%M%S")
-    name = f'SetSUMBT-PPO-Rule-Seed{seed}-{time}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_TripPyDummy_PPO.py b/examples/agent_examples/test_heckmi_TripPyDummy_PPO.py
deleted file mode 100755
index 9be653273b1e96b4dd076614baa48532ca25752c..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_TripPyDummy_PPO.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# available NLU models
-
-from convlab.nlu.jointBERT.multiwoz import BERTNLU
-from convlab.dst.trippy.multiwoz import TRIPPY
-from convlab.policy.ppo import PPO
-from convlab.policy.rule.multiwoz import RulePolicy
-from convlab.nlg.template.multiwoz import TemplateNLG
-
-from convlab.dialog_agent import PipelineAgent
-from convlab.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-from datetime import datetime
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    sys_dst = TRIPPY(model_type='roberta',
-                     model_path='/gpfs/project/heckmi/trippy/multiwoz21/trippy.convlab3/results.42/checkpoint-28392',
-                     nlu_path='/home/heckmi/models/bert_multiwoz_all_context.zip')
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/models/supervised')
-    sys_nlg = None
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True)
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    user_nlg = None
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    now = datetime.now()
-    time = now.strftime("%Y%m%d%H%M%S")
-    name = f'TripPyDummy-PPO-Rule-Seed{seed}-{time}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_TripPyDummy_PPO2.py b/examples/agent_examples/test_heckmi_TripPyDummy_PPO2.py
deleted file mode 100755
index 7a79225d4581d80abcb400bda1fbd94733d51c15..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_TripPyDummy_PPO2.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# available NLU models
-
-from convlab2.nlu.jointBERT.multiwoz import BERTNLU
-from convlab2.dst.trippy.multiwoz import TRIPPY
-from convlab2.policy.ppo import PPO
-from convlab2.policy.rule.multiwoz import RulePolicy
-from convlab2.nlg.template.multiwoz import TemplateNLG
-
-from convlab2.dialog_agent import PipelineAgent
-from convlab2.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-from datetime import datetime
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    sys_dst = TRIPPY(model_type='roberta',
-                     model_path='/home/heckmi/zim/checkpoints/trippy',
-                     nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip')
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    sys_nlg = TemplateNLG(is_user=False)
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True)
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    user_nlg = TemplateNLG(is_user=True)
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    now = datetime.now()
-    time = now.strftime("%Y%m%d%H%M%S")
-    name = f'TripPyDummy2-PPO-Rule-Seed{seed}-{time}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_TripPyR_PPO.py b/examples/agent_examples/test_heckmi_TripPyR_PPO.py
deleted file mode 100755
index 97b7ea1ebb7c0aba4a78dd9ddac9dc0ba20c436a..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_TripPyR_PPO.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# available NLU models
-
-from convlab2.nlu.jointBERT.multiwoz import BERTNLU
-from convlab2.dst.trippyr.multiwoz import TRIPPYR
-from convlab2.policy.ppo import PPO
-from convlab2.policy.rule.multiwoz import RulePolicy
-from convlab2.nlg.template.multiwoz import TemplateNLG
-
-from convlab2.dialog_agent import PipelineAgent
-from convlab2.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-from datetime import datetime
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    #sys_nlu = BERTNLU()
-    sys_dst = TRIPPYR(model_type='roberta',
-                      model_path='/home/heckmi/zim/checkpoints/trippyr',
-                      nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip',
-                      emb_path='/home/heckmi/zim/checkpoints/trippyr',
-                      fp16=True)
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    #sys_nlg = None
-    sys_nlg = TemplateNLG(is_user=False)
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True)
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    #user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
-    #                   model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    #user_nlg = None
-    user_nlg = TemplateNLG(is_user=True)
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    now = datetime.now()
-    time = now.strftime("%Y%m%d%H%M%S")
-    name = f'TripPyR-PPO-Rule-Seed{seed}-{time}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_TripPy_PPO.py b/examples/agent_examples/test_heckmi_TripPy_PPO.py
deleted file mode 100755
index cf1bd546cb8ff8c354278e8be065556e7f18ac33..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_TripPy_PPO.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# available NLU models
-
-from convlab.nlu.jointBERT.multiwoz import BERTNLU
-from convlab.dst.trippy.multiwoz import TRIPPY
-from convlab.policy.ppo import PPO
-from convlab.policy.rule.multiwoz import RulePolicy
-from convlab.nlg.template.multiwoz import TemplateNLG
-
-from convlab.dialog_agent import PipelineAgent
-from convlab.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-from datetime import datetime
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    #sys_nlu = BERTNLU()
-    #sys_dst = TRIPPY(model_type='roberta',
-    #                 model_path='/home/heckmi/gcp/roberta_corrected/20200805-155426803/results.42',
-    #                 nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip')
-    sys_dst = TRIPPY(model_type='roberta',
-                     model_path='/home/heckmi/zim/checkpoints/trippy',
-                     nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip')
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    #sys_nlg = None
-    sys_nlg = TemplateNLG(is_user=False)
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True)
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    #user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
-    #                   model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    #user_nlg = None
-    user_nlg = TemplateNLG(is_user=True)
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    now = datetime.now()
-    time = now.strftime("%Y%m%d%H%M%S")
-    name = f'TripPy-PPO-Rule-Seed{seed}-{time}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)
diff --git a/examples/agent_examples/test_heckmi_TripPy_PPO_semantic.py b/examples/agent_examples/test_heckmi_TripPy_PPO_semantic.py
deleted file mode 100755
index 3fe3f536673016d8383ac598ff35dbb9c97b219f..0000000000000000000000000000000000000000
--- a/examples/agent_examples/test_heckmi_TripPy_PPO_semantic.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# available NLU models
-
-from convlab2.dst.rule.multiwoz import RuleDST
-from convlab2.policy.ppo import PPO
-from convlab2.policy.rule.multiwoz import RulePolicy
-
-from convlab2.dialog_agent import PipelineAgent
-from convlab2.util.analysis_tool.analyzer import Analyzer
-
-import random
-import numpy as np
-
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-
-def set_seed(r_seed):
-    random.seed(r_seed)
-    np.random.seed(r_seed)
-
-
-def test_end2end(seed=20200202, n_dialogues=1000):
-
-    # Dialogue System, receiving user (simulator) output
-    sys_nlu = None
-    sys_dst = RuleDST()
-    sys_policy = PPO(False, seed=seed)
-    sys_policy.load('/home/heckmi/gcp/supervised')
-    sys_nlg = None
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
-
-    # User Simulator, receiving system output
-    user_nlu = None
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    user_nlg = None
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
-
-    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
-
-    set_seed(seed)
-    name=f'TripPy-PPO-Seed{seed}'
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues)
-
-if __name__ == '__main__':
-    # Get arguments
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--seed', help='Seed', default=20200202, type=int)
-    parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int)
-    args = parser.parse_args()
-
-    test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)