diff --git a/.gitignore b/.gitignore
index 39548e53ce43c63e08a25e379b22cf6bd4ac402a..364e64ff7f4e3e260b5ffd3f8843519d251edbff 100644
--- a/.gitignore
+++ b/.gitignore
@@ -102,8 +102,8 @@ convlab/dst/trade/multiwoz_config/
 convlab/deploy/bert_multiwoz_all.zip
 convlab/deploy/templates/dialog_eg.html
 test.py
-
 *convlab/policy/vector/action_dicts
+
 *.egg-info
 pre-trained-models/
 venv
diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py
index 6216eaaac9fb81615f903f048d6d85766ce663c5..508d06b56080f72146d65a98286a0fd099f9b4c2 100755
--- a/convlab/dialog_agent/env.py
+++ b/convlab/dialog_agent/env.py
@@ -50,10 +50,10 @@ class Environment():
             observation) if self.sys_nlu else observation
         self.sys_dst.state['user_action'] = dialog_act
         state = self.sys_dst.update(dialog_act)
-        state = deepcopy(state)
+        self.sys_dst.state['history'].append(["sys", model_response])
+        self.sys_dst.state['history'].append(["usr", observation])
 
-        state['history'].append(["sys", model_response])
-        state['history'].append(["usr", observation])
+        state = deepcopy(state)
 
         terminated = self.usr.is_terminated()
 
diff --git a/convlab/dst/evaluate_unified_datasets.py b/convlab/dst/evaluate_unified_datasets.py
index d4e0720dc02bf90efcfebb1780630211f0722f7f..d4b3ad2464be6c8de0d6f1f51ecdf1bf6bfdbd0e 100644
--- a/convlab/dst/evaluate_unified_datasets.py
+++ b/convlab/dst/evaluate_unified_datasets.py
@@ -7,7 +7,6 @@ def evaluate(predict_result):
 
     metrics = {'TP':0, 'FP':0, 'FN':0}
     acc = []
-
     for sample in predict_result:
         pred_state = sample['predictions']['state']
         gold_state = sample['state']
@@ -37,7 +36,7 @@ def evaluate(predict_result):
                         flag = False
 
         acc.append(flag)
-    
+
     TP = metrics.pop('TP')
     FP = metrics.pop('FP')
     FN = metrics.pop('FN')
diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9492faa9c9a20d1c476819bb995900ca71d56607 100644
--- a/convlab/dst/setsumbt/__init__.py
+++ b/convlab/dst/setsumbt/__init__.py
@@ -0,0 +1 @@
+from convlab.dst.setsumbt.tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py
index 379057e6411082d466b10a81027bb57e4131bb9b..a41f280d3349164a2a67333d0ab176a37cbe50ea 100644
--- a/convlab/dst/setsumbt/calibration_plots.py
+++ b/convlab/dst/setsumbt/calibration_plots.py
@@ -35,7 +35,7 @@ def main():
     path = args.data_dir
 
     models = os.listdir(path)
-    models = [os.path.join(path, model, 'test.belief') for model in models]
+    models = [os.path.join(path, model, 'test.predictions') for model in models]
 
     fig = plt.figure(figsize=(14,8))
     font=20
@@ -56,16 +56,16 @@ def main():
 
 
 def get_calibration(path, device, n_bins=10, temperature=1.00):
-    logits = torch.load(path, map_location=device)
-    y_true = logits['labels']
-    logits = logits['belief_states']
+    probs = torch.load(path, map_location=device)
+    y_true = probs['state_labels']
+    probs = probs['belief_states']
 
-    y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
+    y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs}
     goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
     goal_acc = sum([goal_acc[slot] for slot in goal_acc])
     goal_acc = (goal_acc == len(y_true)).int()
 
-    scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
+    scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs]
     scores = torch.cat(scores, 0).min(0)[0]
 
     step = 1.0 / float(n_bins)
diff --git a/convlab/dst/setsumbt/configs/setsumbt_multitask.json b/convlab/dst/setsumbt/configs/setsumbt_multitask.json
new file mode 100644
index 0000000000000000000000000000000000000000..c076a557cb3e1d567784c70559fb1922fe05c545
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/setsumbt_multitask.json
@@ -0,0 +1,11 @@
+{
+  "model_type": "SetSUMBT",
+  "dataset": "multiwoz21+sgd+tm1+tm2+tm3",
+  "no_action_prediction": true,
+  "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
+  "transformers_local_files_only": true,
+  "train_batch_size": 3,
+  "dev_batch_size": 8,
+  "test_batch_size": 8,
+  "run_nbt": true
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json
new file mode 100644
index 0000000000000000000000000000000000000000..0bff751c16f0bdcdf61f04ce33d616370c0d32d8
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json
@@ -0,0 +1,11 @@
+{
+  "model_type": "SetSUMBT",
+  "dataset": "multiwoz21",
+  "no_action_prediction": true,
+  "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
+  "transformers_local_files_only": true,
+  "train_batch_size": 3,
+  "dev_batch_size": 16,
+  "test_batch_size": 16,
+  "run_nbt": true
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/configs/setsumbt_pretrain.json b/convlab/dst/setsumbt/configs/setsumbt_pretrain.json
new file mode 100644
index 0000000000000000000000000000000000000000..fdc22d157840e7494b0266d0bd99f8a99d242969
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/setsumbt_pretrain.json
@@ -0,0 +1,11 @@
+{
+  "model_type": "SetSUMBT",
+  "dataset": "sgd+tm1+tm2+tm3",
+  "no_action_prediction": true,
+  "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
+  "transformers_local_files_only": true,
+  "train_batch_size": 3,
+  "dev_batch_size": 12,
+  "test_batch_size": 12,
+  "run_nbt": true
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/configs/setsumbt_sgd.json b/convlab/dst/setsumbt/configs/setsumbt_sgd.json
new file mode 100644
index 0000000000000000000000000000000000000000..97f5818334af4c7984ec24448861b627315820e3
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/setsumbt_sgd.json
@@ -0,0 +1,11 @@
+{
+  "model_type": "SetSUMBT",
+  "dataset": "sgd",
+  "no_action_prediction": true,
+  "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
+  "transformers_local_files_only": true,
+  "train_batch_size": 3,
+  "dev_batch_size": 6,
+  "test_batch_size": 3,
+  "run_nbt": true
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/configs/setsumbt_tm.json b/convlab/dst/setsumbt/configs/setsumbt_tm.json
new file mode 100644
index 0000000000000000000000000000000000000000..138f84c358067389d5f7b478ae94c3eb2aa90ea3
--- /dev/null
+++ b/convlab/dst/setsumbt/configs/setsumbt_tm.json
@@ -0,0 +1,11 @@
+{
+  "model_type": "SetSUMBT",
+  "dataset": "tm1+tm2+tm3",
+  "no_action_prediction": true,
+  "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
+  "transformers_local_files_only": true,
+  "train_batch_size": 3,
+  "dev_batch_size": 8,
+  "test_batch_size": 8,
+  "run_nbt": true
+}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b1f93b3b39f95827cf6c09e8826383cd00b805
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/__init__.py
@@ -0,0 +1,2 @@
+from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size
+from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings
diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce150a61077ad61ab9d7af2ae3537971ae925f55
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/ontology.py
@@ -0,0 +1,134 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Create Ontology Embeddings"""
+
+import json
+import os
+import random
+from copy import deepcopy
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+
+def set_seed(args):
+    """
+    Set random seeds
+
+    Args:
+        args (Arguments class): Arguments class containing seed and number of gpus to use
+    """
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+
+def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor:
+    """
+    Embed candidates
+
+    Args:
+        candidates (list): List of candidate descriptions
+        args (argument class): Runtime arguments
+        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
+        embedding_model (transformer Model): Transformer model for embedding candidate descriptions
+
+    Returns:
+        feats (torch.tensor): Embeddings of the candidate descriptions
+    """
+    # Tokenize candidate descriptions
+    feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len,
+                                   padding='max_length', truncation='longest_first')
+             for val in candidates]
+
+    # Encode tokenized descriptions
+    with torch.no_grad():
+        feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]}
+        embedded_feats = embedding_model(**feats)  # [num_candidates, max_candidate_len, hidden_dim]
+
+    # Reduce/pool descriptions embeddings if required
+    if args.set_similarity:
+        feats = embedded_feats.last_hidden_state.detach().cpu()  # [num_candidates, max_candidate_len, hidden_dim]
+    elif args.candidate_pooling == 'cls':
+        feats = embedded_feats.pooler_output.detach().cpu()  # [num_candidates, hidden_dim]
+    elif args.candidate_pooling == "mean":
+        feats = embedded_feats.last_hidden_state.detach().cpu()
+        feats = feats.sum(1)
+        feats = torch.nn.functional.layer_norm(feats, feats.size())
+        feats = feats.detach().cpu()  # [num_candidates, hidden_dim]
+
+    return feats
+
+
+def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True):
+    """
+    Get embeddings for slots and candidates
+
+    Args:
+        ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets
+        set_type (str): Subset of the dataset being used (train/validation/test)
+        args (argument class): Runtime arguments
+        tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
+        embedding_model (transformer Model): Transormer model for embedding candidate descriptions
+        save_to_file (bool): Indication of whether to save information to file
+
+    Returns:
+        slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
+    """
+    # Set model to eval mode
+    embedding_model.eval()
+
+    slots = dict()
+    for domain, subset in tqdm(ontology.items(), desc='Domains'):
+        for slot, slot_info in tqdm(subset.items(), desc='Slots'):
+            # Get description or use "domain-slot"
+            if args.use_descriptions:
+                desc = slot_info['description']
+            else:
+                desc = f"{domain}-{slot}"
+
+            # Encode domain-slot pair description
+            slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0]
+
+            # Obtain possible value set and discard requestable value
+            values = deepcopy(slot_info['possible_values'])
+            is_requestable = False
+            if '?' in values:
+                is_requestable = True
+                values.remove('?')
+
+            # Encode value candidates
+            if values:
+                feats = encode_candidates(values, args, tokenizer, embedding_model)
+            else:
+                feats = None
+
+            # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot
+            slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
+
+    # Dump tensors and ontology for use in training and evaluation
+    if save_to_file:
+        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
+        torch.save(slots, writer)
+
+        writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w')
+        json.dump(ontology, writer, indent=2)
+        writer.close()
+    
+    return slots
diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3a68c3b2e627ac60f555a642dfa837734249b6
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/unified_format.py
@@ -0,0 +1,429 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convlab3 Unified Format Dialogue Datasets"""
+
+from copy import deepcopy
+
+import torch
+import transformers
+from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
+from transformers.tokenization_utils import PreTrainedTokenizer
+from tqdm import tqdm
+
+from convlab.util import load_dataset
+from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
+                                                get_values_from_data, ontology_add_requestable_slots,
+                                                get_requestable_slots, load_dst_data, extract_dialogues,
+                                                combine_value_sets, IdTensor)
+
+transformers.logging.set_verbosity_error()
+
+
+def convert_examples_to_features(data: list,
+                                 ontology: dict,
+                                 tokenizer: PreTrainedTokenizer,
+                                 max_turns: int = 12,
+                                 max_seq_len: int = 64) -> dict:
+    """
+    Convert dialogue examples to model input features and labels
+
+    Args:
+        data (list): List of all extracted dialogues
+        ontology (dict): Ontology dictionary containing slots, slot descriptions and
+        possible value sets including requests
+        tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used
+        max_turns (int): Maximum numbers of turns in a dialogue
+        max_seq_len (int): Maximum number of tokens in a dialogue turn
+
+    Returns:
+        features (dict): All inputs and labels required to train the model
+    """
+    features = dict()
+    ontology = deepcopy(ontology)
+
+    # Get encoder input for system, user utterance pairs
+    input_feats = []
+    for dial in tqdm(data):
+        dial_feats = []
+        for turn in dial:
+            if len(turn['system_utterance']) == 0:
+                usr = turn['user_utterance']
+                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True,
+                                                        max_length=max_seq_len, padding='max_length',
+                                                        truncation='longest_first'))
+            else:
+                usr = turn['user_utterance']
+                sys = turn['system_utterance']
+                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True,
+                                                        max_length=max_seq_len, padding='max_length',
+                                                        truncation='longest_first'))
+            # Truncate
+            if len(dial_feats) >= max_turns:
+                break
+        input_feats.append(dial_feats)
+    del dial_feats
+
+    # Perform turn level padding
+    dial_ids = list()
+    for dial in data:
+        _ids = [turn['dialogue_id'] for turn in dial][:max_turns]
+        _ids += [''] * (max_turns - len(_ids))
+        dial_ids.append(_ids)
+    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                 for dial in input_feats]
+    if 'token_type_ids' in input_feats[0][0]:
+        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                          for dial in input_feats]
+    else:
+        token_type_ids = None
+    if 'attention_mask' in input_feats[0][0]:
+        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
+                          for dial in input_feats]
+    else:
+        attention_mask = None
+    del input_feats
+
+    # Create torch data tensors
+    features['dialogue_ids'] = IdTensor(dial_ids)
+    features['input_ids'] = torch.tensor(input_ids)
+    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
+    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
+    del input_ids, token_type_ids, attention_mask
+
+    # Extract all informable and requestable slots from the ontology
+    informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
+                        if ontology[domain][slot]['possible_values']
+                        and ontology[domain][slot]['possible_values'] != ['?']]
+    requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
+                         if '?' in ontology[domain][slot]['possible_values']]
+    for slot in requestable_slots:
+        domain, slot = slot.split('-', 1)
+        ontology[domain][slot]['possible_values'].remove('?')
+
+    # Extract a list of domains from the ontology slots
+    domains = list(set(informable_slots + requestable_slots))
+    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
+
+    # Create slot labels
+    for domslot in tqdm(informable_slots):
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial:
+                value = [v for d, substate in turn['state'].items() for s, v in substate.items()
+                         if f'{d}-{s}' == domslot]
+                domain, slot = domslot.split('-', 1)
+                if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
+                    value = value[0] if value else 'none'
+                else:
+                    value = -1
+                if value in ontology[domain][slot]['possible_values'] and value != -1:
+                    value = ontology[domain][slot]['possible_values'].index(value)
+                else:
+                    value = -1  # If value is not in ontology then we do not penalise the model
+                labs.append(value)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['state_labels-' + domslot] = labels
+
+    # Create requestable slot labels
+    for domslot in tqdm(requestable_slots):
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial:
+                domain, slot = domslot.split('-', 1)
+                if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
+                    acts = [act['intent'] for act in turn['dialogue_acts']
+                            if act['domain'] == domain and act['slot'] == slot]
+                    if acts:
+                        act_ = acts[0]
+                        if act_ == 'request':
+                            labs.append(1)
+                        else:
+                            labs.append(0)
+                    else:
+                        labs.append(0)
+                else:
+                    labs.append(-1)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['request_labels-' + domslot] = labels
+
+    # General act labels (1-goodbye, 2-thank you)
+    labels = []
+    for dial in tqdm(data):
+        labs = []
+        for turn in dial:
+            acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
+            if acts:
+                if 'bye' in acts:
+                    labs.append(1)
+                else:
+                    labs.append(2)
+            else:
+                labs.append(0)
+            if len(labs) >= max_turns:
+                break
+        labs = labs + [-1] * (max_turns - len(labs))
+        labels.append(labs)
+
+    labels = torch.tensor(labels)
+    features['general_act_labels'] = labels
+
+    # Create active domain labels
+    for domain in tqdm(domains):
+        labels = []
+        for dial in data:
+            labs = []
+            for turn in dial:
+                possible_domains = list()
+                for dom in ontology:
+                    for slt in ontology[dom]:
+                        if turn['dataset_name'] in ontology[dom][slt]['dataset_names']:
+                            possible_domains.append(dom)
+
+                if domain in turn['active_domains']:
+                    labs.append(1)
+                elif domain in possible_domains:
+                    labs.append(0)
+                else:
+                    labs.append(-1)
+                if len(labs) >= max_turns:
+                    break
+            labs = labs + [-1] * (max_turns - len(labs))
+            labels.append(labs)
+
+        labels = torch.tensor(labels)
+        features['active_domain_labels-' + domain] = labels
+
+    del labels
+
+    return features
+
+
+class UnifiedFormatDataset(Dataset):
+    """
+    Class for preprocessing, and storing data easily from the Convlab3 unified format.
+
+    Attributes:
+        dataset_dict (dict): Dictionary containing all the data in dataset
+        ontology (dict): Set of all domain-slot-value triplets in the ontology of the model
+        features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model
+    """
+    def __init__(self,
+                 dataset_name: str,
+                 set_type: str,
+                 tokenizer: PreTrainedTokenizer,
+                 max_turns: int = 12,
+                 max_seq_len: int = 64,
+                 train_ratio: float = 1.0,
+                 seed: int = 0,
+                 data: dict = None,
+                 ontology: dict = None):
+        """
+        Args:
+            dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +)
+            set_type (str): Subset of the dataset to load (train, validation or test)
+            tokenizer (transformers tokenizer): Tokenizer for the encoder model used
+            max_turns (int): Maximum numbers of turns in a dialogue
+            max_seq_len (int): Maximum number of tokens in a dialogue turn
+            train_ratio (float): Fraction of training data to use during training
+            seed (int): Seed governing random order of ids for subsampling
+            data (dict): Dataset features for loading from dict
+            ontology (dict): Ontology dict for loading from dict
+        """
+        if data is not None:
+            self.ontology = ontology
+            self.features = data
+        else:
+            if '+' in dataset_name:
+                dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')]
+            else:
+                dataset_args = [{"dataset_name": dataset_name}]
+            self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
+            self.ontology = get_ontology_slots(dataset_name)
+            values = [get_values_from_data(dataset, set_type) for dataset in self.dataset_dicts]
+            self.ontology = ontology_add_values(self.ontology, combine_value_sets(values), set_type)
+            self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts))
+
+            if train_ratio != 1.0:
+                for dataset_args_ in dataset_args:
+                    dataset_args_['dial_ids_order'] = seed
+                    dataset_args_['split2ratio'] = {'train': train_ratio, 'validation': train_ratio}
+            self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
+
+            data = [load_dst_data(dataset_dict, data_split=set_type, speaker='all',
+                                  dialogue_acts=True, split_to_turn=False)
+                    for dataset_dict in self.dataset_dicts]
+            data_list = [data_[set_type] for data_ in data]
+
+            data = []
+            for idx, data_ in enumerate(data_list):
+                data += extract_dialogues(data_, dataset_args[idx]["dataset_name"])
+            self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len)
+
+    def __getitem__(self, index: int) -> dict:
+        """
+        Obtain dialogues with specific ids from dataset
+
+        Args:
+            index (int/list/tensor): Index/indices of dialogues to get
+
+        Returns:
+            features (dict): All inputs and labels required to train the model
+        """
+        return {label: self.features[label][index] for label in self.features
+                if self.features[label] is not None}
+
+    def __len__(self):
+        """
+        Get number of dialogues in the dataset
+
+        Returns:
+            len (int): Number of dialogues in the dataset object
+        """
+        return self.features['input_ids'].size(0)
+
+    def resample(self, size: int = None) -> Dataset:
+        """
+        Resample subset of the dataset
+
+        Args:
+            size (int): Number of dialogues to sample
+
+        Returns:
+            self (Dataset): Dataset object
+        """
+        # If no subset size is specified we resample a set with the same size as the full dataset
+        n_dialogues = self.__len__()
+        if not size:
+            size = n_dialogues
+
+        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
+        self.features = self.__getitem__(dialogues)
+        
+        return self
+
+    def to(self, device):
+        """
+        Map all data to a device
+
+        Args:
+            device (torch device): Device to map data to
+        """
+        self.device = device
+        self.features = {label: self.features[label].to(device) for label in self.features
+                         if self.features[label] is not None}
+
+    @classmethod
+    def from_datadict(cls, data: dict, ontology: dict):
+        return cls(None, None, None, data=data, ontology=ontology)
+
+
+def get_dataloader(dataset_name: str,
+                   set_type: str,
+                   batch_size: int,
+                   tokenizer: PreTrainedTokenizer,
+                   max_turns: int = 12,
+                   max_seq_len: int = 64,
+                   device='cpu',
+                   resampled_size: int = None,
+                   train_ratio: float = 1.0,
+                   seed: int = 0) -> DataLoader:
+    '''
+    Module to create torch dataloaders
+
+    Args:
+        dataset_name (str): Name of the dataset to load
+        set_type (str): Subset of the dataset to load (train, validation or test)
+        batch_size (int): Batch size for the dataloader
+        tokenizer (transformers tokenizer): Tokenizer for the encoder model used
+        max_turns (int): Maximum numbers of turns in a dialogue
+        max_seq_len (int): Maximum number of tokens in a dialogue turn
+        device (torch device): Device to map data to
+        resampled_size (int): Number of dialogues to sample
+        train_ratio (float): Ratio of training data to use for training
+        seed (int): Seed governing random order of ids for subsampling
+
+    Returns:
+        loader (torch dataloader): Dataloader to train and evaluate the setsumbt model
+    '''
+    data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio,
+                                seed=seed)
+    data.to(device)
+
+    if resampled_size:
+        data = data.resample(resampled_size)
+
+    if set_type in ['test', 'validation']:
+        sampler = SequentialSampler(data)
+    else:
+        sampler = RandomSampler(data)
+    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
+
+    return loader
+
+
+def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader:
+    """
+    Change the batch size of a preloaded loader
+
+    Args:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+        batch_size (int): Batch size for the dataloader
+
+    Returns:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+    """
+
+    if 'SequentialSampler' in str(loader.sampler):
+        sampler = SequentialSampler(loader.dataset)
+    else:
+        sampler = RandomSampler(loader.dataset)
+    loader = DataLoader(loader.dataset, sampler=sampler, batch_size=batch_size)
+
+    return loader
+
+def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader:
+    """
+    Sample a subset of the dialogues in a dataloader
+
+    Args:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+        sample_size (int): Number of dialogues to sample
+
+    Returns:
+        loader (DataLoader): Dataloader to train and evaluate the setsumbt model
+    """
+
+    dataset = loader.dataset.resample(sample_size)
+
+    if 'SequentialSampler' in str(loader.sampler):
+        sampler = SequentialSampler(dataset)
+    else:
+        sampler = RandomSampler(dataset)
+    loader = DataLoader(loader.dataset, sampler=sampler, batch_size=loader.batch_size)
+
+    return loader
diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..96773d6b9b181925b3004e4971e440d9c7720bfb
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/utils.py
@@ -0,0 +1,442 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convlab3 Unified dataset data processing utilities"""
+
+import numpy
+import pdb
+
+from convlab.util import load_ontology, load_dst_data, load_nlu_data
+from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
+
+
+def get_ontology_slots(dataset_name: str) -> dict:
+    """
+    Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
+
+    Args:
+        dataset_name (str): Dataset name
+
+    Returns:
+        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
+    """
+    dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name]
+    ontology_slots = dict()
+    for dataset_name in dataset_names:
+        ontology = load_ontology(dataset_name)
+        domains = [domain for domain in ontology['domains'] if domain not in ['booking', 'general']]
+        for domain in domains:
+            domain_name = DOMAINS_MAP.get(domain, domain.lower())
+            if domain_name not in ontology_slots:
+                ontology_slots[domain_name] = dict()
+            for slot, slot_info in ontology['domains'][domain]['slots'].items():
+                if slot not in ontology_slots[domain_name]:
+                    ontology_slots[domain_name][slot] = {'description': slot_info['description'],
+                                                         'possible_values': list(),
+                                                         'dataset_names': list()}
+                if slot_info['is_categorical']:
+                    ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values']
+
+                ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values']))
+                ontology_slots[domain_name][slot]['dataset_names'].append(dataset_name)
+
+    return ontology_slots
+
+
+def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
+    """
+    Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
+
+    Args:
+        dataset (dict): Dataset dictionary obtained using the load_dataset function
+        data_split (str): Dataset split: train/validation/test
+
+    Returns:
+        value_sets (dict): Dictionary containing possible values obtained from dataset
+    """
+    data = load_dst_data(dataset, data_split='all', speaker='user')
+
+    # Remove test data from the data when building training/validation ontology
+    if data_split == 'train':
+        data = {key: itm for key, itm in data.items() if key == 'train'}
+    elif data_split == 'validation':
+        data = {key: itm for key, itm in data.items() if key in ['train', 'validation']}
+
+    value_sets = {}
+    for set_type, dataset in data.items():
+        for turn in dataset:
+            for domain, substate in turn['state'].items():
+                domain_name = DOMAINS_MAP.get(domain, domain.lower())
+                if domain_name not in value_sets:
+                    value_sets[domain_name] = {}
+                for slot, value in substate.items():
+                    if slot not in value_sets[domain_name]:
+                        value_sets[domain_name][slot] = []
+                    if value and value not in value_sets[domain_name][slot]:
+                        value_sets[domain_name][slot].append(value)
+            # pdb.set_trace()
+
+    return clean_values(value_sets)
+
+
+def combine_value_sets(value_sets: list) -> dict:
+    """
+    Function to combine value sets extracted from different datasets
+
+    Args:
+        value_sets (list): List of value sets extracted using the get_values_from_data function
+
+    Returns:
+        value_set (dict): Dictionary containing possible values obtained from datasets
+    """
+    value_set = value_sets[0]
+    for _value_set in value_sets[1:]:
+        for domain, domain_info in _value_set.items():
+            for slot, possible_values in domain_info.items():
+                if domain not in value_set:
+                    value_set[domain] = dict()
+                if slot not in value_set[domain]:
+                    value_set[domain][slot] = list()
+                value_set[domain][slot] += _value_set[domain][slot]
+                value_set[domain][slot] = list(set(value_set[domain][slot]))
+
+    return value_set
+
+
+def clean_values(value_sets: dict, value_map: dict = VALUE_MAP) -> dict:
+    """
+    Function to clean up the possible value sets extracted from the states in the dataset
+
+    Args:
+        value_sets (dict): Dictionary containing possible values obtained from dataset
+        value_map (dict): Label map to avoid duplication and typos in values
+
+    Returns:
+        clean_vals (dict): Cleaned Dictionary containing possible values obtained from dataset
+    """
+    clean_vals = {}
+    for domain, subset in value_sets.items():
+        clean_vals[domain] = {}
+        for slot, values in subset.items():
+            # Remove pipe separated values
+            values = list(set([val.split('|', 1)[0] for val in values]))
+
+            # Map values using value_map
+            for old, new in value_map.items():
+                values = list(set([val.replace(old, new) for val in values]))
+
+            # Remove empty and dontcare from possible value sets
+            values = [val for val in values if val not in ['', 'dontcare']]
+
+            # MultiWOZ specific value sets for quantity, time and boolean slots
+            if 'people' in slot or 'duration' in slot or 'stay' in slot:
+                values = QUANTITIES
+            elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
+                values = TIME
+            elif 'parking' in slot or 'internet' in slot:
+                values = ['yes', 'no']
+
+            clean_vals[domain][slot] = values
+
+    return clean_vals
+
+
+def ontology_add_values(ontology_slots: dict, value_sets: dict, data_split: str = "train") -> dict:
+    """
+    Add value sets obtained from the dataset to the ontology
+    Args:
+        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
+        value_sets (dict): Cleaned Dictionary containing possible values obtained from dataset
+        data_split (str): Dataset split: train/validation/test
+
+    Returns:
+        ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and possible value sets
+    """
+    ontology = {}
+    for domain in sorted(ontology_slots):
+        if data_split in ['train', 'validation']:
+            if domain not in value_sets:
+                continue
+            possible_values = [v for slot, vals in value_sets[domain].items() for v in vals]
+            if len(possible_values) == 0:
+                continue
+        ontology[domain] = {}
+        for slot in sorted(ontology_slots[domain]):
+            if not ontology_slots[domain][slot]['possible_values']:
+                if domain in value_sets:
+                    if slot in value_sets[domain]:
+                        ontology_slots[domain][slot]['possible_values'] = value_sets[domain][slot]
+            if ontology_slots[domain][slot]['possible_values']:
+                values = sorted(ontology_slots[domain][slot]['possible_values'])
+                ontology_slots[domain][slot]['possible_values'] = ['none', 'do not care'] + values
+
+            ontology[domain][slot] = ontology_slots[domain][slot]
+
+    return ontology
+
+
+def get_requestable_slots(datasets: list) -> dict:
+    """
+    Function to get set of requestable slots from the dataset action labels.
+    Args:
+        datasets (dict): Dataset dictionary obtained using the load_dataset function
+
+    Returns:
+        slots (dict): Dictionary containing requestable domain-slot pairs
+    """
+    datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets]
+
+    slots = {}
+    for data in datasets:
+        for set_type, subset in data.items():
+            for turn in subset:
+                requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request']
+                requests += [act for act in turn['dialogue_acts']['non-categorical'] if act['intent'] == 'request']
+                requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request']
+                requests = [(act['domain'], act['slot']) for act in requests]
+                for domain, slot in requests:
+                    domain_name = DOMAINS_MAP.get(domain, domain.lower())
+                    if domain_name not in slots:
+                        slots[domain_name] = []
+                    slots[domain_name].append(slot)
+
+    slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()}
+
+    return slots
+
+
+def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict) -> dict:
+    """
+    Add requestable slots obtained from the dataset to the ontology
+    Args:
+        ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
+        requestable_slots (dict): Dictionary containing requestable domain-slot pairs
+
+    Returns:
+        ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and
+        possible value sets including requests
+    """
+    for domain in ontology_slots:
+        for slot in ontology_slots[domain]:
+            if domain in requestable_slots:
+                if slot in requestable_slots[domain]:
+                    ontology_slots[domain][slot]['possible_values'].append('?')
+
+    return ontology_slots
+
+
+def extract_turns(dialogue: list, dataset_name: str, dialogue_id: str) -> list:
+    """
+    Extract the required information from the data provided by unified loader
+    Args:
+        dialogue (list): List of turns within a dialogue
+        dataset_name (str): Name of the dataset to which the dialogue belongs
+        dialogue_str (str): ID of the dialogue
+
+    Returns:
+        turns (list): List of turns within a dialogue
+    """
+    turns = []
+    turn_info = {}
+    for turn in dialogue:
+        if turn['speaker'] == 'system':
+            turn_info['system_utterance'] = turn['utterance']
+
+        # System utterance in the first turn is always empty as conversation is initiated by the user
+        if turn['utt_idx'] == 1:
+            turn_info['system_utterance'] = ''
+
+        if turn['speaker'] == 'user':
+            turn_info['user_utterance'] = turn['utterance']
+
+            # Inform acts not required by model
+            turn_info['dialogue_acts'] = [act for act in turn['dialogue_acts']['categorical']
+                                          if act['intent'] not in ['inform']]
+            turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['non-categorical']
+                                           if act['intent'] not in ['inform']]
+            turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['binary']
+                                           if act['intent'] not in ['inform']]
+
+            turn_info['state'] = turn['state']
+            turn_info['dataset_name'] = dataset_name
+            turn_info['dialogue_id'] = dialogue_id
+
+        if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
+            turns.append(turn_info)
+            turn_info = {}
+
+    return turns
+
+
+def clean_states(turns: list) -> list:
+    """
+    Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology)
+    Args:
+        turns (list): List of turns within a dialogue
+
+    Returns:
+        clean_turns (list): List of turns within a dialogue
+    """
+    clean_turns = []
+    for turn in turns:
+        clean_state = {}
+        clean_acts = []
+        for act in turn['dialogue_acts']:
+            domain = act['domain']
+            act['domain'] = DOMAINS_MAP.get(domain, domain.lower())
+            clean_acts.append(act)
+        for domain, subset in turn['state'].items():
+            domain_name = DOMAINS_MAP.get(domain, domain.lower())
+            clean_state[domain_name] = {}
+            for slot, value in subset.items():
+                # Remove pipe separated values
+                value = value.split('|', 1)[0]
+
+                # Map values using value_map
+                for old, new in VALUE_MAP.items():
+                    value = value.replace(old, new)
+
+                # Map dontcare to "do not care" and empty to 'none'
+                value = value.replace('dontcare', 'do not care')
+                value = value if value else 'none'
+
+                # Map quantity values to the integer quantity value
+                if 'people' in slot or 'duration' in slot or 'stay' in slot:
+                    try:
+                        if value not in ['do not care', 'none']:
+                            value = int(value)
+                            value = str(value) if value < 10 else QUANTITIES[-1]
+                    except:
+                        value = value
+                # Map time values to the most appropriate value in the standard time set
+                elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
+                    try:
+                        if value not in ['do not care', 'none']:
+                            # Strip after/before from time value
+                            value = value.replace('after ', '').replace('before ', '')
+                            # Extract hours and minutes from different possible formats
+                            if ':' not in value and len(value) == 4:
+                                h, m = value[:2], value[2:]
+                            elif len(value) == 1:
+                                h = int(value)
+                                m = 0
+                            elif 'pm' in value:
+                                h = int(value.replace('pm', '')) + 12
+                                m = 0
+                            elif 'am' in value:
+                                h = int(value.replace('pm', ''))
+                                m = 0
+                            elif ':' in value:
+                                h, m = value.split(':')
+                            elif ';' in value:
+                                h, m = value.split(';')
+                            # Map to closest 5 minutes
+                            if int(m) % 5 != 0:
+                                m = round(int(m) / 5) * 5
+                                h = int(h)
+                                if m == 60:
+                                    m = 0
+                                    h += 1
+                                if h >= 24:
+                                    h -= 24
+                            # Set in standard 24 hour format
+                            h, m = int(h), int(m)
+                            value = '%02i:%02i' % (h, m)
+                    except:
+                        value = value
+                # Map boolean slots to yes/no value
+                elif 'parking' in slot or 'internet' in slot:
+                    if value not in ['do not care', 'none']:
+                        if value == 'free':
+                            value = 'yes'
+                        elif True in [v in value.lower() for v in ['yes', 'no']]:
+                            value = [v for v in ['yes', 'no'] if v in value][0]
+
+                clean_state[domain_name][slot] = value
+        turn['state'] = clean_state
+        turn['dialogue_acts'] = clean_acts
+        clean_turns.append(turn)
+
+    return clean_turns
+
+
+def get_active_domains(turns: list) -> list:
+    """
+    Get active domains at each turn in a dialogue
+    Args:
+        turns (list): List of turns within a dialogue
+
+    Returns:
+        turns (list): List of turns within a dialogue
+    """
+    for turn_id in range(len(turns)):
+        # At first turn all domains with not none values in the state are active
+        if turn_id == 0:
+            domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none']
+            domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']]
+            domains = [DOMAINS_MAP.get(domain, domain.lower()) for domain in domains]
+            turns[turn_id]['active_domains'] = list(set(domains))
+        else:
+            # Use changes in domains to identify active domains
+            domains = []
+            for domain, substate in turns[turn_id]['state'].items():
+                domain_name = DOMAINS_MAP.get(domain, domain.lower())
+                for slot, value in substate.items():
+                    if value != turns[turn_id - 1]['state'][domain][slot]:
+                        val = value
+                    else:
+                        val = 'none'
+                    if value == 'none':
+                        val = 'none'
+                    if val != 'none':
+                        domains.append(domain_name)
+            # Add all domains activated by a user action
+            domains += [act['domain'] for act in turns[turn_id]['dialogue_acts']
+                        if act['domain'] in turns[turn_id]['state']]
+            turns[turn_id]['active_domains'] = list(set(domains))
+
+    return turns
+
+
+class IdTensor:
+    def __init__(self, values):
+        self.values = numpy.array(values)
+
+    def __getitem__(self, index: int):
+        return self.values[index].tolist()
+
+    def to(self, device):
+        return self
+
+
+def extract_dialogues(data: list, dataset_name: str) -> list:
+    """
+    Extract all dialogues from dataset
+    Args:
+        data (list): List of all dialogues in a subset of the data
+        dataset_name (str): Name of the dataset to which the dialogues belongs
+
+    Returns:
+        dialogues (list): List of all extracted dialogues
+    """
+    dialogues = []
+    for dial in data:
+        dial_id = dial['dialogue_id']
+        turns = extract_turns(dial['turns'], dataset_name, dial_id)
+        turns = clean_states(turns)
+        turns = get_active_domains(turns)
+        dialogues.append(turns)
+
+    return dialogues
diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/dataset/value_maps.py
new file mode 100644
index 0000000000000000000000000000000000000000..619600a7b0a57096918058ff117aa2ca5aac864a
--- /dev/null
+++ b/convlab/dst/setsumbt/dataset/value_maps.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convlab3 Unified dataset value maps"""
+
+
+# MultiWOZ specific label map to avoid duplication and typos in values
+VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
+             'cityroomz': 'city roomz', '  ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
+             'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
+             'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
+             'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
+             'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
+
+
+# Domain map for SGD and TM Data
+DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus',
+               'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events',
+               'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights',
+               'Flights_4': 'flights', 'Homes_1': 'homes', 'Homes_2': 'homes', 'Hotels_1': 'hotel',
+               'Hotels_2': 'hotel', 'Hotels_3': 'hotel', 'Hotels_4': 'hotel', 'Media_1': 'media',
+               'Media_2': 'media', 'Media_3': 'media', 'Messaging_1': 'messaging', 'Movies_1': 'movies',
+               'Movies_2': 'movies', 'Movies_3': 'movies', 'Music_1': 'music', 'Music_2': 'music', 'Music_3': 'music',
+               'Payment_1': 'payment', 'RentalCars_1': 'rentalcars', 'RentalCars_2': 'rentalcars',
+               'RentalCars_3': 'rentalcars', 'Restaurants_1': 'restaurant', 'Restaurants_2': 'restaurant',
+               'RideSharing_1': 'ridesharing', 'RideSharing_2': 'ridesharing', 'Services_1': 'services',
+               'Services_2': 'services', 'Services_3': 'services', 'Services_4': 'services', 'Trains_1': 'train',
+               'Travel_1': 'travel', 'Weather_1': 'weather', 'movie_ticket': 'movies',
+               'restaurant_reservation': 'restaurant', 'coffee_ordering': 'coffee', 'pizza_ordering': 'takeout',
+               'auto_repair': 'car_repairs', 'flights': 'flights', 'food-ordering': 'takeout', 'hotels': 'hotel',
+               'movies': 'movies', 'music': 'music', 'restaurant-search': 'restaurant', 'sports': 'sports',
+               'movie': 'movies'}
+
+
+# Generic value sets for quantity and time slots
+QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
+TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
+TIME = ['%02i:%02i' % t for l in TIME for t in l]
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py
index e0d87bb964041cae19d58351cd6b31f6d836f125..2279e22265ea417ebe9a13e63837a625f858e73d 100644
--- a/convlab/dst/setsumbt/distillation_setup.py
+++ b/convlab/dst/setsumbt/distillation_setup.py
@@ -1,53 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Get ensemble predictions and build distillation dataloaders"""
+
 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
 import os
+import json
 
 import torch
-import transformers
-from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
-from transformers import RobertaConfig, BertConfig
+from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
 from tqdm import tqdm
 
-import convlab
-from convlab.dst.setsumbt.multiwoz.dataset.multiwoz21 import EnsembleMultiWoz21
+from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size
 from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT
+from convlab.dst.setsumbt.modeling import training
 
-DEVICE = 'cuda'
-
-
-def args_parser():
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--model_path', type=str)
-    parser.add_argument('--model_type', type=str)
-    parser.add_argument('--set_type', type=str)
-    parser.add_argument('--batch_size', type=int)
-    parser.add_argument('--ensemble_size', type=int)
-    parser.add_argument('--reduction', type=str, default='mean')
-    parser.add_argument('--get_ensemble_distributions', action='store_true')
-    parser.add_argument('--build_dataloaders', action='store_true')
-    
-    return parser.parse_args()
-
-
-def main():
-    args = args_parser()
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
-    if args.get_ensemble_distributions:
-        get_ensemble_distributions(args)
-    elif args.build_dataloaders:
-        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
-        data = torch.load(path)
-        loader = get_loader(data, args.set_type, args.batch_size)
 
-        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
-        torch.save(loader, path)
-    else:
-        raise NameError("NotImplemented")
+def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader:
+    """
+    Build dataloader from ensemble prediction data
 
+    Args:
+        data: Dictionary of ensemble predictions
+        ontology: Data ontology
+        set_type: Data subset (train/validation/test)
+        batch_size: Number of dialogues per batch
 
-def get_loader(data, set_type='train', batch_size=3):
+    Returns:
+        loader: Data loader object
+    """
     data = flatten_data(data)
     data = do_label_padding(data)
-    data = EnsembleMultiWoz21(data)
+    data = UnifiedFormatDataset.from_datadict(data, ontology)
     if set_type == 'train':
         sampler = RandomSampler(data)
     else:
@@ -57,7 +55,16 @@ def get_loader(data, set_type='train', batch_size=3):
     return loader
 
 
-def do_label_padding(data):
+def do_label_padding(data: dict) -> dict:
+    """
+    Add padding to the ensemble predictions (used as labels in distillation)
+
+    Args:
+        data: Dictionary of ensemble predictions
+
+    Returns:
+        data: Padded ensemble predictions
+    """
     if 'attention_mask' in data:
         dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
     else:
@@ -70,13 +77,17 @@ def do_label_padding(data):
     return data
 
 
-map_dict = {'belief_state': 'belief', 'greeting_act_belief': 'goodbye_belief',
-            'state_labels': 'labels', 'request_labels': 'request',
-            'domain_labels': 'active', 'greeting_labels': 'goodbye'}
-def flatten_data(data):
+def flatten_data(data: dict) -> dict:
+    """
+    Map data to flattened feature format used in training
+    Args:
+        data: Ensemble prediction data
+
+    Returns:
+        data: Flattened ensemble prediction data
+    """
     data_new = dict()
     for label, feats in data.items():
-        label = map_dict.get(label, label)
         if type(feats) == dict:
             for label_, feats_ in feats.items():
                 data_new[label + '-' + label_] = feats_
@@ -87,13 +98,11 @@ def flatten_data(data):
 
 
 def get_ensemble_distributions(args):
-    if args.model_type == 'roberta':
-        config = RobertaConfig
-    elif args.model_type == 'bert':
-        config = BertConfig
-    config = config.from_pretrained(args.model_path)
-    config.ensemble_size = args.ensemble_size
-
+    """
+    Load data and get ensemble predictions
+    Args:
+        args: Runtime arguments
+    """
     device = DEVICE
 
     model = EnsembleSetSUMBT.from_pretrained(args.model_path)
@@ -107,16 +116,10 @@ def get_ensemble_distributions(args):
     dataloader = torch.load(dataloader)
     database = torch.load(database)
 
-    # Get slot and value embeddings
-    slots = {slot: val for slot, val in database.items()}
-    values = {slot: val[1] for slot, val in database.items()}
-    del database
+    if dataloader.batch_size != args.batch_size:
+        dataloader = change_batch_size(dataloader, args.batch_size)
 
-    # Load model ontology
-    model.add_slot_candidates(slots)
-    for slot in model.informable_slot_ids:
-        model.add_value_candidates(slot, values[slot], replace=True)
-    del slots, values
+    training.set_ontology_embeddings(model, database)
 
     print('Environment set up.')
 
@@ -125,18 +128,24 @@ def get_ensemble_distributions(args):
     attention_mask = []
     state_labels = {slot: [] for slot in model.informable_slot_ids}
     request_labels = {slot: [] for slot in model.requestable_slot_ids}
-    domain_labels = {domain: [] for domain in model.domain_ids}
-    greeting_labels = []
+    active_domain_labels = {domain: [] for domain in model.domain_ids}
+    general_act_labels = []
+
+    is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None
+
     belief_state = {slot: [] for slot in model.informable_slot_ids}
-    request_belief = {slot: [] for slot in model.requestable_slot_ids}
-    domain_belief = {domain: [] for domain in model.domain_ids}
-    greeting_act_belief = []
+    request_probs = {slot: [] for slot in model.requestable_slot_ids}
+    active_domain_probs = {domain: [] for domain in model.domain_ids}
+    general_act_probs = []
     model.eval()
     for batch in tqdm(dataloader, desc='Batch:'):
         ids = batch['input_ids']
         tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None
         mask = batch['attention_mask'] if 'attention_mask' in batch else None
 
+        if 'is_noisy' in batch:
+            is_noisy.append(batch['is_noisy'])
+
         input_ids.append(ids)
         token_type_ids.append(tt_ids)
         attention_mask.append(mask)
@@ -146,61 +155,123 @@ def get_ensemble_distributions(args):
         mask = mask.to(device) if mask is not None else None
 
         for slot in state_labels:
-            state_labels[slot].append(batch['labels-' + slot])
-        if model.config.predict_intents:
+            state_labels[slot].append(batch['state_labels-' + slot])
+        if model.config.predict_actions:
             for slot in request_labels:
-                request_labels[slot].append(batch['request-' + slot])
-            for domain in domain_labels:
-                domain_labels[domain].append(batch['active-' + domain])
-            greeting_labels.append(batch['goodbye'])
+                request_labels[slot].append(batch['request_labels-' + slot])
+            for domain in active_domain_labels:
+                active_domain_labels[domain].append(batch['active_domain_labels-' + domain])
+            general_act_labels.append(batch['general_act_labels'])
 
         with torch.no_grad():
-            p, p_req, p_dom, p_bye, _ = model(ids, mask, tt_ids,
-                                            reduction=args.reduction)
+            p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction)
 
         for slot in belief_state:
             belief_state[slot].append(p[slot].cpu())
-        if model.config.predict_intents:
-            for slot in request_belief:
-                request_belief[slot].append(p_req[slot].cpu())
-            for domain in domain_belief:
-                domain_belief[domain].append(p_dom[domain].cpu())
-            greeting_act_belief.append(p_bye.cpu())
+        if model.config.predict_actions:
+            for slot in request_probs:
+                request_probs[slot].append(p_req[slot].cpu())
+            for domain in active_domain_probs:
+                active_domain_probs[domain].append(p_dom[domain].cpu())
+            general_act_probs.append(p_gen.cpu())
     
     input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None
     token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None
     attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None
+    is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None
 
     state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()}
-    if model.config.predict_intents:
+    if model.config.predict_actions:
         request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()}
-        domain_labels = {domain: torch.cat(l, 0) for domain, l in domain_labels.items()}
-        greeting_labels = torch.cat(greeting_labels, 0)
+        active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()}
+        general_act_labels = torch.cat(general_act_labels, 0)
     
     belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()}
-    if model.config.predict_intents:
-        request_belief = {slot: torch.cat(p, 0) for slot, p in request_belief.items()}
-        domain_belief = {domain: torch.cat(p, 0) for domain, p in domain_belief.items()}
-        greeting_act_belief = torch.cat(greeting_act_belief, 0)
+    if model.config.predict_actions:
+        request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()}
+        active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()}
+        general_act_probs = torch.cat(general_act_probs, 0)
 
     data = {'input_ids': input_ids}
     if token_type_ids is not None:
         data['token_type_ids'] = token_type_ids
     if attention_mask is not None:
         data['attention_mask'] = attention_mask
+    if is_noisy is not None:
+        data['is_noisy'] = is_noisy
     data['state_labels'] = state_labels
     data['belief_state'] = belief_state
-    if model.config.predict_intents:
+    if model.config.predict_actions:
         data['request_labels'] = request_labels
-        data['domain_labels'] = domain_labels
-        data['greeting_labels'] = greeting_labels
-        data['request_belief'] = request_belief
-        data['domain_belief'] = domain_belief
-        data['greeting_act_belief'] = greeting_act_belief
+        data['active_domain_labels'] = active_domain_labels
+        data['general_act_labels'] = general_act_labels
+        data['request_probs'] = request_probs
+        data['active_domain_probs'] = active_domain_probs
+        data['general_act_probs'] = general_act_probs
 
     file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
     torch.save(data, file)
 
 
+def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str):
+    """
+    Convert ensemble predictions to predictions file format.
+
+    Args:
+        model_path: Path to ensemble location.
+        set_type: Evaluation dataset (train/dev/test).
+    """
+    data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data"))
+
+    # Get oracle labels
+    if 'request_probs' in data:
+        data_new = {'state_labels': data['state_labels'],
+                    'request_labels': data['request_labels'],
+                    'active_domain_labels': data['active_domain_labels'],
+                    'general_act_labels': data['general_act_labels']}
+    else:
+        data_new = {'state_labels': data['state_labels']}
+
+    # Marginalising across ensemble distributions
+    data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()}
+    if 'request_probs' in data:
+        data_new['request_probs'] = {slot: distribution.mean(-1)
+                                     for slot, distribution in data['request_probs'].items()}
+        data_new['active_domain_probs'] = {domain: distribution.mean(-1)
+                                           for domain, distribution in data['active_domain_probs'].items()}
+        data_new['general_act_probs'] = data['general_act_probs'].mean(-2)
+
+    # Save predictions file
+    predictions_dir = os.path.join(model_path, 'predictions')
+    if not os.path.exists(predictions_dir):
+        os.mkdir(predictions_dir)
+    torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions"))
+
+
 if __name__ == "__main__":
-    main()
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--model_path', type=str)
+    parser.add_argument('--set_type', type=str)
+    parser.add_argument('--batch_size', type=int, default=3)
+    parser.add_argument('--reduction', type=str, default='none')
+    parser.add_argument('--get_ensemble_distributions', action='store_true')
+    parser.add_argument('--convert_distributions_to_predictions', action='store_true')
+    parser.add_argument('--build_dataloaders', action='store_true')
+    args = parser.parse_args()
+
+    if args.get_ensemble_distributions:
+        get_ensemble_distributions(args)
+    if args.convert_distributions_to_predictions:
+        ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type)
+    if args.build_dataloaders:
+        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
+        data = torch.load(path)
+
+        reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
+        ontology = json.load(reader)
+        reader.close()
+
+        loader = get_loader(data, ontology, args.set_type, args.batch_size)
+
+        path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
+        torch.save(loader, path)
diff --git a/convlab/dst/setsumbt/do/calibration.py b/convlab/dst/setsumbt/do/calibration.py
deleted file mode 100644
index 27ee058eca882ce7e10937f9640d143b88e57f5e..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/do/calibration.py
+++ /dev/null
@@ -1,481 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Run SetSUMBT Calibration"""
-
-import logging
-import random
-import os
-from shutil import copy2 as copy
-
-import torch
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer,
-                          AdamW, get_linear_schedule_with_warmup)
-from tqdm import tqdm, trange
-from tensorboardX import SummaryWriter
-from torch.distributions import Categorical
-
-from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
-from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.multiwoz import multiwoz21
-from convlab.dst.setsumbt.multiwoz import ontology as embeddings
-from convlab.dst.setsumbt.utils import get_args, upload_local_directory_to_gcs, update_args
-from convlab.dst.setsumbt.modeling import calibration_utils
-from convlab.dst.setsumbt.modeling import ensemble_utils
-from convlab.dst.setsumbt.loss.ece import ece, jg_ece, l2_acc
-
-
-# Datasets
-DATASETS = {
-    'multiwoz21': multiwoz21
-}
-
-MODELS = {
-    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
-    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
-}
-
-
-def main(args=None, config=None):
-    # Get arguments
-    if args is None:
-        args, config = get_args(MODELS)
-
-    # Select Dataset object
-    if args.dataset in DATASETS:
-        Dataset = DATASETS[args.dataset]
-    else:
-        raise NameError('NotImplemented')
-
-    if args.model_type in MODELS:
-        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
-    else:
-        raise NameError('NotImplemented')
-
-    # Set up output directory
-    OUTPUT_DIR = args.output_dir
-    if not os.path.exists(OUTPUT_DIR):
-        os.mkdir(OUTPUT_DIR)
-    args.output_dir = OUTPUT_DIR
-    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
-        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
-
-    paths = os.listdir(args.output_dir) if os.path.exists(
-        args.output_dir) else []
-    if 'pytorch_model.bin' in paths and 'config.json' in paths:
-        args.model_name_or_path = args.output_dir
-        config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        paths = [os.path.join(args.output_dir, p)
-                 for p in paths if 'checkpoint-' in p]
-        if paths:
-            paths = paths[0]
-            args.model_name_or_path = paths
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-
-    if args.ensemble_size > 0:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        paths = [os.path.join(args.output_dir, p)
-                 for p in paths if 'ensemble_' in p]
-        if paths:
-            args.model_name_or_path = args.output_dir
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-
-    args = update_args(args, config)
-
-    # Set up data directory
-    DATA_DIR = args.data_dir
-    Dataset.set_datadir(DATA_DIR)
-    embeddings.set_datadir(DATA_DIR)
-
-    if args.shrink_active_domains and args.dataset == 'multiwoz21':
-        Dataset.set_active_domains(
-            ['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
-
-    # Download and preprocess
-    Dataset.create_examples(
-        args.max_turn_len, args.predict_intents, args.force_processing)
-
-    # Create logger
-    global logger
-    logger = logging.getLogger(__name__)
-    logger.setLevel(logging.INFO)
-
-    formatter = logging.Formatter(
-        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
-    if 'stream' not in args.logging_path:
-        fh = logging.FileHandler(args.logging_path)
-        fh.setLevel(logging.INFO)
-        fh.setFormatter(formatter)
-        logger.addHandler(fh)
-    else:
-        ch = logging.StreamHandler()
-        ch.setLevel(level=logging.INFO)
-        ch.setFormatter(formatter)
-        logger.addHandler(ch)
-
-    # Get device
-    if torch.cuda.is_available() and args.n_gpu > 0:
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-        args.n_gpu = 0
-
-    if args.n_gpu == 0:
-        args.fp16 = False
-
-    # Set up model training/evaluation
-    calibration.set_logger(logger, None)
-    calibration.set_seed(args)
-
-    if args.ensemble_size > 0:
-        ensemble.set_logger(logger, tb_writer)
-        ensemble_utils.set_seed(args)
-
-    # Perform tasks
-
-    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
-        pred = torch.load(os.path.join(
-            OUTPUT_DIR, 'predictions', 'test.predictions'))
-        labels = pred['labels']
-        belief_states = pred['belief_states']
-        if 'request_labels' in pred:
-            request_labels = pred['request_labels']
-            request_belief = pred['request_belief']
-            domain_labels = pred['domain_labels']
-            domain_belief = pred['domain_belief']
-            greeting_labels = pred['greeting_labels']
-            greeting_belief = pred['greeting_belief']
-        else:
-            request_belief = None
-        del pred
-    elif args.ensemble_size > 0:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            # Create Tokenizer and embedding model for Data Loaders and ontology
-            encoder = CandidateEncoderModel.from_pretrained(
-                config.candidate_embedding_model_name)
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            embeddings.get_slot_candidate_embeddings(
-                'test', args, tokenizer, encoder)
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size == args.test_batch_size:
-                exists = True
-        if not exists:
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                     config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        config, models = ensemble.get_models(
-            args.model_name_or_path, device, ConfigClass, SetSumbtModel)
-
-        belief_states, labels = ensemble_utils.get_predictions(
-            args, models, device, test_dataloader, test_slots)
-        torch.save({'belief_states': belief_states, 'labels': labels},
-                   os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
-    else:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-        else:
-            # Create Tokenizer and embedding model for Data Loaders and ontology
-            encoder = CandidateEncoderModel.from_pretrained(
-                config.candidate_embedding_model_name)
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            embeddings.get_slot_candidate_embeddings(
-                'test', args, tokenizer, encoder)
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
-
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size == args.test_batch_size:
-                exists = True
-        if not exists:
-            tokenizer = Tokenizer(config.candidate_embedding_model_name)
-            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                     config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
-        # Initialise Model
-        model = SetSumbtModel.from_pretrained(
-            args.model_name_or_path, config=config)
-        model = model.to(device)
-
-        # Get slot and value embeddings
-        slots = {slot: test_slots[slot] for slot in test_slots}
-        values = {slot: test_slots[slot][1] for slot in test_slots}
-
-        # Load model ontology
-        model.add_slot_candidates(slots)
-        for slot in model.informable_slot_ids:
-            model.add_value_candidates(slot, values[slot], replace=True)
-
-        belief_states = calibration.get_predictions(
-            args, model, device, test_dataloader)
-        belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states
-        out = {'belief_states': belief_states, 'labels': labels,
-               'request_belief': request_belief, 'request_labels': request_labels,
-               'domain_belief': domain_belief, 'domain_labels': domain_labels,
-               'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels}
-        torch.save(out, os.path.join(
-            OUTPUT_DIR, 'predictions', 'test.predictions'))
-
-    # err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10)
-    #         for slot in belief_states]
-    # err = max(err)
-    # logger.info('ECE: %f' % err)
-
-    # Calculate calibration metrics
-
-    jg = jg_ece(belief_states, labels, 10)
-    logger.info('Joint Goal ECE: %f' % jg)
-
-    binary_states = {}
-    for slot, p in belief_states.items():
-        shp = p.shape
-        p = p.reshape(-1, p.size(-1))
-        p_ = torch.ones(p.shape).to(p.device) * 1e-8
-        p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-        binary_states[slot] = p_.reshape(shp)
-    jg = jg_ece(binary_states, labels, 10)
-    logger.info('Joint Goal Binary ECE: %f' % jg)
-
-    bs = {slot: torch.cat((p[:, :, 0].unsqueeze(-1), p[:, :, 1:].max(-1)
-                          [0].unsqueeze(-1)), -1) for slot, p in belief_states.items()}
-    ls = {}
-    for slot, l in labels.items():
-        y = torch.zeros((l.size(0), l.size(1))).to(l.device)
-        dials, turns = torch.where(l > 0)
-        y[dials, turns] = 1.0
-        dials, turns = torch.where(l < 0)
-        y[dials, turns] = -1.0
-        ls[slot] = y
-
-    jg = jg_ece(bs, ls, 10)
-    logger.info('Slot presence ECE: %f' % jg)
-
-    binary_states = {}
-    for slot, p in bs.items():
-        shp = p.shape
-        p = p.reshape(-1, p.size(-1))
-        p_ = torch.ones(p.shape).to(p.device) * 1e-8
-        p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-        binary_states[slot] = p_.reshape(shp)
-    jg = jg_ece(binary_states, ls, 10)
-    logger.info('Slot presence Binary ECE: %f' % jg)
-
-    jg_acc = 0.0
-    padding = torch.cat([item.unsqueeze(-1)
-                        for _, item in labels.items()], -1).sum(-1) * -1.0
-    padding = (padding == len(labels))
-    padding = padding.reshape(-1)
-    for slot in belief_states:
-        topn = args.accuracy_topn
-        p_ = belief_states[slot]
-        gold = labels[slot]
-
-        if p_.size(-1) <= topn:
-            topn = p_.size(-1) - 1
-        if topn <= 0:
-            topn = 1
-
-        if topn > 1:
-            labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
-            labs = labs[:, :topn]
-        else:
-            labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
-        acc = [lab in s for lab, s, pad in zip(
-            gold.reshape(-1), labs, padding) if not pad]
-        acc = torch.tensor(acc).float()
-
-        jg_acc += acc
-
-    n_turns = jg_acc.size(0)
-    sl_acc = sum(jg_acc / len(belief_states)).float()
-    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
-
-    sl_acc /= n_turns
-    jg_acc /= n_turns
-
-    logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc))
-
-    l2 = l2_acc(belief_states, labels, remove_belief=False)
-    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
-    l2 = l2_acc(belief_states, labels, remove_belief=True)
-    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
-
-    for slot in belief_states:
-        p = belief_states[slot]
-        p = p.reshape(-1, p.size(-1))
-        p = torch.cat(
-            (p[:, 0].unsqueeze(-1), p[:, 1:].max(-1)[0].unsqueeze(-1)), -1)
-        belief_states[slot] = p
-
-        l = labels[slot].reshape(-1)
-        l[l > 0] = 1
-        labels[slot] = l
-
-    f1 = 0.0
-    for slot in belief_states:
-        prd = belief_states[slot].argmax(-1)
-        tp = ((prd == 1) * (labels[slot] == 1)).sum()
-        fp = ((prd == 1) * (labels[slot] == 0)).sum()
-        fn = ((prd == 0) * (labels[slot] == 1)).sum()
-        if tp > 0:
-            f1 += tp / (tp + 0.5 * (fp + fn))
-    f1 /= len(belief_states)
-    logger.info(f'Trucated Goal F1 Score: {f1}')
-
-    l2 = l2_acc(belief_states, labels, remove_belief=False)
-    logger.info(f'Model L2 Norm Trucated Goal Accuracy: {l2}')
-    l2 = l2_acc(belief_states, labels, remove_belief=True)
-    logger.info(f'Binary Model L2 Norm Trucated Goal Accuracy: {l2}')
-
-    if request_belief is not None:
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for slot in request_belief:
-            p = request_belief[slot]
-            l = request_labels[slot]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(request_belief)
-        fp /= len(request_belief)
-        fn /= len(request_belief)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Request F1 Score: %f' % f1.item())
-
-        for slot in request_belief:
-            p = request_belief[slot]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            request_belief[slot] = p
-        jg = jg_ece(request_belief, request_labels, 10)
-        logger.info('Request Joint Goal ECE: %f' % jg)
-
-        binary_states = {}
-        for slot, p in request_belief.items():
-            shp = p.shape
-            p = p.reshape(-1, p.size(-1))
-            p_ = torch.ones(p.shape).to(p.device) * 1e-8
-            p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-            binary_states[slot] = p_.reshape(shp)
-        jg = jg_ece(binary_states, request_labels, 10)
-        logger.info('Request Joint Goal Binary ECE: %f' % jg)
-
-        tp, fp, fn = 0.0, 0.0, 0.0
-        for dom in domain_belief:
-            p = domain_belief[dom]
-            l = domain_labels[dom]
-
-            tp += (p.round().int() * (l == 1)).reshape(-1).float()
-            fp += (p.round().int() * (l == 0)).reshape(-1).float()
-            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
-        tp /= len(domain_belief)
-        fp /= len(domain_belief)
-        fn /= len(domain_belief)
-        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
-        logger.info('Domain F1 Score: %f' % f1.item())
-
-        for dom in domain_belief:
-            p = domain_belief[dom]
-            p = p.unsqueeze(-1)
-            p = torch.cat((1 - p, p), -1)
-            domain_belief[dom] = p
-        jg = jg_ece(domain_belief, domain_labels, 10)
-        logger.info('Domain Joint Goal ECE: %f' % jg)
-
-        binary_states = {}
-        for slot, p in domain_belief.items():
-            shp = p.shape
-            p = p.reshape(-1, p.size(-1))
-            p_ = torch.ones(p.shape).to(p.device) * 1e-8
-            p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8
-            binary_states[slot] = p_.reshape(shp)
-        jg = jg_ece(binary_states, domain_labels, 10)
-        logger.info('Domain Joint Goal Binary ECE: %f' % jg)
-
-        tp = ((greeting_belief.argmax(-1) > 0) *
-              (greeting_labels > 0)).reshape(-1).float().sum()
-        fp = ((greeting_belief.argmax(-1) > 0) *
-              (greeting_labels == 0)).reshape(-1).float().sum()
-        fn = ((greeting_belief.argmax(-1) == 0) *
-              (greeting_labels > 0)).reshape(-1).float().sum()
-        f1 = tp / (tp + 0.5 * (fp + fn))
-        logger.info('Greeting F1 Score: %f' % f1.item())
-
-        err = ece(greeting_belief.reshape(-1, greeting_belief.size(-1)),
-                  greeting_labels.reshape(-1), 10)
-        logger.info('Greetings ECE: %f' % err)
-
-        greeting_belief = greeting_belief.reshape(-1, greeting_belief.size(-1))
-        binary_states = torch.ones(greeting_belief.shape).to(
-            greeting_belief.device) * 1e-8
-        binary_states[range(greeting_belief.size(0)),
-                      greeting_belief.argmax(-1)] = 1.0 - 1e-8
-        err = ece(binary_states, greeting_labels.reshape(-1), 10)
-        logger.info('Greetings Binary ECE: %f' % err)
-
-        for slot in request_belief:
-            p = request_belief[slot].unsqueeze(-1)
-            request_belief[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(request_belief, request_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
-        l2 = l2_acc(request_belief, request_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
-
-        for slot in domain_belief:
-            p = domain_belief[slot].unsqueeze(-1)
-            domain_belief[slot] = torch.cat((1 - p, p), -1)
-
-        l2 = l2_acc(domain_belief, domain_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
-        l2 = l2_acc(domain_belief, domain_labels, remove_belief=True)
-        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
-
-        greeting_labels = {'bye': greeting_labels}
-        greeting_belief = {'bye': greeting_belief}
-
-        l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False)
-        logger.info(f'Model L2 Norm Greeting Accuracy: {l2}')
-        l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False)
-        logger.info(f'Binary Model L2 Norm Greeting Accuracy: {l2}')
-
-
-if __name__ == "__main__":
-    main()
diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fe351b3d5c2af187da58ffcc46e8184013bbcdb
--- /dev/null
+++ b/convlab/dst/setsumbt/do/evaluate.py
@@ -0,0 +1,296 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Run SetSUMBT Calibration"""
+
+import logging
+import os
+
+import torch
+from transformers import (BertModel, BertConfig, BertTokenizer,
+                          RobertaModel, RobertaConfig, RobertaTokenizer)
+
+from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
+from convlab.dst.setsumbt.dataset import unified_format
+from convlab.dst.setsumbt.dataset import ontology as embeddings
+from convlab.dst.setsumbt.utils import get_args, update_args
+from convlab.dst.setsumbt.modeling import evaluation_utils
+from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc
+from convlab.dst.setsumbt.modeling import training
+
+
+# Available model
+MODELS = {
+    'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
+    'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
+}
+
+
+def main(args=None, config=None):
+    # Get arguments
+    if args is None:
+        args, config = get_args(MODELS)
+
+    if args.model_type in MODELS:
+        SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
+    else:
+        raise NameError('NotImplemented')
+
+    # Set up output directory
+    OUTPUT_DIR = args.output_dir
+    args.output_dir = OUTPUT_DIR
+    if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+        os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+
+    # Set pretrained model path to the trained checkpoint
+    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
+    if 'pytorch_model.bin' in paths and 'config.json' in paths:
+        args.model_name_or_path = args.output_dir
+        config = ConfigClass.from_pretrained(args.model_name_or_path)
+    else:
+        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
+        if paths:
+            paths = paths[0]
+            args.model_name_or_path = paths
+            config = ConfigClass.from_pretrained(args.model_name_or_path)
+
+    args = update_args(args, config)
+
+    # Create logger
+    global logger
+    logger = logging.getLogger(__name__)
+    logger.setLevel(logging.INFO)
+
+    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
+
+    fh = logging.FileHandler(args.logging_path)
+    fh.setLevel(logging.INFO)
+    fh.setFormatter(formatter)
+    logger.addHandler(fh)
+
+    # Get device
+    if torch.cuda.is_available() and args.n_gpu > 0:
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+        args.n_gpu = 0
+
+    if args.n_gpu == 0:
+        args.fp16 = False
+
+    # Set up model training/evaluation
+    evaluation_utils.set_seed(args)
+
+    # Perform tasks
+    if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
+        pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
+        state_labels = pred['state_labels']
+        belief_states = pred['belief_states']
+        if 'request_labels' in pred:
+            request_labels = pred['request_labels']
+            request_probs = pred['request_probs']
+            active_domain_labels = pred['active_domain_labels']
+            active_domain_probs = pred['active_domain_probs']
+            general_act_labels = pred['general_act_labels']
+            general_act_probs = pred['general_act_probs']
+        else:
+            request_probs = None
+        del pred
+    else:
+        # Get training batch loaders and ontology embeddings
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size != args.test_batch_size:
+                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
+        else:
+            tokenizer = Tokenizer(config.candidate_embedding_model_name)
+            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
+                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                            config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
+            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
+        else:
+            encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
+            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
+                                                                  'test', args, tokenizer, encoder)
+
+        # Initialise Model
+        model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
+        model = model.to(device)
+
+        training.set_ontology_embeddings(model, test_slots)
+
+        belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader)
+        state_labels = belief_states[1]
+        request_probs = belief_states[2]
+        request_labels = belief_states[3]
+        active_domain_probs = belief_states[4]
+        active_domain_labels = belief_states[5]
+        general_act_probs = belief_states[6]
+        general_act_labels = belief_states[7]
+        belief_states = belief_states[0]
+        out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs,
+               'request_labels': request_labels, 'active_domain_probs': active_domain_probs,
+               'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs,
+               'general_act_labels': general_act_labels}
+        torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
+
+    # Calculate calibration metrics
+    jg = jg_ece(belief_states, state_labels, 10)
+    logger.info('Joint Goal ECE: %f' % jg)
+
+    jg_acc = 0.0
+    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
+    padding = (padding == len(state_labels))
+    padding = padding.reshape(-1)
+    for slot in belief_states:
+        p_ = belief_states[slot]
+        gold = state_labels[slot]
+
+        pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
+        acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad]
+        acc = torch.tensor(acc).float()
+
+        jg_acc += acc
+
+    n_turns = jg_acc.size(0)
+    jg_acc = sum((jg_acc / len(belief_states)).int()).float()
+
+    jg_acc /= n_turns
+
+    logger.info(f'Joint Goal Accuracy: {jg_acc}')
+
+    l2 = l2_acc(belief_states, state_labels, remove_belief=False)
+    logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
+    l2 = l2_acc(belief_states, state_labels, remove_belief=True)
+    logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
+
+    padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
+    padding = (padding == len(state_labels))
+    padding = padding.reshape(-1)
+
+    tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0
+    for slot in belief_states:
+        p_ = belief_states[slot]
+        gold = state_labels[slot].reshape(-1)
+        p_ = p_.reshape(-1, p_.size(-1))
+
+        p_ = p_[~padding].argmax(-1)
+        gold = gold[~padding]
+
+        tp += (p_ == gold)[gold != 0].int().sum().item()
+        fp += (p_ != 0)[gold == 0].int().sum().item()
+        fp += (p_ != gold)[gold != 0].int().sum().item()
+        fp -= (p_ == 0)[gold != 0].int().sum().item()
+        fn += (p_ == 0)[gold != 0].int().sum().item()
+        tn += (p_ == 0)[gold == 0].int().sum().item()
+        n += p_.size(0)
+
+    acc = (tp + tn) / n
+    prec = tp / (tp + fp)
+    rec = tp / (tp + fn)
+    f1 = 2 * (prec * rec) / (prec + rec)
+
+    logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}")
+
+    if request_probs is not None:
+        tp, fp, fn = 0.0, 0.0, 0.0
+        for slot in request_probs:
+            p = request_probs[slot]
+            l = request_labels[slot]
+
+            tp += (p.round().int() * (l == 1)).reshape(-1).float()
+            fp += (p.round().int() * (l == 0)).reshape(-1).float()
+            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
+        tp /= len(request_probs)
+        fp /= len(request_probs)
+        fn /= len(request_probs)
+        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
+        logger.info('Request F1 Score: %f' % f1.item())
+
+        for slot in request_probs:
+            p = request_probs[slot]
+            p = p.unsqueeze(-1)
+            p = torch.cat((1 - p, p), -1)
+            request_probs[slot] = p
+        jg = jg_ece(request_probs, request_labels, 10)
+        logger.info('Request Joint Goal ECE: %f' % jg)
+
+        tp, fp, fn = 0.0, 0.0, 0.0
+        for dom in active_domain_probs:
+            p = active_domain_probs[dom]
+            l = active_domain_labels[dom]
+
+            tp += (p.round().int() * (l == 1)).reshape(-1).float()
+            fp += (p.round().int() * (l == 0)).reshape(-1).float()
+            fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
+        tp /= len(active_domain_probs)
+        fp /= len(active_domain_probs)
+        fn /= len(active_domain_probs)
+        f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
+        logger.info('Domain F1 Score: %f' % f1.item())
+
+        for dom in active_domain_probs:
+            p = active_domain_probs[dom]
+            p = p.unsqueeze(-1)
+            p = torch.cat((1 - p, p), -1)
+            active_domain_probs[dom] = p
+        jg = jg_ece(active_domain_probs, active_domain_labels, 10)
+        logger.info('Domain Joint Goal ECE: %f' % jg)
+
+        tp = ((general_act_probs.argmax(-1) > 0) *
+              (general_act_labels > 0)).reshape(-1).float().sum()
+        fp = ((general_act_probs.argmax(-1) > 0) *
+              (general_act_labels == 0)).reshape(-1).float().sum()
+        fn = ((general_act_probs.argmax(-1) == 0) *
+              (general_act_labels > 0)).reshape(-1).float().sum()
+        f1 = tp / (tp + 0.5 * (fp + fn))
+        logger.info('General Act F1 Score: %f' % f1.item())
+
+        err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)),
+                  general_act_labels.reshape(-1), 10)
+        logger.info('General Act ECE: %f' % err)
+
+        for slot in request_probs:
+            p = request_probs[slot].unsqueeze(-1)
+            request_probs[slot] = torch.cat((1 - p, p), -1)
+
+        l2 = l2_acc(request_probs, request_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Request Accuracy: {l2}')
+        l2 = l2_acc(request_probs, request_labels, remove_belief=True)
+        logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
+
+        for slot in active_domain_probs:
+            p = active_domain_probs[slot].unsqueeze(-1)
+            active_domain_probs[slot] = torch.cat((1 - p, p), -1)
+
+        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
+        l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True)
+        logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
+
+        general_act_labels = {'general': general_act_labels}
+        general_act_probs = {'general': general_act_probs}
+
+        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
+        logger.info(f'Model L2 Norm General Act Accuracy: {l2}')
+        l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
+        logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}')
+
+
+if __name__ == "__main__":
+    main()
diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py
index 821dca598c814240f39e359cedec7ef795a341b5..ea099442ddd18d0cd36a79db13b1f47788eb4fd4 100644
--- a/convlab/dst/setsumbt/do/nbt.py
+++ b/convlab/dst/setsumbt/do/nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,33 +16,27 @@
 """Run SetSUMBT training/eval"""
 
 import logging
-import random
 import os
 from shutil import copy2 as copy
+import json
+from copy import deepcopy
 
 import torch
-from torch.nn import DataParallel
+import transformers
 from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer,
-                          AdamW, get_linear_schedule_with_warmup)
-from tqdm import tqdm, trange
-import numpy as np
+                          RobertaModel, RobertaConfig, RobertaTokenizer)
 from tensorboardX import SummaryWriter
+from tqdm import tqdm
 
-from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
-from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.multiwoz import multiwoz21
+from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
+from convlab.dst.setsumbt.dataset import unified_format
 from convlab.dst.setsumbt.modeling import training
-from convlab.dst.setsumbt.multiwoz import ontology as embeddings
+from convlab.dst.setsumbt.dataset import ontology as embeddings
 from convlab.dst.setsumbt.utils import get_args, update_args
-from convlab.dst.setsumbt.modeling import ensemble_utils
+from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
 
 
-# Datasets
-DATASETS = {
-    'multiwoz21': multiwoz21
-}
-
+# Available model
 MODELS = {
     'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
     'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
@@ -54,12 +48,6 @@ def main(args=None, config=None):
     if args is None:
         args, config = get_args(MODELS)
 
-    # Select Dataset object
-    if args.dataset in DATASETS:
-        Dataset = DATASETS[args.dataset]
-    else:
-        raise NameError('NotImplemented')
-
     if args.model_type in MODELS:
         SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
     else:
@@ -74,53 +62,21 @@ def main(args=None, config=None):
     args.output_dir = OUTPUT_DIR
 
     # Set pretrained model path to the trained checkpoint
-    if args.do_train:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        paths = [os.path.join(args.output_dir, p)
-                 for p in paths if 'checkpoint-' in p]
+    paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
+    if 'pytorch_model.bin' in paths and 'config.json' in paths:
+        args.model_name_or_path = args.output_dir
+        config = ConfigClass.from_pretrained(args.model_name_or_path,
+                                             local_files_only=args.transformers_local_files_only)
+    else:
+        paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
         if paths:
             paths = paths[0]
             args.model_name_or_path = paths
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-        else:
-            paths = os.listdir(args.output_dir) if os.path.exists(
-                args.output_dir) else []
-            if 'pytorch_model.bin' in paths and 'config.json' in paths:
-                args.model_name_or_path = args.output_dir
-                config = ConfigClass.from_pretrained(args.model_name_or_path)
-    else:
-        paths = os.listdir(args.output_dir) if os.path.exists(
-            args.output_dir) else []
-        if 'pytorch_model.bin' in paths and 'config.json' in paths:
-            args.model_name_or_path = args.output_dir
-            config = ConfigClass.from_pretrained(args.model_name_or_path)
-        else:
-            paths = os.listdir(args.output_dir) if os.path.exists(
-                args.output_dir) else []
-            paths = [os.path.join(args.output_dir, p)
-                     for p in paths if 'checkpoint-' in p]
-            if paths:
-                paths = paths[0]
-                args.model_name_or_path = paths
-                config = ConfigClass.from_pretrained(args.model_name_or_path)
+            config = ConfigClass.from_pretrained(args.model_name_or_path,
+                                                 local_files_only=args.transformers_local_files_only)
 
     args = update_args(args, config)
 
-    # Set up data directory
-    DATA_DIR = args.data_dir
-    Dataset.set_datadir(DATA_DIR)
-    embeddings.set_datadir(DATA_DIR)
-
-    # If use shrinked domains, remove bus and hospital domains from the training data and model ontology
-    if args.shrink_active_domains and args.dataset == 'multiwoz21':
-        Dataset.set_active_domains(
-            ['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
-
-    # Download and preprocess
-    Dataset.create_examples(
-        args.max_turn_len, args.predict_actions, args.force_processing)
-
     # Create TensorboardX writer
     tb_writer = SummaryWriter(logdir=args.tensorboard_path)
 
@@ -129,19 +85,12 @@ def main(args=None, config=None):
     logger = logging.getLogger(__name__)
     logger.setLevel(logging.INFO)
 
-    formatter = logging.Formatter(
-        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
 
-    if 'stream' not in args.logging_path:
-        fh = logging.FileHandler(args.logging_path)
-        fh.setLevel(logging.INFO)
-        fh.setFormatter(formatter)
-        logger.addHandler(fh)
-    else:
-        ch = logging.StreamHandler()
-        ch.setLevel(level=logging.INFO)
-        ch.setFormatter(formatter)
-        logger.addHandler(ch)
+    fh = logging.FileHandler(args.logging_path)
+    fh.setLevel(logging.INFO)
+    fh.setFormatter(formatter)
+    logger.addHandler(fh)
 
     # Get device
     if torch.cuda.is_available() and args.n_gpu > 0:
@@ -154,103 +103,123 @@ def main(args=None, config=None):
         args.fp16 = False
 
     # Initialise Model
-    model = SetSumbtModel.from_pretrained(
-        args.model_name_or_path, config=config)
+    transformers.utils.logging.set_verbosity_info()
+    model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config,
+                                          local_files_only=args.transformers_local_files_only)
     model = model.to(device)
 
     # Create Tokenizer and embedding model for Data Loaders and ontology
-    encoder = model.roberta if args.model_type == 'roberta' else None
-    encoder = model.bert if args.model_type == 'bert' else encoder
-
-    tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config)
+    encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name,
+                                                    local_files_only=args.transformers_local_files_only)
+    tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config,
+                                          local_files_only=args.transformers_local_files_only)
 
     # Set up model training/evaluation
     training.set_logger(logger, tb_writer)
     training.set_seed(args)
     embeddings.set_seed(args)
 
+    transformers.utils.logging.set_verbosity_error()
     if args.ensemble_size > 1:
-        ensemble_utils.set_logger(logger, tb_writer)
-        ensemble.set_seed(args)
-        logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
-                                                                           args.data_sampling_size))
-        dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset)
+        # Build all dataloaders
+        train_dataloader = unified_format.get_dataloader(args.dataset,
+                                                         'train',
+                                                         args.train_batch_size,
+                                                         tokenizer,
+                                                         args.max_dialogue_len,
+                                                         args.max_turn_len,
+                                                         train_ratio=args.dataset_train_ratio,
+                                                         seed=args.seed)
+        torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+        dev_dataloader = unified_format.get_dataloader(args.dataset,
+                                                       'validation',
+                                                       args.dev_batch_size,
+                                                       tokenizer,
+                                                       args.max_dialogue_len,
+                                                       args.max_turn_len,
+                                                       train_ratio=args.dataset_train_ratio,
+                                                       seed=args.seed)
+        torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        test_dataloader = unified_format.get_dataloader(args.dataset,
+                                                        'test',
+                                                        args.test_batch_size,
+                                                        tokenizer,
+                                                        args.max_dialogue_len,
+                                                        args.max_turn_len,
+                                                        train_ratio=args.dataset_train_ratio,
+                                                        seed=args.seed)
+        torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+
+        embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder)
+        embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder)
+        embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder)
+
+        setup_ensemble(OUTPUT_DIR, args.ensemble_size)
+
+        logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
+        dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
+                       for _ in tqdm(range(args.ensemble_size))]
         logger.info('Dataloaders built.')
+
         for i, loader in enumerate(dataloaders):
-            path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i)
+            path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
             if not os.path.exists(path):
                 os.mkdir(path)
-            path = os.path.join(path, 'train.dataloader')
+            path = os.path.join(path, 'dataloaders', 'train.dataloader')
             torch.save(loader, path)
         logger.info('Dataloaders saved.')
 
-        train_slots = embeddings.get_slot_candidate_embeddings(
-            'train', args, tokenizer, encoder)
-        dev_slots = embeddings.get_slot_candidate_embeddings(
-            'dev', args, tokenizer, encoder)
-        test_slots = embeddings.get_slot_candidate_embeddings(
-            'test', args, tokenizer, encoder)
-
-        train_dataloader = Dataset.get_dataloader(
-            'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
-        torch.save(dev_dataloader, os.path.join(
-            OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-        dev_dataloader = Dataset.get_dataloader(
-            'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
-        torch.save(dev_dataloader, os.path.join(
-            OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-        test_dataloader = Dataset.get_dataloader(
-            'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
-        torch.save(test_dataloader, os.path.join(
-            OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-
         # Do not perform standard training after ensemble setup is created
         return 0
 
     # Perform tasks
     # TRAINING
     if args.do_train:
-        # Get training batch loaders and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
-            train_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'train.db'))
-        else:
-            train_slots = embeddings.get_slot_candidate_embeddings(
-                'train', args, tokenizer, encoder)
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-            dev_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'dev.db'))
-        else:
-            dev_slots = embeddings.get_slot_candidate_embeddings(
-                'dev', args, tokenizer, encoder)
-
-        exists = False
         if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
-            train_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
-            if train_dataloader.batch_size == args.train_batch_size:
-                exists = True
-        if not exists:
+            train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            if train_dataloader.batch_size != args.train_batch_size:
+                train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size)
+        else:
             if args.data_sampling_size <= 0:
                 args.data_sampling_size = None
-            train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
-                                                      config.max_turn_len, resampled_size=args.data_sampling_size)
-            torch.save(train_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+            train_dataloader = unified_format.get_dataloader(args.dataset,
+                                                             'train',
+                                                             args.train_batch_size,
+                                                             tokenizer,
+                                                             args.max_dialogue_len,
+                                                             config.max_turn_len,
+                                                             resampled_size=args.data_sampling_size,
+                                                             train_ratio=args.dataset_train_ratio,
+                                                             seed=args.seed)
+            torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
+
+        # Get training batch loaders and ontology embeddings
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
+            train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db'))
+        else:
+            train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology,
+                                                                   'train', args, tokenizer, encoder)
 
         # Get development set batch loaders= and ontology embeddings
         if args.do_eval:
-            exists = False
             if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-                dev_dataloader = torch.load(os.path.join(
-                    OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-                if dev_dataloader.batch_size == args.dev_batch_size:
-                    exists = True
-            if not exists:
-                dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
-                                                        config.max_turn_len)
-                torch.save(dev_dataloader, os.path.join(
-                    OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+                dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+                if dev_dataloader.batch_size != args.dev_batch_size:
+                    dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
+            else:
+                dev_dataloader = unified_format.get_dataloader(args.dataset,
+                                                               'validation',
+                                                               args.dev_batch_size,
+                                                               tokenizer,
+                                                               args.max_dialogue_len,
+                                                               config.max_turn_len)
+                torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+
+            if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
+                dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
+            else:
+                dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
+                                                                     'dev', args, tokenizer, encoder)
         else:
             dev_dataloader = None
             dev_slots = None
@@ -259,94 +228,80 @@ def main(args=None, config=None):
         training.set_ontology_embeddings(model, train_slots)
 
         # TRAINING !!!!!!!!!!!!!!!!!!
-        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots,
-                       embeddings=embeddings, tokenizer=tokenizer)
+        training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots)
 
         # Copy final best model to the output dir
         checkpoints = os.listdir(OUTPUT_DIR)
         checkpoints = [p for p in checkpoints if 'checkpoint' in p]
         checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
-        best_checkpoint = checkpoints[-1]
-        best_checkpoint = os.path.join(
-            OUTPUT_DIR, f'checkpoint-{best_checkpoint}')
-        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'),
-             os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
-        copy(os.path.join(best_checkpoint, 'config.json'),
-             os.path.join(OUTPUT_DIR, 'config.json'))
+        best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
+        copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
+        copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json'))
 
         # Load best model for evaluation
-        model = SumbtModel.from_pretrained(OUTPUT_DIR)
+        model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
         model = model.to(device)
 
     # Evaluation on the development set
     if args.do_eval:
-        # Get development set batch loaders= and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
-            dev_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'dev.db'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
+            dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+            if dev_dataloader.batch_size != args.dev_batch_size:
+                dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
         else:
-            dev_slots = embeddings.get_slot_candidate_embeddings(
-                'dev', args, tokenizer, encoder)
+            dev_dataloader = unified_format.get_dataloader(args.dataset,
+                                                           'validation',
+                                                           args.dev_batch_size,
+                                                           tokenizer,
+                                                           args.max_dialogue_len,
+                                                           config.max_turn_len)
+            torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
 
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
-            dev_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
-            if dev_dataloader.batch_size == args.dev_batch_size:
-                exists = True
-        if not exists:
-            dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
-                                                    config.max_turn_len)
-            torch.save(dev_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
+            dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
+        else:
+            dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
+                                                                 'dev', args, tokenizer, encoder)
 
         # Load model ontology
         training.set_ontology_embeddings(model, dev_slots)
 
         # EVALUATION
-        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
-            args, model, device, dev_dataloader)
-        if req_f1:
-            logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
-                        % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-        else:
-            logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f'
-                        % (loss, jg_acc, sl_acc))
+        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader)
+        training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
 
     # Evaluation on the test set
     if args.do_test:
-        # Get test set batch loaders= and ontology embeddings
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
-            test_slots = torch.load(os.path.join(
-                OUTPUT_DIR, 'database', 'test.db'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
+            test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+            if test_dataloader.batch_size != args.test_batch_size:
+                test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
         else:
-            test_slots = embeddings.get_slot_candidate_embeddings(
-                'test', args, tokenizer, encoder)
+            test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
+                                                            args.test_batch_size, tokenizer, args.max_dialogue_len,
+                                                            config.max_turn_len)
+            torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
 
-        exists = False
-        if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
-            test_dataloader = torch.load(os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
-            if test_dataloader.batch_size == args.test_batch_size:
-                exists = True
-        if not exists:
-            test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
-                                                     config.max_turn_len)
-            torch.save(test_dataloader, os.path.join(
-                OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
+        if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
+            test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
+        else:
+            test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
+                                                                  'test', args, tokenizer, encoder)
 
         # Load model ontology
         training.set_ontology_embeddings(model, test_slots)
 
         # TESTING
-        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
-            args, model, device, test_dataloader)
-        if req_f1:
-            logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
-                        % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-        else:
-            logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f'
-                        % (loss, jg_acc, sl_acc))
+        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader,
+                                                                                 return_eval_output=True)
+
+        if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
+            os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
+        writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w')
+        json.dump(output, writer)
+        writer.close()
+
+        training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
 
     tb_writer.close()
 
diff --git a/convlab/dst/setsumbt/get_golden_labels.py b/convlab/dst/setsumbt/get_golden_labels.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fb2841d0d503181119c791a7046fd7e0025d236
--- /dev/null
+++ b/convlab/dst/setsumbt/get_golden_labels.py
@@ -0,0 +1,138 @@
+import json
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+import os
+
+from tqdm import tqdm
+
+from convlab.util import load_dataset
+from convlab.util import load_dst_data
+from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
+
+
+def extract_data(dataset_names: str) -> list:
+    dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')]
+    data = []
+    for dataset_dict in dataset_dicts:
+        dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False)
+        for dial in dataset['test']:
+            data.append(dial)
+
+    return data
+
+def clean_state(state):
+    clean_state = dict()
+    for domain, subset in state.items():
+        clean_state[domain] = {}
+        for slot, value in subset.items():
+            # Remove pipe separated values
+            value = value.split('|')
+
+            # Map values using value_map
+            for old, new in VALUE_MAP.items():
+                value = [val.replace(old, new) for val in value]
+            value = '|'.join(value)
+
+            # Map dontcare to "do not care" and empty to 'none'
+            value = value.replace('dontcare', 'do not care')
+            value = value if value else 'none'
+
+            # Map quantity values to the integer quantity value
+            if 'people' in slot or 'duration' in slot or 'stay' in slot:
+                try:
+                    if value not in ['do not care', 'none']:
+                        value = int(value)
+                        value = str(value) if value < 10 else QUANTITIES[-1]
+                except:
+                    value = value
+            # Map time values to the most appropriate value in the standard time set
+            elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
+                try:
+                    if value not in ['do not care', 'none']:
+                        # Strip after/before from time value
+                        value = value.replace('after ', '').replace('before ', '')
+                        # Extract hours and minutes from different possible formats
+                        if ':' not in value and len(value) == 4:
+                            h, m = value[:2], value[2:]
+                        elif len(value) == 1:
+                            h = int(value)
+                            m = 0
+                        elif 'pm' in value:
+                            h = int(value.replace('pm', '')) + 12
+                            m = 0
+                        elif 'am' in value:
+                            h = int(value.replace('pm', ''))
+                            m = 0
+                        elif ':' in value:
+                            h, m = value.split(':')
+                        elif ';' in value:
+                            h, m = value.split(';')
+                        # Map to closest 5 minutes
+                        if int(m) % 5 != 0:
+                            m = round(int(m) / 5) * 5
+                            h = int(h)
+                            if m == 60:
+                                m = 0
+                                h += 1
+                            if h >= 24:
+                                h -= 24
+                        # Set in standard 24 hour format
+                        h, m = int(h), int(m)
+                        value = '%02i:%02i' % (h, m)
+                except:
+                    value = value
+            # Map boolean slots to yes/no value
+            elif 'parking' in slot or 'internet' in slot:
+                if value not in ['do not care', 'none']:
+                    if value == 'free':
+                        value = 'yes'
+                    elif True in [v in value.lower() for v in ['yes', 'no']]:
+                        value = [v for v in ['yes', 'no'] if v in value][0]
+
+            value = value if value != 'none' else ''
+
+            clean_state[domain][slot] = value
+
+    return clean_state
+
+def extract_states(data):
+    states_data = {}
+    for dial in data:
+        states = []
+        for turn in dial['turns']:
+            if 'state' in turn:
+                state = clean_state(turn['state'])
+                states.append(state)
+        states_data[dial['dialogue_id']] = states
+
+    return states_data
+
+
+def get_golden_state(prediction, data):
+    state = data[prediction['dial_idx']][prediction['utt_idx']]
+    pred = prediction['predictions']['state']
+    pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '')
+                     for slot in state[domain]} for domain in state}
+    prediction['state'] = state
+    prediction['predictions']['state'] = pred
+
+    return prediction
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
+    parser.add_argument('--model_path', type=str, help='Path to model dir')
+    args = parser.parse_args()
+
+    data = extract_data(args.dataset_name)
+    data = extract_states(data)
+
+    reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r')
+    predictions = json.load(reader)
+    reader.close()
+
+    predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)]
+
+    writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w')
+    json.dump(predictions, writer)
+    writer.close()
diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..475f7646126ea03b630efcbbc688f86c5a8ec16e
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/__init__.py
@@ -0,0 +1,4 @@
+from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss
+from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
+from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
+from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss
diff --git a/convlab/dst/setsumbt/loss/bayesian.py b/convlab/dst/setsumbt/loss/bayesian.py
deleted file mode 100644
index e52d8d07733383c7f95b6825b4ab5e5e1c7a0977..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/bayesian.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Bayesian Matching Activation and Loss Functions"""
-
-import torch
-from torch import digamma, lgamma
-from torch.nn import Module
-
-
-# Inverse Linear activation function
-def invlinear(x):
-    z = (1.0 / (1.0 - x)) * (x < 0)
-    z += (1.0 + x) * (x >= 0)
-    return z
-
-# Exponential activation function
-def exponential(x):
-    return torch.exp(x)
-
-
-# Dirichlet activation function for the model
-def dirichlet(a):
-    p = exponential(a)
-    repeat_dim = (1,)*(len(p.shape)-1) + (p.size(-1),)
-    p = p / p.sum(-1).unsqueeze(-1).repeat(repeat_dim)
-    return p
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BayesianMatchingLoss(Module):
-
-    def __init__(self, lamb=0.01, ignore_index=-1):
-        super(BayesianMatchingLoss, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, prior=None):
-        # Assert input sizes
-        assert alpha.dim() == 2                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.max() <= alpha.size(-1):
-            dimension = alpha.size(-1)
-        else:
-            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
-        
-        # Remove observations with no labels
-        if prior is not None:
-            prior = prior[labels != self.ignore_index]
-        alpha = exponential(alpha[labels != self.ignore_index])
-        labels = labels[labels != self.ignore_index]
-        
-        # Initialise and reshape prior parameters
-        if prior is None:
-            prior = torch.ones(dimension)
-        prior = prior.to(alpha.device)
-
-        # KL divergence term
-        lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1)
-        e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1)))
-        e = ((alpha - prior) * e).sum(-1)
-        kl = lb + e
-        kl *= self.lamb
-        del lb, e, prior
-
-        # Expected log likelihood
-        expected_likelihood = digamma(alpha[range(labels.size(0)), labels]) - digamma(alpha.sum(1))
-        del alpha, labels
-
-        # Apply ELBO loss and mean reduction
-        loss = (kl - expected_likelihood).mean()
-        del kl, expected_likelihood
-
-        return loss
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BinaryBayesianMatchingLoss(Module):
-
-    def __init__(self, lamb=0.01, ignore_index=-1):
-        super(BinaryBayesianMatchingLoss, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, prior=None):
-        # Assert input sizes
-        assert alpha.dim() == 1                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.max() <= 2:
-            dimension = 2
-        else:
-            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
-        
-        # Remove observations with no labels
-        if prior is not None:
-            prior = prior[labels != self.ignore_index]
-        alpha = alpha[labels != self.ignore_index]
-        alpha_sum = 1 + (1 / self.lamb)
-        alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1)
-        alpha = torch.cat((alpha_sum - alpha, alpha), 1)
-        labels = labels[labels != self.ignore_index]
-        
-        # Initialise and reshape prior parameters
-        if prior is None:
-            prior = torch.ones(dimension)
-        prior = prior.to(alpha.device)
-
-        # KL divergence term
-        lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1)
-        e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1)))
-        e = ((alpha - prior) * e).sum(-1)
-        kl = lb + e
-        kl *= self.lamb
-        del lb, e, prior
-
-        # Expected log likelihood
-        expected_likelihood = digamma(alpha[range(labels.size(0)), labels.long()]) - digamma(alpha.sum(1))
-        del alpha, labels
-
-        # Apply ELBO loss and mean reduction
-        loss = (kl - expected_likelihood).mean()
-        del kl, expected_likelihood
-
-        return loss
diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/loss/bayesian_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e91444d60afeeb6e2ca54192dd2283810fc5135
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/bayesian_matching.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Bayesian Matching Activation and Loss Functions (see https://arxiv.org/pdf/2002.07965.pdf for details)"""
+
+import torch
+from torch import digamma, lgamma
+from torch.nn import Module
+
+
+class BayesianMatchingLoss(Module):
+    """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
+
+    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Weighting factor for the KL Divergence component
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BayesianMatchingLoss, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+            prior (Tensor): Prior distribution over label classes
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        # Assert input sizes
+        assert inputs.dim() == 2                 # Observations, predictive distribution
+        assert labels.dim() == 1                # Label for each observation
+        assert labels.size(0) == inputs.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if labels.max() <= inputs.size(-1):
+            dimension = inputs.size(-1)
+        else:
+            raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.')
+        
+        # Remove observations to be ignored in loss calculation
+        if prior is not None:
+            prior = prior[labels != self.ignore_index]
+        inputs = torch.exp(inputs[labels != self.ignore_index])
+        labels = labels[labels != self.ignore_index]
+        
+        # Initialise and reshape prior parameters
+        if prior is None:
+            prior = torch.ones(dimension).to(inputs.device)
+        prior = prior.to(inputs.device)
+
+        # KL divergence term (divergence of predictive distribution from prior over label classes - regularisation term)
+        log_gamma_term = lgamma(inputs.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(inputs)).sum(-1)
+        div_term = digamma(inputs) - digamma(inputs.sum(-1)).unsqueeze(-1).repeat((1, inputs.size(-1)))
+        div_term = ((inputs - prior) * div_term).sum(-1)
+        kl_term = log_gamma_term + div_term
+        kl_term *= self.lamb
+        del log_gamma_term, div_term, prior
+
+        # Expected log likelihood
+        expected_likelihood = digamma(inputs[range(labels.size(0)), labels]) - digamma(inputs.sum(-1))
+        del inputs, labels
+
+        # Apply ELBO loss and mean reduction
+        loss = (kl_term - expected_likelihood).mean()
+        del kl_term, expected_likelihood
+
+        return loss
+
+
+class BinaryBayesianMatchingLoss(BayesianMatchingLoss):
+    """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
+
+    def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Weighting factor for the KL Divergence component
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index)
+
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+            prior (Tensor): Prior distribution over label classes
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        
+        # Create 2D input dirichlet distribution
+        input_sum = 1 + (1 / self.lamb)
+        inputs = (torch.sigmoid(inputs) * input_sum).reshape(-1, 1)
+        inputs = torch.cat((input_sum - inputs, inputs), 1)
+
+        return super().forward(inputs, labels, prior=prior)
diff --git a/convlab/dst/setsumbt/loss/distillation.py b/convlab/dst/setsumbt/loss/distillation.py
deleted file mode 100644
index 3cf13f10635376467f3b137adaa83367f26603ef..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/loss/distillation.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Bayesian Matching Activation and Loss Functions"""
-
-import torch
-from torch import lgamma, log
-from torch.nn import Module
-from torch.nn.functional import kl_div
-
-from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class DistillationKL(Module):
-
-    def __init__(self, lamb=1e-4, ignore_index=-1):
-        super(DistillationKL, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, temp=1.0):
-        # Assert input sizes
-        assert alpha.dim() == 2                 # Observations, predictive distribution
-        assert labels.dim() == 2                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        if labels.size(-1) == alpha.size(-1):
-            dimension = alpha.size(-1)
-        else:
-            raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-        alpha = torch.log(torch.softmax(alpha / temp, -1))
-        ids = torch.where(labels[:, 0] != self.ignore_index)[0]
-        alpha = alpha[ids]
-        labels = labels[ids]
-
-        labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))
-
-        kl = kl_div(alpha, labels, reduction='none').sum(-1).mean()
-        return kl    
-
-
-# Pytorch BayesianMatchingLoss nn.Module
-class BinaryDistillationKL(Module):
-
-    def __init__(self, lamb=1e-4, ignore_index=-1):
-        super(BinaryDistillationKL, self).__init__()
-
-        self.lamb = lamb
-        self.ignore_index = ignore_index
-    
-    def forward(self, alpha, labels, temp=0.0):
-        # Assert input sizes
-        assert alpha.dim() == 1                 # Observations, predictive distribution
-        assert labels.dim() == 1                # Label for each observation
-        assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-        # Confirm predictive distribution dimension
-        # if labels.size(-1) == alpha.size(-1):
-        #     dimension = alpha.size(-1)
-        # else:
-        #     raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-        alpha = torch.sigmoid(alpha / temp).unsqueeze(-1)
-        ids = torch.where(labels != self.ignore_index)[0]
-        alpha = alpha[ids]
-        labels = labels[ids]
-
-        alpha = torch.log(torch.cat((1 - alpha, alpha), 1))
-        
-        labels = labels.unsqueeze(-1)
-        labels = torch.cat((1 - labels, labels), -1)
-        labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))
-
-        kl = kl_div(alpha, labels, reduction='none').sum(-1).mean()
-        return kl  
-
-
-# def smart_sort(x, permutation):
-#     assert x.dim() == permutation.dim()
-#     if x.dim() == 3:
-#         d1, d2, d3 = x.size()
-#         ret = x[torch.arange(d1).unsqueeze(-1).unsqueeze(-1).repeat((1, d2, d3)).flatten(),
-#                 torch.arange(d2).unsqueeze(0).unsqueeze(-1).repeat((d1, 1, d3)).flatten(),
-#                 permutation.flatten()].view(d1, d2, d3)
-#         return ret
-#     elif x.dim() == 2:
-#         d1, d2 = x.size()
-#         ret = x[torch.arange(d1).unsqueeze(-1).repeat((1, d2)).flatten(),
-#                 permutation.flatten()].view(d1, d2)
-#         return ret
-
-
-# # Pytorch BayesianMatchingLoss nn.Module
-# class DistillationNLL(Module):
-
-#     def __init__(self, lamb=1e-4, ignore_index=-1):
-#         super(DistillationNLL, self).__init__()
-
-#         self.lamb = lamb
-#         self.ignore_index = ignore_index
-#         self.loss_add = BayesianMatchingLoss(lamb=0.001)
-    
-#     def forward(self, alpha, labels, temp=1.0):
-#         # Assert input sizes
-#         assert alpha.dim() == 2                 # Observations, predictive distribution
-#         assert labels.dim() == 3                # Label for each observation
-#         assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-#         # Confirm predictive distribution dimension
-#         if labels.size(-1) == alpha.size(-1):
-#             dimension = alpha.size(-1)
-#         else:
-#             raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-#         alpha = torch.exp(alpha / temp)
-#         ids = torch.where(labels[:, 0, 0] != self.ignore_index)[0]
-#         alpha = alpha[ids]
-#         labels = labels[ids]
-
-#         best_labels = labels.mean(-2).argmax(-1)
-#         loss2 = self.loss_add(alpha, best_labels)
-
-#         topn = labels.mean(-2).argsort(-1, descending=True)
-#         n = 10
-#         alpha = smart_sort(alpha, topn)[:, :n]
-#         labels = smart_sort(labels, topn.unsqueeze(-2).repeat((1, labels.size(-2), 1)))
-#         labels = labels[:, :, :n]
-#         labels = labels / labels.sum(-1).unsqueeze(-1).repeat((1, 1, labels.size(-1)))
-
-#         labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))))
-
-#         loss = (alpha - 1) * labels.mean(-2)
-#         # loss = (alpha - 1) * labels
-#         loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) 
-#         loss = -1.0 * loss.mean()
-#         # loss = -1.0 * loss.mean() / alpha.size(-1)
-
-#         return loss      
-
-
-# # Pytorch BayesianMatchingLoss nn.Module
-# class BinaryDistillationNLL(Module):
-
-#     def __init__(self, lamb=1e-4, ignore_index=-1):
-#         super(BinaryDistillationNLL, self).__init__()
-
-#         self.lamb = lamb
-#         self.ignore_index = ignore_index
-    
-#     def forward(self, alpha, labels, temp=0.0):
-#         # Assert input sizes
-#         assert alpha.dim() == 1                 # Observations, predictive distribution
-#         assert labels.dim() == 2                # Label for each observation
-#         assert labels.size(0) == alpha.size(0)  # Equal number of observation
-
-#         # Confirm predictive distribution dimension
-#         # if labels.size(-1) == alpha.size(-1):
-#         #     dimension = alpha.size(-1)
-#         # else:
-#         #     raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1)))
-        
-#         # Remove observations with no labels
-#         ids = torch.where(labels[:, 0] != self.ignore_index)[0]
-#         # alpha_sum = 1 + (1 / self.lamb)
-#         alpha_sum = 10.0
-#         alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1)
-#         alpha = alpha[ids]
-#         labels = labels[ids]
-
-#         if temp != 1.0:
-#             alpha = torch.log(alpha + 1e-4)
-#             alpha = torch.exp(alpha / temp)
-
-#         alpha = torch.cat((alpha_sum - alpha, alpha), 1)
-        
-#         labels = labels.unsqueeze(-1)
-#         labels = torch.cat((1 - labels, labels), -1)
-#         # labels[labels[:, 0, 0] == self.ignore_index] = 1
-#         labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))))
-
-#         loss = (alpha - 1) * labels.mean(-2)
-#         loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1)
-#         loss = -1.0 * loss.mean()
-
-#         return loss    
diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/loss/endd_loss.py
index d84c3f720d4970c520d97aa9e65a12647468d3ae..9bd794bf4569f54f5896e1e88ed1edeadc0fe1e2 100644
--- a/convlab/dst/setsumbt/loss/endd_loss.py
+++ b/convlab/dst/setsumbt/loss/endd_loss.py
@@ -1,30 +1,46 @@
 import torch
+from torch.nn import Module
+from torch.nn.functional import kl_div
 
 EPS = torch.finfo(torch.float32).eps
 
+
 @torch.no_grad()
-def compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs):
-    mkl = torch.nn.functional.kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_probs),
-                                     reduction='none').sum(-1).mean(1)
-    return mkl
+def compute_mkl(ensemble_mean_probs: torch.Tensor, ensemble_logprobs: torch.Tensor) -> torch.Tensor:
+    """
+    Computing MKL in ensemble.
+
+    Args:
+        ensemble_mean_probs (Tensor): Marginal predictive distribution of the ensemble
+        ensemble_logprobs (Tensor): Log predictive distributions of individual ensemble members
+
+    Returns:
+        mkl (Tensor): MKL
+    """
+    mkl = kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_logprobs),reduction='none')
+    return mkl.sum(-1).mean(1)
+
 
 @torch.no_grad()
-def compute_ensemble_stats(ensemble_logits):
-    # ensemble_probs = torch.softmax(ensemble_logits, dim=-1)
-    # ensemble_mean_probs = ensemble_probs.mean(dim=1)
-    # ensemble_logprobs = torch.log_softmax(ensemble_logits, dim=-1)
-    ensemble_probs = ensemble_logits
+def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict:
+    """
+    Compute a range of ensemble uncertainty measures
+
+    Args:
+        ensemble_probs (Tensor): Predictive distributions of the ensemble members
+
+    Returns:
+        stats (dict): Dictionary of ensemble uncertainty measures
+    """
     ensemble_mean_probs = ensemble_probs.mean(dim=1)
-    num_classes = ensemble_logits.size(-1)
-    ensemble_logprobs = torch.log(ensemble_logits + (1e-4 / num_classes))
+    num_classes = ensemble_probs.size(-1)
+    ensemble_logprobs = torch.log(ensemble_probs + (1e-4 / num_classes))
 
     entropy_of_expected = torch.distributions.Categorical(probs=ensemble_mean_probs).entropy()
     expected_entropy = torch.distributions.Categorical(probs=ensemble_probs).entropy().mean(dim=1)
     mutual_info = entropy_of_expected - expected_entropy
 
-    mkl = compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs)
-
-    # num_classes = ensemble_logits.size(-1)
+    mkl = compute_mkl(ensemble_mean_probs, ensemble_logprobs)
 
     ensemble_precision = (num_classes - 1) / (2 * mkl.unsqueeze(1) + EPS)
 
@@ -39,108 +55,226 @@ def compute_ensemble_stats(ensemble_logits):
     }
     return stats
 
-def entropy(probs, dim: int = -1):
+
+def entropy(probs: torch.Tensor, dim: int = -1) -> torch.Tensor:
+    """
+    Compute entropy in a predictive distribution
+
+    Args:
+        probs (Tensor): Predictive distributions
+        dim (int): Dimension representing the predictive probabilities for a single prediction
+
+    Returns:
+        entropy (Tensor): Entropy
+    """
     return -(probs * (probs + EPS).log()).sum(dim=dim)
 
 
-def compute_dirichlet_uncertainties(dirichlet_params, precisions, expected_dirichlet):
+def compute_dirichlet_uncertainties(dirichlet_params: torch.Tensor,
+                                    precisions: torch.Tensor,
+                                    expected_dirichlet: torch.Tensor) -> tuple:
     """
     Function which computes measures of uncertainty for Dirichlet model.
-    :param dirichlet_params:  Tensor of size [batch_size, n_classes] of Dirichlet concentration parameters.
-    :param precisions: Tensor of size [batch_size, 1] of Dirichlet Precisions
-    :param expected_dirichlet: Tensor of size [batch_size, n_classes] of probablities of expected categorical under Dirichlet.
-    :return: Tensors of token level uncertainties of size [batch_size]
+
+    Args:
+        dirichlet_params (Tensor): Dirichlet concentration parameters.
+        precisions (Tensor): Dirichlet Precisions
+        expected_dirichlet (Tensor): Probabities of expected categorical under Dirichlet.
+
+    Returns:
+        stats (tuple): Token level uncertainties
     """
     batch_size, n_classes = dirichlet_params.size()
 
     entropy_of_expected = entropy(expected_dirichlet)
 
-    expected_entropy = (
-            -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))).sum(dim=-1)
+    expected_entropy = -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))
+    expected_entropy = expected_entropy.sum(dim=-1)
 
-    mutual_information = -((expected_dirichlet + EPS) * (
-            torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) + torch.digamma(
-        precisions + 1 + EPS))).sum(dim=-1)
-    # assert torch.allclose(mutual_information, entropy_of_expected - expected_entropy, atol=1e-4, rtol=0)
+    mutual_information = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS)
+    mutual_information += torch.digamma(precisions + 1 + EPS)
+    mutual_information *= -(expected_dirichlet + EPS)
+    mutual_information = mutual_information.sum(dim=-1)
 
     epkl = (n_classes - 1) / precisions.squeeze(-1)
 
-    mkl = (expected_dirichlet * (
-            torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) + torch.digamma(
-        precisions + EPS))).sum(dim=-1)
+    mkl = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS)
+    mkl += torch.digamma(precisions + EPS)
+    mkl *= expected_dirichlet
+    mkl = mkl.sum(dim=-1)
+
+    stats = (entropy_of_expected.clamp(min=0), expected_entropy.clamp(min=0), mutual_information.clamp(min=0))
+    stats += (epkl.clamp(min=0), mkl.clamp(min=0))
+
+    return stats
+
+
+def get_dirichlet_parameters(logits: torch.Tensor,
+                             parametrization,
+                             add_to_alphas: float = 0,
+                             dtype=torch.double) -> tuple:
+    """
+    Get dirichlet parameters from model logits
 
-    return entropy_of_expected.clamp(min=0), \
-           expected_entropy.clamp(min=0), \
-           mutual_information.clamp(min=0), \
-           epkl.clamp(min=0), \
-           mkl.clamp(min=0)
+    Args:
+        logits (Tensor): Model logits
+        parametrization (function): Mapping from logits to concentration parameters
+        add_to_alphas (float): Addition constant for stability
+        dtype (data type): Data type of the parameters
 
-def get_dirichlet_parameters(logits, parametrization, add_to_alphas=0, dtype=torch.double):
+    Return:
+        params (tuple): Concentration and precision parameters of the model Dirichlet
+    """
     max_val = torch.finfo(dtype).max / logits.size(-1) - 1
     alphas = torch.clip(parametrization(logits.to(dtype=dtype)) + add_to_alphas, max=max_val)
     precision = torch.sum(alphas, dim=-1, dtype=dtype)
     return alphas, precision
 
 
-def logits_to_mutual_info(logits):
-    alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0)
+def logits_to_mutual_info(logits: torch.Tensor) -> torch.Tensor:
+    """
+    Map modfel logits to mutual information of model Dirichlet
 
-    unsqueezed_precision = precision.unsqueeze(1)
-    normalized_probs = alphas / unsqueezed_precision
+    Args:
+        logits (Tensor): Model logits
 
-    entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas,
-                                                                                                           unsqueezed_precision,
-                                                                                                           normalized_probs)
-    
-    # Max entropy is log(K) for K classes. Hence relative MI is calculated as MI/log(K)
-    # mutual_information /= torch.log(torch.tensor(logits.size(-1)))
-    
-    return mutual_information
+    Returns:
+        mutual_information (Tensor): Mutual information of the model Dirichlet
+    """
+    alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0)
 
+    normalized_probs = alphas / precision.unsqueeze(1)
 
-def rkl_dirichlet_mediator_loss(logits, ensemble_stats, model_offset, target_offset, parametrization=torch.exp):
-    turns = torch.where(ensemble_stats[:, 0, 0] != -1)[0]
-    logits = logits[turns]
-    ensemble_stats = ensemble_stats[turns]
+    _, _, mutual_information, _, _ = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs)
     
-    ensemble_stats = compute_ensemble_stats(ensemble_stats)
-
-    alphas, precision = get_dirichlet_parameters(logits, parametrization, model_offset)
-
-    unsqueezed_precision = precision.unsqueeze(1)
-    normalized_probs = alphas / unsqueezed_precision
-
-    entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas,
-                                                                                                           unsqueezed_precision,
-                                                                                                           normalized_probs)
-
-    stats = {
-        'alpha_min': alphas.min(),
-        'alpha_mean': alphas.mean(),
-        'precision': precision,
-        'entropy_of_expected': entropy_of_expected,
-        'mutual_info': mutual_information,
-        'mkl': mkl,
-    }
-
-    num_classes = alphas.size(-1)
-
-    ensemble_precision = ensemble_stats['precision']
-
-    ensemble_precision += target_offset * num_classes
-    ensemble_probs = ensemble_stats['mean_probs']
-
-    expected_KL_term = -1.0 * torch.sum(ensemble_probs * (torch.digamma(alphas + EPS)
-                                                          - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1)
-    assert torch.isfinite(expected_KL_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype)
-
-    differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) \
-                                   - torch.sum(
-        (alphas - 1) * (torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1)
-    assert torch.isfinite(differential_negentropy_term).all()
-
-    cost = expected_KL_term - differential_negentropy_term / ensemble_precision.squeeze(-1)
+    return mutual_information
 
-    assert torch.isfinite(cost).all()
-    return torch.mean(cost), stats, ensemble_stats
 
+class RKLDirichletMediatorLoss(Module):
+    """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)"""
+
+    def __init__(self,
+                 model_offset: float = 1.0,
+                 target_offset: float = 1,
+                 ignore_index: int = -1,
+                 parameterization=torch.exp):
+        """
+        Args:
+            model_offset (float): Offset of model Dirichlet for stability
+            target_offset (float): Offset of target Dirichlet for stability
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+            parameterization (function): Mapping from logits to concentration parameters
+        """
+        super(RKLDirichletMediatorLoss, self).__init__()
+
+        self.model_offset = model_offset
+        self.target_offset = target_offset
+        self.ignore_index = ignore_index
+        self.parameterization = parameterization
+
+    def logits_to_mutual_info(self, logits: torch.Tensor) -> torch.Tensor:
+        """
+        Map modfel logits to mutual information of model Dirichlet
+
+        Args:
+            logits (Tensor): Model logits
+
+        Returns:
+            mutual_information (Tensor): Mutual information of the model Dirichlet
+        """
+        return logits_to_mutual_info(logits)
+
+    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            logits (Tensor): Model logits
+            targets (Tensor): Ensemble predictive distributions
+
+        Returns:
+            loss (Tensor): RKL dirichlet mediator loss value
+        """
+
+        # Remove padding
+        turns = torch.where(targets[:, 0, 0] != self.ignore_index)[0]
+        logits = logits[turns]
+        targets = targets[turns]
+
+        ensemble_stats = compute_ensemble_stats(targets)
+
+        alphas, precision = get_dirichlet_parameters(logits, self.parameterization, self.model_offset)
+
+        normalized_probs = alphas / precision.unsqueeze(1)
+
+        stats = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs)
+        entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = stats
+
+        stats = {
+            'alpha_min': alphas.min(),
+            'alpha_mean': alphas.mean(),
+            'precision': precision,
+            'entropy_of_expected': entropy_of_expected,
+            'mutual_info': mutual_information,
+            'mkl': mkl,
+        }
+
+        num_classes = alphas.size(-1)
+
+        ensemble_precision = ensemble_stats['precision']
+
+        ensemble_precision += self.target_offset * num_classes
+        ensemble_probs = ensemble_stats['mean_probs']
+
+        expected_kl_term = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)
+        expected_kl_term = -1.0 * torch.sum(ensemble_probs * expected_kl_term, dim=-1)
+        assert torch.isfinite(expected_kl_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype)
+
+        differential_negentropy_term_ = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)
+        differential_negentropy_term_ *= alphas - 1.0
+        differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS)
+        differential_negentropy_term -= torch.sum(differential_negentropy_term_, dim=-1)
+        assert torch.isfinite(differential_negentropy_term).all()
+
+        loss = expected_kl_term - differential_negentropy_term / ensemble_precision.squeeze(-1)
+        assert torch.isfinite(loss).all()
+
+        return torch.mean(loss), stats, ensemble_stats
+
+
+class BinaryRKLDirichletMediatorLoss(RKLDirichletMediatorLoss):
+    """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)"""
+
+    def __init__(self,
+                 model_offset: float = 1.0,
+                 target_offset: float = 1,
+                 ignore_index: int = -1,
+                 parameterization=torch.exp):
+        """
+        Args:
+            model_offset (float): Offset of model Dirichlet for stability
+            target_offset (float): Offset of target Dirichlet for stability
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+            parameterization (function): Mapping from logits to concentration parameters
+        """
+        super(BinaryRKLDirichletMediatorLoss, self).__init__(model_offset, target_offset,
+                                                             ignore_index, parameterization)
+
+    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            logits (Tensor): Model logits
+            targets (Tensor): Ensemble predictive distributions
+
+        Returns:
+            loss (Tensor): RKL dirichlet mediator loss value
+        """
+        # Convert single target probability p to distribution [1-p, p]
+        targets = targets.reshape(-1, targets.size(-1), 1)
+        targets = torch.cat([1 - targets, targets], -1)
+        targets[targets[:, 0, 1] == self.ignore_index] = self.ignore_index
+
+        # Convert input logits into predictive distribution [1-z, z]
+        logits = torch.sigmoid(logits).unsqueeze(1)
+        logits = torch.cat((1 - logits, logits), 1)
+        logits = -1.0 * torch.log((1 / (logits + 1e-8)) - 1)  # Inverse sigmoid
+
+        return super().forward(logits, targets)
diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/loss/kl_distillation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aee234ab68054f2b4a83d6feb5e453384d89e94
--- /dev/null
+++ b/convlab/dst/setsumbt/loss/kl_distillation.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""KL Divergence Ensemble Distillation loss"""
+
+import torch
+from torch.nn import Module
+from torch.nn.functional import kl_div
+
+
+class KLDistillationLoss(Module):
+    """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
+
+    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Target smoothing parameter
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(KLDistillationLoss, self).__init__()
+
+        self.lamb = lamb
+        self.ignore_index = ignore_index
+    
+    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            targets (Tensor): Target distribution (ensemble marginal)
+            temp (float): Temperature scaling coefficient for predictive distribution
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        # Assert input sizes
+        assert inputs.dim() == 2                  # Observations, predictive distribution
+        assert targets.dim() == 2                # Label for each observation
+        assert targets.size(0) == inputs.size(0)  # Equal number of observation
+
+        # Confirm predictive distribution dimension
+        if targets.size(-1) != inputs.size(-1):
+            name_error = f'Target dimension {targets.size(-1)} is not the same as the prediction dimension '
+            name_error += f'{inputs.size(-1)}.'
+            raise NameError(name_error)
+
+        # Remove observations to be ignored in loss calculation
+        inputs = torch.log(torch.softmax(inputs / temp, -1))
+        ids = torch.where(targets[:, 0] != self.ignore_index)[0]
+        inputs = inputs[ids]
+        targets = targets[ids]
+
+        # Target smoothing
+        targets = ((1 - self.lamb) * targets) + (self.lamb / targets.size(-1))
+
+        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
+
+
+# Pytorch BayesianMatchingLoss nn.Module
+class BinaryKLDistillationLoss(KLDistillationLoss):
+    """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
+
+    def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            lamb (float): Target smoothing parameter
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index)
+
+    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
+        """
+        Args:
+            inputs (Tensor): Predictive distribution
+            targets (Tensor): Target distribution (ensemble marginal)
+            temp (float): Temperature scaling coefficient for predictive distribution
+
+        Returns:
+            loss (Tensor): Loss value
+        """
+        # Assert input sizes
+        assert inputs.dim() == 1                 # Observations, predictive distribution
+        assert targets.dim() == 1                # Label for each observation
+        assert targets.size(0) == inputs.size(0)  # Equal number of observation
+        
+        # Convert input and target to 2D binary distribution for KL divergence computation
+        inputs = torch.sigmoid(inputs / temp).unsqueeze(-1)
+        inputs = torch.log(torch.cat((1 - inputs, inputs), 1))
+
+        targets = targets.unsqueeze(-1)
+        targets = torch.cat((1 - targets, targets), -1)
+
+        return super().forward(input, targets, temp)
diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/loss/labelsmoothing.py
index 8fcc60afd50603cb5c2c84fd698fd11ce7fb7415..61d4b353303451eac7eb09592bdb2c5200328250 100644
--- a/convlab/dst/setsumbt/loss/labelsmoothing.py
+++ b/convlab/dst/setsumbt/loss/labelsmoothing.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inhibited Softmax Activation and Loss Functions"""
+"""Label smoothing loss function"""
 
 
 import torch
@@ -23,66 +23,97 @@ from torch.nn.functional import kl_div
 
 class LabelSmoothingLoss(Module):
     """
-    With label smoothing,
-    KL-divergence between q_{smoothed ground truth prob.}(w)
-    and p_{prob. computed by model}(w) is minimized.
+    Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w)
+    and p_{prob. computed by model}(w).
     """
-    def __init__(self, label_smoothing=0.05, ignore_index=-1):
+
+    def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            label_smoothing (float): Label smoothing constant
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
         super(LabelSmoothingLoss, self).__init__()
 
         assert 0.0 < label_smoothing <= 1.0
         self.ignore_index = ignore_index
         self.label_smoothing = float(label_smoothing)
 
-    def forward(self, logits, targets):
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
         """
-        output (FloatTensor): batch_size x n_classes
-        target (LongTensor): batch_size
+        Args:
+            input (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+
+        Returns:
+            loss (Tensor): Loss value
         """
-        assert logits.dim() == 2
-        assert targets.dim() == 1
-        assert self.label_smoothing <= ((logits.size(-1) - 1) / logits.size(-1))
+        # Assert input sizes
+        assert inputs.dim() == 2
+        assert labels.dim() == 1
+        assert self.label_smoothing <= ((inputs.size(-1) - 1) / inputs.size(-1))
 
-        logits = logits[targets != self.ignore_index]
-        targets = targets[targets != self.ignore_index]
+        # Confirm predictive distribution dimension
+        if labels.max() <= inputs.size(-1):
+            dimension = inputs.size(-1)
+        else:
+            raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.')
 
-        logits = torch.log(torch.softmax(logits, -1))
-        labels = torch.ones(logits.size()).float().to(logits.device)
-        labels *= self.label_smoothing / (logits.size(-1) - 1)
-        labels[range(labels.size(0)), targets] = 1.0 - self.label_smoothing
+        # Remove observations to be ignored in loss calculation
+        inputs = inputs[labels != self.ignore_index]
+        labels = labels[labels != self.ignore_index]
 
-        kl = kl_div(logits, labels, reduction='none').sum(-1).mean()
-        del logits, targets, labels
-        return kl
+        if labels.size(0) == 0.0:
+            return torch.zeros(1).float().to(labels.device).mean()
 
+        # Create target distribution
+        inputs = torch.log(torch.softmax(inputs, -1))
+        targets = torch.ones(inputs.size()).float().to(inputs.device)
+        targets *= self.label_smoothing / (dimension - 1)
+        targets[range(labels.size(0)), labels] = 1.0 - self.label_smoothing
 
-class BinaryLabelSmoothingLoss(Module):
+        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
+
+
+class BinaryLabelSmoothingLoss(LabelSmoothingLoss):
     """
-    With label smoothing,
-    KL-divergence between q_{smoothed ground truth prob.}(w)
-    and p_{prob. computed by model}(w) is minimized.
+    Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w)
+    and p_{prob. computed by model}(w).
     """
-    def __init__(self, label_smoothing=0.05):
-        super(BinaryLabelSmoothingLoss, self).__init__()
 
-        assert 0.0 < label_smoothing <= 1.0
-        self.label_smoothing = float(label_smoothing)
+    def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module:
+        """
+        Args:
+            label_smoothing (float): Label smoothing constant
+            ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
+        """
+        super(BinaryLabelSmoothingLoss, self).__init__(label_smoothing, ignore_index)
 
-    def forward(self, logits, targets):
+    def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
         """
-        output (FloatTensor): batch_size x n_classes
-        target (LongTensor): batch_size
+        Args:
+            input (Tensor): Predictive distribution
+            labels (Tensor): Label indices
+
+        Returns:
+            loss (Tensor): Loss value
         """
-        assert logits.dim() == 1
-        assert targets.dim() == 1
+        # Assert input sizes
+        assert inputs.dim() == 1
+        assert labels.dim() == 1
         assert self.label_smoothing <= 0.5
 
-        logits = torch.sigmoid(logits).reshape(-1, 1)
-        logits = torch.log(torch.cat((1 - logits, logits), 1))
-        labels = torch.ones(logits.size()).float().to(logits.device)
-        labels *= self.label_smoothing
-        labels[range(labels.size(0)), targets.long()] = 1.0 - self.label_smoothing
+        # Remove observations to be ignored in loss calculation
+        inputs = inputs[labels != self.ignore_index]
+        labels = labels[labels != self.ignore_index]
+
+        if labels.size(0) == 0.0:
+            return torch.zeros(1).float().to(labels.device).mean()
+
+        inputs = torch.sigmoid(inputs).reshape(-1, 1)
+        inputs = torch.log(torch.cat((1 - inputs, inputs), 1))
+        targets = torch.ones(inputs.size()).float().to(inputs.device)
+        targets *= self.label_smoothing
+        targets[range(labels.size(0)), labels.long()] = 1.0 - self.label_smoothing
 
-        kl = kl_div(logits, labels, reduction='none').sum(-1).mean()
-        del logits, targets
-        return kl
+        return kl_div(inputs, targets, reduction='none').sum(-1).mean()
diff --git a/convlab/dst/setsumbt/loss/ece.py b/convlab/dst/setsumbt/loss/uncertainty_measures.py
similarity index 50%
rename from convlab/dst/setsumbt/loss/ece.py
rename to convlab/dst/setsumbt/loss/uncertainty_measures.py
index 034b9aa0bf5882aea49b08a64d7f93164208b5d9..87c89dd31c724cc7d599230c6d4a15faee9b680e 100644
--- a/convlab/dst/setsumbt/loss/ece.py
+++ b/convlab/dst/setsumbt/loss/uncertainty_measures.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,14 +13,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Expected calibration error"""
+"""Uncertainty evaluation metrics for dialogue belief tracking"""
 
 import torch
 
 
-def fill_bins(n_bins, logits):
-    assert logits.dim() == 2
-    logits = logits.max(-1)[0]
+def fill_bins(n_bins: int, probs: torch.Tensor) -> list:
+    """
+    Function to split observations into bins based on predictive probabilities
+
+    Args:
+        n_bins (int): Number of bins
+        probs (Tensor): Predictive probabilities for the observations
+
+    Returns:
+        bins (list): List of observation ids for each bin
+    """
+    assert probs.dim() == 2
+    probs = probs.max(-1)[0]
 
     step = 1.0 / n_bins
     bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
@@ -28,29 +38,49 @@ def fill_bins(n_bins, logits):
     for b in range(n_bins):
         lower, upper = bin_ranges[b], bin_ranges[b + 1]
         if b == 0:
-            ids = torch.where((logits >= lower) * (logits <= upper))[0]
+            ids = torch.where((probs >= lower) * (probs <= upper))[0]
         else:
-            ids = torch.where((logits > lower) * (logits <= upper))[0]
+            ids = torch.where((probs > lower) * (probs <= upper))[0]
         bins.append(ids)
     return bins
 
 
-def bin_confidence(bins, logits):
-    logits = logits.max(-1)[0]
+def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor:
+    """
+    Compute the confidence score within each bin
+
+    Args:
+        bins (list): List of observation ids for each bin
+        probs (Tensor): Predictive probabilities for the observations
+
+    Returns:
+        scores (Tensor): Average confidence score within each bin
+    """
+    probs = probs.max(-1)[0]
 
     scores = []
     for b in bins:
         if b is not None:
-            l = logits[b]
-            scores.append(l.mean())
+            scores.append(probs[b].mean())
         else:
             scores.append(-1)
     scores = torch.tensor(scores)
     return scores
 
 
-def bin_accuracy(bins, logits, y_true):
-    y_pred = logits.argmax(-1)
+def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
+    """
+    Compute the accuracy score for observations in each bin
+
+    Args:
+        bins (list): List of observation ids for each bin
+        probs (Tensor): Predictive probabilities for the observations
+        y_true (Tensor): Labels for the observations
+
+    Returns:
+        acc (Tensor): Accuracies for the observations in each bin
+    """
+    y_pred = probs.argmax(-1)
 
     acc = []
     for b in bins:
@@ -68,13 +98,24 @@ def bin_accuracy(bins, logits, y_true):
     return acc
 
 
-def ece(logits, y_true, n_bins):
-    bins = fill_bins(n_bins, logits)
+def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float:
+    """
+    Expected calibration error calculation
 
-    scores = bin_confidence(bins, logits)
-    acc = bin_accuracy(bins, logits, y_true)
+    Args:
+        probs (Tensor): Predictive probabilities for the observations
+        y_true (Tensor): Labels for the observations
+        n_bins (int): Number of bins
 
-    n = logits.size(0)
+    Returns:
+        ece (float): Expected calibration error
+    """
+    bins = fill_bins(n_bins, probs)
+
+    scores = bin_confidence(bins, probs)
+    acc = bin_accuracy(bins, probs, y_true)
+
+    n = probs.size(0)
     bk = torch.tensor([b.size(0) for b in bins])
 
     ece = torch.abs(scores - acc) * bk / n
@@ -84,34 +125,30 @@ def ece(logits, y_true, n_bins):
     return ece
 
 
-def jg_ece(logits, y_true, n_bins):
-    y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
+def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
+    """
+        Joint goal expected calibration error calculation
+
+        Args:
+            belief_state (dict): Belief state probabilities for the dialogue turns
+            y_true (dict): Labels for the state in dialogue turns
+            n_bins (int): Number of bins
+
+        Returns:
+            ece (float): Joint goal expected calibration error
+        """
+    y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()}
     goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
     goal_acc = sum([goal_acc[slot] for slot in goal_acc])
     goal_acc = (goal_acc == len(y_true)).int()
 
-    scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
+    # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state
+    scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()]
     scores = torch.cat(scores, 0).min(0)[0]
 
-    step = 1.0 / n_bins
-    bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
-    bins = []
-    for b in range(n_bins):
-        lower, upper = bin_ranges[b], bin_ranges[b + 1]
-        if b == 0:
-            ids = torch.where((scores >= lower) * (scores <= upper))[0]
-        else:
-            ids = torch.where((scores > lower) * (scores <= upper))[0]
-        bins.append(ids)
+    bins = fill_bins(n_bins, scores.unsqueeze(-1))
 
-    conf = []
-    for b in bins:
-        if b is not None:
-            l = scores[b]
-            conf.append(l.mean())
-        else:
-            conf.append(-1)
-    conf = torch.tensor(conf)
+    conf = bin_confidence(bins, scores.unsqueeze(-1))
 
     slot = [s for s in y_true][0]
     acc = []
@@ -127,7 +164,7 @@ def jg_ece(logits, y_true, n_bins):
             acc.append(-1)
     acc = torch.tensor(acc)
 
-    n = logits[slot].reshape(-1, logits[slot].size(-1)).size(0)
+    n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0)
     bk = torch.tensor([b.size(0) for b in bins])
 
     ece = torch.abs(conf - acc) * bk / n
@@ -137,12 +174,22 @@ def jg_ece(logits, y_true, n_bins):
     return ece
 
 
-def l2_acc(belief_state, labels, remove_belief=False):
+def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float:
+    """
+    Compute L2 Error of belief state prediction
+
+    Args:
+        belief_state (dict): Belief state probabilities for the dialogue turns
+        labels (dict): Labels for the state in dialogue turns
+        remove_belief (bool): Convert belief state to dialogue state
+
+    Returns:
+        err (float): L2 Error of belief state prediction
+    """
     # Get ids used for removing padding turns.
     padding = labels[list(labels.keys())[0]].reshape(-1)
     padding = torch.where(padding != -1)[0]
 
-    # l2 = []
     state = []
     labs = []
     for slot, bs in belief_state.items():
@@ -163,13 +210,8 @@ def l2_acc(belief_state, labels, remove_belief=False):
         y = torch.zeros(bs.shape).cuda()
         y[range(y.size(0)), lab] = 1.0
 
-        # err = torch.sqrt(((y - bs) ** 2).sum(-1))
-        # l2.append(err.unsqueeze(-1))
-
         state.append(bs)
         labs.append(y)
-    
-    # err = torch.cat(l2, -1).max(-1)[0]
 
     # Concatenate all slots into a single belief state
     state = torch.cat(state, -1)
diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py
index 011a1a774e2d1a22e46e242d1812549895f2246b..59f1439948421ac365e4602b7800c94d3b8b32dd 100644
--- a/convlab/dst/setsumbt/modeling/__init__.py
+++ b/convlab/dst/setsumbt/modeling/__init__.py
@@ -1,3 +1,5 @@
 from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
 from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
-from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT
+from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT
+
+from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler
diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py
index 8b402b6be09684b27bb73acf17e578bc0e3b4bbd..6762fb3891b4720c3889d8c0809b8791f3bf7633 100644
--- a/convlab/dst/setsumbt/modeling/bert_nbt.py
+++ b/convlab/dst/setsumbt/modeling/bert_nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,11 +16,10 @@
 """BERT SetSUMBT"""
 
 import torch
-import transformers
 from torch.autograd import Variable
 from transformers import BertModel, BertPreTrainedModel
 
-from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward
+from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
 
 
 class BertSetSUMBT(BertPreTrainedModel):
@@ -35,59 +34,37 @@ class BertSetSUMBT(BertPreTrainedModel):
             for p in self.bert.parameters():
                 p.requires_grad = False
 
-        _initialise(self, config)
-
-    # Add new slot candidates to the model
-    def add_slot_candidates(self, slot_candidates):
-        """slot_candidates is a list of tuples for each slot.
-        - The tuples contains the slot embedding, informable value embeddings and a request indicator.
-        - If the informable value embeddings is None the slot is not informable
-        - If the request indicator is false the slot is not requestable"""
-        if self.slot_embeddings.size(0) != 0:
-            embeddings = self.slot_embeddings.detach()
-        else:
-            embeddings = torch.zeros(0)
-
-        for slot in slot_candidates:
-            if slot in self.slot_ids:
-                index = self.slot_ids[slot]
-                embeddings[index, :] = slot_candidates[slot][0]
-            else:
-                index = embeddings.size(0)
-                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
-                embeddings = torch.cat((embeddings, emb), 0)
-                self.slot_ids[slot] = index
-                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
-            # Add slot to relevant requestable and informable slot lists
-            if slot_candidates[slot][2]:
-                self.requestable_slot_ids[slot] = index
-            if slot_candidates[slot][1] is not None:
-                self.informable_slot_ids[slot] = index
-            
-            domain = slot.split('-', 1)[0]
-            if domain not in self.domain_ids:
-                self.domain_ids[domain] = []
-            self.domain_ids[domain].append(index)
-            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
-        
-        self.slot_embeddings = Variable(embeddings, requires_grad=False)
-
-
-    # Add new value candidates to the model
-    def add_value_candidates(self, slot, value_candidates, replace=False):
-        embeddings = getattr(self, slot + '_value_embeddings')
-
-        if embeddings.size(0) == 0 or replace:
-            embeddings = value_candidates
-        else:
-            embeddings = torch.cat((embeddings, value_candidates), 0)
-        
-        setattr(self, slot + '_value_embeddings', embeddings)
-
-    
-    def forward(self, input_ids, token_type_ids, attention_mask, hidden_state=None, inform_labels=None,
-                request_labels=None, domain_labels=None, goodbye_labels=None,
-                get_turn_pooled_representation=False, calculate_inform_mutual_info=False):
+        self.setsumbt = SetSUMBTHead(config)
+        self.add_slot_candidates = self.setsumbt.add_slot_candidates
+        self.add_value_candidates = self.setsumbt.add_value_candidates
+
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                get_turn_pooled_representation: bool = False,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
+            calculate_state_mutual_info: Return mutual information in the dialogue state
+
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
 
         # Encode Dialogues
         batch_size, dialogue_size, turn_size = input_ids.size()
@@ -103,9 +80,10 @@ class BertSetSUMBT(BertPreTrainedModel):
         turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
 
         if get_turn_pooled_representation:
-            return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
-                                dialogue_size, turn_size, hidden_state, inform_labels, request_labels, domain_labels,
-                                goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,)
-        return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size,
-                            turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
-                            calculate_inform_mutual_info)
+            return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
+                                 batch_size, dialogue_size, hidden_state, state_labels,
+                                 request_labels, active_domain_labels, general_act_labels,
+                                 calculate_state_mutual_info) + (bert_output.pooler_output,)
+        return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
+                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
+                             general_act_labels, calculate_state_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/calibration_utils.py b/convlab/dst/setsumbt/modeling/calibration_utils.py
deleted file mode 100644
index 8514ac8d259162c5bcc55607bb8356de6d4b47c7..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/calibration_utils.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Discriminative models calibration"""
-
-import random
-
-import torch
-import numpy as np
-from tqdm import tqdm
-
-
-# Load logger and tensorboard summary writer
-def set_logger(logger_, tb_writer_):
-    global logger, tb_writer
-    logger = logger_
-    tb_writer = tb_writer_
-
-
-# Set seeds
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-    logger.info('Seed set to %d.' % args.seed)
-
-
-def get_predictions(args, model, device, dataloader):
-    logger.info("  Num Batches = %d", len(dataloader))
-
-    model.eval()
-    if args.dropout_iterations > 1:
-        model.train()
-    
-    belief_states = {slot: [] for slot in model.informable_slot_ids}
-    request_belief = {slot: [] for slot in model.requestable_slot_ids}
-    domain_belief = {dom: [] for dom in model.domain_ids}
-    greeting_belief = []
-    labels = {slot: [] for slot in model.informable_slot_ids}
-    request_labels = {slot: [] for slot in model.requestable_slot_ids}
-    domain_labels = {dom: [] for dom in model.domain_ids}
-    greeting_labels = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration")
-    for step, batch in enumerate(epoch_iterator):
-        with torch.no_grad():    
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-            if args.dropout_iterations > 1:
-                p = {slot: [] for slot in model.informable_slot_ids}
-                for _ in range(args.dropout_iterations):
-                    p_, p_req_, p_dom_, p_bye_, _ = model(input_ids=input_ids,
-                                                        token_type_ids=token_type_ids,
-                                                        attention_mask=attention_mask)
-                    for slot in model.informable_slot_ids:
-                        p[slot].append(p_[slot].unsqueeze(0))
-                
-                mu = {slot: torch.cat(p[slot], 0).mean(0) for slot in model.informable_slot_ids}
-                sig = {slot: torch.cat(p[slot], 0).var(0) for slot in model.informable_slot_ids}
-                p = {slot: mu[slot] / torch.sqrt(1 + sig[slot]) for slot in model.informable_slot_ids}
-                p = {slot: normalise(p[slot]) for slot in model.informable_slot_ids}
-            else:
-                p, p_req, p_dom, p_bye, _ = model(input_ids=input_ids,
-                                                token_type_ids=token_type_ids,
-                                                attention_mask=attention_mask)
-            
-            for slot in model.informable_slot_ids:
-                p_ = p[slot]
-                labs = batch['labels-' + slot].to(device)
-                
-                belief_states[slot].append(p_)
-                labels[slot].append(labs)
-            
-            if p_req is not None:
-                for slot in model.requestable_slot_ids:
-                    p_ = p_req[slot]
-                    labs = batch['request-' + slot].to(device)
-
-                    request_belief[slot].append(p_)
-                    request_labels[slot].append(labs)
-                
-                for domain in model.domain_ids:
-                    p_ = p_dom[domain]
-                    labs = batch['active-' + domain].to(device)
-
-                    domain_belief[domain].append(p_)
-                    domain_labels[domain].append(labs)
-                
-                greeting_belief.append(p_bye)
-                greeting_labels.append(batch['goodbye'].to(device))
-    
-    for slot in belief_states:
-        belief_states[slot] = torch.cat(belief_states[slot], 0)
-        labels[slot] = torch.cat(labels[slot], 0)
-    if p_req is not None:
-        for slot in request_belief:
-            request_belief[slot] = torch.cat(request_belief[slot], 0)
-            request_labels[slot] = torch.cat(request_labels[slot], 0)
-        for domain in domain_belief:
-            domain_belief[domain] = torch.cat(domain_belief[domain], 0)
-            domain_labels[domain] = torch.cat(domain_labels[domain], 0)
-        greeting_belief = torch.cat(greeting_belief, 0)
-        greeting_labels = torch.cat(greeting_labels, 0)
-    else:
-        request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6
-
-    return belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels
-
-
-def normalise(p):
-    p_shape = p.size()
-
-    p = p.reshape(-1, p_shape[-1]) + 1e-10
-    p_sum = p.sum(-1).unsqueeze(1).repeat((1, p_shape[-1]))
-    p /= p_sum
-
-    p = p.reshape(p_shape)
-
-    return p
diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
index 9f101d128c6b8c9093c4959834cf5aac35b322a2..6d3d8035a4d6f47f2ea8551050ca8da682ea0376 100644
--- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py
+++ b/convlab/dst/setsumbt/modeling/ensemble_nbt.py
@@ -16,9 +16,9 @@
 """Ensemble SetSUMBT"""
 
 import os
+from shutil import copy2 as copy
 
 import torch
-import transformers
 from torch.nn import Module
 from transformers import RobertaConfig, BertConfig
 
@@ -29,8 +29,13 @@ MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
 
 
 class EnsembleSetSUMBT(Module):
+    """Ensemble SetSUMBT Model for joint ensemble prediction"""
 
     def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
         super(EnsembleSetSUMBT, self).__init__()
         self.config = config
 
@@ -38,175 +43,138 @@ class EnsembleSetSUMBT(Module):
         model_cls = MODELS[self.config.model_type]
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             setattr(self, attr, model_cls(config))
-    
 
-    # Load all ensemble memeber parameters
-    def load(self, path, config=None):
-        if config is None:
-            config = self.config
-        
+    def _load(self, path: str):
+        """
+        Load parameters
+        Args:
+            path: Location of model parameters
+        """
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             idx = attr.split('_', 1)[-1]
-            state_dict = torch.load(os.path.join(path, f'pytorch_model_{idx}.bin'))
+            state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin'))
             getattr(self, attr).load_state_dict(state_dict)
-    
 
-    # Add new slot candidates to the ensemble members
-    def add_slot_candidates(self, slot_candidates):
+    def add_slot_candidates(self, slot_candidates: tuple):
+        """
+        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
+        and a request indicator, if the informable value embeddings is None the slot is not informable and if
+        the request indicator is false the slot is not requestable.
+
+        Args:
+            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
+        """
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             getattr(self, attr).add_slot_candidates(slot_candidates)
-        self.requestable_slot_ids = self.model_0.requestable_slot_ids
-        self.informable_slot_ids = self.model_0.informable_slot_ids
-        self.domain_ids = self.model_0.domain_ids
-
-
-    # Add new value candidates to the ensemble members
-    def add_value_candidates(self, slot, value_candidates, replace=False):
+        self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids
+        self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids
+        self.domain_ids = self.model_0.setsumbt.domain_ids
+
+    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
+        """
+        Add value candidates for a slot
+
+        Args:
+            slot: Slot name
+            value_candidates: Value candidate embeddings
+            replace: If true existing value candidates are replaced
+        """
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             getattr(self, attr).add_value_candidates(slot, value_candidates, replace)
-        
 
-    # Forward pass of full ensemble
-    def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'):
-        logits, request_logits, domain_logits, goodbye_scores = [], [], [], []
-        logits = {slot: [] for slot in self.model_0.informable_slot_ids}
-        request_logits = {slot: [] for slot in self.model_0.requestable_slot_ids}
-        domain_logits = {dom: [] for dom in self.model_0.domain_ids}
-        goodbye_scores = []
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                reduction: str = 'mean') -> tuple:
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            reduction: Reduction of ensemble member predictive distributions (mean, none)
+
+        Returns:
+
+        """
+        belief_state_probs = {slot: [] for slot in self.informable_slot_ids}
+        request_probs = {slot: [] for slot in self.requestable_slot_ids}
+        active_domain_probs = {dom: [] for dom in self.domain_ids}
+        general_act_probs = []
         for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
             # Prediction from each ensemble member
-            l, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
+            b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids,
                                                 token_type_ids=token_type_ids,
                                                 attention_mask=attention_mask)
-            for slot in logits:
-                logits[slot].append(l[slot].unsqueeze(-2))
-            if self.config.predict_intents:
-                for slot in request_logits:
-                    request_logits[slot].append(r[slot].unsqueeze(-1))
-                for dom in domain_logits:
-                    domain_logits[dom].append(d[dom].unsqueeze(-1))
-                goodbye_scores.append(g.unsqueeze(-2))
+            for slot in belief_state_probs:
+                belief_state_probs[slot].append(b[slot].unsqueeze(-2))
+            if self.config.predict_actions:
+                for slot in request_probs:
+                    request_probs[slot].append(r[slot].unsqueeze(-1))
+                for dom in active_domain_probs:
+                    active_domain_probs[dom].append(d[dom].unsqueeze(-1))
+                general_act_probs.append(g.unsqueeze(-2))
         
-        logits = {slot: torch.cat(l, -2) for slot, l in logits.items()}
-        if self.config.predict_intents:
-            request_logits = {slot: torch.cat(l, -1) for slot, l in request_logits.items()}
-            domain_logits = {dom: torch.cat(l, -1) for dom, l in domain_logits.items()}
-            goodbye_scores = torch.cat(goodbye_scores, -2)
+        belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
+        if self.config.predict_actions:
+            request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()}
+            active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()}
+            general_act_probs = torch.cat(general_act_probs, -2)
         else:
-            request_logits = {}
-            domain_logits = {}
-            goodbye_scores = torch.tensor(0.0)
+            request_probs = {}
+            active_domain_probs = {}
+            general_act_probs = torch.tensor(0.0)
 
         # Apply reduction of ensemble to single posterior
         if reduction == 'mean':
-            logits = {slot: l.mean(-2) for slot, l in logits.items()}
-            request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()}
-            domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()}
-            goodbye_scores = goodbye_scores.mean(-2)
+            belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()}
+            request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()}
+            active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()}
+            general_act_probs = general_act_probs.mean(-2)
         elif reduction != 'none':
             raise(NameError('Not Implemented!'))
 
-        return logits, request_logits, domain_logits, goodbye_scores, _
+        return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _
     
 
     @classmethod
     def from_pretrained(cls, path):
-        if not os.path.exists(os.path.join(path, 'config.json')):
+        config_path = os.path.join(path, 'ens-0', 'config.json')
+        if not os.path.exists(config_path):
             raise(NameError('Could not find config.json in model path.'))
-        if not os.path.exists(os.path.join(path, 'pytorch_model_0.bin')):
-            raise(NameError('Could not find a model binary in the model path.'))
         
         try:
-            config = RobertaConfig.from_pretrained(path)
+            config = RobertaConfig.from_pretrained(config_path)
         except:
-            config = BertConfig.from_pretrained(path)
+            config = BertConfig.from_pretrained(config_path)
+
+        config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir])
         
         model = cls(config)
-        model.load(path)
+        model._load(path)
 
         return model
 
 
-class DropoutEnsembleSetSUMBT(Module):
-
-    def __init__(self, config):
-        super(DropoutEnsembleBeliefTracker, self).__init__()
-        self.config = config
-
-        model_cls = MODELS[self.config.model_type]
-        self.model = model_cls(config)
-        self.model.train()
-    
-
-    def load(self, path, config=None):
-        if config is None:
-            config = self.config
-        state_dict = torch.load(os.path.join(path, f'pytorch_model.bin'))
-        self.model.load_state_dict(state_dict)
-    
-
-    # Add new slot candidates to the model
-    def add_slot_candidates(self, slot_candidates):
-        self.model.add_slot_candidates(slot_candidates)
-        self.requestable_slot_ids = self.model.requestable_slot_ids
-        self.informable_slot_ids = self.model.informable_slot_ids
-        self.domain_ids = self.model.domain_ids
-
-
-    # Add new value candidates to the model
-    def add_value_candidates(self, slot, value_candidates, replace=False):
-        self.model.add_value_candidates(slot, value_candidates, replace)
-        
-    
-    def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'):
-
-        input_ids = input_ids.unsqueeze(0).repeat((self.config.ensemble_size, 1, 1, 1))
-        input_ids = input_ids.reshape(-1, input_ids.size(-2), input_ids.size(-1))
-        if attention_mask is not None:
-            attention_mask = attention_mask.unsqueeze(0).repeat((10, 1, 1, 1))
-            attention_mask = attention_mask.reshape(-1, attention_mask.size(-2), attention_mask.size(-1))
-        if token_type_ids is not None:
-            token_type_ids = token_type_ids.unsqueeze(0).repeat((10, 1, 1, 1))
-            token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-2), token_type_ids.size(-1))
-        
-        self.model.train()
-        logits, request_logits, domain_logits, goodbye_scores, _ = self.model(input_ids=input_ids,
-                                                                            attention_mask=attention_mask,
-                                                                            token_type_ids=token_type_ids)
-        
-        logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-2), l.size(-1)).transpose(0, 1).transpose(1, 2)
-                for s, l in logits.items()}
-        request_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2)
-                        for s, l in request_logits.items()}
-        domain_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2)
-                        for s, l in domain_logits.items()}
-        goodbye_scores = goodbye_scores.reshape(self.config.ensemble_size, -1, goodbye_scores.size(-2), goodbye_scores.size(-1))
-        goodbye_scores = goodbye_scores.transpose(0, 1).transpose(1, 2)
-
-        if reduction == 'mean':
-            logits = {slot: l.mean(-2) for slot, l in logits.items()}
-            request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()}
-            domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()}
-            goodbye_scores = goodbye_scores.mean(-2)
-        elif reduction != 'none':
-            raise(NameError('Not Implemented!'))
-
-        return logits, request_logits, domain_logits, goodbye_scores, _
-    
-
-    @classmethod
-    def from_pretrained(cls, path):
-        if not os.path.exists(os.path.join(path, 'config.json')):
-            raise(NameError('Could not find config.json in model path.'))
-        if not os.path.exists(os.path.join(path, 'pytorch_model.bin')):
-            raise(NameError('Could not find a model binary in the model path.'))
-        
-        try:
-            config = RobertaConfig.from_pretrained(path)
-        except:
-            config = BertConfig.from_pretrained(path)
-        
-        model = cls(config)
-        model.load(path)
-
-        return model
+def setup_ensemble(model_path: str, ensemble_size: int):
+    """
+    Setup ensemble model directory structure.
+
+    Args:
+        model_path: Path to ensemble model directory
+        ensemble_size: Number of ensemble members
+    """
+    for i in range(ensemble_size):
+        path = os.path.join(model_path, f'ens-{i}')
+        if not os.path.exists(path):
+            os.mkdir(path)
+            os.mkdir(os.path.join(path, 'dataloaders'))
+            os.mkdir(os.path.join(path, 'database'))
+            # Add development set dataloader to each ensemble member directory
+            for set_type in ['dev']:
+                copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'),
+                     os.path.join(path, 'dataloaders', f'{set_type}.dataloader'))
+            # Add training and development set ontologies to each ensemble member directory
+            for set_type in ['train', 'dev']:
+                copy(os.path.join(model_path, 'database', f'{set_type}.db'),
+                     os.path.join(path, 'database', f'{set_type}.db'))
diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c73d4b6d32a485a2cf2b5948dbd6a9a4d7f346cb
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/evaluation_utils.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Evaluation Utilities"""
+
+import random
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+
+def set_seed(args):
+    """
+    Set random seeds
+
+    Args:
+        args (Arguments class): Arguments class containing seed and number of gpus to use
+    """
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+
+def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple:
+    """
+    Get model predictions
+
+    Args:
+        args: Runtime arguments
+        model: SetSUMBT Model
+        device: Torch device
+        dataloader: Dataloader containing eval data
+    """
+    model.eval()
+    
+    belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids}
+    request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
+    active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids}
+    general_act_probs = []
+    state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids}
+    request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
+    active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids}
+    general_act_labels = []
+    epoch_iterator = tqdm(dataloader, desc="Iteration")
+    for step, batch in enumerate(epoch_iterator):
+        with torch.no_grad():    
+            input_ids = batch['input_ids'].to(device)
+            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+
+            p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids,
+                                              attention_mask=attention_mask)
+
+            for slot in belief_states:
+                p_ = p[slot]
+                labs = batch['state_labels-' + slot].to(device)
+                
+                belief_states[slot].append(p_)
+                state_labels[slot].append(labs)
+            
+            if p_req is not None:
+                for slot in request_probs:
+                    p_ = p_req[slot]
+                    labs = batch['request_labels-' + slot].to(device)
+
+                    request_probs[slot].append(p_)
+                    request_labels[slot].append(labs)
+                
+                for domain in active_domain_probs:
+                    p_ = p_dom[domain]
+                    labs = batch['active_domain_labels-' + domain].to(device)
+
+                    active_domain_probs[domain].append(p_)
+                    active_domain_labels[domain].append(labs)
+                
+                general_act_probs.append(p_gen)
+                general_act_labels.append(batch['general_act_labels'].to(device))
+    
+    for slot in belief_states:
+        belief_states[slot] = torch.cat(belief_states[slot], 0)
+        state_labels[slot] = torch.cat(state_labels[slot], 0)
+    if p_req is not None:
+        for slot in request_probs:
+            request_probs[slot] = torch.cat(request_probs[slot], 0)
+            request_labels[slot] = torch.cat(request_labels[slot], 0)
+        for domain in active_domain_probs:
+            active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0)
+            active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0)
+        general_act_probs = torch.cat(general_act_probs, 0)
+        general_act_labels = torch.cat(general_act_labels, 0)
+    else:
+        request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4
+        general_act_probs, general_act_labels = [None] * 2
+
+    out = (belief_states, state_labels, request_probs, request_labels)
+    out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels)
+    return out
diff --git a/convlab/dst/setsumbt/modeling/functional.py b/convlab/dst/setsumbt/modeling/functional.py
deleted file mode 100644
index 0dd083d0da080ca089e81d9ae01e5f0954243f61..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/modeling/functional.py
+++ /dev/null
@@ -1,456 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""SetSUMBT functionals"""
-
-import torch
-import transformers
-from torch.autograd import Variable
-from torch.nn import (MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
-                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
-                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
-from torch.nn.init import (xavier_normal_, constant_)
-
-from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss, BinaryBayesianMatchingLoss, dirichlet
-from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
-from convlab.dst.setsumbt.loss.distillation import DistillationKL, BinaryDistillationKL
-from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info
-
-
-# Default belief tracker model intialisation function
-def _initialise(self, config):
-    # Slot Utterance matching attention
-    self.slot_attention = MultiheadAttention(
-        config.hidden_size, config.slot_attention_heads)
-
-    # Latent context tracker
-    # Initial state prediction
-    if not config.rnn_zero_init and config.nbt_type in ['gru', 'lstm']:
-        self.belief_init = Sequential(Linear(config.hidden_size, config.nbt_hidden_size),
-                                      ReLU(), Dropout(config.dropout_rate))
-
-    # Recurrent context tracker setup
-    if config.nbt_type == 'gru':
-        self.nbt = GRU(input_size=config.hidden_size,
-                       hidden_size=config.nbt_hidden_size,
-                       num_layers=config.nbt_layers,
-                       dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate,
-                       batch_first=True)
-        # Initialise Parameters
-        xavier_normal_(self.nbt.weight_ih_l0)
-        xavier_normal_(self.nbt.weight_hh_l0)
-        constant_(self.nbt.bias_ih_l0, 0.0)
-        constant_(self.nbt.bias_hh_l0, 0.0)
-    elif config.nbt_type == 'lstm':
-        self.nbt = LSTM(input_size=config.hidden_size,
-                        hidden_size=config.nbt_hidden_size,
-                        num_layers=config.nbt_layers,
-                        dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate,
-                        batch_first=True)
-        # Initialise Parameters
-        xavier_normal_(self.nbt.weight_ih_l0)
-        xavier_normal_(self.nbt.weight_hh_l0)
-        constant_(self.nbt.bias_ih_l0, 0.0)
-        constant_(self.nbt.bias_hh_l0, 0.0)
-    else:
-        raise NameError('Not Implemented')
-
-    # Feature decoder and layer norm
-    self.intermediate = Linear(config.nbt_hidden_size, config.hidden_size)
-    self.layer_norm = LayerNorm(config.hidden_size)
-
-    # Dropout
-    self.dropout = Dropout(config.dropout_rate)
-
-    # Set pooler for set similarity model
-    if self.config.set_similarity:
-        # 1D convolutional set pooler
-        if self.config.set_pooling == 'cnn':
-            self.conv_pooler = Conv1d(
-                self.config.hidden_size, self.config.hidden_size, 3)
-        # Deep averaging network set pooler
-        elif self.config.set_pooling == 'dan':
-            self.avg_net = Sequential(Linear(self.config.hidden_size, 2 * self.config.hidden_size), GELU(),
-                                      Linear(2 * self.config.hidden_size, self.config.hidden_size))
-
-    # Model ontology placeholders
-    self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
-    self.slot_ids = dict()
-    self.requestable_slot_ids = dict()
-    self.informable_slot_ids = dict()
-    self.domain_ids = dict()
-
-    # Matching network similarity measure
-    if config.distance_measure == 'cosine':
-        self.distance = CosineSimilarity(dim=-1, eps=1e-8)
-    elif config.distance_measure == 'euclidean':
-        self.distance = PairwiseDistance(p=2.0, eps=1e-06, keepdim=False)
-    else:
-        raise NameError('NotImplemented')
-
-    # Belief state loss function
-    if config.loss_function == 'crossentropy':
-        self.loss = CrossEntropyLoss(ignore_index=-1)
-    elif config.loss_function == 'bayesianmatching':
-        self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-    elif config.loss_function == 'labelsmoothing':
-        self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-    elif config.loss_function == 'distillation':
-        self.loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-        self.temp = 1.0
-    elif config.loss_function == 'distribution_distillation':
-        self.loss = rkl_dirichlet_mediator_loss
-        self.temp = 1.0
-    else:
-        raise NameError('NotImplemented')
-
-    # Intent and domain prediction heads
-    if config.predict_actions:
-        self.request_gate = Linear(config.hidden_size, 1)
-        self.goodbye_gate = Linear(config.hidden_size, 3)
-        self.domain_gate = Linear(config.hidden_size, 1)
-
-        # Intent and domain loss function
-        self.request_weight = float(self.config.user_request_loss_weight)
-        self.goodbye_weight = float(self.config.user_general_act_loss_weight)
-        self.domain_weight = float(self.config.active_domain_loss_weight)
-        if config.loss_function == 'crossentropy':
-            self.request_loss = BCEWithLogitsLoss()
-            self.goodbye_loss = CrossEntropyLoss(ignore_index=-1)
-            self.domain_loss = BCEWithLogitsLoss()
-        elif config.loss_function == 'labelsmoothing':
-            self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-            self.goodbye_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
-            self.domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
-        elif config.loss_function == 'bayesianmatching':
-            self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-            self.goodbye_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-            self.domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
-        elif config.loss_function == 'distillation':
-            self.request_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-            self.goodbye_loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-            self.domain_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing)
-
-
-# Default belief tracker forward pass.
-def _nbt_forward(self, turn_embeddings,
-                 turn_pooled_representation,
-                 attention_mask,
-                 batch_size,
-                 dialogue_size,
-                 turn_size,
-                 hidden_state,
-                 inform_labels,
-                 request_labels,
-                 domain_labels,
-                 goodbye_labels,
-                 calculate_inform_mutual_info):
-    hidden_size = turn_embeddings.size(-1)
-    # Initialise loss
-    loss = 0.0
-
-    # Goodbye predictions
-    goodbye_probs = None
-    if self.config.predict_actions:
-        # General action prediction
-        goodbye_scores = self.goodbye_gate(
-            turn_pooled_representation.reshape(batch_size * dialogue_size, hidden_size))
-
-        # Compute loss for general action predictions (weighted loss)
-        if goodbye_labels is not None:
-            if self.config.loss_function == 'distillation':
-                goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-1))
-                loss += self.goodbye_loss(goodbye_scores, goodbye_labels, self.temp) * self.goodbye_weight
-            elif self.config.loss_function == 'distribution_distillation':
-                goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-2), goodbye_labels.size(-1))
-                loss += self.loss(goodbye_scores, goodbye_labels, 1.0, 1.0)[0] * self.goodbye_weight
-            else:
-                goodbye_labels = goodbye_labels.reshape(-1)
-                loss += self.goodbye_loss(goodbye_scores, goodbye_labels) * self.request_weight
-
-        # Compute general action probabilities
-        if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'distillation', 'distribution_distillation']:
-            goodbye_probs = torch.softmax(goodbye_scores, -1).reshape(batch_size, dialogue_size, -1)
-        elif self.config.loss_function in ['bayesianmatching']:
-            goodbye_probs = dirichlet(goodbye_scores.reshape(batch_size, dialogue_size, -1))
-
-    # Slot utterance matching
-    num_slots = self.slot_embeddings.size(0)
-    slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
-    slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)).to(turn_embeddings.device)
-
-    if self.config.set_similarity:
-        # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
-        slot_mask = (slot_embeddings != 0.0).float()
-
-    # Turn embeddings shape [turn_size, batch_size * dialogue_size, 768]
-    turn_embeddings = turn_embeddings.transpose(0, 1)
-    # Compute key padding mask
-    key_padding_mask = (attention_mask[:, :, 0] == 0.0)
-    key_padding_mask[key_padding_mask[:, 0] == True, :] = False
-    # Multi head attention of slot over tokens
-    hidden, _ = self.slot_attention(query=slot_embeddings,
-                                    key=turn_embeddings,
-                                    value=turn_embeddings,
-                                    key_padding_mask=key_padding_mask)  # [num_slots, batch_size * dialogue_size, 768]
-
-    # Set embeddings for all masked tokens to 0
-    attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
-    hidden = hidden * attention_mask
-    if self.config.set_similarity:
-        hidden = hidden * slot_mask
-    # Hidden layer shape [num_dials, num_slots, num_turns, 768]
-    hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
-
-    # Latent context tracking
-    # [batch_size * num_slots, dialogue_size, 768]
-    hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
-
-    if self.config.nbt_type == 'gru':
-        self.nbt.flatten_parameters()
-        if hidden_state is None:
-            if self.config.rnn_zero_init:
-                context = torch.zeros(self.config.nbt_layers, batch_size * slot_embeddings.size(0),
-                                      self.config.nbt_hidden_size)
-                context = context.to(turn_embeddings.device)
-            else:
-                context = self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1))
-        else:
-            context = hidden_state.to(hidden.device)
-
-        # [batch_size, dialogue_size, nbt_hidden_size]
-        belief_embedding, context = self.nbt(hidden, context)
-    elif self.config.nbt_type == 'lstm':
-        self.nbt.flatten_parameters()
-        if self.config.rnn_zero_init:
-            context = (torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size),
-                       torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size))
-            context = (context[0].to(turn_embeddings.device),
-                       context[1].to(turn_embeddings.device))
-        else:
-            context = (self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)),
-                       torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size))
-            context = (context[0], context[1].to(turn_embeddings.device))
-
-        # [batch_size, dialogue_size, nbt_hidden_size]
-        belief_embedding, context = self.nbt(hidden, context)
-
-    # Decode features
-    belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), dialogue_size, -1).transpose(1, 2)
-    if self.config.set_similarity:
-        belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
-                                                    self.config.nbt_hidden_size)
-    # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
-    # Normalisation and regularisation
-    belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
-    belief_embedding = self.dropout(belief_embedding)
-
-    # Pooling of the set of latent context representation
-    if self.config.set_similarity:
-        slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
-        belief_embedding = belief_embedding * slot_mask
-
-        # Apply pooler to latent context sequence
-        if self.config.set_pooling == 'mean':
-            belief_embedding = belief_embedding.sum(-2) / slot_mask.sum(-2)
-            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
-        elif self.config.set_pooling == 'cnn':
-            belief_embedding = belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size).transpose(1, 2)
-            belief_embedding = self.conv_pooler(belief_embedding)
-            # Mean pooling after CNN
-            belief_embedding = belief_embedding.mean(-1).reshape(batch_size, dialogue_size, num_slots, -1)
-        elif self.config.set_pooling == 'dan':
-            # sqrt N reduction
-            belief_embedding = belief_embedding.sum(-2) / torch.sqrt(torch.tensor(slot_mask.sum(-2)))
-            # Deep averaging feature extractor
-            belief_embedding = self.avg_net(belief_embedding)
-            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
-
-    # Perform classification
-    if self.config.predict_actions:
-        # User request prediction
-        request_probs = dict()
-        for slot, slot_id in self.requestable_slot_ids.items():
-            request_scores = self.request_gate(belief_embedding[:, :, slot_id, :])
-
-            # Store output probabilities
-            request_scores = request_scores.reshape(batch_size, dialogue_size)
-            mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
-            batches, dialogues = torch.where(mask == 0.0)
-            # Set request scores to 0.0 for padded turns
-            request_scores[batches, dialogues] = 0.0
-            if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching',
-                                             'distillation', 'distribution_distillation']:
-                request_probs[slot] = torch.sigmoid(request_scores)
-
-            if request_labels is not None:
-                # Compute request gate loss
-                request_scores = request_scores.reshape(-1)
-                if self.config.loss_function == 'distillation':
-                    loss += self.request_loss(request_scores, request_labels[slot].reshape(-1),
-                                              self.temp) * self.request_weight
-                elif self.config.loss_function == 'distribution_distillation':
-                    scores, labs = convert_probs_to_logits(request_scores, request_labels[slot])
-                    loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight
-                else:
-                    labs = request_labels[slot].reshape(-1)
-                    request_scores = request_scores[labs != -1]
-                    labs = labs[labs != -1].float()
-                    loss += self.request_loss(request_scores, labs) * self.request_weight
-
-        # Active domain prediction
-        domain_probs = dict()
-        for domain, slot_ids in self.domain_ids.items():
-            belief = belief_embedding[:, :, slot_ids, :]
-            if len(slot_ids) > 1:
-                # SqrtN reduction across all slots within a domain
-                belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
-            domain_scores = self.domain_gate(belief)
-
-            # Store output probabilities
-            domain_scores = domain_scores.reshape(batch_size, dialogue_size)
-            mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
-            batches, dialogues = torch.where(mask == 0.0)
-            domain_scores[batches, dialogues] = 0.0
-            if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', 'distillation',
-                                             'distribution_distillation']:
-                domain_probs[domain] = torch.sigmoid(domain_scores)
-
-            if domain_labels is not None:
-                # Compute domain prediction loss
-                domain_scores = domain_scores.reshape(-1)
-                if self.config.loss_function == 'distillation':
-                    loss += self.domain_loss(domain_scores, domain_labels[domain].reshape(-1),
-                                             self.temp) * self.domain_weight
-                elif self.config.loss_function == 'distribution_distillation':
-                    scores, labs = convert_probs_to_logits(domain_scores, domain_labels[domain])
-                    loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight
-                else:
-                    labs = domain_labels[domain].reshape(-1)
-                    domain_scores = domain_scores[labs != -1]
-                    labs = labs[labs != -1].float()
-                    loss += self.domain_loss(domain_scores, labs) * self.domain_weight
-    else:
-        request_probs, domain_probs = None, None
-
-    # Informable slot predictions
-    inform_probs = dict()
-    out_dict = dict()
-    mutual_info = dict()
-    stats = dict()
-    for slot, slot_id in self.informable_slot_ids.items():
-        # Get slot belief embedding and value candidates
-        candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
-        belief = belief_embedding[:, :, slot_id, :]
-        slot_size = candidate_embeddings.size(0)
-
-        # Use similaroty matching to produce belief state
-        if self.config.distance_measure in ['cosine', 'euclidean']:
-            belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
-            belief = belief.reshape(-1, self.config.hidden_size)
-
-            # Pooling of set of value candidate description representation
-            if self.config.set_similarity and self.config.set_pooling == 'mean':
-                candidate_mask = (candidate_embeddings != 0.0).float()
-                candidate_embeddings = candidate_embeddings.sum(1) / candidate_mask.sum(1)
-            elif self.config.set_similarity and self.config.set_pooling == 'cnn':
-                candidate_embeddings = candidate_embeddings.transpose(1, 2)
-                candidate_embeddings = self.conv_pooler(candidate_embeddings).mean(-1)
-            elif self.config.set_similarity and self.config.set_pooling == 'dan':
-                candidate_mask = (candidate_embeddings != 0.0).float()
-                candidate_embeddings = candidate_embeddings.sum(1) / torch.sqrt(torch.tensor(candidate_mask.sum(1)))
-                candidate_embeddings = self.avg_net(candidate_embeddings)
-
-            candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
-                                                                                          dialogue_size, 1, 1))
-            candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
-
-        # Score value candidates
-        if self.config.distance_measure == 'cosine':
-            scores = self.distance(belief, candidate_embeddings)
-            # *27 here rescales the cosine similarity for better learning
-            scores = scores.reshape(batch_size * dialogue_size, -1) * 27.0
-        elif self.config.distance_measure == 'euclidean':
-            scores = -1.0 * self.distance(belief, candidate_embeddings)
-            scores = scores.reshape(batch_size * dialogue_size, -1)
-
-        # Calculate belief state
-        if self.config.loss_function in ['crossentropy', 'inhibitedce',
-                                         'labelsmoothing', 'distillation', 'distribution_distillation']:
-            probs_ = torch.softmax(scores.reshape(batch_size, dialogue_size, -1), -1)
-        elif self.config.loss_function in ['bayesianmatching']:
-            probs_ = dirichlet(scores.reshape(batch_size, dialogue_size, -1))
-
-        # Compute knowledge uncertainty in the beleif states
-        if calculate_inform_mutual_info and self.config.loss_function == 'distribution_distillation':
-            mutual_info[slot] = logits_to_mutual_info(scores).reshape(batch_size, dialogue_size)
-
-        # Set padded turn probabilities to zero
-        mask = attention_mask[self.slot_ids[slot],:, 0].reshape(batch_size, dialogue_size)
-        batches, dialogues = torch.where(mask == 0.0)
-        probs_[batches, dialogues, :] = 0.0
-        inform_probs[slot] = probs_
-
-        # Calculate belief state loss
-        if inform_labels is not None and slot in inform_labels:
-            if self.config.loss_function == 'bayesianmatching':
-                prior = torch.ones(scores.size(-1)).float().to(scores.device)
-                prior = prior * self.config.prior_constant
-                prior = prior.unsqueeze(0).repeat((scores.size(0), 1))
-
-                loss += self.loss(scores, inform_labels[slot].reshape(-1), prior=prior)
-            elif self.config.loss_function == 'distillation':
-                labels = inform_labels[slot]
-                labels = labels.reshape(-1, labels.size(-1))
-                loss += self.loss(scores, labels, self.temp)
-            elif self.config.loss_function == 'distribution_distillation':
-                labels = inform_labels[slot]
-                labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
-                loss_, model_stats, ensemble_stats = self.loss(scores, labels, 1.0, 1.0)
-                loss += loss_
-
-                # Calculate stats regarding model precisions
-                precision = model_stats['precision']
-                ensemble_precision = ensemble_stats['precision']
-                stats[slot] = {'model_precision_min': precision.min(),
-                               'model_precision_max': precision.max(),
-                               'model_precision_mean': precision.mean(),
-                               'ensemble_precision_min': ensemble_precision.min(),
-                               'ensemble_precision_max': ensemble_precision.max(),
-                               'ensemble_precision_mean': ensemble_precision.mean()}
-            else:
-                loss += self.loss(scores, inform_labels[slot].reshape(-1))
-
-    # Return model outputs
-    out = inform_probs, request_probs, domain_probs, goodbye_probs, context
-    if inform_labels is not None or request_labels is not None or domain_labels is not None or goodbye_labels is not None:
-        out = (loss,) + out + (stats,)
-    if calculate_inform_mutual_info:
-        out = out + (mutual_info,)
-    return out
-
-
-# Convert binary scores and labels to 2 class classification problem for distribution distillation
-def convert_probs_to_logits(scores, labels):
-    # Convert single target probability p to distribution [1-p, p]
-    labels = labels.reshape(-1, labels.size(-1), 1)
-    labels = torch.cat([1 - labels, labels], -1)
-
-    # Convert input scores into predictive distribution [1-z, z]
-    scores = torch.sigmoid(scores).unsqueeze(1)
-    scores = torch.cat((1 - scores, scores), 1)
-    scores = -1.0 * torch.log((1 / (scores + 1e-8)) - 1)  # Inverse sigmoid
-
-    return scores, labels
diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py
index 36920c5ca550a3295f31aaca53c2ebed8c22be37..f72d17fafa50553434b6d4dcd20b8e53d143892f 100644
--- a/convlab/dst/setsumbt/modeling/roberta_nbt.py
+++ b/convlab/dst/setsumbt/modeling/roberta_nbt.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,16 +16,19 @@
 """RoBERTa SetSUMBT"""
 
 import torch
-import transformers
-from torch.autograd import Variable
 from transformers import RobertaModel, RobertaPreTrainedModel
 
-from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward
+from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
 
 
 class RobertaSetSUMBT(RobertaPreTrainedModel):
+    """Roberta based SetSUMBT model"""
 
     def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
         super(RobertaSetSUMBT, self).__init__(config)
         self.config = config
 
@@ -35,60 +38,37 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
             for p in self.roberta.parameters():
                 p.requires_grad = False
 
-        _initialise(self, config)
+        self.setsumbt = SetSUMBTHead(config)
+        self.add_slot_candidates = self.setsumbt.add_slot_candidates
+        self.add_value_candidates = self.setsumbt.add_value_candidates
     
+    def forward(self,
+                input_ids: torch.Tensor,
+                attention_mask: torch.Tensor,
+                token_type_ids: torch.Tensor = None,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                get_turn_pooled_representation: bool = False,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            input_ids: Input token ids
+            attention_mask: Input padding mask
+            token_type_ids: Token type indicator
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            get_turn_pooled_representation: Return pooled representation of the current dialogue turn
+            calculate_state_mutual_info: Return mutual information in the dialogue state
 
-    # Add new slot candidates to the model
-    def add_slot_candidates(self, slot_candidates):
-        """slot_candidates is a list of tuples for each slot.
-        - The tuples contains the slot embedding, informable value embeddings and a request indicator.
-        - If the informable value embeddings is None the slot is not informable
-        - If the request indicator is false the slot is not requestable"""
-        if self.slot_embeddings.size(0) != 0:
-            embeddings = self.slot_embeddings.detach()
-        else:
-            embeddings = torch.zeros(0)
-
-        for slot in slot_candidates:
-            if slot in self.slot_ids:
-                index = self.slot_ids[slot]
-                embeddings[index, :] = slot_candidates[slot][0]
-            else:
-                index = embeddings.size(0)
-                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
-                embeddings = torch.cat((embeddings, emb), 0)
-                self.slot_ids[slot] = index
-                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
-            # Add slot to relevant requestable and informable slot lists
-            if slot_candidates[slot][2]:
-                self.requestable_slot_ids[slot] = index
-            if slot_candidates[slot][1] is not None:
-                self.informable_slot_ids[slot] = index
-            
-            domain = slot.split('-', 1)[0]
-            if domain not in self.domain_ids:
-                self.domain_ids[domain] = []
-            self.domain_ids[domain].append(index)
-            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
-        
-        self.slot_embeddings = Variable(embeddings, requires_grad=False)
-
-
-    # Add new value candidates to the model
-    def add_value_candidates(self, slot, value_candidates, replace=False):
-        embeddings = getattr(self, slot + '_value_embeddings')
-
-        if embeddings.size(0) == 0 or replace:
-            embeddings = value_candidates
-        else:
-            embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
-        
-        setattr(self, slot + '_value_embeddings', embeddings)
-        
-    
-    def forward(self, input_ids, attention_mask, token_type_ids=None, hidden_state=None, inform_labels=None,
-                request_labels=None, domain_labels=None, goodbye_labels=None,
-                get_turn_pooled_representation=False, calculate_inform_mutual_info=False):
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
         if token_type_ids is not None:
             token_type_ids = None
 
@@ -106,9 +86,10 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
         turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
         
         if get_turn_pooled_representation:
-            return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
-                                turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
-                                calculate_inform_mutual_info) + (roberta_output.pooler_output,)
-        return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
-                            turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
-                            calculate_inform_mutual_info)
+            return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
+                                 batch_size, dialogue_size, hidden_state, state_labels,
+                                 request_labels, active_domain_labels, general_act_labels,
+                                 calculate_state_mutual_info) + (roberta_output.pooler_output,)
+        return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
+                             dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
+                             general_act_labels, calculate_state_mutual_info)
diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0249649f0840d66b0cec8a65c91aded906f62f85
--- /dev/null
+++ b/convlab/dst/setsumbt/modeling/setsumbt.py
@@ -0,0 +1,564 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""SetSUMBT Prediction Head"""
+
+import torch
+from torch.autograd import Variable
+from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
+                      CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
+                      Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
+from torch.nn.init import (xavier_normal_, constant_)
+
+from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss,
+                                       KLDistillationLoss, BinaryKLDistillationLoss,
+                                       LabelSmoothingLoss, BinaryLabelSmoothingLoss,
+                                       RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss)
+
+
+class SlotUtteranceMatching(Module):
+    """Slot Utterance matching attention based information extractor"""
+
+    def __init__(self, hidden_size: int = 768, attention_heads: int = 12):
+        """
+        Args:
+            hidden_size (int): Dimension of token embeddings
+            attention_heads (int): Number of attention heads to use in attention module
+        """
+        super(SlotUtteranceMatching, self).__init__()
+
+        self.attention = MultiheadAttention(hidden_size, attention_heads)
+
+    def forward(self,
+                turn_embeddings: torch.Tensor,
+                attention_mask: torch.Tensor,
+                slot_embeddings: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size]
+            attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size]
+            slot_embeddings: Embeddings for each token in the slot descriptions
+
+        Returns:
+            hidden: Information extracted from turn related to slot descriptions
+        """
+        turn_embeddings = turn_embeddings.transpose(0, 1)
+
+        key_padding_mask = (attention_mask[:, :, 0] == 0.0)
+        key_padding_mask[key_padding_mask[:, 0], :] = False
+
+        hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
+                                   key_padding_mask=key_padding_mask)
+
+        attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
+        hidden = hidden * attention_mask
+
+        return hidden
+
+
+class RecurrentNeuralBeliefTracker(Module):
+    """Recurrent latent neural belief tracking module"""
+
+    def __init__(self,
+                 nbt_type: str = 'gru',
+                 rnn_zero_init: bool = False,
+                 input_size: int = 768,
+                 hidden_size: int = 300,
+                 hidden_layers: int = 1,
+                 dropout_rate: float = 0.3):
+        """
+        Args:
+            nbt_type: Type of recurrent neural network (gru/lstm)
+            rnn_zero_init: Use zero initialised state for the RNN
+            input_size: Embedding size of the inputs
+            hidden_size: Hidden size of the RNN
+            hidden_layers: Number of RNN Layers
+            dropout_rate: Dropout rate
+        """
+        super(RecurrentNeuralBeliefTracker, self).__init__()
+
+        if rnn_zero_init:
+            self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate))
+        else:
+            self.belief_init = None
+
+        self.nbt_type = nbt_type
+        self.hidden_layers = hidden_layers
+        self.hidden_size = hidden_size
+        if nbt_type == 'gru':
+            self.nbt = GRU(input_size=input_size,
+                           hidden_size=hidden_size,
+                           num_layers=hidden_layers,
+                           dropout=0.0 if hidden_layers == 1 else dropout_rate,
+                           batch_first=True)
+        elif nbt_type == 'lstm':
+            self.nbt = LSTM(input_size=input_size,
+                            hidden_size=hidden_size,
+                            num_layers=hidden_layers,
+                            dropout=0.0 if hidden_layers == 1 else dropout_rate,
+                            batch_first=True)
+        else:
+            raise NameError('Not Implemented')
+
+        # Initialise Parameters
+        xavier_normal_(self.nbt.weight_ih_l0)
+        xavier_normal_(self.nbt.weight_hh_l0)
+        constant_(self.nbt.bias_ih_l0, 0.0)
+        constant_(self.nbt.bias_hh_l0, 0.0)
+
+        # Intermediate feature mapping and layer normalisation
+        self.intermediate = Linear(hidden_size, input_size)
+        self.layer_norm = LayerNorm(input_size)
+        self.dropout = Dropout(dropout_rate)
+
+    def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor:
+        """
+        Args:
+            inputs: Latent turn level information
+            hidden_state: Latent internal belief state
+
+        Returns:
+            belief_embedding: Belief state embeddings
+            context: Latent internal belief state
+        """
+        self.nbt.flatten_parameters()
+        if hidden_state is None:
+            if self.belief_init is None:
+                context = torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device)
+            else:
+                context = self.belief_init(inputs[:, 0, :]).unsqueeze(0).repeat((self.hidden_layers, 1, 1))
+            if self.nbt_type == "lstm":
+                context = (context, torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device))
+        else:
+            context = hidden_state.to(inputs.device)
+
+        # [batch_size, dialogue_size, nbt_hidden_size]
+        belief_embedding, context = self.nbt(inputs, context)
+
+        # Normalisation and regularisation
+        belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
+        belief_embedding = self.dropout(belief_embedding)
+
+        return belief_embedding, context
+
+
+class SetPooler(Module):
+    """Token set pooler"""
+
+    def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768):
+        """
+        Args:
+            pooling_strategy: Type of set pooler (cnn/dan/mean)
+            hidden_size: Token embedding size
+        """
+        super(SetPooler, self).__init__()
+
+        self.pooling_strategy = pooling_strategy
+        if pooling_strategy == 'cnn':
+            self.cnn_filter_size = 3
+            self.pooler = Conv1d(hidden_size, hidden_size, self.cnn_filter_size)
+        elif pooling_strategy == 'dan':
+            self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size))
+
+    def forward(self, inputs, attention_mask):
+        """
+        Args:
+            inputs: Token set embeddings
+            attention_mask: Padding mask for the set of tokens
+
+        Returns:
+
+        """
+        if self.pooling_strategy == "mean":
+            hidden = inputs.sum(1) / attention_mask.sum(1)
+        elif self.pooling_strategy == "cnn":
+            hidden = self.pooler(inputs.transpose(1, 2)).mean(-1)
+        elif self.pooling_strategy == 'dan':
+            hidden = inputs.sum(1) / torch.sqrt(torch.tensor(attention_mask.sum(1)))
+            hidden = self.pooler(hidden)
+
+        return hidden
+
+
+class SetSUMBTHead(Module):
+    """SetSUMBT Prediction Head for Language Models"""
+
+    def __init__(self, config):
+        """
+        Args:
+            config (configuration): Model configuration class
+        """
+        super(SetSUMBTHead, self).__init__()
+        self.config = config
+        # Slot Utterance matching attention
+        self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads)
+
+        # Latent context tracker
+        self.nbt = RecurrentNeuralBeliefTracker(config.nbt_type, config.rnn_zero_init, config.hidden_size,
+                                                config.nbt_hidden_size, config.nbt_layers, config.dropout_rate)
+
+        # Set pooler for set similarity model
+        if self.config.set_similarity:
+            self.set_pooler = SetPooler(config.set_pooling, config.hidden_size)
+
+        # Model ontology placeholders
+        self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
+        self.slot_ids = dict()
+        self.requestable_slot_ids = dict()
+        self.informable_slot_ids = dict()
+        self.domain_ids = dict()
+
+        # Matching network similarity measure
+        if config.distance_measure == 'cosine':
+            self.distance = CosineSimilarity(dim=-1, eps=1e-8)
+        elif config.distance_measure == 'euclidean':
+            self.distance = PairwiseDistance(p=2.0, eps=1e-6, keepdim=False)
+        else:
+            raise NameError('NotImplemented')
+
+        # User goal prediction loss function
+        if config.loss_function == 'crossentropy':
+            self.loss = CrossEntropyLoss(ignore_index=-1)
+        elif config.loss_function == 'bayesianmatching':
+            self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+        elif config.loss_function == 'labelsmoothing':
+            self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
+        elif config.loss_function == 'distillation':
+            self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+            self.temp = 1.0
+        elif config.loss_function == 'distribution_distillation':
+            self.loss = RKLDirichletMediatorLoss(ignore_index=-1)
+        else:
+            raise NameError('NotImplemented')
+
+        # Intent and domain prediction heads
+        if config.predict_actions:
+            self.request_gate = Linear(config.hidden_size, 1)
+            self.general_act_gate = Linear(config.hidden_size, 3)
+            self.active_domain_gate = Linear(config.hidden_size, 1)
+
+            # Intent and domain loss function
+            self.request_weight = float(self.config.user_request_loss_weight)
+            self.general_act_weight = float(self.config.user_general_act_loss_weight)
+            self.active_domain_weight = float(self.config.active_domain_loss_weight)
+            if config.loss_function == 'crossentropy':
+                self.request_loss = BCEWithLogitsLoss()
+                self.general_act_loss = CrossEntropyLoss(ignore_index=-1)
+                self.active_domain_loss = BCEWithLogitsLoss()
+            elif config.loss_function == 'labelsmoothing':
+                self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
+                self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
+                self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
+            elif config.loss_function == 'bayesianmatching':
+                self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+                self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+                self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
+            elif config.loss_function == 'distillation':
+                self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+                self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+                self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
+            elif config.loss_function == 'distribution_distillation':
+                self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
+                self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1)
+                self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
+
+    def add_slot_candidates(self, slot_candidates: tuple):
+        """
+        Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
+        and a request indicator, if the informable value embeddings is None the slot is not informable and if
+        the request indicator is false the slot is not requestable.
+
+        Args:
+            slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
+        """
+        if self.slot_embeddings.size(0) != 0:
+            embeddings = self.slot_embeddings.detach()
+        else:
+            embeddings = torch.zeros(0)
+
+        for slot in slot_candidates:
+            if slot in self.slot_ids:
+                index = self.slot_ids[slot]
+                embeddings[index, :] = slot_candidates[slot][0]
+            else:
+                index = embeddings.size(0)
+                emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
+                embeddings = torch.cat((embeddings, emb), 0)
+                self.slot_ids[slot] = index
+                setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
+            # Add slot to relevant requestable and informable slot lists
+            if slot_candidates[slot][2]:
+                self.requestable_slot_ids[slot] = index
+            if slot_candidates[slot][1] is not None:
+                self.informable_slot_ids[slot] = index
+
+            domain = slot.split('-', 1)[0]
+            if domain not in self.domain_ids:
+                self.domain_ids[domain] = []
+            self.domain_ids[domain].append(index)
+            self.domain_ids[domain] = list(set(self.domain_ids[domain]))
+
+        self.slot_embeddings = Variable(embeddings, requires_grad=False)
+
+    def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
+        """
+        Add value candidates for a slot
+
+        Args:
+            slot: Slot name
+            value_candidates: Value candidate embeddings
+            replace: If true existing value candidates are replaced
+        """
+        embeddings = getattr(self, slot + '_value_embeddings')
+
+        if embeddings.size(0) == 0 or replace:
+            embeddings = value_candidates
+        else:
+            embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
+
+        setattr(self, slot + '_value_embeddings', embeddings)
+
+    def forward(self,
+                turn_embeddings: torch.Tensor,
+                turn_pooled_representation: torch.Tensor,
+                attention_mask: torch.Tensor,
+                batch_size: int,
+                dialogue_size: int,
+                hidden_state: torch.Tensor = None,
+                state_labels: torch.Tensor = None,
+                request_labels: torch.Tensor = None,
+                active_domain_labels: torch.Tensor = None,
+                general_act_labels: torch.Tensor = None,
+                calculate_state_mutual_info: bool = False):
+        """
+        Args:
+            turn_embeddings: Token embeddings in the current turn
+            turn_pooled_representation: Pooled representation of the current dialogue turn
+            attention_mask: Padding mask for the current dialogue turn
+            batch_size: Number of dialogues in the batch
+            dialogue_size: Number of turns in each dialogue
+            hidden_state: Latent internal dialogue belief state
+            state_labels: Dialogue state labels
+            request_labels: User request action labels
+            active_domain_labels: Current active domain labels
+            general_act_labels: General user action labels
+            calculate_state_mutual_info: Return mutual information in the dialogue state
+
+        Returns:
+            out: Tuple containing loss, predictive distributions, model statistics and state mutual information
+        """
+        hidden_size = turn_embeddings.size(-1)
+        # Initialise loss
+        loss = 0.0
+
+        # General Action predictions
+        general_act_probs = None
+        if self.config.predict_actions:
+            # General action prediction
+            general_act_logits = self.general_act_gate(turn_pooled_representation.reshape(batch_size * dialogue_size,
+                                                                                          hidden_size))
+
+            # Compute loss for general action predictions (weighted loss)
+            if general_act_labels is not None:
+                if self.config.loss_function == 'distillation':
+                    general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-1))
+                    loss += self.general_act_loss(general_act_logits, general_act_labels,
+                                                  self.temp) * self.general_act_weight
+                elif self.config.loss_function == 'distribution_distillation':
+                    general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-2),
+                                                                    general_act_labels.size(-1))
+                    loss += self.general_act_loss(general_act_logits, general_act_labels)[0] * self.general_act_weight
+                else:
+                    general_act_labels = general_act_labels.reshape(-1)
+                    loss += self.general_act_loss(general_act_logits, general_act_labels) * self.general_act_weight
+
+            # Compute general action probabilities
+            general_act_probs = torch.softmax(general_act_logits, -1).reshape(batch_size, dialogue_size, -1)
+
+        # Slot utterance matching
+        num_slots = self.slot_embeddings.size(0)
+        slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
+        slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1))
+        slot_embeddings = slot_embeddings.to(turn_embeddings.device)
+
+        if self.config.set_similarity:
+            # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
+            slot_mask = (slot_embeddings != 0.0).float()
+
+        hidden = self.slot_utterance_matching(turn_embeddings, attention_mask, slot_embeddings)
+
+        if self.config.set_similarity:
+            hidden = hidden * slot_mask
+        # Hidden layer shape [num_dials, num_slots, num_turns, 768]
+        hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
+
+        # Latent context tracking
+        # [batch_size * num_slots, dialogue_size, 768]
+        hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
+        belief_embedding, hidden_state = self.nbt(hidden, hidden_state)
+
+        belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0),
+                                                    dialogue_size, -1).transpose(1, 2)
+        if self.config.set_similarity:
+            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
+                                                        self.config.hidden_size)
+        # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
+
+        # Pooling of the set of latent context representation
+        if self.config.set_similarity:
+            slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
+            belief_embedding = belief_embedding * slot_mask
+
+            belief_embedding = self.set_pooler(belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size),
+                                               slot_mask.reshape(-1, slot_mask.size(-2), hidden_size))
+            belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
+
+        # Perform classification
+        # Get padded batch, dialogue idx pairs
+        batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0)
+        
+        if self.config.predict_actions:
+            # User request prediction
+            request_probs = dict()
+            for slot, slot_id in self.requestable_slot_ids.items():
+                request_logits = self.request_gate(belief_embedding[:, :, slot_id, :])
+
+                # Store output probabilities
+                request_logits = request_logits.reshape(batch_size, dialogue_size)
+                # Set request scores to 0.0 for padded turns
+                request_logits[batches, dialogues] = 0.0
+                request_probs[slot] = torch.sigmoid(request_logits)
+
+                if request_labels is not None:
+                    # Compute request gate loss
+                    request_logits = request_logits.reshape(-1)
+                    if self.config.loss_function == 'distillation':
+                        loss += self.request_loss(request_logits, request_labels[slot].reshape(-1),
+                                                  self.temp) * self.request_weight
+                    elif self.config.loss_function == 'distribution_distillation':
+                        loss += self.request_loss(request_logits, request_labels[slot])[0] * self.request_weight
+                    else:
+                        labs = request_labels[slot].reshape(-1)
+                        request_logits = request_logits[labs != -1]
+                        labs = labs[labs != -1].float()
+                        loss += self.request_loss(request_logits, labs) * self.request_weight
+
+            # Active domain prediction
+            active_domain_probs = dict()
+            for domain, slot_ids in self.domain_ids.items():
+                belief = belief_embedding[:, :, slot_ids, :]
+                if len(slot_ids) > 1:
+                    # SqrtN reduction across all slots within a domain
+                    belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
+                active_domain_logits = self.active_domain_gate(belief)
+
+                # Store output probabilities
+                active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size)
+                active_domain_logits[batches, dialogues] = 0.0
+                active_domain_probs[domain] = torch.sigmoid(active_domain_logits)
+
+                if active_domain_labels is not None and domain in active_domain_labels:
+                    # Compute domain prediction loss
+                    active_domain_logits = active_domain_logits.reshape(-1)
+                    if self.config.loss_function == 'distillation':
+                        loss += self.active_domain_loss(active_domain_logits, active_domain_labels[domain].reshape(-1),
+                                                        self.temp) * self.active_domain_weight
+                    elif self.config.loss_function == 'distribution_distillation':
+                        loss += self.active_domain_loss(active_domain_logits,
+                                                        active_domain_labels[domain])[0] * self.active_domain_weight
+                    else:
+                        labs = active_domain_labels[domain].reshape(-1)
+                        active_domain_logits = active_domain_logits[labs != -1]
+                        labs = labs[labs != -1].float()
+                        loss += self.active_domain_loss(active_domain_logits, labs) * self.active_domain_weight
+        else:
+            request_probs, active_domain_probs = None, None
+
+        # Dialogue state predictions
+        belief_state_probs = dict()
+        belief_state_mutual_info = dict()
+        belief_state_stats = dict()
+        for slot, slot_id in self.informable_slot_ids.items():
+            # Get slot belief embedding and value candidates
+            candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
+            belief = belief_embedding[:, :, slot_id, :]
+            slot_size = candidate_embeddings.size(0)
+
+            belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
+            belief = belief.reshape(-1, self.config.hidden_size)
+
+            if self.config.set_similarity:
+                candidate_embeddings = self.set_pooler(candidate_embeddings, (candidate_embeddings != 0.0).float())
+            candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
+                                                                                          dialogue_size, 1, 1))
+            candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
+
+            # Score value candidates
+            if self.config.distance_measure == 'cosine':
+                logits = self.distance(belief, candidate_embeddings)
+                # *27 here rescales the cosine similarity for better learning
+                logits = logits.reshape(batch_size * dialogue_size, -1) * 27.0
+            elif self.config.distance_measure == 'euclidean':
+                logits = -1.0 * self.distance(belief, candidate_embeddings)
+                logits = logits.reshape(batch_size * dialogue_size, -1)
+
+            # Calculate belief state
+            probs_ = torch.softmax(logits.reshape(batch_size, dialogue_size, -1), -1)
+
+            # Compute knowledge uncertainty in the beleif states
+            if calculate_state_mutual_info and self.config.loss_function == 'distribution_distillation':
+                belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size)
+
+            # Set padded turn probabilities to zero
+            probs_[batches, dialogues, :] = 0.0
+            belief_state_probs[slot] = probs_
+
+            # Calculate belief state loss
+            if state_labels is not None and slot in state_labels:
+                if self.config.loss_function == 'bayesianmatching':
+                    prior = torch.ones(logits.size(-1)).float().to(logits.device)
+                    prior = prior * self.config.prior_constant
+                    prior = prior.unsqueeze(0).repeat((logits.size(0), 1))
+
+                    loss += self.loss(logits, state_labels[slot].reshape(-1), prior=prior)
+                elif self.config.loss_function == 'distillation':
+                    labels = state_labels[slot]
+                    labels = labels.reshape(-1, labels.size(-1))
+                    loss += self.loss(logits, labels, self.temp)
+                elif self.config.loss_function == 'distribution_distillation':
+                    labels = state_labels[slot]
+                    labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
+                    loss_, model_stats, ensemble_stats = self.loss(logits, labels)
+                    loss += loss_
+
+                    # Calculate stats regarding model precisions
+                    precision = model_stats['precision']
+                    ensemble_precision = ensemble_stats['precision']
+                    belief_state_stats[slot] = {'model_precision_min': precision.min(),
+                                                'model_precision_max': precision.max(),
+                                                'model_precision_mean': precision.mean(),
+                                                'ensemble_precision_min': ensemble_precision.min(),
+                                                'ensemble_precision_max': ensemble_precision.max(),
+                                                'ensemble_precision_mean': ensemble_precision.mean()}
+                else:
+                    loss += self.loss(logits, state_labels[slot].reshape(-1))
+
+        # Return model outputs
+        out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state
+        if state_labels is not None or request_labels is not None:
+            out = (loss,) + out + (belief_state_stats,)
+        if calculate_state_mutual_info:
+            out = out + (belief_state_mutual_info,)
+        return out
diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
index fab205befe3350c9beb9d81566c813cf00b55cf2..654e83c5d1ad9dc908213cca8967a84893395b04 100644
--- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py
+++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,50 +13,70 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Temperature Scheduler Class"""
-import torch
+"""Linear Temperature Scheduler Class"""
+
 
 # Temp scheduler class for ensemble distillation
-class TemperatureScheduler:
+class LinearTemperatureScheduler:
+    """
+    Temperature scheduler object used for distribution temperature scheduling in distillation
 
-    def __init__(self, total_steps, base_temp=2.5, cycle_len=0.1):
-        self.state = {}
+    Attributes:
+        state (dict): Internal state of scheduler
+    """
+    def __init__(self,
+                 total_steps: int,
+                 base_temp: float = 2.5,
+                 cycle_len: float = 0.1):
+        """
+        Args:
+            total_steps (int): Total number of training steps
+            base_temp (float): Starting temperature
+            cycle_len (float): Fraction of total steps used for scheduling cycle
+        """
+        self.state = dict()
         self.state['total_steps'] = total_steps
         self.state['current_step'] = 0
         self.state['base_temp'] = base_temp
         self.state['current_temp'] = base_temp
         self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)]
+        self.state['rate'] = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
     
     def step(self):
+        """
+        Update temperature based on the schedule
+        """
         self.state['current_step'] += 1
         assert self.state['current_step'] <= self.state['total_steps']
         if self.state['current_step'] > self.state['cycles'][0]:
             if self.state['current_step'] < self.state['cycles'][1]:
-                rate = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
-                self.state['current_temp'] -= rate
+                self.state['current_temp'] -= self.state['rate']
             else:
                 self.state['current_temp'] = 1.0
     
     def temp(self):
+        """
+        Get current temperature
+
+        Returns:
+            temp (float): Current temperature for distribution scaling
+        """
         return float(self.state['current_temp'])
     
     def state_dict(self):
-        return self.state
-    
-    def load_state_dict(self, sd):
-        self.state = sd
+        """
+        Return scheduler state
 
-
-# if __name__ == "__main__":
-#     temp_scheduler = TemperatureScheduler(100)
-#     print(temp_scheduler.state_dict())
-
-#     temp = []
-#     for i in range(100):
-#         temp.append(temp_scheduler.temp())
-#         temp_scheduler.step()
+        Returns:
+            state (dict): Dictionary format state of the scheduler
+        """
+        return self.state
     
-#     temp_scheduler.load_state_dict(temp_scheduler.state_dict())
-#     print(temp_scheduler.state_dict())
+    def load_state_dict(self, state_dict: dict):
+        """
+        Load scheduler state from dictionary
 
-#     print(temp)
+        Args:
+            state_dict (dict): Dictionary format state of the scheduler
+        """
+        self.state = state_dict
diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py
index 259c6e1da061ad6800f92e9237324181678459cb..590b2ac7372b26262625d08691a8528ffddd82d2 100644
--- a/convlab/dst/setsumbt/modeling/training.py
+++ b/convlab/dst/setsumbt/modeling/training.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,17 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Training utils"""
+"""Training and evaluation utils"""
 
 import random
 import os
 import logging
+from copy import deepcopy
 
 import torch
 from torch.nn import DataParallel
 from torch.distributions import Categorical
 import numpy as np
-from transformers import AdamW, get_linear_schedule_with_warmup
+from transformers import get_linear_schedule_with_warmup
+from torch.optim import AdamW
 from tqdm import tqdm, trange
 try:
     from apex import amp
@@ -31,7 +33,7 @@ except:
     print('Apex not used')
 
 from convlab.dst.setsumbt.utils import clear_checkpoints
-from convlab.dst.setsumbt.modeling.temperature_scheduler import TemperatureScheduler
+from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler
 
 
 # Load logger and tensorboard summary writer
@@ -59,18 +61,131 @@ def set_ontology_embeddings(model, slots, load_slots=True):
     if load_slots:
         slots = {slot: embs for slot, embs in slots.items()}
         model.add_slot_candidates(slots)
-    for slot in model.informable_slot_ids:
+    try:
+        informable_slot_ids = model.setsumbt.informable_slot_ids
+    except:
+        informable_slot_ids = model.informable_slot_ids
+    for slot in informable_slot_ids:
         model.add_value_candidates(slot, values[slot], replace=True)
 
 
-def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev, embeddings=None, tokenizer=None):
-    """Train model!"""
+def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None):
+    """
+    Log training statistics.
+
+    Args:
+        global_step: Number of global training steps completed
+        loss: Training loss
+        jg_acc: Joint goal accuracy
+        sl_acc: Slot accuracy
+        req_f1: Request prediction F1 score
+        dom_f1: Active domain prediction F1 score
+        gen_f1: General action prediction F1 score
+        stats: Uncertainty measure statistics of model
+    """
+    if type(global_step) == int:
+        info = f"{global_step} steps complete, "
+        info += f"Loss since last update: {loss}. Validation set stats: "
+    elif global_step == 'training_complete':
+        info = f"Training Complete, "
+        info += f"Validation set stats: "
+    elif global_step == 'dev':
+        info = f"Validation set stats: Loss: {loss}, "
+    elif global_step == 'test':
+        info = f"Test set stats: Loss: {loss}, "
+    info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, "
+    if req_f1 is not None:
+        info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, "
+        info += f"General Action F1 Score: {gen_f1}"
+    logger.info(info)
+
+    if type(global_step) == int:
+        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
+        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
+        if req_f1 is not None:
+            tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
+            tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step)
+            tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step)
+        tb_writer.add_scalar('Loss/Dev', loss, global_step)
+
+        if stats:
+            for slot, stats_slot in stats.items():
+                for key, item in stats_slot.items():
+                    tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
+
+
+def get_input_dict(batch: dict,
+                   predict_actions: bool,
+                   model_informable_slot_ids: list,
+                   model_requestable_slot_ids: list = None,
+                   model_domain_ids: list = None,
+                   device = 'cpu') -> dict:
+    """
+    Produce model input arguments
+
+    Args:
+        batch: Batch of data from the dataloader
+        predict_actions: Model should predict user actions if set true
+        model_informable_slot_ids: List of model dialogue state slots
+        model_requestable_slot_ids: List of model requestable slots
+        model_domain_ids: List of model domains
+        device: Current torch device in use
+
+    Returns:
+        input_dict: Dictrionary containing model inputs for the batch
+    """
+    input_dict = dict()
+
+    input_dict['input_ids'] = batch['input_ids'].to(device)
+    input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
+    input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+
+    if any('belief_state' in key for key in batch):
+        input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device)
+                                      for slot in model_informable_slot_ids
+                                      if ('belief_state-' + slot) in batch}
+        if predict_actions:
+            input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device)
+                                            for slot in model_requestable_slot_ids
+                                            if ('request_probs-' + slot) in batch}
+            input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device)
+                                                  for domain in model_domain_ids
+                                                  if ('active_domain_probs-' + domain) in batch}
+            input_dict['general_act_labels'] = batch['general_act_probs'].to(device)
+    else:
+        input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device)
+                                      for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch}
+        if predict_actions:
+            input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device)
+                                            for slot in model_requestable_slot_ids
+                                            if ('request_labels-' + slot) in batch}
+            input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device)
+                                                  for domain in model_domain_ids
+                                                  if ('active_domain_labels-' + domain) in batch}
+            input_dict['general_act_labels'] = batch['general_act_labels'].to(device)
+
+    return input_dict
+
+
+def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict):
+    """
+    Train the SetSUMBT model.
+
+    Args:
+        args: Runtime arguments
+        model: SetSUMBT Model instance to train
+        device: Torch device to use during training
+        train_dataloader: Dataloader containing the training data
+        dev_dataloader: Dataloader containing the validation set data
+        slots: Model ontology used for training
+        slots_dev: Model ontology used for evaluating on the validation set
+    """
 
     # Calculate the total number of training steps to be performed
     if args.max_training_steps > 0:
         t_total = args.max_training_steps
-        args.num_train_epochs = args.max_training_steps // (
-            (len(train_dataloader) // args.gradient_accumulation_steps) + 1)
+        args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1
+        args.num_train_epochs = args.max_training_steps // args.num_train_epochs
     else:
         t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs
 
@@ -88,12 +203,12 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
         {
             "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0,
-            "lr":args.learning_rate
+            "lr": args.learning_rate
         },
     ]
 
     # Initialise the optimizer
-    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
 
     # Initialise linear lr scheduler
     num_warmup_steps = int(t_total * args.warmup_proportion)
@@ -109,8 +224,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
 
     # Set up fp16 and multi gpu usage
     if args.fp16:
-        model, optimizer = amp.initialize(
-            model, optimizer, opt_level=args.fp16_opt_level)
+        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
     if args.n_gpu > 1:
         model = DataParallel(model)
 
@@ -118,7 +232,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
     best_model = {'joint goal accuracy': 0.0,
                   'request f1 score': 0.0,
                   'active domain f1 score': 0.0,
-                  'goodbye act f1 score': 0.0,
+                  'general act f1 score': 0.0,
                   'train loss': np.inf}
     if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')):
         logger.info("Optimizer loaded from previous run.")
@@ -136,27 +250,27 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
             model.eval()
             set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-            jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
+            jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True)
 
             # Set model back to training mode
             model.train()
             model.zero_grad()
             set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
         else:
-            jg_acc, req_f1, dom_f1, bye_f1 = 0.0, 0.0, 0.0, 0.0
+            jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0
 
         best_model['joint goal accuracy'] = jg_acc
         best_model['request f1 score'] = req_f1
         best_model['active domain f1 score'] = dom_f1
-        best_model['goodbye act f1 score'] = bye_f1
+        best_model['general act f1 score'] = gen_f1
 
     # Log training set up
-    logger.info("Device: %s, Number of GPUs: %s, FP16 training: %s" % (device, args.n_gpu, args.fp16))
+    logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}")
     logger.info("***** Running training *****")
-    logger.info("  Num Batches = %d" % len(train_dataloader))
-    logger.info("  Num Epochs = %d" % args.num_train_epochs)
-    logger.info("  Gradient Accumulation steps = %d" % args.gradient_accumulation_steps)
-    logger.info("  Total optimization steps = %d" % t_total)
+    logger.info(f"  Num Batches = {len(train_dataloader)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {t_total}")
 
     # Initialise training parameters
     global_step = 0
@@ -173,11 +287,11 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
             steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
 
             logger.info("  Continuing training from checkpoint, will skip to saved global_step")
-            logger.info("  Continuing training from epoch %d" % epochs_trained)
-            logger.info("  Continuing training from global step %d" % global_step)
-            logger.info("  Will skip the first %d steps in the first epoch" % steps_trained_in_current_epoch)
+            logger.info(f"  Continuing training from epoch {epochs_trained}")
+            logger.info(f"  Continuing training from global step {global_step}")
+            logger.info(f"  Will skip the first {steps_trained_in_current_epoch} steps in the first epoch")
         except ValueError:
-            logger.info("  Starting fine-tuning.")
+            logger.info(f"  Starting fine-tuning.")
 
     # Prepare model for training
     tr_loss, logging_loss = 0.0, 0.0
@@ -196,43 +310,15 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                 continue
 
             # Extract all label dictionaries from the batch
-            if 'goodbye_belief' in batch:
-                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('belief-' + slot) in batch}
-                request_labels = {slot: batch['request_belief-' + slot].to(device)
-                                  for slot in model.requestable_slot_ids
-                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
-                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye_belief'].to(
-                    device) if args.predict_actions else None
-            else:
-                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('labels-' + slot) in batch}
-                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
-                                 if ('active-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye'].to(
-                    device) if args.predict_actions else None
-
-            # Extract all model inputs from batch
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
+            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
+                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
 
             # Set up temperature scaling for the model
             if temp_scheduler is not None:
-                model.temp = temp_scheduler.temp()
+                model.setsumbt.temp = temp_scheduler.temp()
 
             # Forward pass to obtain loss
-            loss, _, _, _, _, _, stats = model(input_ids=input_ids,
-                                               token_type_ids=token_type_ids,
-                                               attention_mask=attention_mask,
-                                               inform_labels=labels,
-                                               request_labels=request_labels,
-                                               domain_labels=domain_labels,
-                                               goodbye_labels=goodbye_labels)
+            loss, _, _, _, _, _, stats = model(**input_dict)
 
             if args.n_gpu > 1:
                 loss = loss.mean()
@@ -258,7 +344,6 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                 tb_writer.add_scalar('LearningRate', lr, global_step)
 
                 if stats:
-                    # print(stats.keys())
                     for slot, stats_slot in stats.items():
                         for key, item in stats_slot.items():
                             tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step)
@@ -273,7 +358,6 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
 
                 tr_loss += loss.float().item()
                 epoch_iterator.set_postfix(loss=loss.float().item())
-                loss = 0.0
                 global_step += 1
 
             # Save model checkpoint
@@ -286,52 +370,34 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                     model.eval()
                     set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-                    jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
+                    jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
+                                                                                   is_train=True)
                     # Log model eval information
-                    if req_f1 is not None:
-                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f'
-                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
-                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
-                        tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step)
-                        tb_writer.add_scalar('DomainF1Score/Dev', dom_f1, global_step)
-                        tb_writer.add_scalar('GoodbyeF1Score/Dev', bye_f1, global_step)
-                    else:
-                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f'
-                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc))
-                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
-                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
-                    tb_writer.add_scalar('Loss/Dev', loss, global_step)
-                    if stats:
-                        for slot, stats_slot in stats.items():
-                            for key, item in stats_slot.items():
-                                tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step)
+                    log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats)
 
                     # Set model back to training mode
                     model.train()
                     model.zero_grad()
                     set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False)
                 else:
-                    jg_acc, req_f1 = 0.0, None
-                    logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps))
+                    log_info(global_step, logging_loss / args.save_steps)
 
                 logging_loss = tr_loss
 
                 # Compute the score of the best model
                 try:
-                    best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \
-                        (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \
-                        (best_model['goodbye act f1 score'] *
-                         model.config.user_general_act_loss_weight)
+                    best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
+                    best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
+                    best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
                 except AttributeError:
                     best_score = 0.0
                 best_score += best_model['joint goal accuracy']
 
                 # Compute the score of the current model
                 try:
-                    current_score = (req_f1 * model.config.user_request_loss_weight) + \
-                        (dom_f1 * model.config.active_domain_loss_weight) + \
-                        (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0
+                    current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
+                    current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
+                    current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
                 except AttributeError:
                     current_score = 0.0
                 current_score += jg_acc
@@ -353,10 +419,10 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                     if req_f1:
                         best_model['request f1 score'] = req_f1
                         best_model['active domain f1 score'] = dom_f1
-                        best_model['goodbye act f1 score'] = bye_f1
+                        best_model['general act f1 score'] = gen_f1
                     best_model['train loss'] = tr_loss / global_step
 
-                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
+                    output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                     if not os.path.exists(output_dir):
                         os.makedirs(output_dir, exist_ok=True)
 
@@ -386,14 +452,15 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
                 epoch_iterator.close()
                 break
 
-        logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step))
+        steps_trained_in_current_epoch = 0
+        logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}')
 
         if args.max_training_steps > 0 and global_step > args.max_training_steps:
             train_iterator.close()
             break
         if args.patience > 0 and steps_since_last_update >= args.patience:
             train_iterator.close()
-            logger.info('Model has not improved for at least %i steps. Training stopped!' % args.patience)
+            logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!')
             break
 
     # Evaluate final model
@@ -401,30 +468,25 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
         model.eval()
         set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False)
 
-        jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader)
-        if req_f1 is not None:
-            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f'
-                        % (tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
-        else:
-            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f'
-                        % (tr_loss / global_step, jg_acc, sl_acc))
+        jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader,
+                                                                       is_train=True)
+
+        log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
     else:
-        jg_acc = 0.0
         logger.info('Training complete!')
 
     # Store final model
     try:
-        best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \
-            (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \
-            (best_model['goodbye act f1 score'] *
-             model.config.user_general_act_loss_weight)
+        best_score = best_model['request f1 score'] * model.config.user_request_loss_weight
+        best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight
+        best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight
     except AttributeError:
         best_score = 0.0
     best_score += best_model['joint goal accuracy']
     try:
-        current_score = (req_f1 * model.config.user_request_loss_weight) + \
-                        (dom_f1 * model.config.active_domain_loss_weight) + \
-                        (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0
+        current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0
+        current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0
+        current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0
     except AttributeError:
         current_score = 0.0
     current_score += jg_acc
@@ -456,225 +518,88 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de
             torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
         clear_checkpoints(args.output_dir)
     else:
-        logger.info(
-            'Final model not saved, since it is not the best performing model.')
-
-
-# Function for validation
-def train_eval(args, model, device, dev_dataloader):
-    """Evaluate Model during training!"""
-    accuracy_jg = []
-    accuracy_sl = []
-    accuracy_req = []
-    truepos_req, falsepos_req, falseneg_req = [], [], []
-    truepos_dom, falsepos_dom, falseneg_dom = [], [], []
-    truepos_bye, falsepos_bye, falseneg_bye = [], [], []
-    accuracy_dom = []
-    accuracy_bye = []
-    turns = []
-    for batch in dev_dataloader:
-        # Perform with no gradients stored
-        with torch.no_grad():
-            if 'goodbye_belief' in batch:
-                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('belief-' + slot) in batch}
-                request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
-                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye_belief'].to(
-                    device) if args.predict_actions else None
-            else:
-                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('labels-' + slot) in batch}
-                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
-                                 if ('active-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye'].to(
-                    device) if args.predict_actions else None
-
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(
-                device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(
-                device) if 'attention_mask' in batch else None
-
-            loss, p, p_req, p_dom, p_bye, _, stats = model(input_ids=input_ids,
-                                                           token_type_ids=token_type_ids,
-                                                           attention_mask=attention_mask,
-                                                           inform_labels=labels,
-                                                           request_labels=request_labels,
-                                                           domain_labels=domain_labels,
-                                                           goodbye_labels=goodbye_labels)
-
-        jg_acc = 0.0
-        req_acc = 0.0
-        req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
-        dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
-        dom_acc = 0.0
-        for slot in model.informable_slot_ids:
-            labels = batch['labels-' + slot].to(device)
-            p_ = p[slot]
-
-            acc = (p_.argmax(-1) == labels).reshape(-1).float()
-            jg_acc += acc
-
-        if model.config.predict_actions:
-            for slot in model.requestable_slot_ids:
-                p_req_ = p_req[slot]
-                request_labels = batch['request-' + slot].to(device)
-
-                acc = (p_req_.round().int() == request_labels).reshape(-1).float()
-                tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
-                fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float()
-                req_acc += acc
-                req_tp += tp
-                req_fp += fp
-                req_fn += fn
-
-            for domain in model.domain_ids:
-                p_dom_ = p_dom[domain]
-                domain_labels = batch['active-' + domain].to(device)
+        logger.info('Final model not saved, as it is not the best performing model.')
 
-                acc = (p_dom_.round().int() == domain_labels).reshape(-1).float()
-                tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float()
-                fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float()
-                dom_acc += acc
-                dom_tp += tp
-                dom_fp += fp
-                dom_fn += fn
-
-            goodbye_labels = batch['goodbye'].to(device)
-            bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum()
-            bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
-            bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum()
-            bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
-        else:
-            req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0)
-            req_tp, req_fp, req_fn = None, None, None
-            dom_tp, dom_fp, dom_fn = None, None, None
-            bye_tp, bye_fp, bye_fn = torch.tensor(
-                0.0), torch.tensor(0.0), torch.tensor(0.0)
-
-        sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float()
-        jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float()
-        req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0)
-        req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
-        req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
-        req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
-        dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
-        dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
-        dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
-        dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0)
-        n_turns = (labels >= 0).reshape(-1).sum().float().item()
-
-        accuracy_jg.append(jg_acc.item())
-        accuracy_sl.append(sl_acc.item())
-        accuracy_req.append(req_acc.item())
-        truepos_req.append(req_tp.item())
-        falsepos_req.append(req_fp.item())
-        falseneg_req.append(req_fn.item())
-        accuracy_dom.append(dom_acc.item())
-        truepos_dom.append(dom_tp.item())
-        falsepos_dom.append(dom_fp.item())
-        falseneg_dom.append(dom_fn.item())
-        accuracy_bye.append(bye_acc.item())
-        truepos_bye.append(bye_tp.item())
-        falsepos_bye.append(bye_fp.item())
-        falseneg_bye.append(bye_fn.item())
-        turns.append(n_turns)
-
-    # Global accuracy reduction across batches
-    turns = sum(turns)
-    jg_acc = sum(accuracy_jg) / turns
-    sl_acc = sum(accuracy_sl) / turns
-    if model.config.predict_actions:
-        req_acc = sum(accuracy_req) / turns
-        req_tp = sum(truepos_req)
-        req_fp = sum(falsepos_req)
-        req_fn = sum(falseneg_req)
-        req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn))
-        dom_acc = sum(accuracy_dom) / turns
-        dom_tp = sum(truepos_dom)
-        dom_fp = sum(falsepos_dom)
-        dom_fn = sum(falseneg_dom)
-        dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn))
-        bye_tp = sum(truepos_bye)
-        bye_fp = sum(falsepos_bye)
-        bye_fn = sum(falseneg_bye)
-        bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn))
-        bye_acc = sum(accuracy_bye) / turns
-    else:
-        req_acc, dom_acc, bye_acc = None, None, None
-        req_f1, dom_f1, bye_f1 = None, None, None
 
-    return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats
+def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False):
+    """
+    Evaluate model
 
+    Args:
+        args: Runtime arguments
+        model: SetSUMBT model instance
+        device: Torch device in use
+        dataloader: Dataloader of data to evaluate on
+        return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format
+        is_train: If true model is training and no logging is performed
 
-def evaluate(args, model, device, dataloader):
-    """Evaluate Model!"""
-    # Evaluate!
-    logger.info("***** Running evaluation *****")
-    logger.info("  Num Batches = %d", len(dataloader))
+    Returns:
+        out: Evaluated model statistics
+    """
+    return_eval_output = False if is_train else return_eval_output
+    if not is_train:
+        logger.info("***** Running evaluation *****")
+        logger.info("  Num Batches = %d", len(dataloader))
 
     tr_loss = 0.0
     model.eval()
+    if return_eval_output:
+        ontology = dataloader.dataset.ontology
 
-    # logits = {slot: [] for slot in model.informable_slot_ids}
     accuracy_jg = []
     accuracy_sl = []
-    accuracy_req = []
     truepos_req, falsepos_req, falseneg_req = [], [], []
     truepos_dom, falsepos_dom, falseneg_dom = [], [], []
-    truepos_bye, falsepos_bye, falseneg_bye = [], [], []
-    accuracy_dom = []
-    accuracy_bye = []
+    truepos_gen, falsepos_gen, falseneg_gen = [], [], []
     turns = []
-    epoch_iterator = tqdm(dataloader, desc="Iteration")
+    if return_eval_output:
+        evaluation_output = []
+    epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader
     for batch in epoch_iterator:
         with torch.no_grad():
-            if 'goodbye_belief' in batch:
-                labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('belief-' + slot) in batch}
-                request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request_belief-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids
-                                 if ('domain_belief-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye_belief'].to(
-                    device) if args.predict_actions else None
-            else:
-                labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids
-                          if ('labels-' + slot) in batch}
-                request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids
-                                  if ('request-' + slot) in batch} if args.predict_actions else None
-                domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids
-                                 if ('active-' + domain) in batch} if args.predict_actions else None
-                goodbye_labels = batch['goodbye'].to(
-                    device) if args.predict_actions else None
-
-            input_ids = batch['input_ids'].to(device)
-            token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
-            attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
-
-            loss, p, p_req, p_dom, p_bye, _, _ = model(input_ids=input_ids,
-                                                       token_type_ids=token_type_ids,
-                                                       attention_mask=attention_mask,
-                                                       inform_labels=labels,
-                                                       request_labels=request_labels,
-                                                       domain_labels=domain_labels,
-                                                       goodbye_labels=goodbye_labels)
+            input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids,
+                                        model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device)
+
+            loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict)
 
         jg_acc = 0.0
+        num_inform_slots = 0.0
         req_acc = 0.0
         req_tp, req_fp, req_fn = 0.0, 0.0, 0.0
         dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0
         dom_acc = 0.0
-        for slot in model.informable_slot_ids:
+
+        if return_eval_output:
+            eval_output_batch = []
+            for dial_id, dial in enumerate(input_dict['input_ids']):
+                for turn_id, turn in enumerate(dial):
+                    if turn.sum() != 0:
+                        eval_output_batch.append({'dial_idx': dial_id,
+                                                  'utt_idx': turn_id,
+                                                  'state': dict(),
+                                                  'predictions': {'state': dict()}
+                                                  })
+
+        for slot in model.setsumbt.informable_slot_ids:
             p_ = p[slot]
-            labels = batch['labels-' + slot].to(device)
+            state_labels = batch['state_labels-' + slot].to(device)
+
+            if return_eval_output:
+                prediction = p_.argmax(-1)
+
+                for sample in eval_output_batch:
+                    dom, slt = slot.split('-', 1)
+                    lab = state_labels[sample['dial_idx']][sample['utt_idx']].item()
+                    lab = ontology[dom][slt]['possible_values'][lab] if lab != -1 else 'NOT_IN_ONTOLOGY'
+                    pred = prediction[sample['dial_idx']][sample['utt_idx']].item()
+                    pred = ontology[dom][slt]['possible_values'][pred]
+
+                    if dom not in sample['state']:
+                        sample['state'][dom] = dict()
+                        sample['predictions']['state'][dom] = dict()
+
+                    sample['state'][dom][slt] = lab if lab != 'none' else ''
+                    sample['predictions']['state'][dom][slt] = pred if pred != 'none' else ''
 
             if args.temp_scaling > 0.0:
                 p_ = torch.log(p_ + 1e-10) / args.temp_scaling
@@ -683,28 +608,21 @@ def evaluate(args, model, device, dataloader):
                 p_ = torch.log(p_ + 1e-10) / 1.0
                 p_ = torch.softmax(p_, -1)
 
-            # logits[slot].append(p_)
-
-            if args.accuracy_samples > 0:
-                dist = Categorical(probs=p_.reshape(-1, p_.size(-1)))
-                lab_sample = dist.sample((args.accuracy_samples,))
-                lab_sample = lab_sample.transpose(0, 1)
-                acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)]
-                acc = torch.tensor(acc).float()
-            elif args.accuracy_topn > 0:
-                labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
-                labs = labs[:, :args.accuracy_topn]
-                acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)]
-                acc = torch.tensor(acc).float()
-            else:
-                acc = (p_.argmax(-1) == labels).reshape(-1).float()
+            acc = (p_.argmax(-1) == state_labels).reshape(-1).float()
 
             jg_acc += acc
+            num_inform_slots += (state_labels != -1).float().reshape(-1)
+
+        if return_eval_output:
+            for sample in eval_output_batch:
+                sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']]
+                evaluation_output.append(deepcopy(sample))
+            eval_output_batch = []
 
         if model.config.predict_actions:
-            for slot in model.requestable_slot_ids:
+            for slot in model.setsumbt.requestable_slot_ids:
                 p_req_ = p_req[slot]
-                request_labels = batch['request-' + slot].to(device)
+                request_labels = batch['request_labels-' + slot].to(device)
 
                 acc = (p_req_.round().int() == request_labels).reshape(-1).float()
                 tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float()
@@ -715,85 +633,83 @@ def evaluate(args, model, device, dataloader):
                 req_fp += fp
                 req_fn += fn
 
-            for domain in model.domain_ids:
+            domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch]
+            for domain in domains:
                 p_dom_ = p_dom[domain]
-                domain_labels = batch['active-' + domain].to(device)
+                active_domain_labels = batch['active_domain_labels-' + domain].to(device)
 
-                acc = (p_dom_.round().int() == domain_labels).reshape(-1).float()
-                tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float()
-                fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float()
-                fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float()
+                acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float()
+                tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float()
+                fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float()
+                fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float()
                 dom_acc += acc
                 dom_tp += tp
                 dom_fp += fp
                 dom_fn += fn
 
-            goodbye_labels = batch['goodbye'].to(device)
-            bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum()
-            bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
-            bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum()
-            bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum()
+            general_act_labels = batch['general_act_labels'].to(device)
+            gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum()
+            gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum()
+            gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum()
         else:
-            req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0)
             req_tp, req_fp, req_fn = None, None, None
             dom_tp, dom_fp, dom_fn = None, None, None
-            bye_tp, bye_fp, bye_fn = torch.tensor(
-                0.0), torch.tensor(0.0), torch.tensor(0.0)
-
-        sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float()
-        jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float()
-        req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0)
-        req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
-        req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
-        req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
-        dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
-        dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
-        dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
-        dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0)
-        n_turns = (labels >= 0).reshape(-1).sum().float().item()
+            gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
+
+        jg_acc = jg_acc[num_inform_slots > 0]
+        num_inform_slots = num_inform_slots[num_inform_slots > 0]
+        sl_acc = sum(jg_acc / num_inform_slots).float()
+        jg_acc = sum((jg_acc == num_inform_slots).int()).float()
+        if req_tp is not None and model.setsumbt.requestable_slot_ids:
+            req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float()
+            req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float()
+            req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float()
+        else:
+            req_tp, req_fp, req_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
+        dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
+        dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
+        dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
+        n_turns = num_inform_slots.size(0)
 
         accuracy_jg.append(jg_acc.item())
         accuracy_sl.append(sl_acc.item())
-        accuracy_req.append(req_acc.item())
         truepos_req.append(req_tp.item())
         falsepos_req.append(req_fp.item())
         falseneg_req.append(req_fn.item())
-        accuracy_dom.append(dom_acc.item())
         truepos_dom.append(dom_tp.item())
         falsepos_dom.append(dom_fp.item())
         falseneg_dom.append(dom_fn.item())
-        accuracy_bye.append(bye_acc.item())
-        truepos_bye.append(bye_tp.item())
-        falsepos_bye.append(bye_fp.item())
-        falseneg_bye.append(bye_fn.item())
+        truepos_gen.append(gen_tp.item())
+        falsepos_gen.append(gen_fp.item())
+        falseneg_gen.append(gen_fn.item())
         turns.append(n_turns)
         tr_loss += loss.item()
 
-    # for slot in logits:
-    #     logits[slot] = torch.cat(logits[slot], 0)
-
     # Global accuracy reduction across batches
     turns = sum(turns)
     jg_acc = sum(accuracy_jg) / turns
     sl_acc = sum(accuracy_sl) / turns
     if model.config.predict_actions:
-        req_acc = sum(accuracy_req) / turns
         req_tp = sum(truepos_req)
         req_fp = sum(falsepos_req)
         req_fn = sum(falseneg_req)
-        req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn))
-        dom_acc = sum(accuracy_dom) / turns
+        req_f1 = req_tp + 0.5 * (req_fp + req_fn)
+        req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0
         dom_tp = sum(truepos_dom)
         dom_fp = sum(falsepos_dom)
         dom_fn = sum(falseneg_dom)
-        dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn))
-        bye_tp = sum(truepos_bye)
-        bye_fp = sum(falsepos_bye)
-        bye_fn = sum(falseneg_bye)
-        bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn))
-        bye_acc = sum(accuracy_bye) / turns
+        dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn)
+        dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0
+        gen_tp = sum(truepos_gen)
+        gen_fp = sum(falsepos_gen)
+        gen_fn = sum(falseneg_gen)
+        gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn)
+        gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0
     else:
-        req_acc, dom_acc, bye_acc = None, None, None
-        req_f1, dom_f1, bye_f1 = None, None, None
+        req_f1, dom_f1, gen_f1 = None, None, None
 
-    return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, tr_loss / len(dataloader)
+    if return_eval_output:
+        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output
+    if is_train:
+        return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats
+    return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader)
diff --git a/convlab/dst/setsumbt/multiwoz/Tracker.py b/convlab/dst/setsumbt/multiwoz/Tracker.py
deleted file mode 100644
index fed1a1a62334e9a721c1f5b7de442b17bfd14b76..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/Tracker.py
+++ /dev/null
@@ -1,455 +0,0 @@
-import os
-import json
-import copy
-import logging
-
-import torch
-import transformers
-from transformers import (BertModel, BertConfig, BertTokenizer,
-                          RobertaModel, RobertaConfig, RobertaTokenizer)
-from convlab.dst.setsumbt.modeling import (RobertaSetSUMBT,
-                                            BertSetSUMBT)
-
-from convlab.dst.dst import DST
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
-from convlab.dst.rule.multiwoz import normalize_value
-from convlab.util.custom_util import model_downloader
-
-USE_CUDA = torch.cuda.is_available()
-
-# Map from SetSUMBT slot names to Convlab slot names
-SLOT_MAP = {'arrive by': 'arriveBy',
-            'leave at': 'leaveAt',
-            'price range': 'pricerange',
-            'trainid': 'trainID',
-            'reference': 'Ref',
-            'taxi types': 'car type'}
-
-
-class SetSUMBTTracker(DST):
-
-    def __init__(self, model_path="", model_type="roberta",
-                 get_turn_pooled_representation=False,
-                 get_confidence_scores=False,
-                 threshold='auto',
-                 return_entropy=False,
-                 return_mutual_info=False,
-                 store_full_belief_state=False):
-        super(SetSUMBTTracker, self).__init__()
-
-        self.model_type = model_type
-        self.model_path = model_path
-        self.get_turn_pooled_representation = get_turn_pooled_representation
-        self.get_confidence_scores = get_confidence_scores
-        self.threshold = threshold
-        self.return_entropy = return_entropy
-        self.return_mutual_info = return_mutual_info
-        self.store_full_belief_state = store_full_belief_state
-        if self.store_full_belief_state:
-            self.full_belief_state = {}
-        self.info_dict = {}
-
-        # Download model if needed
-        if not os.path.exists(self.model_path):
-            # Get path /.../convlab/dst/setsumbt/multiwoz/models
-            download_path = os.path.dirname(os.path.abspath(__file__))
-            download_path = os.path.join(download_path, 'models')
-            if not os.path.exists(download_path):
-                os.mkdir(download_path)
-            model_downloader(download_path, self.model_path)
-            # Downloadable model path format http://.../setsumbt_model_name.zip
-            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
-            self.model_path = os.path.join(download_path, self.model_path)
-
-        # Select model type based on the encoder
-        if model_type == "roberta":
-            self.config = RobertaConfig.from_pretrained(self.model_path)
-            self.tokenizer = RobertaTokenizer
-            self.model = RobertaSetSUMBT
-        elif model_type == "bert":
-            self.config = BertConfig.from_pretrained(self.model_path)
-            self.tokenizer = BertTokenizer
-            self.model = BertSetSUMBT
-        else:
-            logging.debug("Name Error: Not Implemented")
-
-        self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
-
-        # Value dict for value normalisation
-        path = os.path.dirname(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        path = os.path.join(path, 'data/multiwoz/value_dict.json')
-        self.value_dict = json.load(open(path))
-
-        self.load_weights()
-
-    def load_weights(self):
-        # Load tokenizer and model checkpoints
-        logging.info('Loading SetSUMBT pretrained model.')
-        self.tokenizer = self.tokenizer.from_pretrained(
-            self.config.tokenizer_name)
-        logging.info(
-            f'Model tokenizer loaded from {self.config.tokenizer_name}.')
-        self.model = self.model.from_pretrained(
-            self.model_path, config=self.config)
-        logging.info(f'Model loaded from {self.model_path}.')
-
-        # Transfer model to compute device and setup eval environment
-        self.model = self.model.to(self.device)
-        self.model.eval()
-        logging.info(f'Model transferred to device: {self.device}')
-
-        logging.info('Loading model ontology')
-        f = open(os.path.join(self.model_path, 'ontology.json'), 'r')
-        self.ontology = json.load(f)
-        f.close()
-
-        db = torch.load(os.path.join(self.model_path, 'ontology.db'))
-        # Get slot and value embeddings
-        slots = {slot: db[slot] for slot in db}
-        values = {slot: db[slot][1] for slot in db}
-        del db
-
-        # Load model ontology
-        self.model.add_slot_candidates(slots)
-        for slot in values:
-            self.model.add_value_candidates(slot, values[slot], replace=True)
-
-        if self.get_confidence_scores:
-            logging.info('Model will output action and state confidence scores.')
-        if self.get_confidence_scores:
-            self.get_thresholds(self.threshold)
-            logging.info('Uncertain Querying set up and thresholds set up at:')
-            logging.info(self.thresholds)
-        if self.return_entropy:
-            logging.info('Model will output state distribution entropy.')
-        if self.return_mutual_info:
-            logging.info('Model will output state distribution mutual information.')
-        logging.info('Ontology loaded successfully.')
-
-        self.det_dic = {}
-        for domain, dic in REF_USR_DA.items():
-            for key, value in dic.items():
-                assert '-' not in key
-                self.det_dic[key.lower()] = key + '-' + domain
-                self.det_dic[value.lower()] = key + '-' + domain
-
-    def get_thresholds(self, threshold='auto'):
-        self.thresholds = {}
-        for slot, value_candidates in self.ontology.items():
-            domain, slot = slot.split('-', 1)
-            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
-            slot = slot.strip().split()[1] if 'book ' in slot else slot
-            slot = SLOT_MAP.get(slot, slot)
-
-            # Auto thresholds are set based on the number of value candidates per slot
-            if domain not in self.thresholds:
-                self.thresholds[domain] = {}
-            if threshold == 'auto':
-                thres = 1.0 / (float(len(value_candidates)) - 2.1)
-                self.thresholds[domain][slot] = max(0.05, thres)
-            else:
-                self.thresholds[domain][slot] = max(0.05, threshold)
-
-        return self.thresholds
-
-    def init_session(self):
-        self.state = default_state()
-        self.active_domains = {}
-        self.hidden_states = None
-        self.info_dict = {}
-
-    def update(self, user_act=''):
-        prev_state = self.state
-
-        # Convert dialogs into transformer input features (token_ids, masks, etc)
-        features = self.get_features(user_act)
-        # Model forward pass
-        pred_states, active_domains, user_acts, turn_pooled_representation, belief_state, entropy_, mutual_info_ = self.predict(
-            features)
-
-        if entropy_ is not None:
-            entropy = {}
-            for slot, e in entropy_.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in entropy:
-                    entropy[domain] = {}
-                if 'book' in slot:
-                    assert slot.startswith('book ')
-                    slot = slot.strip().split()[1]
-                slot = SLOT_MAP.get(slot, slot)
-                entropy[domain][slot] = e
-            del entropy_
-        else:
-            entropy = None
-
-        if mutual_info_ is not None:
-            mutual_info = {}
-            for slot, mi in mutual_info_.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in mutual_info:
-                    mutual_info[domain] = {}
-                if 'book' in slot:
-                    assert slot.startswith('book ')
-                    slot = slot.strip().split()[1]
-                slot = SLOT_MAP.get(slot, slot)
-                mutual_info[domain][slot] = mi[0, 0]
-        else:
-            mutual_info = None
-
-        if belief_state is not None:
-            bs_probs = {}
-            belief_state, request_dist, domain_dist, greeting_dist = belief_state
-            for slot, p in belief_state.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in bs_probs:
-                    bs_probs[domain] = {}
-                if 'book' in slot:
-                    assert slot.startswith('book ')
-                    slot = slot.strip().split()[1]
-                slot = SLOT_MAP.get(slot, slot)
-                if slot not in bs_probs[domain]:
-                    bs_probs[domain][slot] = {}
-                bs_probs[domain][slot]['inform'] = p
-
-            for slot, p in request_dist.items():
-                domain, slot = slot.split('-', 1)
-                if domain not in bs_probs:
-                    bs_probs[domain] = {}
-                slot = SLOT_MAP.get(slot, slot)
-                if slot not in bs_probs[domain]:
-                    bs_probs[domain][slot] = {}
-                bs_probs[domain][slot]['request'] = p
-
-            for domain, p in domain_dist.items():
-                if domain not in bs_probs:
-                    bs_probs[domain] = {}
-                bs_probs[domain]['none'] = {'inform': p}
-
-            if 'general' not in bs_probs:
-                bs_probs['general'] = {}
-            bs_probs['general']['none'] = greeting_dist
-
-        new_domains = [d for d, active in active_domains.items() if active]
-        new_domains = [
-            d for d in new_domains if not self.active_domains.get(d, False)]
-        self.active_domains = active_domains
-
-        for domain in new_domains:
-            user_acts.append(['Inform', domain.capitalize(), 'none', 'none'])
-
-        new_belief_state = copy.deepcopy(prev_state['belief_state'])
-        # user_acts = []
-        for state, value in pred_states.items():
-            domain, slot = state.split('-', 1)
-            value = '' if value == 'none' else value
-            value = 'dontcare' if value == 'do not care' else value
-            value = 'guesthouse' if value == 'guest house' else value
-            if slot not in ['name', 'book']:
-                if domain not in new_belief_state:
-                    if domain == 'bus':
-                        continue
-                    else:
-                        logging.debug(
-                            'Error: domain <{}> not in belief state'.format(domain))
-            slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
-            assert 'semi' in new_belief_state[domain]
-            assert 'book' in new_belief_state[domain]
-            if 'book' in slot:
-                assert slot.startswith('book ')
-                slot = slot.strip().split()[1]
-            slot = SLOT_MAP.get(slot, slot)
-
-            # Uncertainty clipping of state
-            if belief_state is not None:
-                if bs_probs[domain][slot].get('inform', 1.0) < self.thresholds[domain][slot]:
-                    value = ''
-
-            domain_dic = new_belief_state[domain]
-            value = normalize_value(self.value_dict, domain, slot, value)
-            if slot in domain_dic['semi']:
-                new_belief_state[domain]['semi'][slot] = value
-                if prev_state['belief_state'][domain]['semi'][slot] != value:
-                    user_acts.append(['Inform', domain.capitalize(
-                    ), REF_USR_DA[domain.capitalize()].get(slot, slot), value])
-            elif slot in domain_dic['book']:
-                new_belief_state[domain]['book'][slot] = value
-                if prev_state['belief_state'][domain]['book'][slot] != value:
-                    user_acts.append(['Inform', domain.capitalize(
-                    ), REF_USR_DA[domain.capitalize()].get(slot, slot), value])
-            elif slot.lower() in domain_dic['book']:
-                new_belief_state[domain]['book'][slot.lower()] = value
-                if prev_state['belief_state'][domain]['book'][slot.lower()] != value:
-                    user_acts.append(['Inform', domain.capitalize(
-                    ), REF_USR_DA[domain.capitalize()].get(slot.lower(), slot.lower()), value])
-            else:
-                logging.debug(
-                    'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(
-                        slot, value, domain, state)
-                )
-
-        new_state = copy.deepcopy(dict(prev_state))
-        new_state['belief_state'] = new_belief_state
-        new_state['active_domains'] = self.active_domains
-        if belief_state is not None:
-            new_state['belief_state_probs'] = bs_probs
-        if entropy is not None:
-            new_state['entropy'] = entropy
-        if mutual_info is not None:
-            new_state['mutual_information'] = mutual_info
-
-        new_state['user_action'] = user_acts
-
-        user_requests = [[a, d, s, v]
-                         for a, d, s, v in user_acts if a == 'Request']
-        for act, domain, slot, value in user_requests:
-            k = REF_SYS_DA[domain].get(slot, slot)
-            domain = domain.lower()
-            if domain not in new_state['request_state']:
-                new_state['request_state'][domain] = {}
-            if k not in new_state['request_state'][domain]:
-                new_state['request_state'][domain][k] = 0
-
-        if turn_pooled_representation is not None:
-            new_state['turn_pooled_representation'] = turn_pooled_representation
-
-        self.state = new_state
-        self.info_dict = copy.deepcopy(dict(new_state))
-
-        return self.state
-
-    # Model prediction function
-
-    def predict(self, features):
-        # Forward Pass
-        mutual_info = None
-        with torch.no_grad():
-            turn_pooled_representation = None
-            if self.get_turn_pooled_representation:
-                belief_state, request, domain, goodbye, self.hidden_states, turn_pooled_representation = self.model(input_ids=features['input_ids'],
-                                                                                                                    token_type_ids=features[
-                                                                                                                        'token_type_ids'],
-                                                                                                                    attention_mask=features[
-                                                                                                                        'attention_mask'],
-                                                                                                                    hidden_state=self.hidden_states,
-                                                                                                                    get_turn_pooled_representation=True)
-            elif self.return_mutual_info:
-                belief_state, request, domain, goodbye, self.hidden_states, mutual_info = self.model(input_ids=features['input_ids'],
-                                                                                                     token_type_ids=features[
-                                                                                                         'token_type_ids'],
-                                                                                                     attention_mask=features[
-                                                                                                         'attention_mask'],
-                                                                                                     hidden_state=self.hidden_states,
-                                                                                                     get_turn_pooled_representation=False,
-                                                                                                     calculate_inform_mutual_info=True)
-            else:
-                belief_state, request, domain, goodbye, self.hidden_states = self.model(input_ids=features['input_ids'],
-                                                                                        token_type_ids=features['token_type_ids'],
-                                                                                        attention_mask=features['attention_mask'],
-                                                                                        hidden_state=self.hidden_states,
-                                                                                        get_turn_pooled_representation=False)
-
-        # Convert belief state into dialog state
-        predictions = {slot: state[0, 0, :].argmax().item()
-                       for slot, state in belief_state.items()}
-        predictions = {slot: self.ontology[slot][idx]
-                       for slot, idx in predictions.items()}
-        predictions = {s: v for s, v in predictions.items() if v != 'none'}
-
-        if self.store_full_belief_state:
-            self.full_belief_state = belief_state
-
-        # Obtain model output probabilities
-        if self.get_confidence_scores:
-            entropy = None
-            if self.return_entropy:
-                entropy = {slot: state[0, 0, :]
-                           for slot, state in belief_state.items()}
-                entropy = {slot: self.relative_entropy(
-                    p).item() for slot, p in entropy.items()}
-
-            # Confidence score is the max probability across all not "none" values candidates.
-            belief_state = {slot: state[0, 0, 1:].max().item()
-                            for slot, state in belief_state.items()}
-            request_dist = {SLOT_MAP.get(
-                slot, slot): p[0, 0].item() for slot, p in request.items()}
-            domain_dist = {domain: p[0, 0].item()
-                           for domain, p in domain.items()}
-            greeting_dist = {'bye': goodbye[0, 0, 1].item(
-            ), 'thank': goodbye[0, 0, 2].item()}
-            belief_state = (belief_state, request_dist,
-                            domain_dist, greeting_dist)
-        else:
-            belief_state = None
-            entropy = None
-
-        # Construct request action prediction
-        request = [slot for slot, p in request.items() if p[0, 0].item() > 0.5]
-        request = [slot.split('-', 1) for slot in request]
-        request = [[domain, SLOT_MAP.get(slot, slot)]
-                   for domain, slot in request]
-        request = [['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(
-            slot, slot), '?'] for domain, slot in request]
-
-        # Construct active domain set
-        domain = {domain: p[0, 0].item() > 0.5 for domain, p in domain.items()}
-
-        # Construct general domain action
-        goodbye = goodbye[0, 0, :].argmax(-1).item()
-        goodbye = [[], ['bye'], ['thank']][goodbye]
-        goodbye = [[act, 'general', 'none', 'none'] for act in goodbye]
-
-        user_acts = request + goodbye
-
-        return predictions, domain, user_acts, turn_pooled_representation, belief_state, entropy, mutual_info
-
-    def relative_entropy(self, probs):
-        entropy = probs * torch.log(probs + 1e-8)
-        entropy = -entropy.sum()
-        # Maximum entropy of a K dimentional distribution is ln(K)
-        entropy /= torch.log(torch.tensor(probs.size(-1)).float())
-
-        return entropy
-
-    # Convert dialog turns into model features
-    def get_features(self, user_act):
-        # Extract system utterance from dialog history
-        context = self.state['history']
-        if context:
-            if context[-1][0] != 'sys':
-                system_act = ''
-            else:
-                system_act = context[-1][-1]
-        else:
-            system_act = ''
-
-        # Tokenize dialog
-        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, max_length=self.config.max_turn_len,
-                                              padding='max_length', truncation='longest_first')
-
-        input_ids = torch.tensor(features['input_ids']).reshape(
-            1, 1, -1).to(self.device) if 'input_ids' in features else None
-        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
-            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
-        attention_mask = torch.tensor(features['attention_mask']).reshape(
-            1, 1, -1).to(self.device) if 'attention_mask' in features else None
-        features = {'input_ids': input_ids,
-                    'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
-
-        return features
-
-
-# if __name__ == "__main__":
-#     tracker = SetSUMBTTracker(model_type='roberta', model_path='/gpfs/project/niekerk/results/nbt/convlab_setsumbt_acts')
-#                         # nlu_path='/gpfs/project/niekerk/data/bert_multiwoz_all_context.zip')
-#     tracker.init_session()
-#     state = tracker.update('hey. I need a cheap restaurant.')
-#     # tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
-#     # tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
-#     # state = tracker.update('If you have something Asian that would be great.')
-#     # tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
-#     # tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
-#     # state = tracker.update('Great. Where are they located?')
-#     # tracker.state['history'].append(['usr', 'Great. Where are they located?'])
-#     print(tracker.state)
diff --git a/convlab/dst/setsumbt/multiwoz/__init__.py b/convlab/dst/setsumbt/multiwoz/__init__.py
deleted file mode 100644
index a1f1fb894a0430545c62b78f9a6f4786c4e328a8..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from convlab.dst.setsumbt.multiwoz.dataset import multiwoz21, ontology
-from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
deleted file mode 100644
index 34df41d01e93ce27039e721e1ffb55bf9267e5a2..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair
+++ /dev/null
@@ -1,83 +0,0 @@
-it's	it is
-don't	do not
-doesn't	does not
-didn't	did not
-you'd	you would
-you're	you are
-you'll	you will
-i'm	i am
-they're	they are
-that's	that is
-what's	what is
-couldn't	could not
-i've	i have
-we've	we have
-can't	cannot
-i'd	i would
-i'd	i would
-aren't	are not
-isn't	is not
-wasn't	was not
-weren't	were not
-won't	will not
-there's	there is
-there're	there are
-. .	.
-restaurants	restaurant -s
-hotels	hotel -s
-laptops	laptop -s
-cheaper	cheap -er
-dinners	dinner -s
-lunches	lunch -s
-breakfasts	breakfast -s
-expensively	expensive -ly
-moderately	moderate -ly
-cheaply	cheap -ly
-prices	price -s
-places	place -s
-venues	venue -s
-ranges	range -s
-meals	meal -s
-locations	location -s
-areas	area -s
-policies	policy -s
-children	child -s
-kids	kid -s
-kidfriendly	kid friendly
-cards	card -s
-upmarket	expensive
-inpricey	cheap
-inches	inch -s
-uses	use -s
-dimensions	dimension -s
-driverange	drive range
-includes	include -s
-computers	computer -s
-machines	machine -s
-families	family -s
-ratings	rating -s
-constraints	constraint -s
-pricerange	price range
-batteryrating	battery rating
-requirements	requirement -s
-drives	drive -s
-specifications	specification -s
-weightrange	weight range
-harddrive	hard drive
-batterylife	battery life
-businesses	business -s
-hours	hour -s
-one	1
-two	2
-three	3
-four	4
-five	5
-six	6
-seven	7
-eight	8
-nine	9
-ten	10
-eleven	11
-twelve	12
-anywhere	any where
-good bye	goodbye
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
deleted file mode 100644
index 2c8e98f35429ce5194d82d07a1cf0de8fee54515..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py
+++ /dev/null
@@ -1,502 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""MultiWOZ 2.1/2.3 Dialogue Dataset"""
-
-import os
-import json
-import requests
-import zipfile
-import io
-from shutil import copy2 as copy
-
-import torch
-from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
-from tqdm import tqdm
-
-from convlab.dst.setsumbt.multiwoz.dataset.utils import (clean_text, ACTIVE_DOMAINS, get_domains, set_util_domains,
-                                                        fix_delexicalisation, extract_dialogue, PRICERANGE,
-                                                        BOOLEAN, DAYS, QUANTITIES, TIME, VALUE_MAP, map_values)
-
-
-# Set up global data_directory
-def set_datadir(dir):
-    global DATA_DIR
-    DATA_DIR = dir
-
-
-def set_active_domains(domains):
-    global ACTIVE_DOMAINS
-    ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
-    set_util_domains(ACTIVE_DOMAINS)
-
-
-# MultiWOZ2.1 download link
-URL = 'https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip'
-def set_url(url):
-    global URL
-    URL = url
-
-
-# Create Dialogue examples from the dataset
-def create_examples(max_utt_len, get_requestable_slots=False, force_processing=False):
-
-    # Load or download Raw Data
-    if not os.path.exists(DATA_DIR):
-        os.mkdir(DATA_DIR)
-    if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')):
-        # Download data archive and extract
-        archive = _download()
-        data = _extract(archive)
-
-        writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w')
-        json.dump(data, writer, indent = 2)
-        del archive, writer
-    else:
-        reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r')
-        data = json.load(reader)
-
-    if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')):
-        # Preprocess all dialogues
-        data_processed = _process(data['data'], data['system_acts'])
-        # Format data and split train, test and devlopment sets
-        train, dev, test = _split_data(data_processed, data['testListFile'],
-                                                            data['valListFile'], max_utt_len)
-
-        # Write data
-        writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w')
-        json.dump(train, writer, indent = 2)
-        writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w')
-        json.dump(test, writer, indent = 2)
-        writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w')
-        json.dump(dev, writer, indent = 2)
-        writer.flush()
-        writer.close()
-        del writer
-
-        # Extract slots and slot value candidates from the dataset
-        for set_type in ['train', 'dev', 'test']:
-            _get_ontology(set_type, get_requestable_slots)
-        
-        script_path = os.path.abspath(__file__).replace('/multiwoz21.py', '')
-        file_name = 'mwoz21_ont_request.json' if get_requestable_slots else 'mwoz21_ont.json'
-        copy(os.path.join(script_path, file_name), os.path.join(DATA_DIR, 'ontology_test.json'))
-        copy(os.path.join(script_path, 'mwoz21_slot_descriptions.json'), os.path.join(DATA_DIR, 'slot_descriptions.json'))
-
-
-# Extract slots and slot value candidates from the dataset
-def _get_ontology(set_type, get_requestable_slots=False):
-
-    datasets = ['train']
-    if set_type in ['test', 'dev']:
-        datasets.append('dev')
-        datasets.append('test')
-
-    # Load examples
-    data = []
-    for dataset in datasets:
-        reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r')
-        data += json.load(reader)
-
-    ontology = dict()
-    for dial in data:
-        for turn in dial['dialogue']:
-            for state in turn['dialogue_state']:
-                slot, value = state
-                value = map_values(value)
-                if slot not in ontology:
-                    ontology[slot] = [value]
-                else:
-                    ontology[slot].append(value)
-
-    requestable_slots = []
-    if get_requestable_slots:
-        for dial in data:
-            for turn in dial['dialogue']:
-                for act, dom, slot, val in turn['user_acts']:
-                    if act == 'request':
-                        requestable_slots.append(f'{dom}-{slot}')
-        requestable_slots = list(set(requestable_slots))
-
-    for slot in ontology:
-        if 'price' in slot:
-            ontology[slot] = PRICERANGE
-        if 'parking' in slot or 'internet' in slot:
-            ontology[slot] = BOOLEAN
-        if 'day' in slot:
-            ontology[slot] = DAYS
-        if 'people' in slot or 'duration' in slot or 'stay' in slot:
-            ontology[slot] = QUANTITIES
-        if 'time' in slot or 'leave' in slot or 'arrive' in slot:
-            ontology[slot] = TIME
-        if 'stars' in slot:
-            ontology[slot] += [str(i) for i in range(5)]
-
-    # Sort slot values and add none and dontcare values
-    for slot in ontology:
-        ontology[slot] = list(set(ontology[slot]))
-        ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']])
-    for slot in requestable_slots:
-        if slot in ontology:
-            ontology[slot].append('request')
-        else:
-            ontology[slot] = ['request']
-
-    writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w')
-    json.dump(ontology, writer, indent=2)
-    writer.close()
-
-
-# Convert dialogue examples to model input features and labels
-def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64):
-
-    features = dict()
-
-    # Load examples
-    reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r')
-    data = json.load(reader)
-
-    # Get encoder input for system, user utterance pairs
-    input_feats = []
-    for dial in data:
-        dial_feats = []
-        for turn in dial['dialogue']:
-            if len(turn['system_transcript']) == 0:
-                usr = turn['transcript']
-                dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens = True,
-                                                        max_length = max_seq_len, padding='max_length',
-                                                        truncation = 'longest_first'))
-            else:
-                usr = turn['transcript']
-                sys = turn['system_transcript']
-                dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens = True,
-                                                        max_length = max_seq_len, padding='max_length',
-                                                        truncation = 'longest_first'))
-            if len(dial_feats) >= max_turns:
-                break
-        input_feats.append(dial_feats)
-    del dial_feats
-
-    # Perform turn level padding
-    input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
-    if 'token_type_ids' in input_feats[0][0]:
-        token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
-    else:
-        token_type_ids = None
-    if 'attention_mask' in input_feats[0][0]:
-        attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
-    else:
-        attention_mask = None
-    del input_feats
-
-    # Create torch data tensors
-    features['input_ids'] = torch.tensor(input_ids)
-    features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
-    features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
-    del input_ids, token_type_ids, attention_mask
-
-    # Load ontology
-    reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
-    ontology = json.load(reader)
-    reader.close()
-
-    informable_slots = [slot for slot, values in ontology.items() if values != ['request']]
-    requestable_slots = [slot for slot, values in ontology.items() if 'request' in values]
-    for slot in requestable_slots:
-        ontology[slot].remove('request')
-    
-    domains = list(set(informable_slots + requestable_slots))
-    domains = list(set([slot.split('-', 1)[0] for slot in domains]))
-
-    # Create slot labels
-    for slot in informable_slots:
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial['dialogue']:
-                slots_active = [s for s, v in turn['dialogue_state']]
-                if slot in slots_active:
-                    value = [v for s, v in turn['dialogue_state'] if s == slot][0]
-                else:
-                    value = 'none'
-                if value in ontology[slot]:
-                    value = ontology[slot].index(value)
-                else:
-                    value = map_values(value)
-                    if value in ontology[slot]:
-                        value = ontology[slot].index(value)
-                    else:
-                        value = -1 # If value is not in ontology then we do not penalise the model
-                labs.append(value)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-
-        labels = torch.tensor(labels)
-        features['labels-' + slot] = labels
-
-    for slot in requestable_slots:
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial['dialogue']:
-                slots_active = [[d, s] for i, d, s, v in turn['user_acts']]
-                if slot.split('-', 1) in slots_active:
-                    act_ = [i for i, d, s, v in turn['user_acts'] if f"{d}-{s}" == slot][0]
-                    if act_ == 'request':
-                        labs.append(1)
-                    else:
-                        labs.append(0)
-                else:
-                    labs.append(0)
-                if len(labs) >= max_turns:
-                    break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-        
-        labels = torch.tensor(labels)
-        features['request-' + slot] = labels
-    
-    # Greeting act labels (0-no greeting, 1-goodbye, 2-thank you)
-    labels = []
-    for dial in data:
-        labs = []
-        for turn in dial['dialogue']:
-            greeting_active = [i for i, d, s, v in turn['user_acts'] if i in ['bye', 'thank']]
-            if greeting_active:
-                if 'bye' in greeting_active:
-                    labs.append(1)
-                else :
-                    labs.append(2)
-            else:
-                labs.append(0)
-            if len(labs) >= max_turns:
-                break
-        labs = labs + [-1] * (max_turns - len(labs))
-        labels.append(labs)
-    
-    labels = torch.tensor(labels)
-    features['goodbye'] = labels
-
-    for domain in domains:
-        labels = []
-        for dial in data:
-            labs = []
-            for turn in dial['dialogue']:
-                if domain == turn['domain']:
-                    labs.append(1)
-                else:
-                    labs.append(0)
-                if len(labs) >= max_turns:
-                        break
-            labs = labs + [-1] * (max_turns - len(labs))
-            labels.append(labs)
-        
-        labels = torch.tensor(labels)
-        features['active-' + domain] = labels
-
-    del labels
-
-    return features
-
-
-# MultiWOZ2.1 Dataset object
-class MultiWoz21(Dataset):
-
-    def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64):
-        self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len)
-
-    def __getitem__(self, index):
-        return {label: self.features[label][index] for label in self.features
-                if self.features[label] is not None}
-
-    def __len__(self):
-        return self.features['input_ids'].size(0)
-
-    def resample(self, size=None):
-        n_dialogues = self.__len__()
-        if not size:
-            size = n_dialogues
-
-        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
-        self.features = {label: self.features[label][dialogues] for label in self.features
-                        if self.features[label] is not None}
-        
-        return self
-
-    def to(self, device):
-        self.device = device
-        self.features = {label: self.features[label].to(device) for label in self.features
-                         if self.features[label] is not None}
-
-
-# MultiWOZ2.1 Dataset object
-class EnsembleMultiWoz21(Dataset):
-    def __init__(self, data):
-        self.features = data
-
-    def __getitem__(self, index):
-        return {label: self.features[label][index] for label in self.features
-                if self.features[label] is not None}
-
-    def __len__(self):
-        return self.features['input_ids'].size(0)
-
-    def resample(self, size=None):
-        n_dialogues = self.__len__()
-        if not size:
-            size = n_dialogues
-
-        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
-        self.features = {label: self.features[label][dialogues] for label in self.features
-                        if self.features[label] is not None}
-
-    def to(self, device):
-        self.device = device
-        self.features = {label: self.features[label].to(device) for label in self.features
-                         if self.features[label] is not None}
-
-
-# Module to create torch dataloaders
-def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None):
-    data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len)
-    data.to('cpu')
-
-    if resampled_size:
-        data.resample(resampled_size)
-
-    if set_type in ['test', 'dev']:
-        sampler = SequentialSampler(data)
-    else:
-        sampler = RandomSampler(data)
-    loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
-
-    return loader
-
-
-def _download(chunk_size=1048576):
-    """Download data archive.
-
-    Parameters:
-        chunk_size (int): Download chunk size. (default=1048576)
-    Returns:
-        archive: ZipFile archive object.
-    """
-    # Download the archive byte string
-    req = requests.get(URL, stream=True)
-    archive = b''
-    for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'):
-        if chunk:
-            archive += chunk
-
-    # Convert the bytestring into a zipfile object
-    archive = io.BytesIO(archive)
-    archive = zipfile.ZipFile(archive)
-
-    return archive
-
-
-def _extract(archive):
-    """Extract the json dictionaries from the archive.
-
-    Parameters:
-        archive: ZipFile archive object.
-    Returns:
-        data: Data dictionary.
-    """
-    files = [file for file in archive.filelist if ('.json' in file.filename or '.txt' in file.filename)
-            and 'MACOSX' not in file.filename]
-    objects = []
-    for file in tqdm(files, desc='File'):
-        data = archive.open(file).read()
-        # Get data objects from the files
-        try:
-            data = json.loads(data)
-        except json.decoder.JSONDecodeError:
-            data = data.decode().split('\n')
-        objects.append(data)
-
-    files = [file.filename.split('/')[-1].split('.')[0] for file in files]
-
-    data = {file: data for file, data in zip(files, objects)}
-    return data
-
-
-# Process files
-def _process(dialogue_data, acts_data):
-    print('Processing Dialogues')
-    out = {}
-    for dial_name in tqdm(dialogue_data):
-        dialogue = dialogue_data[dial_name]
-
-        prev_dom = ''
-        for turn_id, turn in enumerate(dialogue['log']):
-            dialogue['log'][turn_id]['text'] = clean_text(turn['text'])
-            if len(turn['metadata']) != 0:
-                crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom)
-                prev_dom = crnt_dom
-                dialogue['log'][turn_id - 1]['domain'] = crnt_dom
-
-            dialogue['log'][turn_id] = fix_delexicalisation(turn)
-
-        out[dial_name] = dialogue
-
-    return out
-
-
-# Split data (train, dev, test)
-def _split_data(dial_data, test, dev, max_utt_len):
-    train_dials, test_dials, dev_dials = [], [], []
-    print('Formatting and Splitting Data')
-    for name in tqdm(dial_data):
-        dialogue = dial_data[name]
-        domains = []
-
-        dial = extract_dialogue(dialogue, max_utt_len)
-        if dial:
-            dialogue = dict()
-            dialogue['dialogue_idx'] = name
-            dialogue['domains'] = []
-            dialogue['dialogue'] = []
-
-            for turn_id, turn in enumerate(dial):
-                turn_dialog = dict()
-                turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else ''
-                turn_dialog['turn_idx'] = turn_id
-                turn_dialog['dialogue_state'] = turn['ds']
-                turn_dialog['transcript'] = turn['usr']
-                # turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else []
-                turn_dialog['user_acts'] = turn['usr_a']
-                turn_dialog['domain'] = turn['domain']
-                dialogue['domains'].append(turn['domain'])
-                dialogue['dialogue'].append(turn_dialog)
-
-            dialogue['domains'] = [d for d in list(set(dialogue['domains'])) if d != '']
-            if True in [dom not in ACTIVE_DOMAINS for dom in dialogue['domains']]:
-                dialogue['domains'] = []
-            dialogue['domains'] = [dom for dom in dialogue['domains'] if dom in ACTIVE_DOMAINS]
-
-            if dialogue['domains']:
-                if name in test:
-                    test_dials.append(dialogue)
-                elif name in dev:
-                    dev_dials.append(dialogue)
-                else:
-                    train_dials.append(dialogue)
-
-    print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials)))
-
-    return train_dials, dev_dials, test_dials
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
deleted file mode 100644
index b703793dc747535132b748d8b69838b4c151d8d5..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json
+++ /dev/null
@@ -1,2990 +0,0 @@
-{
-  "hotel-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate"
-  ],
-  "hotel-type": [
-    "none",
-    "do not care",
-    "bed and breakfast",
-    "guest house",
-    "hotel"
-  ],
-  "hotel-parking": [
-    "none",
-    "do not care",
-    "no",
-    "yes"
-  ],
-  "hotel-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "hotel-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-book stay": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "train-destination": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "london liverpool street",
-    "centre",
-    "bishop stortford",
-    "liverpool",
-    "leicester",
-    "broxbourne",
-    "gourmet burger kitchen",
-    "copper kettle",
-    "bournemouth",
-    "stevenage",
-    "liverpool street",
-    "norwich",
-    "huntingdon marriott hotel",
-    "city centre north",
-    "taj tandoori",
-    "the copper kettle",
-    "peterborough",
-    "ely",
-    "lecester",
-    "london",
-    "willi",
-    "stansted airport",
-    "huntington marriott",
-    "cambridge",
-    "gonv",
-    "glastonbury",
-    "hol",
-    "north",
-    "birmingham new street",
-    "norway",
-    "petersborough",
-    "london kings cross",
-    "curry prince",
-    "bishops storford"
-  ],
-  "train-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "train-departure": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "brookshite",
-    "london liverpool street",
-    "cam",
-    "liverpool",
-    "bro",
-    "leicester",
-    "broxbourne",
-    "norwhich",
-    "saint johns",
-    "stevenage",
-    "stansted",
-    "london liverpool",
-    "cambrid",
-    "city hall",
-    "rosas bed and breakfast",
-    "alpha-milton",
-    "wandlebury country park",
-    "norwich",
-    "liecester",
-    "stratford",
-    "peterborough",
-    "duxford",
-    "ely",
-    "london",
-    "stansted airport",
-    "lon",
-    "cambridge",
-    "panahar",
-    "cineworld",
-    "leicaster",
-    "birmingham",
-    "cafe uno",
-    "camboats",
-    "huntingdon",
-    "birmingham new street",
-    "arbu",
-    "alpha milton",
-    "east london",
-    "london kings cross",
-    "hamilton lodge",
-    "aylesbray lodge guest",
-    "el shaddai"
-  ],
-  "train-day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "train-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-stars": [
-    "none",
-    "do not care",
-    "0",
-    "1",
-    "2",
-    "3",
-    "4",
-    "5"
-  ],
-  "hotel-internet": [
-    "none",
-    "do not care",
-    "no",
-    "yes"
-  ],
-  "hotel-name": [
-    "a and b guest house",
-    "city roomz",
-    "carolina bed and breakfast",
-    "limehouse",
-    "anatolia",
-    "hamilton lodge",
-    "the lensfield hotel",
-    "rosa's bed and breakfast",
-    "gall",
-    "aylesbray lodge",
-    "kirkwood",
-    "cambridge belfry",
-    "warkworth house",
-    "gonville",
-    "belfy hotel",
-    "nus",
-    "alexander",
-    "super 5",
-    "aylesbray lodge guest house",
-    "the gonvile hotel",
-    "allenbell",
-    "nothamilton lodge",
-    "ashley hotel",
-    "autumn house",
-    "hobsons house",
-    "hotel",
-    "ashely hotel",
-    "caridge belfrey",
-    "el shaddia guest house",
-    "avalon",
-    "cote",
-    "city centre north bed and breakfast",
-    "the cambridge belfry",
-    "home from home",
-    "wandlebury coutn",
-    "wankworth house",
-    "city stop rest",
-    "the worth house",
-    "cityroomz",
-    "huntingdon marriottt hotel",
-    "none",
-    "lensfield",
-    "rosas bed and breakfast",
-    "leverton house",
-    "gonville hotel",
-    "holiday inn cambridge",
-    "do not care",
-    "archway house",
-    "lan hon",
-    "levert",
-    "acorn guest house",
-    "cambridge",
-    "the ashley hotel",
-    "el shaddai",
-    "sleeperz",
-    "alpha milton guest house",
-    "doubletree by hilton cambridge",
-    "tandoori palace",
-    "express by",
-    "express by holiday inn cambridge",
-    "north bed and breakfast",
-    "holiday inn",
-    "arbury lodge guest house",
-    "alexander bed and breakfast",
-    "huntingdon marriott hotel",
-    "royal spice",
-    "sou",
-    "finches bed and breakfast",
-    "the alpha milton",
-    "bridge guest house",
-    "the acorn guest house",
-    "kirkwood house",
-    "eraina",
-    "la margherit",
-    "lensfield hotel",
-    "marriott hotel",
-    "nusha",
-    "city centre bed and breakfast",
-    "the allenbell",
-    "university arms hotel",
-    "clare",
-    "cherr",
-    "wartworth",
-    "acorn place",
-    "lovell lodge",
-    "whale"
-  ],
-  "train-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "restaurant-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate"
-  ],
-  "restaurant-food": [
-    "british food",
-    "steakhouse",
-    "turkish",
-    "sushi",
-    "north american",
-    "scottish",
-    "french",
-    "austrian",
-    "korean",
-    "eastern european",
-    "swedish",
-    "gastro pub",
-    "modern eclectic",
-    "afternoon tea",
-    "welsh",
-    "christmas",
-    "tuscan",
-    "gastropub",
-    "sri lankan",
-    "molecular gastronomy",
-    "traditional american",
-    "italian",
-    "pizza",
-    "thai",
-    "south african",
-    "creative",
-    "english",
-    "asian",
-    "lebanese",
-    "hungarian",
-    "halal",
-    "portugese",
-    "modern english",
-    "african",
-    "light bites",
-    "malaysian",
-    "venetian",
-    "traditional",
-    "chinese",
-    "vegetarian",
-    "persian",
-    "thai and chinese",
-    "scandinavian",
-    "catalan",
-    "polynesian",
-    "crossover",
-    "canapes",
-    "cantonese",
-    "north african",
-    "seafood",
-    "brazilian",
-    "south indian",
-    "australasian",
-    "belgian",
-    "barbeque",
-    "the americas",
-    "indonesian",
-    "singaporean",
-    "irish",
-    "middle eastern",
-    "dojo noodle bar",
-    "caribbean",
-    "vietnamese",
-    "modern european",
-    "russian",
-    "none",
-    "german",
-    "world",
-    "japanese",
-    "moroccan",
-    "modern global",
-    "do not care",
-    "indian",
-    "british",
-    "american",
-    "danish",
-    "panasian",
-    "swiss",
-    "basque",
-    "north indian",
-    "modern american",
-    "australian",
-    "european",
-    "corsica",
-    "greek",
-    "northern european",
-    "mediterranean",
-    "portuguese",
-    "romanian",
-    "jamaican",
-    "polish",
-    "international",
-    "unusual",
-    "latin american",
-    "asian oriental",
-    "mexican",
-    "bistro",
-    "cuban",
-    "fusion",
-    "new zealand",
-    "spanish",
-    "eritrean",
-    "afghan",
-    "kosher"
-  ],
-  "attraction-name": [
-    "downing college",
-    "fitzwilliam",
-    "clare college",
-    "ruskin gallery",
-    "sidney sussex college",
-    "great saint mary's church",
-    "cherry hinton water play park",
-    "wandlebury country park",
-    "cafe uno",
-    "place",
-    "broughton",
-    "cineworld cinema",
-    "jesus college",
-    "vue cinema",
-    "history of science museum",
-    "mumford theatre",
-    "whale of time",
-    "fitzbillies",
-    "christs church",
-    "churchill college",
-    "museum of classical archaeology",
-    "gonville and caius college",
-    "pizza",
-    "kirkwood",
-    "saint catharines college",
-    "kings college",
-    "parkside",
-    "by",
-    "st catharines college",
-    "saint john's college",
-    "cherry hinton water park",
-    "st christs college",
-    "christ's college",
-    "bangkok city",
-    "scudamores punti co",
-    "free",
-    "great saint marys church",
-    "milton country park",
-    "the fez club",
-    "soultree",
-    "autu",
-    "whipple museum of the history of science",
-    "aylesbray lodge guest house",
-    "broughton house gallery",
-    "peoples portraits exhibition",
-    "primavera",
-    "kettles yard",
-    "all saint's church",
-    "cinema cinema",
-    "regency gallery",
-    "corpus christi",
-    "corn cambridge exchange",
-    "da vinci pizzeria",
-    "school",
-    "hobsons house",
-    "cambride and country folk museum",
-    "north",
-    "da v",
-    "cambridge corn exchange",
-    "soul tree nightclub",
-    "cambridge arts theater",
-    "saint catharine's college",
-    "byard art",
-    "cambridge punter",
-    "cambridge university botanic gardens",
-    "castle galleries",
-    "museum of archaelogy and anthropogy",
-    "no specific location",
-    "cherry hinton hall",
-    "gallery at 12 a high street",
-    "parkside pools",
-    "queen's college",
-    "little saint mary's church",
-    "gallery",
-    "home from home",
-    "tenpin",
-    "the wandlebury",
-    "county folk museum",
-    "swimming pool",
-    "christs college",
-    "cafe jello museum",
-    "scott polar",
-    "christ college",
-    "cambridge museum of technology",
-    "abbey pool and astroturf pitch",
-    "king hedges learner pool",
-    "the cambridge arts theatre",
-    "the castle galleries",
-    "cambridge and country folk museum",
-    "kohinoor",
-    "scudamores punting co",
-    "sidney sussex",
-    "the man on the moon",
-    "little saint marys church",
-    "queens",
-    "the place",
-    "old school",
-    "churchill",
-    "churchills college",
-    "hughes hall",
-    "churchhill college",
-    "riverboat georgina",
-    "none",
-    "belf",
-    "cambridge temporary art",
-    "abc theatre",
-    "cambridge contemporary art museum",
-    "man on the moon",
-    "the junction",
-    "cherry hinton water play",
-    "adc theatre",
-    "gonville hotel",
-    "magdalene college",
-    "peoples portraits exhibition at girton college",
-    "boat",
-    "centre",
-    "sheep's green and lammas land park fen causeway",
-    "do not care",
-    "the mumford theatre",
-    "archway house",
-    "queens' college",
-    "williams art and antiques",
-    "funky fun house",
-    "cherry hinton village centre",
-    "camboats",
-    "cambridge",
-    "old schools",
-    "kettle's yard",
-    "whale of a time",
-    "the churchill college",
-    "cafe jello gallery",
-    "aut",
-    "salsa",
-    "city",
-    "clare hall",
-    "boating",
-    "pembroke college",
-    "kings hedges learner pool",
-    "caffe uno",
-    "lammas land park",
-    "museum",
-    "the fitzwilliam museum",
-    "the cherry hinton village centre",
-    "the cambridge corn exchange",
-    "fitzwilliam museum",
-    "museum of archaelogy and anthropology",
-    "fez club",
-    "the cambridge punter",
-    "saint johns college",
-    "emmanuel college",
-    "cambridge belf",
-    "scudamore",
-    "lynne strover gallery",
-    "king's college",
-    "whippple museum",
-    "trinity college",
-    "college in the north",
-    "sheep's green",
-    "kambar",
-    "museum of archaelogy",
-    "adc",
-    "garde",
-    "club salsa",
-    "people's portraits exhibition at girton college",
-    "botanic gardens",
-    "carol",
-    "college",
-    "gallery at twelve a high street",
-    "abbey pool and astroturf",
-    "cambridge book and print gallery",
-    "jesus green outdoor pool",
-    "scott polar museum",
-    "saint barnabas press gallery",
-    "cambridge artworks",
-    "older churches",
-    "cambridge contemporary art",
-    "cherry hinton hall and grounds",
-    "univ",
-    "jesus green",
-    "ballare",
-    "abbey pool",
-    "cambridge botanic gardens",
-    "nusha",
-    "worth house",
-    "thanh",
-    "university arms hotel",
-    "cambridge arts theatre",
-    "cafe jello",
-    "cambridge and county folk museum",
-    "the cambridge artworks",
-    "all saints church",
-    "holy trinity church",
-    "contemporary art museum",
-    "architectural churches",
-    "queens college",
-    "trinity street college"
-  ],
-  "restaurant-name": [
-    "none",
-    "do not care",
-    "hotel du vin and bistro",
-    "ask",
-    "gourmet formal kitchen",
-    "the meze bar",
-    "lan hong house",
-    "cow pizza",
-    "one seven",
-    "prezzo",
-    "maharajah tandoori restaurant",
-    "alex",
-    "shanghai",
-    "golden wok",
-    "restaurant",
-    "fitzbillies",
-    "nil",
-    "copper kettle",
-    "meghna",
-    "hk fusion",
-    "bangkok city",
-    "hobsons house",
-    "tang chinese",
-    "anatolia",
-    "ugly duckling",
-    "anatolia and efes restaurant",
-    "sitar tandoori",
-    "city stop",
-    "ashley",
-    "pizza express fen ditton",
-    "molecular gastronomy",
-    "autumn house",
-    "el shaddia guesthouse",
-    "the grafton hotel",
-    "limehouse",
-    "gardenia",
-    "not metioned",
-    "hakka",
-    "michaelhouse cafe",
-    "pipasha",
-    "meze bar",
-    "archway",
-    "molecular gastonomy",
-    "yipee noodle bar",
-    "the peking",
-    "curry prince",
-    "midsummer house restaurant",
-    "pizza hut cherry hinton",
-    "the lucky star",
-    "stazione restaurant and coffee bar",
-    "shanghi family restaurant",
-    "good luck",
-    "j restaurant",
-    "bedouin",
-    "cott",
-    "little seoul",
-    "south",
-    "thanh binh",
-    "el",
-    "efes restaurant",
-    "kohinoor",
-    "clowns",
-    "india",
-    "the slug and lettuce",
-    "shiraz",
-    "barbakan",
-    "zizzi cambridge",
-    "restaurant one seven",
-    "slug and lettuce",
-    "travellers rest",
-    "binh",
-    "worth house",
-    "broughton house gallery",
-    "chiquito",
-    "the river bar steakhouse and grill",
-    "ros",
-    "golden house",
-    "india west",
-    "cam",
-    "panahar",
-    "restaurant 22",
-    "adden",
-    "indian",
-    "hu",
-    "jinling noodle bar",
-    "darrys cookhouse and wine shop",
-    "hobson house",
-    "cambridge be",
-    "el shaddai",
-    "ac",
-    "nandos",
-    "cambridge lodge",
-    "the cow pizza kitchen and bar",
-    "charlie",
-    "rajmahal",
-    "kymmoy",
-    "cambri",
-    "backstreet bistro",
-    "galleria",
-    "restaurant 2 two",
-    "chiquito restaurant bar",
-    "royal standard",
-    "lucky star",
-    "curry king",
-    "grafton hotel restaurant",
-    "mahal of cambridge",
-    "the bedouin",
-    "nus",
-    "the kohinoor",
-    "pizza hut fenditton",
-    "camboats",
-    "the gardenia",
-    "de luca cucina and bar",
-    "nusha",
-    "european",
-    "taj tandoori",
-    "tandoori palace",
-    "golden curry",
-    "efes",
-    "loch fyne",
-    "the maharajah tandoor",
-    "lovel",
-    "restaurant 17",
-    "clowns cafe",
-    "cambridge punter",
-    "bloomsbury restaurant",
-    "la mimosa",
-    "the cambridge chop house",
-    "funky",
-    "cotto",
-    "oak bistro",
-    "restaurant two two",
-    "pipasha restaurant",
-    "river bar steakhouse and grill",
-    "royal spice",
-    "the copper kettle",
-    "graffiti",
-    "nandos city centre",
-    "saffron brasserie",
-    "cambridge chop house",
-    "sitar",
-    "kitchen and bar",
-    "the good luck chinese food takeaway",
-    "clu",
-    "la tasca",
-    "cafe uno",
-    "cote",
-    "the varsity restaurant",
-    "bri",
-    "eraina",
-    "bridge",
-    "fin",
-    "cambridge lodge restaurant",
-    "grafton",
-    "hotpot",
-    "sala thong",
-    "margherita",
-    "wise buddha",
-    "the missing sock",
-    "seasame restaurant and bar",
-    "the dojo noodle bar",
-    "restaurant alimentum",
-    "gastropub",
-    "saigon city",
-    "la margherita",
-    "pizza hut",
-    "curry garden",
-    "ashley hotel",
-    "eraina and michaelhouse cafe",
-    "the golden curry",
-    "curry queen",
-    "cow pizza kitchen and bar",
-    "the peking restaurant:",
-    "hamilton lodge",
-    "alimentum",
-    "yippee noodle bar",
-    "2 two and cote",
-    "shanghai family restaurant",
-    "grafton hotel",
-    "yes",
-    "ali baba",
-    "dif",
-    "fitzbillies restaurant",
-    "peking restaurant",
-    "lev",
-    "nirala",
-    "the alex",
-    "tandoori",
-    "city stop restaurant",
-    "rice house",
-    "cityr",
-    "yu garden",
-    "meze bar restaurant",
-    "the",
-    "don pasquale pizzeria",
-    "rice boat",
-    "the hotpot",
-    "old school",
-    "the oak bistro",
-    "sesame restaurant and bar",
-    "pizza express",
-    "the gandhi",
-    "pizza hut fen ditton",
-    "charlie chan",
-    "da vinci pizzeria",
-    "dojo noodle bar",
-    "gourmet burger kitchen",
-    "the golden house",
-    "india house",
-    "hobso",
-    "missing sock",
-    "pizza hut city centre",
-    "parkside pools",
-    "riverside brasserie",
-    "caffe uno",
-    "primavera",
-    "the nirala",
-    "wagamama",
-    "au",
-    "ian hong house",
-    "frankie and bennys",
-    "4 kings parade city centre",
-    "shiraz restaurant",
-    "scudamores punt",
-    "mahal",
-    "saint johns chop house",
-    "de luca cucina and bar riverside brasserie",
-    "cocum",
-    "la raza"
-  ],
-  "attraction-type": [
-    "none",
-    "do not care",
-    "architecture",
-    "boat",
-    "boating",
-    "camboats",
-    "church",
-    "churchills college",
-    "cinema",
-    "college",
-    "concert",
-    "concerthall",
-    "entertainment",
-    "gallery",
-    "gastropub",
-    "hiking",
-    "hotel",
-    "multiple sports",
-    "museum",
-    "museum kettles yard",
-    "night club",
-    "outdoor",
-    "park",
-    "pool",
-    "special",
-    "sports",
-    "swimming pool",
-    "theater",
-    "theatre",
-    "concert hall",
-    "local site",
-    "nightclub",
-    "hotspot"
-  ],
-  "taxi-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "taxi-destination": [
-    "none",
-    "do not care",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "attraction",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge county fair next to the city tourist museum",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge road church of christ",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village centre",
-    "cherry hinton water park",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "efes restaurant",
-    "el shaddia guesthouse",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "finders corner newmarket road",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "gastropub",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hk fusion",
-    "hobsons house",
-    "holy trinity church",
-    "home from home",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "ian hong",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "leicester train station",
-    "lensfield hotel",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "midsummer house restaurant",
-    "milton country park",
-    "mumford theatre",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "nil",
-    "nirala",
-    "norwich train station",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pipasha restaurant",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "shanghai family restaurant",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "sleeperz hotel",
-    "soul tree nightclub",
-    "st johns chop house",
-    "stansted airport train station",
-    "station road",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tall monument",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the anatolia",
-    "the cambridge corn exchange",
-    "the cambridge shop",
-    "the fez club",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the regent street city center",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "wankworth hotel",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "leverton house",
-    "the cambridge chop house",
-    "saint john's college",
-    "churchill college",
-    "the nirala",
-    "the cow pizza kitchen and bar",
-    "christ's college",
-    "el shaddai",
-    "saint catharine's college",
-    "camb",
-    "the golden curry",
-    "little saint mary's church",
-    "country folk museum",
-    "meze bar restaurant",
-    "the cambridge belfry",
-    "the fitzwilliam museum",
-    "the lensfield hotel",
-    "pizza express fen ditton",
-    "the cambridge punter",
-    "king's college",
-    "the cherry hinton village centre",
-    "shiraz restaurant",
-    "sheep's green and lammas land park fen causeway",
-    "caffe uno",
-    "the ghandi",
-    "the copper kettle",
-    "man on the moon concert hall",
-    "alpha-milton guest house",
-    "queen's college",
-    "restaurant one seven",
-    "restaurant two two",
-    "city centre north b and b",
-    "rosa's bed and breakfast",
-    "the good luck chinese food takeaway",
-    "not museum of archaeology and anthropologymentioned",
-    "tandori in cambridge",
-    "kettle's yard",
-    "megna",
-    "grou",
-    "gallery at twelve a high street",
-    "maharajah tandoori restaurant",
-    "pizza hut fen ditton",
-    "gandhi",
-    "tranh binh",
-    "kambur",
-    "people's portraits exhibition at girton college",
-    "hotel",
-    "restaurant",
-    "the galleria",
-    "queens' college",
-    "great saint mary's church",
-    "theathre",
-    "cambridge artworks",
-    "acorn house",
-    "shiraz",
-    "riverboat georginawd",
-    "mic",
-    "the gallery at twelve",
-    "the soul tree",
-    "finches"
-  ],
-  "taxi-departure": [
-    "none",
-    "do not care",
-    "172 chestertown road",
-    "4455 woodbridge road",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "alyesbray lodge hotel",
-    "ambridge",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "caffee uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge towninfo centre",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "centre of town at my hotel",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village center",
-    "cherry hinton village centre",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "citiroomz",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clair hall",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "curry queen",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "downing street",
-    "el shaddia guesthouse",
-    "ely",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "girton college",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hobsons house",
-    "holy trinity church",
-    "home",
-    "home from home",
-    "hotel",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "junction theatre",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kings lynn train station",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "lensfield hotel",
-    "leverton house",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "milton country park",
-    "mumford theatre",
-    "museum",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "new england",
-    "nirala",
-    "norwich train station",
-    "nstaot mentioned",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "sheeps green and lammas land park",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "soul tree nightclub",
-    "st johns college",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the cambridge corn exchange",
-    "the fez club",
-    "the gallery at 12",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "christ's college",
-    "city centre north b and b",
-    "the lensfield hotel",
-    "alpha-milton guest house",
-    "el shaddai",
-    "churchill college",
-    "the cambridge belfry",
-    "king's college",
-    "great saint mary's church",
-    "restaurant two two",
-    "queens' college",
-    "little saint mary's church",
-    "chinese city centre",
-    "kettle's yard",
-    "pizza hut",
-    "the golden curry",
-    "rosa's bed and breakfast",
-    "the cambridge punter",
-    "the byard art museum",
-    "saint catharine's college",
-    "meze bar restaurant",
-    "the good luck chinese food takeaway",
-    "restaurant one seven",
-    "pizza hut fen ditton",
-    "the nirala",
-    "the fitzwilliam museum",
-    "st. john's college",
-    "gallery at twelve a high street",
-    "sheep's green and lammas land park fen causeway",
-    "the cherry hinton village centre",
-    "pizza express fen ditton",
-    "corpus cristi",
-    "cas",
-    "acorn house",
-    "lens",
-    "the cambridge chop house",
-    "the copper kettle",
-    "the avalon",
-    "saint john's college",
-    "aylesbray lodge",
-    "the alexander bed and breakfast",
-    "cambridge belfy",
-    "people's portraits exhibition at girton college",
-    "gonville",
-    "caffe uno",
-    "the cow pizza kitchen and bar",
-    "lovell ldoge",
-    "cinema",
-    "shiraz restaurant",
-    "park",
-    "the allenbell"
-  ],
-  "restaurant-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "restaurant-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "restaurant-book time": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "taxi-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "restaurant-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west"
-  ],
-  "hotel-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west"
-  ],
-  "attraction-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west"
-  ]
-}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
deleted file mode 100644
index b0dd00fdf6dc2b824f1f50a44e776c63ce72f14b..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json
+++ /dev/null
@@ -1,3128 +0,0 @@
-{
-  "hotel-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate",
-    "request"
-  ],
-  "hotel-type": [
-    "none",
-    "do not care",
-    "bed and breakfast",
-    "guest house",
-    "hotel",
-    "request"
-  ],
-  "hotel-parking": [
-    "none",
-    "do not care",
-    "no",
-    "yes",
-    "request"
-  ],
-  "hotel-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "hotel-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-book stay": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "train-destination": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "london liverpool street",
-    "centre",
-    "bishop stortford",
-    "liverpool",
-    "leicester",
-    "broxbourne",
-    "gourmet burger kitchen",
-    "copper kettle",
-    "bournemouth",
-    "stevenage",
-    "liverpool street",
-    "norwich",
-    "huntingdon marriott hotel",
-    "city centre north",
-    "taj tandoori",
-    "the copper kettle",
-    "peterborough",
-    "ely",
-    "lecester",
-    "london",
-    "willi",
-    "stansted airport",
-    "huntington marriott",
-    "cambridge",
-    "gonv",
-    "glastonbury",
-    "hol",
-    "north",
-    "birmingham new street",
-    "norway",
-    "petersborough",
-    "london kings cross",
-    "curry prince",
-    "bishops storford"
-  ],
-  "train-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "train-departure": [
-    "none",
-    "do not care",
-    "bishops stortford",
-    "kings lynn",
-    "brookshite",
-    "london liverpool street",
-    "cam",
-    "liverpool",
-    "bro",
-    "leicester",
-    "broxbourne",
-    "norwhich",
-    "saint johns",
-    "stevenage",
-    "stansted",
-    "london liverpool",
-    "cambrid",
-    "city hall",
-    "rosas bed and breakfast",
-    "alpha-milton",
-    "wandlebury country park",
-    "norwich",
-    "liecester",
-    "stratford",
-    "peterborough",
-    "duxford",
-    "ely",
-    "london",
-    "stansted airport",
-    "lon",
-    "cambridge",
-    "panahar",
-    "cineworld",
-    "leicaster",
-    "birmingham",
-    "cafe uno",
-    "camboats",
-    "huntingdon",
-    "birmingham new street",
-    "arbu",
-    "alpha milton",
-    "east london",
-    "london kings cross",
-    "hamilton lodge",
-    "aylesbray lodge guest",
-    "el shaddai"
-  ],
-  "train-day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "train-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "hotel-stars": [
-    "none",
-    "do not care",
-    "0",
-    "1",
-    "2",
-    "3",
-    "4",
-    "5",
-    "request"
-  ],
-  "hotel-internet": [
-    "none",
-    "do not care",
-    "no",
-    "yes",
-    "request"
-  ],
-  "hotel-name": [
-    "none",
-    "do not care",
-    "a and b guest house",
-    "city roomz",
-    "carolina bed and breakfast",
-    "limehouse",
-    "anatolia",
-    "hamilton lodge",
-    "the lensfield hotel",
-    "rosa's bed and breakfast",
-    "gall",
-    "aylesbray lodge",
-    "kirkwood",
-    "cambridge belfry",
-    "warkworth house",
-    "gonville",
-    "belfy hotel",
-    "nus",
-    "alexander",
-    "super 5",
-    "aylesbray lodge guest house",
-    "the gonvile hotel",
-    "allenbell",
-    "nothamilton lodge",
-    "ashley hotel",
-    "autumn house",
-    "hobsons house",
-    "hotel",
-    "ashely hotel",
-    "caridge belfrey",
-    "el shaddia guest house",
-    "avalon",
-    "cote",
-    "city centre north bed and breakfast",
-    "the cambridge belfry",
-    "home from home",
-    "wandlebury coutn",
-    "wankworth house",
-    "city stop rest",
-    "the worth house",
-    "cityroomz",
-    "huntingdon marriottt hotel",
-    "lensfield",
-    "rosas bed and breakfast",
-    "leverton house",
-    "gonville hotel",
-    "holiday inn cambridge",
-    "archway house",
-    "lan hon",
-    "levert",
-    "acorn guest house",
-    "cambridge",
-    "the ashley hotel",
-    "el shaddai",
-    "sleeperz",
-    "alpha milton guest house",
-    "doubletree by hilton cambridge",
-    "tandoori palace",
-    "express by",
-    "express by holiday inn cambridge",
-    "north bed and breakfast",
-    "holiday inn",
-    "arbury lodge guest house",
-    "alexander bed and breakfast",
-    "huntingdon marriott hotel",
-    "royal spice",
-    "sou",
-    "finches bed and breakfast",
-    "the alpha milton",
-    "bridge guest house",
-    "the acorn guest house",
-    "kirkwood house",
-    "eraina",
-    "la margherit",
-    "lensfield hotel",
-    "marriott hotel",
-    "nusha",
-    "city centre bed and breakfast",
-    "the allenbell",
-    "university arms hotel",
-    "clare",
-    "cherr",
-    "wartworth",
-    "acorn place",
-    "lovell lodge",
-    "whale"
-  ],
-  "train-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "restaurant-price range": [
-    "none",
-    "do not care",
-    "cheap",
-    "expensive",
-    "moderate",
-    "request"
-  ],
-  "restaurant-food": [
-    "none",
-    "do not care",
-    "british food",
-    "steakhouse",
-    "turkish",
-    "sushi",
-    "north american",
-    "scottish",
-    "french",
-    "austrian",
-    "korean",
-    "eastern european",
-    "swedish",
-    "gastro pub",
-    "modern eclectic",
-    "afternoon tea",
-    "welsh",
-    "christmas",
-    "tuscan",
-    "gastropub",
-    "sri lankan",
-    "molecular gastronomy",
-    "traditional american",
-    "italian",
-    "pizza",
-    "thai",
-    "south african",
-    "creative",
-    "english",
-    "asian",
-    "lebanese",
-    "hungarian",
-    "halal",
-    "portugese",
-    "modern english",
-    "african",
-    "light bites",
-    "malaysian",
-    "venetian",
-    "traditional",
-    "chinese",
-    "vegetarian",
-    "persian",
-    "thai and chinese",
-    "scandinavian",
-    "catalan",
-    "polynesian",
-    "crossover",
-    "canapes",
-    "cantonese",
-    "north african",
-    "seafood",
-    "brazilian",
-    "south indian",
-    "australasian",
-    "belgian",
-    "barbeque",
-    "the americas",
-    "indonesian",
-    "singaporean",
-    "irish",
-    "middle eastern",
-    "dojo noodle bar",
-    "caribbean",
-    "vietnamese",
-    "modern european",
-    "russian",
-    "german",
-    "world",
-    "japanese",
-    "moroccan",
-    "modern global",
-    "indian",
-    "british",
-    "american",
-    "danish",
-    "panasian",
-    "swiss",
-    "basque",
-    "north indian",
-    "modern american",
-    "australian",
-    "european",
-    "corsica",
-    "greek",
-    "northern european",
-    "mediterranean",
-    "portuguese",
-    "romanian",
-    "jamaican",
-    "polish",
-    "international",
-    "unusual",
-    "latin american",
-    "asian oriental",
-    "mexican",
-    "bistro",
-    "cuban",
-    "fusion",
-    "new zealand",
-    "spanish",
-    "eritrean",
-    "afghan",
-    "kosher",
-    "request"
-  ],
-  "attraction-name": [
-    "none",
-    "do not care",
-    "downing college",
-    "fitzwilliam",
-    "clare college",
-    "ruskin gallery",
-    "sidney sussex college",
-    "great saint mary's church",
-    "cherry hinton water play park",
-    "wandlebury country park",
-    "cafe uno",
-    "place",
-    "broughton",
-    "cineworld cinema",
-    "jesus college",
-    "vue cinema",
-    "history of science museum",
-    "mumford theatre",
-    "whale of time",
-    "fitzbillies",
-    "christs church",
-    "churchill college",
-    "museum of classical archaeology",
-    "gonville and caius college",
-    "pizza",
-    "kirkwood",
-    "saint catharines college",
-    "kings college",
-    "parkside",
-    "by",
-    "st catharines college",
-    "saint john's college",
-    "cherry hinton water park",
-    "st christs college",
-    "christ's college",
-    "bangkok city",
-    "scudamores punti co",
-    "free",
-    "great saint marys church",
-    "milton country park",
-    "the fez club",
-    "soultree",
-    "autu",
-    "whipple museum of the history of science",
-    "aylesbray lodge guest house",
-    "broughton house gallery",
-    "peoples portraits exhibition",
-    "primavera",
-    "kettles yard",
-    "all saint's church",
-    "cinema cinema",
-    "regency gallery",
-    "corpus christi",
-    "corn cambridge exchange",
-    "da vinci pizzeria",
-    "school",
-    "hobsons house",
-    "cambride and country folk museum",
-    "north",
-    "da v",
-    "cambridge corn exchange",
-    "soul tree nightclub",
-    "cambridge arts theater",
-    "saint catharine's college",
-    "byard art",
-    "cambridge punter",
-    "cambridge university botanic gardens",
-    "castle galleries",
-    "museum of archaelogy and anthropogy",
-    "no specific location",
-    "cherry hinton hall",
-    "gallery at 12 a high street",
-    "parkside pools",
-    "queen's college",
-    "little saint mary's church",
-    "gallery",
-    "home from home",
-    "tenpin",
-    "the wandlebury",
-    "county folk museum",
-    "swimming pool",
-    "christs college",
-    "cafe jello museum",
-    "scott polar",
-    "christ college",
-    "cambridge museum of technology",
-    "abbey pool and astroturf pitch",
-    "king hedges learner pool",
-    "the cambridge arts theatre",
-    "the castle galleries",
-    "cambridge and country folk museum",
-    "kohinoor",
-    "scudamores punting co",
-    "sidney sussex",
-    "the man on the moon",
-    "little saint marys church",
-    "queens",
-    "the place",
-    "old school",
-    "churchill",
-    "churchills college",
-    "hughes hall",
-    "churchhill college",
-    "riverboat georgina",
-    "belf",
-    "cambridge temporary art",
-    "abc theatre",
-    "cambridge contemporary art museum",
-    "man on the moon",
-    "the junction",
-    "cherry hinton water play",
-    "adc theatre",
-    "gonville hotel",
-    "magdalene college",
-    "peoples portraits exhibition at girton college",
-    "boat",
-    "centre",
-    "sheep's green and lammas land park fen causeway",
-    "the mumford theatre",
-    "archway house",
-    "queens' college",
-    "williams art and antiques",
-    "funky fun house",
-    "cherry hinton village centre",
-    "camboats",
-    "cambridge",
-    "old schools",
-    "kettle's yard",
-    "whale of a time",
-    "the churchill college",
-    "cafe jello gallery",
-    "aut",
-    "salsa",
-    "city",
-    "clare hall",
-    "boating",
-    "pembroke college",
-    "kings hedges learner pool",
-    "caffe uno",
-    "lammas land park",
-    "museum",
-    "the fitzwilliam museum",
-    "the cherry hinton village centre",
-    "the cambridge corn exchange",
-    "fitzwilliam museum",
-    "museum of archaelogy and anthropology",
-    "fez club",
-    "the cambridge punter",
-    "saint johns college",
-    "emmanuel college",
-    "cambridge belf",
-    "scudamore",
-    "lynne strover gallery",
-    "king's college",
-    "whippple museum",
-    "trinity college",
-    "college in the north",
-    "sheep's green",
-    "kambar",
-    "museum of archaelogy",
-    "adc",
-    "garde",
-    "club salsa",
-    "people's portraits exhibition at girton college",
-    "botanic gardens",
-    "carol",
-    "college",
-    "gallery at twelve a high street",
-    "abbey pool and astroturf",
-    "cambridge book and print gallery",
-    "jesus green outdoor pool",
-    "scott polar museum",
-    "saint barnabas press gallery",
-    "cambridge artworks",
-    "older churches",
-    "cambridge contemporary art",
-    "cherry hinton hall and grounds",
-    "univ",
-    "jesus green",
-    "ballare",
-    "abbey pool",
-    "cambridge botanic gardens",
-    "nusha",
-    "worth house",
-    "thanh",
-    "university arms hotel",
-    "cambridge arts theatre",
-    "cafe jello",
-    "cambridge and county folk museum",
-    "the cambridge artworks",
-    "all saints church",
-    "holy trinity church",
-    "contemporary art museum",
-    "architectural churches",
-    "queens college",
-    "trinity street college"
-  ],
-  "restaurant-name": [
-    "none",
-    "do not care",
-    "hotel du vin and bistro",
-    "ask",
-    "gourmet formal kitchen",
-    "the meze bar",
-    "lan hong house",
-    "cow pizza",
-    "one seven",
-    "prezzo",
-    "maharajah tandoori restaurant",
-    "alex",
-    "shanghai",
-    "golden wok",
-    "restaurant",
-    "fitzbillies",
-    "nil",
-    "copper kettle",
-    "meghna",
-    "hk fusion",
-    "bangkok city",
-    "hobsons house",
-    "tang chinese",
-    "anatolia",
-    "ugly duckling",
-    "anatolia and efes restaurant",
-    "sitar tandoori",
-    "city stop",
-    "ashley",
-    "pizza express fen ditton",
-    "molecular gastronomy",
-    "autumn house",
-    "el shaddia guesthouse",
-    "the grafton hotel",
-    "limehouse",
-    "gardenia",
-    "not metioned",
-    "hakka",
-    "michaelhouse cafe",
-    "pipasha",
-    "meze bar",
-    "archway",
-    "molecular gastonomy",
-    "yipee noodle bar",
-    "the peking",
-    "curry prince",
-    "midsummer house restaurant",
-    "pizza hut cherry hinton",
-    "the lucky star",
-    "stazione restaurant and coffee bar",
-    "shanghi family restaurant",
-    "good luck",
-    "j restaurant",
-    "bedouin",
-    "cott",
-    "little seoul",
-    "south",
-    "thanh binh",
-    "el",
-    "efes restaurant",
-    "kohinoor",
-    "clowns",
-    "india",
-    "the slug and lettuce",
-    "shiraz",
-    "barbakan",
-    "zizzi cambridge",
-    "restaurant one seven",
-    "slug and lettuce",
-    "travellers rest",
-    "binh",
-    "worth house",
-    "broughton house gallery",
-    "chiquito",
-    "the river bar steakhouse and grill",
-    "ros",
-    "golden house",
-    "india west",
-    "cam",
-    "panahar",
-    "restaurant 22",
-    "adden",
-    "indian",
-    "hu",
-    "jinling noodle bar",
-    "darrys cookhouse and wine shop",
-    "hobson house",
-    "cambridge be",
-    "el shaddai",
-    "ac",
-    "nandos",
-    "cambridge lodge",
-    "the cow pizza kitchen and bar",
-    "charlie",
-    "rajmahal",
-    "kymmoy",
-    "cambri",
-    "backstreet bistro",
-    "galleria",
-    "restaurant 2 two",
-    "chiquito restaurant bar",
-    "royal standard",
-    "lucky star",
-    "curry king",
-    "grafton hotel restaurant",
-    "mahal of cambridge",
-    "the bedouin",
-    "nus",
-    "the kohinoor",
-    "pizza hut fenditton",
-    "camboats",
-    "the gardenia",
-    "de luca cucina and bar",
-    "nusha",
-    "european",
-    "taj tandoori",
-    "tandoori palace",
-    "golden curry",
-    "efes",
-    "loch fyne",
-    "the maharajah tandoor",
-    "lovel",
-    "restaurant 17",
-    "clowns cafe",
-    "cambridge punter",
-    "bloomsbury restaurant",
-    "la mimosa",
-    "the cambridge chop house",
-    "funky",
-    "cotto",
-    "oak bistro",
-    "restaurant two two",
-    "pipasha restaurant",
-    "river bar steakhouse and grill",
-    "royal spice",
-    "the copper kettle",
-    "graffiti",
-    "nandos city centre",
-    "saffron brasserie",
-    "cambridge chop house",
-    "sitar",
-    "kitchen and bar",
-    "the good luck chinese food takeaway",
-    "clu",
-    "la tasca",
-    "cafe uno",
-    "cote",
-    "the varsity restaurant",
-    "bri",
-    "eraina",
-    "bridge",
-    "fin",
-    "cambridge lodge restaurant",
-    "grafton",
-    "hotpot",
-    "sala thong",
-    "margherita",
-    "wise buddha",
-    "the missing sock",
-    "seasame restaurant and bar",
-    "the dojo noodle bar",
-    "restaurant alimentum",
-    "gastropub",
-    "saigon city",
-    "la margherita",
-    "pizza hut",
-    "curry garden",
-    "ashley hotel",
-    "eraina and michaelhouse cafe",
-    "the golden curry",
-    "curry queen",
-    "cow pizza kitchen and bar",
-    "the peking restaurant:",
-    "hamilton lodge",
-    "alimentum",
-    "yippee noodle bar",
-    "2 two and cote",
-    "shanghai family restaurant",
-    "grafton hotel",
-    "yes",
-    "ali baba",
-    "dif",
-    "fitzbillies restaurant",
-    "peking restaurant",
-    "lev",
-    "nirala",
-    "the alex",
-    "tandoori",
-    "city stop restaurant",
-    "rice house",
-    "cityr",
-    "yu garden",
-    "meze bar restaurant",
-    "the",
-    "don pasquale pizzeria",
-    "rice boat",
-    "the hotpot",
-    "old school",
-    "the oak bistro",
-    "sesame restaurant and bar",
-    "pizza express",
-    "the gandhi",
-    "pizza hut fen ditton",
-    "charlie chan",
-    "da vinci pizzeria",
-    "dojo noodle bar",
-    "gourmet burger kitchen",
-    "the golden house",
-    "india house",
-    "hobso",
-    "missing sock",
-    "pizza hut city centre",
-    "parkside pools",
-    "riverside brasserie",
-    "caffe uno",
-    "primavera",
-    "the nirala",
-    "wagamama",
-    "au",
-    "ian hong house",
-    "frankie and bennys",
-    "4 kings parade city centre",
-    "shiraz restaurant",
-    "scudamores punt",
-    "mahal",
-    "saint johns chop house",
-    "de luca cucina and bar riverside brasserie",
-    "cocum",
-    "la raza"
-  ],
-  "attraction-type": [
-    "none",
-    "do not care",
-    "architecture",
-    "boat",
-    "boating",
-    "camboats",
-    "church",
-    "churchills college",
-    "cinema",
-    "college",
-    "concert",
-    "concerthall",
-    "entertainment",
-    "gallery",
-    "gastropub",
-    "hiking",
-    "hotel",
-    "multiple sports",
-    "museum",
-    "museum kettles yard",
-    "night club",
-    "outdoor",
-    "park",
-    "pool",
-    "special",
-    "sports",
-    "swimming pool",
-    "theater",
-    "theatre",
-    "concert hall",
-    "local site",
-    "nightclub",
-    "hotspot",
-    "request"
-  ],
-  "taxi-leave at": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "taxi-destination": [
-    "none",
-    "do not care",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "attraction",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge county fair next to the city tourist museum",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge road church of christ",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village centre",
-    "cherry hinton water park",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "efes restaurant",
-    "el shaddia guesthouse",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "finders corner newmarket road",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "gastropub",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hk fusion",
-    "hobsons house",
-    "holy trinity church",
-    "home from home",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "ian hong",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "leicester train station",
-    "lensfield hotel",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "midsummer house restaurant",
-    "milton country park",
-    "mumford theatre",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "nil",
-    "nirala",
-    "norwich train station",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pipasha restaurant",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "shanghai family restaurant",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "sleeperz hotel",
-    "soul tree nightclub",
-    "st johns chop house",
-    "stansted airport train station",
-    "station road",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tall monument",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the anatolia",
-    "the cambridge corn exchange",
-    "the cambridge shop",
-    "the fez club",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the regent street city center",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "wankworth hotel",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "leverton house",
-    "the cambridge chop house",
-    "saint john's college",
-    "churchill college",
-    "the nirala",
-    "the cow pizza kitchen and bar",
-    "christ's college",
-    "el shaddai",
-    "saint catharine's college",
-    "camb",
-    "the golden curry",
-    "little saint mary's church",
-    "country folk museum",
-    "meze bar restaurant",
-    "the cambridge belfry",
-    "the fitzwilliam museum",
-    "the lensfield hotel",
-    "pizza express fen ditton",
-    "the cambridge punter",
-    "king's college",
-    "the cherry hinton village centre",
-    "shiraz restaurant",
-    "sheep's green and lammas land park fen causeway",
-    "caffe uno",
-    "the ghandi",
-    "the copper kettle",
-    "man on the moon concert hall",
-    "alpha-milton guest house",
-    "queen's college",
-    "restaurant one seven",
-    "restaurant two two",
-    "city centre north b and b",
-    "rosa's bed and breakfast",
-    "the good luck chinese food takeaway",
-    "not museum of archaeology and anthropologymentioned",
-    "tandori in cambridge",
-    "kettle's yard",
-    "megna",
-    "grou",
-    "gallery at twelve a high street",
-    "maharajah tandoori restaurant",
-    "pizza hut fen ditton",
-    "gandhi",
-    "tranh binh",
-    "kambur",
-    "people's portraits exhibition at girton college",
-    "hotel",
-    "restaurant",
-    "the galleria",
-    "queens' college",
-    "great saint mary's church",
-    "theathre",
-    "cambridge artworks",
-    "acorn house",
-    "shiraz",
-    "riverboat georginawd",
-    "mic",
-    "the gallery at twelve",
-    "the soul tree",
-    "finches"
-  ],
-  "taxi-departure": [
-    "none",
-    "do not care",
-    "172 chestertown road",
-    "4455 woodbridge road",
-    "a and b guest house",
-    "abbey pool and astroturf pitch",
-    "acorn guest house",
-    "adc theatre",
-    "addenbrookes hospital",
-    "alexander bed and breakfast",
-    "ali baba",
-    "all saints church",
-    "allenbell",
-    "alpha milton guest house",
-    "alyesbray lodge hotel",
-    "ambridge",
-    "anatolia",
-    "arbury lodge guesthouse",
-    "archway house",
-    "ashley hotel",
-    "ask",
-    "autumn house",
-    "avalon",
-    "aylesbray lodge guest house",
-    "backstreet bistro",
-    "ballare",
-    "bangkok city",
-    "bedouin",
-    "birmingham new street train station",
-    "bishops stortford train station",
-    "bloomsbury restaurant",
-    "bridge guest house",
-    "broughton house gallery",
-    "broxbourne train station",
-    "byard art",
-    "cafe jello gallery",
-    "cafe uno",
-    "caffee uno",
-    "camboats",
-    "cambridge",
-    "cambridge and county folk museum",
-    "cambridge arts theatre",
-    "cambridge artworks",
-    "cambridge belfry",
-    "cambridge book and print gallery",
-    "cambridge chop house",
-    "cambridge contemporary art",
-    "cambridge lodge restaurant",
-    "cambridge museum of technology",
-    "cambridge punter",
-    "cambridge towninfo centre",
-    "cambridge train station",
-    "cambridge university botanic gardens",
-    "carolina bed and breakfast",
-    "castle galleries",
-    "centre of town at my hotel",
-    "charlie chan",
-    "cherry hinton hall and grounds",
-    "cherry hinton village center",
-    "cherry hinton village centre",
-    "cherry hinton water play",
-    "chiquito restaurant bar",
-    "christ college",
-    "churchills college",
-    "cineworld cinema",
-    "citiroomz",
-    "city centre north bed and breakfast",
-    "city stop restaurant",
-    "cityroomz",
-    "clair hall",
-    "clare college",
-    "clare hall",
-    "clowns cafe",
-    "club salsa",
-    "cocum",
-    "copper kettle",
-    "corpus christi",
-    "cote",
-    "cotto",
-    "cow pizza kitchen and bar",
-    "curry garden",
-    "curry king",
-    "curry prince",
-    "curry queen",
-    "da vinci pizzeria",
-    "darrys cookhouse and wine shop",
-    "de luca cucina and bar",
-    "dojo noodle bar",
-    "don pasquale pizzeria",
-    "downing college",
-    "downing street",
-    "el shaddia guesthouse",
-    "ely",
-    "ely train station",
-    "emmanuel college",
-    "eraina",
-    "express by holiday inn cambridge",
-    "finches bed and breakfast",
-    "fitzbillies restaurant",
-    "fitzwilliam museum",
-    "frankie and bennys",
-    "funky fun house",
-    "galleria",
-    "gallery at 12 a high street",
-    "girton college",
-    "golden curry",
-    "golden house",
-    "golden wok",
-    "gonville and caius college",
-    "gonville hotel",
-    "good luck",
-    "gourmet burger kitchen",
-    "graffiti",
-    "grafton hotel restaurant",
-    "great saint marys church",
-    "hakka",
-    "hamilton lodge",
-    "hobsons house",
-    "holy trinity church",
-    "home",
-    "home from home",
-    "hotel",
-    "hotel du vin and bistro",
-    "hughes hall",
-    "huntingdon marriott hotel",
-    "india house",
-    "j restaurant",
-    "jesus college",
-    "jesus green outdoor pool",
-    "jinling noodle bar",
-    "junction theatre",
-    "kambar",
-    "kettles yard",
-    "kings college",
-    "kings hedges learner pool",
-    "kings lynn train station",
-    "kirkwood house",
-    "kohinoor",
-    "kymmoy",
-    "la margherita",
-    "la mimosa",
-    "la raza",
-    "la tasca",
-    "lan hong house",
-    "lensfield hotel",
-    "leverton house",
-    "limehouse",
-    "little saint marys church",
-    "little seoul",
-    "loch fyne",
-    "london kings cross train station",
-    "london liverpool street",
-    "london liverpool street train station",
-    "lovell lodge",
-    "lynne strover gallery",
-    "magdalene college",
-    "mahal of cambridge",
-    "maharajah tandoori restaurant",
-    "meghna",
-    "meze bar",
-    "michaelhouse cafe",
-    "milton country park",
-    "mumford theatre",
-    "museum",
-    "museum of archaelogy and anthropology",
-    "museum of classical archaeology",
-    "nandos",
-    "nandos city centre",
-    "new england",
-    "nirala",
-    "norwich train station",
-    "nstaot mentioned",
-    "nusha",
-    "old schools",
-    "panahar",
-    "parkside police station",
-    "parkside pools",
-    "peking restaurant",
-    "pembroke college",
-    "peoples portraits exhibition at girton college",
-    "peterborough train station",
-    "pizza express",
-    "pizza hut cherry hinton",
-    "pizza hut city centre",
-    "pizza hut fenditton",
-    "prezzo",
-    "primavera",
-    "queens college",
-    "rajmahal",
-    "regency gallery",
-    "restaurant 17",
-    "restaurant 2 two",
-    "restaurant alimentum",
-    "rice boat",
-    "rice house",
-    "riverboat georgina",
-    "riverside brasserie",
-    "rosas bed and breakfast",
-    "royal spice",
-    "royal standard",
-    "ruskin gallery",
-    "saffron brasserie",
-    "saigon city",
-    "saint barnabas press gallery",
-    "saint catharines college",
-    "saint johns chop house",
-    "saint johns college",
-    "sala thong",
-    "scott polar museum",
-    "scudamores punting co",
-    "sesame restaurant and bar",
-    "sheeps green and lammas land park",
-    "sheeps green and lammas land park fen causeway",
-    "shiraz",
-    "sidney sussex college",
-    "sitar tandoori",
-    "soul tree nightclub",
-    "st johns college",
-    "stazione restaurant and coffee bar",
-    "stevenage train station",
-    "taj tandoori",
-    "tandoori palace",
-    "tang chinese",
-    "tenpin",
-    "thanh binh",
-    "the cambridge corn exchange",
-    "the fez club",
-    "the gallery at 12",
-    "the gandhi",
-    "the gardenia",
-    "the hotpot",
-    "the junction",
-    "the lucky star",
-    "the man on the moon",
-    "the missing sock",
-    "the oak bistro",
-    "the place",
-    "the river bar steakhouse and grill",
-    "the slug and lettuce",
-    "the varsity restaurant",
-    "travellers rest",
-    "trinity college",
-    "ugly duckling",
-    "university arms hotel",
-    "vue cinema",
-    "wagamama",
-    "wandlebury country park",
-    "warkworth house",
-    "whale of a time",
-    "whipple museum of the history of science",
-    "williams art and antiques",
-    "worth house",
-    "yippee noodle bar",
-    "yu garden",
-    "zizzi cambridge",
-    "christ's college",
-    "city centre north b and b",
-    "the lensfield hotel",
-    "alpha-milton guest house",
-    "el shaddai",
-    "churchill college",
-    "the cambridge belfry",
-    "king's college",
-    "great saint mary's church",
-    "restaurant two two",
-    "queens' college",
-    "little saint mary's church",
-    "chinese city centre",
-    "kettle's yard",
-    "pizza hut",
-    "the golden curry",
-    "rosa's bed and breakfast",
-    "the cambridge punter",
-    "the byard art museum",
-    "saint catharine's college",
-    "meze bar restaurant",
-    "the good luck chinese food takeaway",
-    "restaurant one seven",
-    "pizza hut fen ditton",
-    "the nirala",
-    "the fitzwilliam museum",
-    "st. john's college",
-    "gallery at twelve a high street",
-    "sheep's green and lammas land park fen causeway",
-    "the cherry hinton village centre",
-    "pizza express fen ditton",
-    "corpus cristi",
-    "cas",
-    "acorn house",
-    "lens",
-    "the cambridge chop house",
-    "the copper kettle",
-    "the avalon",
-    "saint john's college",
-    "aylesbray lodge",
-    "the alexander bed and breakfast",
-    "cambridge belfy",
-    "people's portraits exhibition at girton college",
-    "gonville",
-    "caffe uno",
-    "the cow pizza kitchen and bar",
-    "lovell ldoge",
-    "cinema",
-    "shiraz restaurant",
-    "park",
-    "the allenbell"
-  ],
-  "restaurant-book day": [
-    "none",
-    "do not care",
-    "friday",
-    "monday",
-    "saterday",
-    "sunday",
-    "thursday",
-    "tuesday",
-    "wednesday"
-  ],
-  "restaurant-book people": [
-    "none",
-    "do not care",
-    "1",
-    "10 or more",
-    "2",
-    "3",
-    "4",
-    "5",
-    "6",
-    "7",
-    "8",
-    "9"
-  ],
-  "restaurant-book time": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55"
-  ],
-  "taxi-arrive by": [
-    "none",
-    "do not care",
-    "00:00",
-    "00:05",
-    "00:10",
-    "00:15",
-    "00:20",
-    "00:25",
-    "00:30",
-    "00:35",
-    "00:40",
-    "00:45",
-    "00:50",
-    "00:55",
-    "01:00",
-    "01:05",
-    "01:10",
-    "01:15",
-    "01:20",
-    "01:25",
-    "01:30",
-    "01:35",
-    "01:40",
-    "01:45",
-    "01:50",
-    "01:55",
-    "02:00",
-    "02:05",
-    "02:10",
-    "02:15",
-    "02:20",
-    "02:25",
-    "02:30",
-    "02:35",
-    "02:40",
-    "02:45",
-    "02:50",
-    "02:55",
-    "03:00",
-    "03:05",
-    "03:10",
-    "03:15",
-    "03:20",
-    "03:25",
-    "03:30",
-    "03:35",
-    "03:40",
-    "03:45",
-    "03:50",
-    "03:55",
-    "04:00",
-    "04:05",
-    "04:10",
-    "04:15",
-    "04:20",
-    "04:25",
-    "04:30",
-    "04:35",
-    "04:40",
-    "04:45",
-    "04:50",
-    "04:55",
-    "05:00",
-    "05:05",
-    "05:10",
-    "05:15",
-    "05:20",
-    "05:25",
-    "05:30",
-    "05:35",
-    "05:40",
-    "05:45",
-    "05:50",
-    "05:55",
-    "06:00",
-    "06:05",
-    "06:10",
-    "06:15",
-    "06:20",
-    "06:25",
-    "06:30",
-    "06:35",
-    "06:40",
-    "06:45",
-    "06:50",
-    "06:55",
-    "07:00",
-    "07:05",
-    "07:10",
-    "07:15",
-    "07:20",
-    "07:25",
-    "07:30",
-    "07:35",
-    "07:40",
-    "07:45",
-    "07:50",
-    "07:55",
-    "08:00",
-    "08:05",
-    "08:10",
-    "08:15",
-    "08:20",
-    "08:25",
-    "08:30",
-    "08:35",
-    "08:40",
-    "08:45",
-    "08:50",
-    "08:55",
-    "09:00",
-    "09:05",
-    "09:10",
-    "09:15",
-    "09:20",
-    "09:25",
-    "09:30",
-    "09:35",
-    "09:40",
-    "09:45",
-    "09:50",
-    "09:55",
-    "10:00",
-    "10:05",
-    "10:10",
-    "10:15",
-    "10:20",
-    "10:25",
-    "10:30",
-    "10:35",
-    "10:40",
-    "10:45",
-    "10:50",
-    "10:55",
-    "11:00",
-    "11:05",
-    "11:10",
-    "11:15",
-    "11:20",
-    "11:25",
-    "11:30",
-    "11:35",
-    "11:40",
-    "11:45",
-    "11:50",
-    "11:55",
-    "12:00",
-    "12:05",
-    "12:10",
-    "12:15",
-    "12:20",
-    "12:25",
-    "12:30",
-    "12:35",
-    "12:40",
-    "12:45",
-    "12:50",
-    "12:55",
-    "13:00",
-    "13:05",
-    "13:10",
-    "13:15",
-    "13:20",
-    "13:25",
-    "13:30",
-    "13:35",
-    "13:40",
-    "13:45",
-    "13:50",
-    "13:55",
-    "14:00",
-    "14:05",
-    "14:10",
-    "14:15",
-    "14:20",
-    "14:25",
-    "14:30",
-    "14:35",
-    "14:40",
-    "14:45",
-    "14:50",
-    "14:55",
-    "15:00",
-    "15:05",
-    "15:10",
-    "15:15",
-    "15:20",
-    "15:25",
-    "15:30",
-    "15:35",
-    "15:40",
-    "15:45",
-    "15:50",
-    "15:55",
-    "16:00",
-    "16:05",
-    "16:10",
-    "16:15",
-    "16:20",
-    "16:25",
-    "16:30",
-    "16:35",
-    "16:40",
-    "16:45",
-    "16:50",
-    "16:55",
-    "17:00",
-    "17:05",
-    "17:10",
-    "17:15",
-    "17:20",
-    "17:25",
-    "17:30",
-    "17:35",
-    "17:40",
-    "17:45",
-    "17:50",
-    "17:55",
-    "18:00",
-    "18:05",
-    "18:10",
-    "18:15",
-    "18:20",
-    "18:25",
-    "18:30",
-    "18:35",
-    "18:40",
-    "18:45",
-    "18:50",
-    "18:55",
-    "19:00",
-    "19:05",
-    "19:10",
-    "19:15",
-    "19:20",
-    "19:25",
-    "19:30",
-    "19:35",
-    "19:40",
-    "19:45",
-    "19:50",
-    "19:55",
-    "20:00",
-    "20:05",
-    "20:10",
-    "20:15",
-    "20:20",
-    "20:25",
-    "20:30",
-    "20:35",
-    "20:40",
-    "20:45",
-    "20:50",
-    "20:55",
-    "21:00",
-    "21:05",
-    "21:10",
-    "21:15",
-    "21:20",
-    "21:25",
-    "21:30",
-    "21:35",
-    "21:40",
-    "21:45",
-    "21:50",
-    "21:55",
-    "22:00",
-    "22:05",
-    "22:10",
-    "22:15",
-    "22:20",
-    "22:25",
-    "22:30",
-    "22:35",
-    "22:40",
-    "22:45",
-    "22:50",
-    "22:55",
-    "23:00",
-    "23:05",
-    "23:10",
-    "23:15",
-    "23:20",
-    "23:25",
-    "23:30",
-    "23:35",
-    "23:40",
-    "23:45",
-    "23:50",
-    "23:55",
-    "request"
-  ],
-  "restaurant-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west",
-    "request"
-  ],
-  "hotel-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west",
-    "request"
-  ],
-  "attraction-area": [
-    "none",
-    "do not care",
-    "centre",
-    "east",
-    "north",
-    "south",
-    "west",
-    "request"
-  ],
-  "hospital-department": [
-    "none",
-    "do not care",
-    "acute medical assessment unit",
-    "acute medicine for the elderly",
-    "antenatal",
-    "cambridge eye unit",
-    "cardiology",
-    "cardiology and coronary care unit",
-    "childrens oncology and haematology",
-    "childrens surgical and medicine",
-    "clinical decisions unit",
-    "clinical research facility",
-    "coronary care unit",
-    "diabetes and endocrinology",
-    "emergency department",
-    "gastroenterology",
-    "gynaecology",
-    "haematology",
-    "haematology and haematological oncology",
-    "haematology day unit",
-    "hepatobillary and gastrointestinal surgery regional referral centre",
-    "hepatology",
-    "infectious diseases",
-    "infusion services",
-    "inpatient occupational therapy",
-    "intermediate dependancy area",
-    "john farman intensive care unit",
-    "medical decisions unit",
-    "medicine for the elderly",
-    "neonatal unit",
-    "neurology",
-    "neurology neurosurgery",
-    "neurosciences",
-    "neurosciences critical care unit",
-    "oncology",
-    "oral and maxillofacial surgery and ent",
-    "paediatric clinic",
-    "paediatric day unit",
-    "paediatric intensive care unit",
-    "plastic and vascular surgery plastics",
-    "psychiatry",
-    "respiratory medicine",
-    "surgery",
-    "teenage cancer trust unit",
-    "transitional care",
-    "transplant high dependency unit",
-    "trauma and orthopaedics",
-    "trauma high dependency unit",
-    "urology"
-  ],
-  "police-postcode": [
-    "request"
-  ],
-  "restaurant-postcode": [
-    "request"
-  ],
-  "train-duration": [
-    "request"
-  ],
-  "train-trainid": [
-    "request"
-  ],
-  "hospital-address": [
-    "request"
-  ],
-  "restaurant-phone": [
-    "request"
-  ],
-  "hotel-phone": [
-    "request"
-  ],
-  "restaurant-address": [
-    "request"
-  ],
-  "hotel-postcode": [
-    "request"
-  ],
-  "attraction-phone": [
-    "request"
-  ],
-  "attraction-entrance fee": [
-    "request"
-  ],
-  "hotel-reference": [
-    "request"
-  ],
-  "taxi-taxi types": [
-    "request"
-  ],
-  "attraction-address": [
-    "request"
-  ],
-  "hospital-phone": [
-    "request"
-  ],
-  "attraction-postcode": [
-    "request"
-  ],
-  "police-address": [
-    "request"
-  ],
-  "taxi-taxi phone": [
-    "request"
-  ],
-  "train-price": [
-    "request"
-  ],
-  "hospital-postcode": [
-    "request"
-  ],
-  "police-phone": [
-    "request"
-  ],
-  "hotel-address": [
-    "request"
-  ],
-  "restaurant-reference": [
-    "request"
-  ],
-  "train-reference": [
-    "request"
-  ]
-}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
deleted file mode 100644
index 87e315363ad3dfd8cadd5bd10cfd7e5047450160..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json
+++ /dev/null
@@ -1,57 +0,0 @@
-{
-  "hotel-price range": "preferred cost or price of the hotel",
-  "hotel-type": "what is the type of the hotel",
-  "hotel-parking": "does the hotel have parking",
-  "hotel-book stay": "number of nights for the hotel reservation",
-  "hotel-book day": "starting day of the hotel booking",
-  "hotel-book people": "number of people for the hotel booking",
-  "hotel-area": "area or place of the hotel",
-  "hotel-stars": "star rating of the hotel",
-  "hotel-internet": "does the hotel have internet or wifi",
-  "hotel-name": "name of the hotel",
-  "hotel-phone": "phone number of the hotel",
-  "hotel-postcode": "postcode of the hotel",
-  "hotel-reference": "booking reference of the hotel booking",
-  "hotel-address": "street address of the hotel",
-  "train-destination": "train station you want to travel to",
-  "train-day": "day of the train booking",
-  "train-departure": "train station you want to leave from",
-  "train-arrive by": "arrival time of the train",
-  "train-book people": "number of people for the train booking",
-  "train-leave at": "departure time for the train",
-  "train-duration": "duration of the train journey",
-  "train-trainid": "train identifier or number",
-  "train-price": "how much does the train trip cost",
-  "train-reference": "booking reference of the train booking",
-  "attraction-type": "type of attraction or point of interest",
-  "attraction-area": "area or place of the attraction",
-  "attraction-name": "name of the attraction",
-  "attraction-phone": "phone number of the attraction",
-  "attraction-entrance fee": "entrace fee at the attraction",
-  "attraction-address": "street address of the attraction",
-  "attraction-postcode": "postcode of the attraction",
-  "restaurant-book people": "number of people for the restaurant booking",
-  "restaurant-book day": "weekday for the restaurant booking",
-  "restaurant-book time": "time of the restaurant booking",
-  "restaurant-food": "type of food served at the restaurant",
-  "restaurant-price range": "preferred cost or price of the restaurant",
-  "restaurant-name": "name of the restaurant",
-  "restaurant-area": "area or place of the restaurant",
-  "restaurant-postcode": "postcode of the restaurant",
-  "restaurant-phone": "phone number of the restaurant",
-  "restaurant-address": "street address of the restaurant",
-  "restaurant-reference": "booking reference of the hotel booking",
-  "taxi-leave at": "what time you want the taxi to leave by",
-  "taxi-destination": "where you want the taxi to drop you off",
-  "taxi-departure": "where you want the taxi to pick you up",
-  "taxi-arrive by": "what time you to arrive at your destination",
-  "taxi-taxi types": "vehicle type of the taxi",
-  "taxi-taxi phone": "phone number of the taxi",
-  "hospital-department": "name of hospital department",
-  "hospital-address": "street address of the hospital",
-  "hospital-phone": "phone number of the hospital",
-  "hospital-postcode": "postcode of the hospital",
-  "police-postcode": "postcode of the police station",
-  "police-address": "street address of the police station",
-  "police-phone": "phone number of the police station"
-}
\ No newline at end of file
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py
deleted file mode 100644
index c6b9c3365764eb8f1c1cab5d68674dd85b39ce2a..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py
+++ /dev/null
@@ -1,168 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Authors: Carel van Niekerk (niekerk@hhu.de)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Create Ontology Embeddings"""
-
-import json
-import os
-import random
-
-import torch
-import numpy as np
-
-
-# Slot mapping table for description extractions
-# SLOT_NAME_MAPPINGS = {
-#     'arrive at': 'arriveAt',
-#     'arrive by': 'arriveBy',
-#     'leave at': 'leaveAt',
-#     'leave by': 'leaveBy',
-#     'arriveby': 'arriveBy',
-#     'arriveat': 'arriveAt',
-#     'leaveat': 'leaveAt',
-#     'leaveby': 'leaveBy',
-#     'price range': 'pricerange'
-# }
-
-# Set up global data directory
-def set_datadir(dir):
-    global DATA_DIR
-    DATA_DIR = dir
-
-
-# Set seeds
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-# Get embeddings for slots and candidates
-def get_slot_candidate_embeddings(set_type, args, tokenizer, embedding_model, save_to_file=True):
-    # Get set alots and candidates
-    reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
-    ontology = json.load(reader)
-    reader.close()
-
-    reader = open(os.path.join(DATA_DIR, 'slot_descriptions.json'), 'r')
-    slot_descriptions = json.load(reader)
-    reader.close()
-
-    embedding_model.eval()
-
-    slots = dict()
-    for slot in ontology:
-        if args.use_descriptions:
-            # d, s = slot.split('-', 1)
-            # s = SLOT_NAME_MAPPINGS[s] if s in SLOT_NAME_MAPPINGS else s
-            # s = d + '-' + s
-            # if slot in slot_descriptions:
-            desc = slot_descriptions[slot]
-            # elif slot.lower() in slot_descriptions:
-            #     desc = slot_descriptions[s.lower()]
-            # else:
-            #     desc = slot.replace('-', ' ')
-        else:
-            desc = slot
-
-        # Tokenize slot and get embeddings
-        feats = tokenizer.encode_plus(desc, add_special_tokens = True,
-                                            max_length = args.max_slot_len, padding='max_length',
-                                            truncation = 'longest_first')
-
-        with torch.no_grad():
-            input_ids = torch.tensor([feats['input_ids']]).to(embedding_model.device) # [1, max_slot_len]
-            if 'token_type_ids' in feats:
-                token_type_ids = torch.tensor([feats['token_type_ids']]).to(embedding_model.device) # [1, max_slot_len]
-                if 'attention_mask' in feats:
-                    attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) # [1, max_slot_len]
-                    feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
-                                            attention_mask=attention_mask)
-                    attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                    feats = feats * attention_mask # [1, max_slot_len, hidden_dim]
-                else:
-                    feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids)
-            else:
-                if 'attention_mask' in feats:
-                    attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device)
-                    feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
-                    attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                    feats = feats * attention_mask # [1, max_slot_len, hidden_dim]
-                else:
-                    feats, pooled_feats = embedding_model(input_ids=input_ids) # [1, max_slot_len, hidden_dim]
-        
-        if args.set_similarity:
-            slot_emb = feats[0, :, :].detach().cpu() # [seq_len, hidden_dim]
-        else:
-            if args.candidate_pooling == 'cls' and pooled_feats is not None:
-                slot_emb = pooled_feats[0, :].detach().cpu() # [hidden_dim]
-            elif args.candidate_pooling == 'mean':
-                feats = feats.sum(1)
-                feats = torch.nn.functional.layer_norm(feats, feats.size())
-                slot_emb = feats[0, :].detach().cpu() # [hidden_dim]
-
-        # Tokenize value candidates and get embeddings
-        values = ontology[slot]
-        is_requestable = False
-        if 'request' in values:
-            is_requestable = True
-            values.remove('request')
-        if values:
-            feats = [tokenizer.encode_plus(val, add_special_tokens = True,
-                                                max_length = args.max_candidate_len, padding='max_length',
-                                                truncation = 'longest_first')
-                    for val in values]
-            with torch.no_grad():
-                input_ids = torch.tensor([f['input_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
-                if 'token_type_ids' in feats[0]:
-                    token_type_ids = torch.tensor([f['token_type_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
-                    if 'attention_mask' in feats[0]:
-                        attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
-                        feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
-                                                attention_mask=attention_mask)
-                        attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                        feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
-                    else:
-                        feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) # [num_candidates, max_candidate_len, hidden_dim]
-                else:
-                    if 'attention_mask' in feats[0]:
-                        attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device)
-                        feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
-                        attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1)))
-                        feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
-                    else:
-                        feats, pooled_feats = embedding_model(input_ids=input_ids) # [num_candidates, max_candidate_len, hidden_dim]
-            
-            if args.set_similarity:
-                feats = feats.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim]
-            else:
-                if args.candidate_pooling == 'cls' and pooled_feats is not None:
-                    feats = pooled_feats.detach().cpu()
-                elif args.candidate_pooling == "mean":
-                    feats = feats.sum(1)
-                    feats = torch.nn.functional.layer_norm(feats, feats.size())
-                    feats = feats.detach().cpu()
-        else:
-            feats = None
-        slots[slot] = (slot_emb, feats, is_requestable)
-
-    # Dump tensors for use in training
-    if save_to_file:
-        writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
-        torch.save(slots, writer)
-    
-    return slots
diff --git a/convlab/dst/setsumbt/multiwoz/dataset/utils.py b/convlab/dst/setsumbt/multiwoz/dataset/utils.py
deleted file mode 100644
index 485dee643148ad1d37069064ebc4e2e4553b3dac..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/multiwoz/dataset/utils.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
-# Code adapted from the TRADE preprocessing code (https://github.com/jasonwu0731/trade-dst)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""MultiWOZ2.1/3 data processing utilities"""
-
-import re
-import os
-
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
-from convlab.dst.rule.multiwoz import normalize_value
-
-# ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train']
-ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train', 'hospital', 'police']
-def set_util_domains(domains):
-    global ACTIVE_DOMAINS
-    ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
-
-MAPPING_PATH = os.path.abspath(__file__).replace('utils.py', 'mapping.pair')
-# Read replacement pairs from the mapping.pair file
-REPLACEMENTS = []
-for line in open(MAPPING_PATH).readlines():
-    tok_from, tok_to = line.replace('\n', '').split('\t')
-    REPLACEMENTS.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
-
-# Extract belief state from mturk annotations
-def build_dialoguestate(metadata, get_domains=False):
-    domains_list = [dom for dom in ACTIVE_DOMAINS if dom in metadata]
-    dialogue_state, domains = [], []
-    for domain in domains_list:
-        active = False
-        # Extract booking information
-        booking = []
-        for slot in sorted(metadata[domain]['book'].keys()):
-            if slot != 'booked':
-                if metadata[domain]['book'][slot] == 'not mentioned':
-                    continue
-                if metadata[domain]['book'][slot] != '':
-                    val = ['%s-book %s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['book'][slot])]
-                    dialogue_state.append(val)
-                    active = True
-
-        for slot in metadata[domain]['semi']:
-            if metadata[domain]['semi'][slot] == 'not mentioned':
-                continue
-            elif metadata[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", 'don not care',
-                                                    'do not care', 'does not care']:
-                dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), 'do not care'])
-                active = True
-            elif metadata[domain]['semi'][slot]:
-                dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['semi'][slot])])
-                active = True
-
-        if active:
-            domains.append(domain)
-
-    if get_domains:
-        return domains
-    return clean_dialoguestate(dialogue_state)
-
-
-PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive']
-BOOLEAN = ['do not care', 'yes', 'no']
-DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday',
-        'friday', 'saterday', 'sunday']
-QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
-TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
-TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l]
-
-VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
-            'cityroomz': 'city roomz', '  ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
-            'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
-            'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
-            'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
-            'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
-
-def map_values(value):
-    for old, new in VALUE_MAP.items():
-        value = value.replace(old, new)
-    return value
-
-def clean_dialoguestate(states, is_acts=False):
-    # path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
-    # path = os.path.join(path, 'data/multiwoz/value_dict.json')
-    # value_dict = json.load(open(path))
-    clean_state = []
-    for slot, value in states:
-        if 'pricerange' in slot:
-            d, s = slot.split('-', 1)
-            s = 'price range'
-            slot = f'{d}-{s}'
-            if value in PRICERANGE:
-                clean_state.append([slot, value])
-            elif True in [v in value for v in PRICERANGE]:
-                value = [v for v in PRICERANGE if v in value][0]
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'parking' in slot or 'internet' in slot:
-            if value in BOOLEAN:
-                clean_state.append([slot, value])
-            if value == 'free':
-                value = 'yes'
-                clean_state.append([slot, value])
-            elif True in [v in value for v in BOOLEAN]:
-                value = [v for v in BOOLEAN if v in value][0]
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'day' in slot:
-            if value in DAYS:
-                clean_state.append([slot, value])
-            elif True in [v in value for v in DAYS]:
-                value = [v for v in DAYS if v in value][0]
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'people' in slot or 'duration' in slot or 'stay' in slot:
-            if value in QUANTITIES:
-                clean_state.append([slot, value])
-            elif True in [v in value for v in QUANTITIES]:
-                value = [v for v in QUANTITIES if v in value][0]
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                try:
-                    value = int(value)
-                    if value >= 10:
-                        value = '10 or more'
-                        clean_state.append([slot, value])
-                    else:
-                        continue
-                except:
-                    continue
-        elif 'time' in slot or 'leaveat' in slot or 'arriveby' in slot:
-            if 'leaveat' in slot:
-                d, s = slot.split('-', 1)
-                s = 'leave at'
-                slot = f'{d}-{s}'
-            if 'arriveby' in slot:
-                d, s = slot.split('-', 1)
-                s = 'arrive by'
-                slot = f'{d}-{s}'
-            if value in TIME:
-                if value == 'do not care':
-                    clean_state.append([slot, value])
-                else:
-                    h, m = value.split(':')
-                    if int(m) % 5 == 0:
-                        clean_state.append([slot, value])
-                    else:
-                        m = round(int(m) / 5) * 5
-                        h = int(h)
-                        if m == 60:
-                            m = 0
-                            h += 1
-                        if h >= 24:
-                            h -= 24
-                        value = '%02i:%02i' % (h, m)
-                        clean_state.append([slot, value])
-            elif True in [v in value for v in TIME]:
-                value = [v for v in TIME if v in value][0]
-                h, m = value.split(':')
-                if int(m) % 5 == 0:
-                    clean_state.append([slot, value])
-                else:
-                    m = round(int(m) / 5) * 5
-                    h = int(h)
-                    if m == 60:
-                        m = 0
-                        h += 1
-                    if h >= 24:
-                        h -= 24
-                    value = '%02i:%02i' % (h, m)
-                    clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            else:
-                continue
-        elif 'stars' in slot:
-            if len(value) == 1 or value == 'do not care':
-                clean_state.append([slot, value])
-            elif value == '?' and is_acts:
-                clean_state.append([slot, value])
-            elif len(value) > 1:
-                try:
-                    value = int(value[0])
-                    value = str(value)
-                    clean_state.append([slot, value])
-                except:
-                    continue
-        elif 'area' in slot:
-            if '|' in value:
-                value = value.split('|', 1)[0]
-            clean_state.append([slot, value])
-        else:
-            if '|' in value:
-                value = value.split('|', 1)[0]
-                value = map_values(value)
-                # d, s = slot.split('-', 1)
-                # value = normalize_value(value_dict, d, s, value)
-            clean_state.append([slot, value])
-    
-    return clean_state
-
-
-# Module to process a dialogue and check its validity
-def process_dialogue(dialogue, max_utt_len=128):
-    if len(dialogue['log']) % 2 != 0:
-        return None
-
-    # Extract user and system utterances
-    usr_utts, sys_utts = [], []
-    avg_len = sum(len(utt['text'].split(' ')) for utt in dialogue['log'])
-    avg_len = avg_len / len(dialogue['log'])
-    if avg_len > max_utt_len:
-        return None
-
-    # If the first term is a system turn then ignore dialogue
-    if dialogue['log'][0]['metadata']:
-        return None
-
-    usr, sys = None, None
-    for turn in dialogue['log']:
-        if not is_ascii(turn['text']):
-            return None
-
-        if not usr or not sys:
-            if len(turn['metadata']) == 0:
-                usr = turn
-            else:
-                sys = turn
-        
-        if usr and sys:
-            states = build_dialoguestate(sys['metadata'], get_domains = False)
-            sys['dialogue_states'] = states
-
-            usr_utts.append(usr)
-            sys_utts.append(sys)
-            usr, sys = None, None
-
-    dial_clean = dict()
-    dial_clean['usr_log'] = usr_utts
-    dial_clean['sys_log'] = sys_utts
-    return dial_clean
-
-
-# Get new domains
-def get_act_domains(prev, crnt):
-    diff = {}
-    if not prev or not crnt:
-        return diff
-
-    for ((prev_dom, prev_val), (crnt_dom, crnt_val)) in zip(prev.items(), crnt.items()):
-        assert prev_dom == crnt_dom
-        if prev_val != crnt_val:
-            diff[crnt_dom] = crnt_val
-    return diff
-
-
-# Get current domains
-def get_domains(dial_log, turn_id, prev_domain):
-    if turn_id == 1:
-        active = build_dialoguestate(dial_log[turn_id]['metadata'], get_domains=True)
-        acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
-        acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
-        active += acts
-        crnt = active[0] if active else ''
-    else:
-        active = get_act_domains(dial_log[turn_id - 2]['metadata'], dial_log[turn_id]['metadata'])
-        active = list(active.keys())
-        acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
-        acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
-        active += acts
-        crnt = [prev_domain] if not active else active
-        crnt = crnt[0]
-
-    return crnt
-
-
-# Function to extract dialogue info from data
-def extract_dialogue(dialogue, max_utt_len=50):
-    dialogue = process_dialogue(dialogue, max_utt_len)
-    if not dialogue:
-        return None
-
-    usr_utts = [turn['text'] for turn in dialogue['usr_log']]
-    sys_utts = [turn['text'] for turn in dialogue['sys_log']]
-    # sys_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['sys_log']]
-    usr_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['usr_log']]
-    dialogue_states = [turn['dialogue_states'] for turn in dialogue['sys_log']]
-    domains = [turn['domain'] for turn in dialogue['usr_log']]
-
-    # dial = [{'usr': u,'sys': s, 'usr_a': ua, 'sys_a': a, 'domain': d, 'ds': v}
-    #         for u, s, ua, a, d, v in zip(usr_utts, sys_utts, usr_acts, sys_acts, domains, dialogue_states)]
-    dial = [{'usr': u,'sys': s, 'usr_a': ua, 'domain': d, 'ds': v}
-            for u, s, ua, d, v in zip(usr_utts, sys_utts, usr_acts, domains, dialogue_states)]    
-    return dial
-
-
-def format_acts(acts):
-    new_acts = []
-    for key, item in acts.items():
-        domain, intent = key.split('-', 1)
-        if domain.lower() in ACTIVE_DOMAINS + ['general']:
-            state = []
-            for slot, value in item:
-                slot = str(REF_SYS_DA[domain].get(slot, slot)).lower() if domain in REF_SYS_DA else slot
-                value = clean_text(value)
-                slot = slot.replace('_', ' ').replace('ref', 'reference')
-                state.append([f'{domain.lower()}-{slot}', value])
-            state = clean_dialoguestate(state, is_acts=True)
-            if domain == 'general':
-                if intent in ['thank', 'bye']:
-                    state = [['general-none', 'none']]
-                else:
-                    state = []
-            for slot, value in state:
-                if slot not in ['train-people']:
-                    slot = slot.split('-', 1)[-1]
-                    new_acts.append([intent.lower(), domain.lower(), slot, value])
-    
-    return new_acts
-                
-
-# Fix act labels
-def fix_delexicalisation(turn):
-    if 'dialog_act' in turn:
-        for dom, act in turn['dialog_act'].items():
-            if 'Attraction' in dom:
-                if 'restaurant_' in turn['text']:
-                    turn['text'] = turn['text'].replace("restaurant", "attraction")
-                if 'hotel_' in turn['text']:
-                    turn['text'] = turn['text'].replace("hotel", "attraction")
-            if 'Hotel' in dom:
-                if 'attraction_' in turn['text']:
-                    turn['text'] = turn['text'].replace("attraction", "hotel")
-                if 'restaurant_' in turn['text']:
-                    turn['text'] = turn['text'].replace("restaurant", "hotel")
-            if 'Restaurant' in dom:
-                if 'attraction_' in turn['text']:
-                    turn['text'] = turn['text'].replace("attraction", "restaurant")
-                if 'hotel_' in turn['text']:
-                    turn['text'] = turn['text'].replace("hotel", "restaurant")
-
-    return turn
-
-
-# Check if a character is an ascii character
-def is_ascii(s):
-    return all(ord(c) < 128 for c in s)
-
-
-# Insert white space
-def separate_token(token, text):
-    sidx = 0
-    while True:
-        # Find next instance of token
-        sidx = text.find(token, sidx)
-        if sidx == -1:
-            break
-        # If the token is already seperated continue to next
-        if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
-                re.match('[0-9]', text[sidx + 1]):
-            sidx += 1
-            continue
-        # Create white space separation around token
-        if text[sidx - 1] != ' ':
-            text = text[:sidx] + ' ' + text[sidx:]
-            sidx += 1
-        if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
-            text = text[:sidx + 1] + ' ' + text[sidx + 1:]
-        sidx += 1
-    return text
-
-
-def clean_text(text):
-    # Replace white spaces in front and end
-    text = re.sub(r'^\s*|\s*$', '', text.strip().lower())
-
-    # Replace b&v or 'b and b' with 'bed and breakfast'
-    text = re.sub(r"b&b", "bed and breakfast", text)
-    text = re.sub(r"b and b", "bed and breakfast", text)
-
-    # Fix apostrophies
-    text = re.sub(u"(\u2018|\u2019)", "'", text)
-
-    # Correct punctuation
-    text = text.replace(';', ',')
-    text = re.sub('$\/', '', text)
-    text = text.replace('/', ' and ')
-
-    # Replace special characters
-    text = text.replace('-', ' ')
-    text = re.sub('[\"\<>@\(\)]', '', text)
-
-    # Insert white space around special tokens:
-    for token in ['?', '.', ',', '!']:
-        text = separate_token(token, text)
-
-    # insert white space for 's
-    text = separate_token('\'s', text)
-
-    # replace it's, does't, you'd ... etc
-    text = re.sub('^\'', '', text)
-    text = re.sub('\'$', '', text)
-    text = re.sub('\'\s', ' ', text)
-    text = re.sub('\s\'', ' ', text)
-
-    # Perform pair replacements listed in the mapping.pair file
-    for fromx, tox in REPLACEMENTS:
-        text = ' ' + text + ' '
-        text = text.replace(fromx, tox)[1:-1]
-
-    # Remove multiple spaces
-    text = re.sub(' +', ' ', text)
-
-    # Concatenate numbers eg '1 3' -> '13'
-    tokens = text.split()
-    i = 1
-    while i < len(tokens):
-        if re.match(u'^\d+$', tokens[i]) and \
-                re.match(u'\d+$', tokens[i - 1]):
-            tokens[i - 1] += tokens[i]
-            del tokens[i]
-        else:
-            i += 1
-    text = ' '.join(tokens)
-
-    return text
diff --git a/convlab/dst/setsumbt/predict_user_actions.py b/convlab/dst/setsumbt/predict_user_actions.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c304a569cb5e29920332ed21c8f862dd00c1e48
--- /dev/null
+++ b/convlab/dst/setsumbt/predict_user_actions.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
+# Authors: Carel van Niekerk (niekerk@hhu.de)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Predict dataset user action using SetSUMBT Model"""
+
+from copy import deepcopy
+import os
+import json
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+from convlab.util.custom_util import flatten_acts as flatten
+from convlab.util import load_dataset, load_policy_data
+from convlab.dst.setsumbt import SetSUMBTTracker
+
+
+def flatten_acts(acts: dict) -> list:
+    """
+    Flatten dictionary actions.
+
+    Args:
+        acts: Dictionary acts
+
+    Returns:
+        flat_acts: Flattened actions
+    """
+    acts = flatten(acts)
+    flat_acts = []
+    for intent, domain, slot, value in acts:
+        flat_acts.append([intent,
+                          domain,
+                          slot if slot != 'none' else '',
+                          value.lower() if value != 'none' else ''])
+
+    return flat_acts
+
+
+def get_user_actions(context: list, system_acts: list) -> list:
+    """
+    Extract user actions from the data.
+
+    Args:
+        context: Previous dialogue turns.
+        system_acts: List of flattened system actions.
+
+    Returns:
+        user_acts: List of flattened user actions.
+    """
+    user_acts = context[-1]['dialogue_acts']
+    user_acts = flatten_acts(user_acts)
+    if len(context) == 3:
+        prev_state = context[-3]['state']
+        cur_state = context[-1]['state']
+        for domain, substate in cur_state.items():
+            for slot, value in substate.items():
+                if prev_state[domain][slot] != value:
+                    act = ['inform', domain, slot, value]
+                    if act not in user_acts and act not in system_acts:
+                        user_acts.append(act)
+
+    return user_acts
+
+
+def extract_dataset(dataset: str = 'multiwoz21') -> list:
+    """
+    Extract acts and utterances from the dataset.
+
+    Args:
+        dataset: Dataset name
+
+    Returns:
+        data: Extracted data
+    """
+    data = load_dataset(dataset_name=dataset)
+    raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test']
+
+    dialogue = list()
+    data = list()
+    for turn in raw_data:
+        state = dict()
+        state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else ''
+        state['utterance'] = turn['context'][-1]['utterance']
+        state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {}
+        state['system_actions'] = flatten_acts(state['system_actions'])
+        state['user_actions'] = get_user_actions(turn['context'], state['system_actions'])
+        dialogue.append(state)
+        if turn['terminated']:
+            data.append(dialogue)
+            dialogue = list()
+
+    return data
+
+
+def unflatten_acts(acts: list) -> dict:
+    """
+    Convert acts from flat list format to dict format.
+
+    Args:
+        acts: List of flat actions.
+
+    Returns:
+        unflat_acts: Dictionary of acts.
+    """
+    binary_acts = []
+    cat_acts = []
+    for intent, domain, slot, value in acts:
+        include = True if (domain == 'general') or (slot != 'none') else False
+        if include and (value == '' or value == 'none' or intent == 'request'):
+            binary_acts.append({'intent': intent,
+                                'domain': domain,
+                                'slot': slot if slot != 'none' else ''})
+        elif include:
+            cat_acts.append({'intent': intent,
+                             'domain': domain,
+                             'slot': slot if slot != 'none' else '',
+                             'value': value})
+
+    unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()}
+
+    return unflat_acts
+
+
+def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list:
+    """
+    Predict the user actions using the SetSUMBT Tracker.
+
+    Args:
+        data: List of dialogues.
+        tracker: SetSUMBT Tracker
+
+    Returns:
+        predict_result: List of turns containing predictions and true user actions.
+    """
+    tracker.init_session()
+    predict_result = []
+    for dial_idx, dialogue in enumerate(data):
+        for turn_idx, state in enumerate(dialogue):
+            sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx}
+
+            tracker.state['history'].append(['sys', state['system_utterance']])
+            predicted_state = deepcopy(tracker.update(state['utterance']))
+            tracker.state['history'].append(['usr', state['utterance']])
+            tracker.state['system_action'] = state['system_actions']
+
+            sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])}
+            sample['dialogue_acts'] = unflatten_acts(state['user_actions'])
+
+            predict_result.append(sample)
+
+        tracker.init_session()
+
+    return predict_result
+
+
+if __name__ =="__main__":
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
+    parser.add_argument('--model_path', type=str, help='Path to model dir')
+    args = parser.parse_args()
+
+    dataset = extract_dataset(args.dataset_name)
+    tracker = SetSUMBTTracker(args.model_path)
+    predict_results = predict_user_acts(dataset, tracker)
+
+    with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer:
+        json.dump(predict_results, writer, indent=2)
+        writer.close()
diff --git a/convlab/dst/setsumbt/process_mwoz_data.py b/convlab/dst/setsumbt/process_mwoz_data.py
deleted file mode 100755
index 701a523613961d83a5188fa9ab0cf786b19a5a7e..0000000000000000000000000000000000000000
--- a/convlab/dst/setsumbt/process_mwoz_data.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-import json
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-
-import torch
-from tqdm import tqdm
-
-from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
-from convlab.util.multiwoz.lexicalize import deflat_da, flat_da
-
-
-def load_data(path):
-    with open(path, 'r') as reader:
-        data = json.load(reader)
-        reader.close()
-    
-    return data
-
-
-def load_tracker(model_checkpoint):
-    model = SetSUMBTTracker(model_path=model_checkpoint)
-    model.init_session()
-
-    return model
-
-
-def process_dialogue(dial, model, get_full_belief_state):
-    model.store_full_belief_state = get_full_belief_state
-    model.init_session()
-
-    model.state['history'].append(['sys', ''])
-    processed_dial = []
-    belief_state = {}
-    for turn in dial:
-        if not turn['metadata']:
-            state = model.update(turn['text'])
-            model.state['history'].append(['usr', turn['text']])
-            
-            acts = model.state['user_action']
-            acts = [[val.replace('-', ' ') for val in act] for act in acts]
-            acts = flat_da(acts)
-            acts = deflat_da(acts)
-            turn['dialog_act'] = acts
-        else:
-            model.state['history'].append(['sys', turn['text']])
-            turn['metadata'] = model.state['belief_state']
-        
-        if get_full_belief_state:
-            for slot, probs in model.full_belief_state.items():
-                if slot not in belief_state:
-                    belief_state[slot] = [probs[0]]
-                else:
-                    belief_state[slot].append(probs[0])
-        
-        processed_dial.append(turn)
-    
-    if get_full_belief_state:
-        belief_state = {slot: torch.cat(probs, 0).cpu() for slot, probs in belief_state.items()}
-
-    return processed_dial, belief_state
-
-
-def process_dialogues(data, model, get_full_belief_state=False):
-    processed_data = {}
-    belief_states = {}
-    for dial_id, dial in tqdm(data.items()):
-        dial['log'], bs = process_dialogue(dial['log'], model, get_full_belief_state)
-        processed_data[dial_id] = dial
-        if get_full_belief_state:
-            belief_states[dial_id] = bs
-
-    return processed_data, belief_states
-
-
-def get_arguments():
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-
-    parser.add_argument('--model_path')
-    parser.add_argument('--data_path')
-    parser.add_argument('--get_full_belief_state', action='store_true')
-
-    return parser.parse_args()
-
-
-if __name__ == "__main__":
-    args = get_arguments()
-
-    print('Loading data and model...')
-    data = load_data(os.path.join(args.data_path, 'data.json'))
-    model = load_tracker(args.model_path)
-
-    print('Processing data...\n')
-    data, belief_states = process_dialogues(data, model, get_full_belief_state=args.get_full_belief_state)
-    
-    print('Saving results...\n')
-    torch.save(belief_states, os.path.join(args.data_path, 'setsumbt_belief_states.bin'))
-    with open(os.path.join(args.data_path, 'setsumbt_data.json'), 'w') as writer:
-        json.dump(data, writer, indent=2)
-        writer.close()
diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py
index b9c9a75b86d47cd5db733b4755d6af11f08b827d..e45bf129f0c9f2c5c1fba01d4b5eb80e29a5a1f0 100644
--- a/convlab/dst/setsumbt/run.py
+++ b/convlab/dst/setsumbt/run.py
@@ -33,8 +33,8 @@ def main():
     if args.run_nbt:
         from convlab.dst.setsumbt.do.nbt import main
         main(args, config)
-    if args.run_calibration:
-        from convlab.dst.setsumbt.do.calibration import main
+    if args.run_evaluation:
+        from convlab.dst.setsumbt.do.evaluate import main
         main(args, config)
 
 
diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b620247fd4a36223fbed8c46c54615f7c69da98
--- /dev/null
+++ b/convlab/dst/setsumbt/tracker.py
@@ -0,0 +1,446 @@
+import os
+import json
+import copy
+import logging
+
+import torch
+import transformers
+from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer
+
+from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT
+from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings
+from convlab.dst.dst import DST
+from convlab.util.custom_util import model_downloader
+
+USE_CUDA = torch.cuda.is_available()
+transformers.logging.set_verbosity_error()
+
+
+class SetSUMBTTracker(DST):
+    """SetSUMBT Tracker object for Convlab dialogue system"""
+
+    def __init__(self,
+                 model_path: str = "",
+                 model_type: str = "roberta",
+                 return_turn_pooled_representation: bool = False,
+                 return_confidence_scores: bool = False,
+                 confidence_threshold='auto',
+                 return_belief_state_entropy: bool = False,
+                 return_belief_state_mutual_info: bool = False,
+                 store_full_belief_state: bool = False):
+        """
+        Args:
+            model_path: Model path or download URL
+            model_type: Transformer type (roberta/bert)
+            return_turn_pooled_representation: If true a turn level pooled representation is returned
+            return_confidence_scores: If true act confidence scores are included in the state
+            confidence_threshold: Confidence threshold value for constraints or option auto
+            return_belief_state_entropy: If true belief state distribution entropies are included in the state
+            return_belief_state_mutual_info: If true belief state distribution mutual infos are included in the state
+            store_full_belief_state: If true full belief state is stored within tracker object
+        """
+        super(SetSUMBTTracker, self).__init__()
+
+        self.model_type = model_type
+        self.model_path = model_path
+        self.return_turn_pooled_representation = return_turn_pooled_representation
+        self.return_confidence_scores = return_confidence_scores
+        self.confidence_threshold = confidence_threshold
+        self.return_belief_state_entropy = return_belief_state_entropy
+        self.return_belief_state_mutual_info = return_belief_state_mutual_info
+        self.store_full_belief_state = store_full_belief_state
+        if self.store_full_belief_state:
+            self.full_belief_state = {}
+        self.info_dict = {}
+
+        # Download model if needed
+        if not os.path.exists(self.model_path):
+            # Get path /.../convlab/dst/setsumbt/multiwoz/models
+            download_path = os.path.dirname(os.path.abspath(__file__))
+            download_path = os.path.join(download_path, 'models')
+            if not os.path.exists(download_path):
+                os.mkdir(download_path)
+            model_downloader(download_path, self.model_path)
+            # Downloadable model path format http://.../setsumbt_model_name.zip
+            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
+            self.model_path = os.path.join(download_path, self.model_path)
+
+        # Select model type based on the encoder
+        if model_type == "roberta":
+            self.config = RobertaConfig.from_pretrained(self.model_path)
+            self.tokenizer = RobertaTokenizer
+            self.model = RobertaSetSUMBT
+        elif model_type == "bert":
+            self.config = BertConfig.from_pretrained(self.model_path)
+            self.tokenizer = BertTokenizer
+            self.model = BertSetSUMBT
+        else:
+            logging.debug("Name Error: Not Implemented")
+
+        self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
+
+        self.load_weights()
+
+    def load_weights(self):
+        """Load model weights and model ontology"""
+        logging.info('Loading SetSUMBT pretrained model.')
+        self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name)
+        logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.')
+        self.model = self.model.from_pretrained(self.model_path, config=self.config)
+        logging.info(f'Model loaded from {self.model_path}.')
+
+        # Transfer model to compute device and setup eval environment
+        self.model = self.model.to(self.device)
+        self.model.eval()
+        logging.info(f'Model transferred to device: {self.device}')
+
+        logging.info('Loading model ontology')
+        f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r')
+        self.ontology = json.load(f)
+        f.close()
+
+        db = torch.load(os.path.join(self.model_path, 'database', 'test.db'))
+        set_ontology_embeddings(self.model, db)
+
+        if self.return_confidence_scores:
+            logging.info('Model returns user action and belief state confidence scores.')
+            self.get_thresholds(self.confidence_threshold)
+            logging.info('Uncertain Querying set up and thresholds set up at:')
+            logging.info(self.confidence_thresholds)
+        if self.return_belief_state_entropy:
+            logging.info('Model returns belief state distribution entropy scores (Total uncertainty).')
+        if self.return_belief_state_mutual_info:
+            logging.info('Model returns belief state distribution mutual information scores (Knowledge uncertainty).')
+        logging.info('Ontology loaded successfully.')
+
+    def get_thresholds(self, threshold='auto') -> dict:
+        """
+        Setup dictionary of domain specific confidence thresholds
+
+        Args:
+            threshold: Threshold value or option auto
+
+        Returns:
+            confidence_thresholds: Domain specific confidence thresholds
+        """
+        self.confidence_thresholds = dict()
+        for domain, substate in self.ontology.items():
+            for slot, slot_info in substate.items():
+                # Auto thresholds are set based on the number of value candidates per slot
+                if domain not in self.confidence_thresholds:
+                    self.confidence_thresholds[domain] = dict()
+                if threshold == 'auto':
+                    thres = 1.0 / (float(len(slot_info['possible_values'])) - 2.1)
+                    self.confidence_thresholds[domain][slot] = max(0.05, thres)
+                else:
+                    self.confidence_thresholds[domain][slot] = max(0.05, threshold)
+
+        return self.confidence_thresholds
+
+    def init_session(self):
+        self.state = dict()
+        self.state['belief_state'] = dict()
+        self.state['booked'] = dict()
+        for domain, substate in self.ontology.items():
+            self.state['belief_state'][domain] = dict()
+            for slot, slot_info in substate.items():
+                if slot_info['possible_values'] and slot_info['possible_values'] != ['?']:
+                    self.state['belief_state'][domain][slot] = ''
+            self.state['booked'][domain] = list()
+        self.state['history'] = []
+        self.state['system_action'] = []
+        self.state['user_action'] = []
+        self.state['terminated'] = False
+        self.active_domains = {}
+        self.hidden_states = None
+        self.info_dict = {}
+
+    def update(self, user_act: str = '') -> dict:
+        """
+        Update user actions and dialogue and belief states.
+
+        Args:
+            user_act:
+
+        Returns:
+
+        """
+        prev_state = self.state
+        _output = self.predict(self.get_features(user_act))
+
+        # Format state entropy
+        if _output[5] is not None:
+            state_entropy = dict()
+            for slot, e in _output[5].items():
+                domain, slot = slot.split('-', 1)
+                if domain not in state_entropy:
+                    state_entropy[domain] = dict()
+                state_entropy[domain][slot] = e
+        else:
+            state_entropy = None
+
+        # Format state mutual information
+        if _output[6] is not None:
+            state_mutual_info = dict()
+            for slot, mi in _output[6].items():
+                domain, slot = slot.split('-', 1)
+                if domain not in state_mutual_info:
+                    state_mutual_info[domain] = dict()
+                state_mutual_info[domain][slot] = mi[0, 0]
+        else:
+            state_mutual_info = None
+
+        # Format all confidence scores
+        belief_state_confidence = None
+        if _output[4] is not None:
+            belief_state_confidence = dict()
+            belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4]
+            for slot, p in belief_state_conf.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in belief_state_confidence:
+                    belief_state_confidence[domain] = dict()
+                if slot not in belief_state_confidence[domain]:
+                    belief_state_confidence[domain][slot] = dict()
+                belief_state_confidence[domain][slot]['inform'] = p
+
+            for slot, p in request_probs.items():
+                domain, slot = slot.split('-', 1)
+                if domain not in belief_state_confidence:
+                    belief_state_confidence[domain] = dict()
+                if slot not in belief_state_confidence[domain]:
+                    belief_state_confidence[domain][slot] = dict()
+                belief_state_confidence[domain][slot]['request'] = p
+
+            for domain, p in active_domain_probs.items():
+                if domain not in belief_state_confidence:
+                    belief_state_confidence[domain] = dict()
+                belief_state_confidence[domain]['none'] = {'inform': p}
+
+            if 'general' not in belief_state_confidence:
+                belief_state_confidence['general'] = dict()
+            belief_state_confidence['general']['none'] = general_act_probs
+
+        # Get new domain activation actions
+        new_domains = [d for d, active in _output[1].items() if active]
+        new_domains = [d for d in new_domains if not self.active_domains.get(d, False)]
+        self.active_domains = _output[1]
+
+        user_acts = _output[2]
+        for domain in new_domains:
+            user_acts.append(['inform', domain, 'none', 'none'])
+
+        new_belief_state = copy.deepcopy(prev_state['belief_state'])
+        for domain, substate in _output[0].items():
+            for slot, value in substate.items():
+                value = '' if value == 'none' else value
+                value = 'dontcare' if value == 'do not care' else value
+                value = 'guesthouse' if value == 'guest house' else value
+
+                if domain not in new_belief_state:
+                    if domain == 'bus':
+                        continue
+                    else:
+                        logging.debug('Error: domain <{}> not in belief state'.format(domain))
+
+                # Uncertainty clipping of state
+                if belief_state_confidence is not None:
+                    threshold = self.confidence_thresholds[domain][slot]
+                    if belief_state_confidence[domain][slot].get('inform', 1.0) < threshold:
+                        value = ''
+
+                new_belief_state[domain][slot] = value
+                if prev_state['belief_state'][domain][slot] != value:
+                    user_acts.append(['inform', domain, slot, value])
+                else:
+                    bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>'
+                    logging.debug(bug)
+
+        new_state = copy.deepcopy(dict(prev_state))
+        new_state['belief_state'] = new_belief_state
+        new_state['active_domains'] = self.active_domains
+        if belief_state_confidence is not None:
+            new_state['belief_state_probs'] = belief_state_confidence
+        if state_entropy is not None:
+            new_state['entropy'] = state_entropy
+        if state_mutual_info is not None:
+            new_state['mutual_information'] = state_mutual_info
+
+        user_acts = [act for act in user_acts if act not in new_state['system_action']]
+        new_state['user_action'] = user_acts
+
+        if _output[3] is not None:
+            new_state['turn_pooled_representation'] = _output[3]
+
+        self.state = new_state
+        self.info_dict = copy.deepcopy(dict(new_state))
+
+        return self.state
+
+    def predict(self, features: dict) -> tuple:
+        """
+        Model forward pass and prediction post processing.
+
+        Args:
+            features: Dictionary of model input features
+
+        Returns:
+            out: Model predictions and uncertainty features
+        """
+        state_mutual_info = None
+        with torch.no_grad():
+            turn_pooled_representation = None
+            if self.return_turn_pooled_representation:
+                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
+                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
+                                      get_turn_pooled_representation=True)
+                belief_state = _outputs[0]
+                request_probs = _outputs[1]
+                active_domain_probs = _outputs[2]
+                general_act_probs = _outputs[3]
+                self.hidden_states = _outputs[4]
+                turn_pooled_representation = _outputs[5]
+            elif self.return_belief_state_mutual_info:
+                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
+                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
+                                      get_turn_pooled_representation=True, calculate_state_mutual_info=True)
+                belief_state = _outputs[0]
+                request_probs = _outputs[1]
+                active_domain_probs = _outputs[2]
+                general_act_probs = _outputs[3]
+                self.hidden_states = _outputs[4]
+                state_mutual_info = _outputs[5]
+            else:
+                _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
+                                      attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
+                                      get_turn_pooled_representation=False)
+                belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs
+
+        # Convert belief state into dialog state
+        dialogue_state = dict()
+        for slot, probs in belief_state.items():
+            dom, slot = slot.split('-', 1)
+            if dom not in dialogue_state:
+                dialogue_state[dom] = dict()
+            val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()]
+            if val != 'none':
+                dialogue_state[dom][slot] = val
+
+        if self.store_full_belief_state:
+            self.full_belief_state = belief_state
+
+        # Obtain model output probabilities
+        if self.return_confidence_scores:
+            state_entropy = None
+            if self.return_belief_state_entropy:
+                state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()}
+                state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()}
+
+            # Confidence score is the max probability across all not "none" values candidates.
+            belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()}
+            _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()}
+            _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()}
+            _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()}
+            confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs)
+        else:
+            confidence_scores = None
+            state_entropy = None
+
+        # Construct request action prediction
+        request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
+        request_acts = [slot.split('-', 1) for slot in request_acts]
+        request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
+
+        # Construct active domain set
+        active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
+
+        # Construct general domain action
+        general_acts = general_act_probs[0, 0, :].argmax(-1).item()
+        general_acts = [[], ['bye'], ['thank']][general_acts]
+        general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
+
+        user_acts = request_acts + general_acts
+
+        out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores)
+        out += (state_entropy, state_mutual_info)
+        return out
+
+    def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor:
+        """
+        Compute relative entrop for a probability distribution
+
+        Args:
+            probs: Probability distributions
+
+        Returns:
+            entropy: Relative entropy
+        """
+        entropy = probs * torch.log(probs + 1e-8)
+        entropy = -entropy.sum()
+        # Maximum entropy of a K dimentional distribution is ln(K)
+        entropy /= torch.log(torch.tensor(probs.size(-1)).float())
+
+        return entropy
+
+    def get_features(self, user_act: str) -> dict:
+        """
+        Tokenize utterances and construct model input features
+
+        Args:
+            user_act: User action string
+
+        Returns:
+            features: Model input features
+        """
+        # Extract system utterance from dialog history
+        context = self.state['history']
+        if context:
+            if context[-1][0] != 'sys':
+                system_act = ''
+            else:
+                system_act = context[-1][-1]
+        else:
+            system_act = ''
+
+        # Tokenize dialog
+        features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True,
+                                              max_length=self.config.max_turn_len, padding='max_length',
+                                              truncation='longest_first')
+
+        input_ids = torch.tensor(features['input_ids']).reshape(
+            1, 1, -1).to(self.device) if 'input_ids' in features else None
+        token_type_ids = torch.tensor(features['token_type_ids']).reshape(
+            1, 1, -1).to(self.device) if 'token_type_ids' in features else None
+        attention_mask = torch.tensor(features['attention_mask']).reshape(
+            1, 1, -1).to(self.device) if 'attention_mask' in features else None
+        features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
+
+        return features
+
+
+# if __name__ == "__main__":
+#     from convlab.policy.vector.vector_uncertainty import VectorUncertainty
+#     # from convlab.policy.vector.vector_binary import VectorBinary
+#     tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
+#                               return_confidence_scores=True, confidence_threshold='auto',
+#                               return_belief_state_entropy=True)
+#     vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds,
+#                                use_masking=True)
+#     # vector = VectorBinary()
+#     tracker.init_session()
+#
+#     state = tracker.update('hey. I need a cheap restaurant.')
+#     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
+#     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
+#     state = tracker.update('If you have something Asian that would be great.')
+#     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
+#     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
+#     tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'],
+#                                       ['inform', 'restaurant', 'name', 'the golden wok']]
+#     state = tracker.update('Great. Where are they located?')
+#     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
+#     state = tracker.state
+#     state['terminated'] = False
+#     state['booked'] = {}
+#
+#     print(state)
+#     print(vector.state_vectorize(state))
diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils.py
index 75a6a1febd7510f1a6d152676b82d709abf177f0..ff374116a3f8e88e6219fdc8b134d40b0bee7caf 100644
--- a/convlab/dst/setsumbt/utils.py
+++ b/convlab/dst/setsumbt/utils.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
+# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
 # Authors: Carel van Niekerk (niekerk@hhu.de)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,151 +15,124 @@
 # limitations under the License.
 """SetSUMBT utils"""
 
-import re
 import os
+import json
 import shutil
 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-from glob import glob
 from datetime import datetime
 
-from google.cloud import storage
+from git import Repo
 
 
-def get_args(MODELS):
+def get_args(base_models: dict):
     # Get arguments
     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
 
+    # Config file usage
+    parser.add_argument('--starting_config_name', default=None, type=str)
+
     # Optional
-    parser.add_argument('--tensorboard_path',
-                        help='Path to tensorboard', default='')
+    parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='')
     parser.add_argument('--logging_path', help='Path for log file', default='')
-    parser.add_argument(
-        '--seed', help='Seed value for reproducability', default=0, type=int)
+    parser.add_argument('--seed', help='Seed value for reproducibility', default=0, type=int)
 
     # DATASET (Optional)
-    parser.add_argument(
-        '--dataset', help='Dataset Name: multiwoz21/simr', default='multiwoz21')
-    parser.add_argument('--shrink_active_domains', help='Shrink active domains to only well represented test set domains',
-                        action='store_true')
-    parser.add_argument(
-        '--data_dir', help='Data storage directory', default=None)
-    parser.add_argument(
-        '--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
-    parser.add_argument(
-        '--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
-    parser.add_argument(
-        '--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
-    parser.add_argument('--max_candidate_len',
-                        help='Maximum number of tokens per value candidate', default=12, type=int)
-    parser.add_argument('--force_processing', action='store_true',
-                        help='Force preprocessing of data.')
-    parser.add_argument('--data_sampling_size',
-                        help='Resampled dataset size', default=-1, type=int)
-    parser.add_argument('--use_descriptions', help='Use slot descriptions rather than slot names for embeddings',
+    parser.add_argument('--dataset', help='Dataset Name (See Convlab 3 unified format for possible datasets',
+                        default='multiwoz21')
+    parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0,
+                        type=float)
+    parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
+    parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
+    parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
+    parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12,
+                        type=int)
+    parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.')
+    parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int)
+    parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings',
                         action='store_true')
 
     # MODEL
     # Environment
-    parser.add_argument(
-        '--output_dir', help='Output storage directory', default=None)
-    parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta',
-                        default='roberta')
-    parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.',
-                        default=None)
+    parser.add_argument('--output_dir', help='Output storage directory', default=None)
+    parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta')
+    parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None)
     parser.add_argument('--candidate_embedding_model_name', default=None,
                         help='Name of the pretrained candidate embedding model.')
+    parser.add_argument('--transformers_local_files_only', help='Use local files only for huggingface transformers',
+                        action='store_true')
 
     # Architecture
     parser.add_argument('--freeze_encoder', help='No training performed on the turn encoder Bert Model',
                         action='store_true')
     parser.add_argument('--slot_attention_heads', help='Number of attention heads for slot conditioning',
                         default=12, type=int)
-    parser.add_argument('--dropout_rate', help='Dropout Rate',
-                        default=0.3, type=float)
-    parser.add_argument(
-        '--nbt_type', help='Belief Tracker type: gru/lstm', default='gru')
+    parser.add_argument('--dropout_rate', help='Dropout Rate', default=0.3, type=float)
+    parser.add_argument('--nbt_type', help='Belief Tracker type: gru/lstm', default='gru')
     parser.add_argument('--nbt_hidden_size', help='Hidden embedding size for the Neural Belief Tracker',
                         default=300, type=int)
-    parser.add_argument(
-        '--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int)
-    parser.add_argument(
-        '--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true')
+    parser.add_argument('--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int)
+    parser.add_argument('--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true')
     parser.add_argument('--distance_measure', default='cosine',
                         help='Similarity measure for candidate scoring: cosine/euclidean')
-    parser.add_argument(
-        '--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
-    parser.add_argument('--set_similarity', action='store_true',
-                        help='Set True to not use set similarity (Model tracks latent belief state as sequence and performs semantic similarity of sets)')
-    parser.add_argument('--set_pooling', help='Set pooling method for set similarity model using single embedding distances',
+    parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
+    parser.add_argument('--no_set_similarity', action='store_true', help='Set True to not use set similarity')
+    parser.add_argument('--set_pooling',
+                        help='Set pooling method for set similarity model using single embedding distances',
                         default='cnn')
-    parser.add_argument('--candidate_pooling', help='Pooling approach for non set based candidate representations: cls/mean',
+    parser.add_argument('--candidate_pooling',
+                        help='Pooling approach for non set based candidate representations: cls/mean',
                         default='mean')
-    parser.add_argument('--predict_actions', help='Model predicts user actions and active domain',
+    parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain',
                         action='store_true')
 
     # Loss
-    parser.add_argument('--loss_function', help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/distillation/distribution_distillation',
+    parser.add_argument('--loss_function',
+                        help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/...',
                         default='labelsmoothing')
     parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss',
                         type=float)
     parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss',
                         type=float)
-    parser.add_argument('--ensemble_smoothing',
-                        help='Ensemble distribution smoothing constant', type=float)
-    parser.add_argument('--annealing_base_temp', help='Ensemble Distribution destillation temp annealing base temp',
+    parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float)
+    parser.add_argument('--annealing_base_temp', help='Ensemble Distribution distillation temp annealing base temp',
+                        type=float)
+    parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution distillation temp annealing cycle length',
                         type=float)
-    parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution destillation temp annealing cycle length',
+    parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float)
+    parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0',
                         type=float)
-    parser.add_argument('--inhibiting_factor',
-                        help='Inhibiting factor for Inhibited Softmax CE', type=float)
-    parser.add_argument('--label_smoothing',
-                        help='Label smoothing coefficient.', type=float)
-    parser.add_argument(
-        '--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument(
-        '--user_request_loss_weight', help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument(
-        '--user_general_act_loss_weight', help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float)
-    parser.add_argument(
-        '--active_domain_loss_weight', help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--user_request_loss_weight',
+                        help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--user_general_act_loss_weight',
+                        help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float)
+    parser.add_argument('--active_domain_loss_weight',
+                        help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float)
 
     # TRAINING
-    parser.add_argument('--train_batch_size',
-                        help='Training Set Batch Size', default=4, type=int)
-    parser.add_argument('--max_training_steps', help='Maximum number of training update steps',
-                        default=-1, type=int)
+    parser.add_argument('--train_batch_size', help='Training Set Batch Size', default=8, type=int)
+    parser.add_argument('--max_training_steps', help='Maximum number of training update steps', default=-1, type=int)
     parser.add_argument('--gradient_accumulation_steps', default=1, type=int,
                         help='Number of batches accumulated for one update step')
-    parser.add_argument('--num_train_epochs',
-                        help='Number of training epochs', default=50, type=int)
+    parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int)
     parser.add_argument('--patience', help='Number of training steps without improving model before stopping.',
-                        default=25, type=int)
-    parser.add_argument(
-        '--weight_decay', help='Weight decay rate', default=0.01, type=float)
-    parser.add_argument('--learning_rate',
-                        help='Initial Learning Rate', default=5e-5, type=float)
-    parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler',
-                        default=0.2, type=float)
-    parser.add_argument(
-        '--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float)
-    parser.add_argument(
-        '--save_steps', help='Number of update steps between saving model', default=-1, type=int)
-    parser.add_argument(
-        '--keep_models', help='How many model checkpoints should be kept during training', default=1, type=int)
+                        default=20, type=int)
+    parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float)
+    parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float)
+    parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float)
+    parser.add_argument('--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float)
+    parser.add_argument('--save_steps', help='Number of update steps between saving model', default=-1, type=int)
+    parser.add_argument('--keep_models', help='How many model checkpoints should be kept during training',
+                        default=1, type=int)
 
     # CALIBRATION
-    parser.add_argument(
-        '--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float)
+    parser.add_argument('--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float)
 
     # EVALUATION
-    parser.add_argument('--dev_batch_size',
-                        help='Dev Set Batch Size', default=16, type=int)
-    parser.add_argument('--test_batch_size',
-                        help='Test Set Batch Size', default=16, type=int)
+    parser.add_argument('--dev_batch_size', help='Dev Set Batch Size', default=16, type=int)
+    parser.add_argument('--test_batch_size', help='Test Set Batch Size', default=16, type=int)
 
     # COMPUTING
-    parser.add_argument(
-        '--n_gpu', help='Number of GPUs to use', default=1, type=int)
+    parser.add_argument('--n_gpu', help='Number of GPUs to use', default=1, type=int)
     parser.add_argument('--fp16', action='store_true',
                         help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
     parser.add_argument('--fp16_opt_level', type=str, default='O1',
@@ -167,32 +140,35 @@ def get_args(MODELS):
                              "See details at https://nvidia.github.io/apex/amp.html")
 
     # ACTIONS
-    parser.add_argument('--run_nbt', help='Run NBT script',
-                        action='store_true')
-    parser.add_argument('--run_calibration',
-                        help='Run calibration', action='store_true')
+    parser.add_argument('--run_nbt', help='Run NBT script', action='store_true')
+    parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true')
 
     # RUN_NBT ACTIONS
-    parser.add_argument(
-        '--do_train', help='Perform training', action='store_true')
-    parser.add_argument(
-        '--do_eval', help='Perform model evaluation during training', action='store_true')
-    parser.add_argument(
-        '--do_test', help='Evaulate model on test data', action='store_true')
+    parser.add_argument('--do_train', help='Perform training', action='store_true')
+    parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true')
+    parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
     args = parser.parse_args()
 
-    # Setup default directories
-    if not args.data_dir:
-        args.data_dir = os.path.dirname(os.path.abspath(__file__))
-        args.data_dir = os.path.join(args.data_dir, 'data')
-        os.makedirs(args.data_dir, exist_ok=True)
+    if args.starting_config_name:
+        args = get_starting_config(args)
+
+    if args.do_train:
+        args.do_eval = True
+
+    # Simplify args
+    args.set_similarity = not args.no_set_similarity
+    args.use_descriptions = not args.no_descriptions
+    args.predict_actions = not args.no_action_prediction
 
+    # Setup default directories
     if not args.output_dir:
         args.output_dir = os.path.dirname(os.path.abspath(__file__))
         args.output_dir = os.path.join(args.output_dir, 'models')
 
-        name = 'SetSUMBT'
-        name += '-Acts' if args.predict_actions else ''
+        name = 'SetSUMBT' if args.set_similarity else 'SUMBT'
+        name += '+ActPrediction' if args.predict_actions else ''
+        name += '-' + args.dataset
+        name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else ''
         name += '-' + args.model_type
         name += '-' + args.nbt_type
         name += '-' + args.distance_measure
@@ -208,9 +184,6 @@ def get_args(MODELS):
             args.kl_scaling_factor = 0.001
         if not args.prior_constant:
             args.prior_constant = 1.0
-    if args.loss_function == 'inhibitedce':
-        if not args.inhibiting_factor:
-            args.inhibiting_factor = 1.0
     if args.loss_function == 'labelsmoothing':
         if not args.label_smoothing:
             args.label_smoothing = 0.05
@@ -233,10 +206,8 @@ def get_args(MODELS):
         if not args.active_domain_loss_weight:
             args.active_domain_loss_weight = 0.2
 
-    args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(
-        args.output_dir, 'tb_logs')
-    args.logging_path = args.logging_path if args.logging_path else os.path.join(
-        args.output_dir, 'run.log')
+    args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(args.output_dir, 'tb_logs')
+    args.logging_path = args.logging_path if args.logging_path else os.path.join(args.output_dir, 'run.log')
 
     # Default model_name's
     if not args.model_name_or_path:
@@ -250,30 +221,62 @@ def get_args(MODELS):
     if not args.candidate_embedding_model_name:
         args.candidate_embedding_model_name = args.model_name_or_path
 
-    if args.model_type in MODELS:
-        configClass = MODELS[args.model_type][-2]
+    if args.model_type in base_models:
+        config_class = base_models[args.model_type][-2]
     else:
         raise NameError('NotImplemented')
-    config = build_config(configClass, args)
+    config = build_config(config_class, args)
     return args, config
 
 
-def build_config(configClass, args):
-    if args.model_type == 'fasttext':
-        config = configClass.from_pretrained('bert-base-uncased')
-        config.model_type == 'fasttext'
-        config.fasttext_path = args.model_name_or_path
-        config.vocab_size = None
-    elif not os.path.exists(args.model_name_or_path):
-        config = configClass.from_pretrained(args.model_name_or_path)
+def get_starting_config(args):
+    path = os.path.dirname(os.path.realpath(__file__))
+    path = os.path.join(path, 'configs', f"{args.starting_config_name}.json")
+    reader = open(path, 'r')
+    config = json.load(reader)
+    reader.close()
+
+    if "model_type" in config:
+        if config["model_type"].lower() == 'setsumbt':
+            config["model_type"] = 'roberta'
+            config["no_set_similarity"] = False
+            config["no_descriptions"] = False
+        elif config["model_type"].lower() == 'sumbt':
+            config["model_type"] = 'bert'
+            config["no_set_similarity"] = True
+            config["no_descriptions"] = False
+
+    variables = vars(args).keys()
+    for key, value in config.items():
+        if key in variables:
+            setattr(args, key, value)
+
+    return args
+
+
+def get_git_info():
+    repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True)
+    branch_name = repo.active_branch.name
+    commit_hex = repo.head.object.hexsha
+
+    info = f"{branch_name}/{commit_hex}"
+    return info
+
+
+def build_config(config_class, args):
+    config = config_class.from_pretrained(args.model_name_or_path)
+    config.code_version = get_git_info()
+    if not os.path.exists(args.model_name_or_path):
         config.tokenizer_name = args.model_name_or_path
-    elif 'tod-bert' in args.model_name_or_path.lower():
-        config = configClass.from_pretrained(args.model_name_or_path)
+    try:
+        config.tokenizer_name = config.tokenizer_name
+    except AttributeError:
         config.tokenizer_name = args.model_name_or_path
-    else:
-        config = configClass.from_pretrained(args.model_name_or_path)
-    if args.candidate_embedding_model_name:
-        config.candidate_embedding_model_name = args.candidate_embedding_model_name
+    try:
+        config.candidate_embedding_model_name = config.candidate_embedding_model_name
+    except:
+        if args.candidate_embedding_model_name:
+            config.candidate_embedding_model_name = args.candidate_embedding_model_name
     config.max_dialogue_len = args.max_dialogue_len
     config.max_turn_len = args.max_turn_len
     config.max_slot_len = args.max_slot_len
diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py
index b7331479e2ea2ee1a1f5c4bce7aef82546dcbd2e..c89361b2a84824198e516941b71806440a9ba3a5 100755
--- a/convlab/evaluator/multiwoz_eval.py
+++ b/convlab/evaluator/multiwoz_eval.py
@@ -3,15 +3,32 @@
 import logging
 import re
 import numpy as np
+
 from copy import deepcopy
+from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map
+from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
 from convlab.evaluator.evaluator import Evaluator
+from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map
 from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple
 from convlab.util.multiwoz.dbquery import Database
-import os
 from convlab.util import relative_import_module_from_unified_datasets
 
+# import reflect table
+REF_SYS_DA_M = {}
+for dom, ref_slots in REF_SYS_DA.items():
+    dom = dom.lower()
+    REF_SYS_DA_M[dom] = {}
+    for slot_a, slot_b in ref_slots.items():
+        if slot_a == 'Ref':
+            slot_b = 'ref'
+        REF_SYS_DA_M[dom][slot_a.lower()] = slot_b
+    REF_SYS_DA_M[dom]['none'] = 'none'
+REF_SYS_DA_M['taxi']['phone'] = 'phone'
+REF_SYS_DA_M['taxi']['car'] = 'car type'
+
 reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da')
 
+
 requestable = \
     {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'],
      'restaurant': ['addr', 'phone', 'post', 'ref', 'price', 'area', 'food'],
@@ -24,13 +41,13 @@ requestable = \
 belief_domains = requestable.keys()
 
 mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone',
-                          'post': 'postcode', 'price': 'pricerange'},
+                          'post': 'postcode', 'price': 'pricerange', 'ref': 'ref'},
            'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name',
-                     'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'},
+                     'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type', 'ref': 'ref'},
            'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone',
                           'post': 'postcode', 'type': 'type'},
            'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination',
-                     'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'},
+                     'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price', 'ref': 'ref'},
            'taxi': {'car': 'car type', 'phone': 'phone'},
            'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
            'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
@@ -39,6 +56,32 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na
 time_re = re.compile(r'^(([01]\d|2[0-4]):([0-5]\d)|24:00)$')
 NUL_VALUE = ["", "dont care", 'not mentioned',
              "don't care", "dontcare", "do n't care"]
+REF_SYS_DA_M = {}
+for dom, ref_slots in REF_SYS_DA.items():
+    dom = dom.lower()
+    REF_SYS_DA_M[dom] = {}
+    for slot_a, slot_b in ref_slots.items():
+        if slot_a == 'Ref':
+            slot_b = 'ref'
+        REF_SYS_DA_M[dom][slot_a.lower()] = slot_b
+    REF_SYS_DA_M[dom]['none'] = 'none'
+REF_SYS_DA_M['taxi']['phone'] = 'phone'
+REF_SYS_DA_M['taxi']['car'] = 'car type'
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK]
+
+# Not sure values in inform
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK]
 
 
 class MultiWozEvaluator(Evaluator):
@@ -56,7 +99,8 @@ class MultiWozEvaluator(Evaluator):
         self.success = 0
         self.success_strict = 0
         self.successful_domains = []
-        logging.info(f"We check booking constraints: {self.check_book_constraints}")
+        logging.info(
+            f"We check booking constraints: {self.check_book_constraints}")
 
     def _init_dict(self):
         dic = {}
@@ -93,12 +137,19 @@ class MultiWozEvaluator(Evaluator):
         """
         self.sys_da_array = []
         self.usr_da_array = []
-        self.goal = goal
+        self.goal = deepcopy(goal)
         self.cur_domain = ''
         self.booked = self._init_dict_booked()
         self.booked_states = self._init_dict_booked()
         self.successful_domains = []
 
+    @staticmethod
+    def _convert_action(act):
+        act = unified_format(act)
+        act = reverse_da(act)
+        act = act_dict_to_flat_tuple(act)
+        return act
+
     def add_sys_da(self, da_turn, belief_state=None):
         """add sys_da into array
 
@@ -107,11 +158,7 @@ class MultiWozEvaluator(Evaluator):
                 list[intent, domain, slot, value]
         """
 
-        sys_dialog_act = da_turn
-        sys_dialog_act = unified_format(sys_dialog_act)
-        sys_dialog_act = reverse_da(sys_dialog_act)
-        sys_dialog_act = act_dict_to_flat_tuple(sys_dialog_act)
-        da_turn = sys_dialog_act
+        da_turn = self._convert_action(da_turn)
 
         for intent, domain, slot, value in da_turn:
             dom_int = '-'.join([domain, intent])
@@ -131,12 +178,15 @@ class MultiWozEvaluator(Evaluator):
                 else:
                     if not self.booked[domain] and re.match(r'^\d{8}$', value) and \
                             len(self.dbs[domain]) > int(value):
-                        self.booked[domain] = self.dbs[domain][int(value)].copy()
+                        self.booked[domain] = self.dbs[domain][int(
+                            value)].copy()
                         self.booked[domain]['Ref'] = value
                         if belief_state is not None:
-                            self.booked_states[domain] = deepcopy(belief_state[domain])
+                            self.booked_states[domain] = deepcopy(
+                                belief_state[domain])
                         else:
                             self.booked_states[domain] = None
+        self.goal = self.update_goal(self.goal, da_turn)
 
     def add_usr_da(self, da_turn):
         """add usr_da into array
@@ -145,6 +195,7 @@ class MultiWozEvaluator(Evaluator):
             da_turn:
                 list[intent, domain, slot, value]
         """
+        da_turn = self._convert_action(da_turn)
         for intent, domain, slot, value in da_turn:
             dom_int = '-'.join([domain, intent])
             domain = dom_int.split('-')[0].lower()
@@ -384,7 +435,9 @@ class MultiWozEvaluator(Evaluator):
                     goal[d]['info'][mapping[d][s]] = v
                 elif i == 'request':
                     goal[d]['reqt'].append(s)
-        TP, FP, FN, _, _, _ = self._inform_F1_goal(goal, self.sys_da_array)
+
+        TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt = self._inform_F1_goal(
+            goal, self.sys_da_array)
         if aggregate:
             try:
                 rec = TP / (TP + FN)
@@ -405,22 +458,22 @@ class MultiWozEvaluator(Evaluator):
         """
         booking_done = self.check_booking_done(ref2goal)
         book_sess = self.book_rate(ref2goal)
-        #book_constraint_sess = self.book_rate_constrains(ref2goal)
-        book_constraint_sess = 1
+        book_constraint_sess = self.book_rate_constrains(ref2goal)
         inform_sess = self.inform_F1(ref2goal)
         goal_sess = self.final_goal_analyze()
-        #goal_sess = 1
-        # book rate == 1 & inform recall == 1
+
         if ((book_sess == 1 and inform_sess[1] == 1)
             or (book_sess == 1 and inform_sess[1] is None)
             or (book_sess is None and inform_sess[1] == 1)) \
                 and goal_sess == 1:
             self.complete = 1
             self.success = 1
-            self.success_strict = 1 if (book_constraint_sess == 1 or book_constraint_sess is None) else 0
+            self.success_strict = 1 if (
+                book_constraint_sess == 1 or book_constraint_sess is None) else 0
             return self.success if not self.check_book_constraints else self.success_strict
         else:
-            self.complete = 1 if booking_done and (inform_sess[1] == 1 or inform_sess[1] is None) else 0
+            self.complete = 1 if booking_done and (
+                inform_sess[1] == 1 or inform_sess[1] is None) else 0
             self.success = 0
             self.success_strict = 0
             return 0
@@ -473,13 +526,16 @@ class MultiWozEvaluator(Evaluator):
                 elif i == 'request':
                     goal[d]['reqt'].append(s)
 
-        book_constraints = self._book_goal_constraints(goal, self.booked_states, [domain])
-        book_constraints = np.mean(book_constraints) if book_constraints else None
+        book_constraints = self._book_goal_constraints(
+            goal, self.booked_states, [domain])
+        book_constraints = np.mean(
+            book_constraints) if book_constraints else None
 
         book_rate = self._book_rate_goal(goal, self.booked, [domain])
         book_rate = np.mean(book_rate) if book_rate else None
         match, mismatch = self._final_goal_analyze_domain(domain)
-        goal_sess = 1 if (match == 0 and mismatch == 0) else match / (match + mismatch)
+        goal_sess = 1 if (match == 0 and mismatch ==
+                          0) else match / (match + mismatch)
 
         inform = self._inform_F1_goal(goal, self.sys_da_array, [domain])
         try:
@@ -488,9 +544,10 @@ class MultiWozEvaluator(Evaluator):
             inform_rec = None
 
         if ((book_rate == 1 and inform_rec == 1) or (book_rate == 1 and inform_rec is None) or
-            (book_rate is None and inform_rec == 1)) and goal_sess == 1:
+                (book_rate is None and inform_rec == 1)) and goal_sess == 1:
             domain_success = 1
-            domain_strict_success = 1 if (book_constraints == 1 or book_constraints is None) else 0
+            domain_strict_success = 1 if (
+                book_constraints == 1 or book_constraints is None) else 0
             return domain_success if not self.check_book_constraints else domain_strict_success
         else:
             return 0
@@ -514,7 +571,7 @@ class MultiWozEvaluator(Evaluator):
         else:
             info_constraints = []
         query_result = self.database.query(
-            domain, info_constraints, soft_contraints=reqt_constraints)
+            domain, info_constraints + reqt_constraints)
         if not query_result:
             mismatch += 1
 
@@ -547,7 +604,7 @@ class MultiWozEvaluator(Evaluator):
             else:
                 info_constraints = []
             query_result = self.database.query(
-                domain, info_constraints, soft_contraints=reqt_constraints)
+                domain, info_constraints + reqt_constraints)
             if not query_result:
                 mismatch += 1
                 continue
@@ -593,3 +650,32 @@ class MultiWozEvaluator(Evaluator):
                     self.successful_domains.append(self.cur_domain)
 
         return reward
+
+    def evaluate_dialog(self, goal, user_acts, system_acts, system_states):
+
+        self.add_goal(goal.domain_goals)
+        for sys_act, sys_state, user_act in zip(system_acts, system_states, user_acts):
+            self.add_sys_da(sys_act, sys_state)
+            self.add_usr_da(user_act)
+        self.task_success()
+        return {"complete": self.complete, "success": self.success, "success_strict": self.success_strict}
+
+    def update_goal(self, goal, system_action):
+        for intent, domain, slot, val in system_action:
+            # need to reverse slot to old representation
+            if slot in reverse_da_slot_name_map:
+                slot = reverse_da_slot_name_map[slot]
+            elif domain in reverse_da_slot_name_map and slot in reverse_da_slot_name_map[domain]:
+                slot = reverse_da_slot_name_map[domain][slot]
+            else:
+                slot = slot.capitalize()
+            if intent.lower() in ['inform', 'recommend']:
+                if domain.lower() in goal:
+                    if 'reqt' in goal[domain.lower()]:
+                        if REF_SYS_DA_M.get(domain.lower(), {}).get(slot.lower(), slot.lower()) \
+                                in goal[domain.lower()]['reqt']:
+                            if val in NOT_SURE_VALS:
+                                val = '\"' + val + '\"'
+                            goal[domain.lower()]['reqt'][
+                                REF_SYS_DA_M.get(domain.lower(), {}).get(slot.lower(), slot.lower())] = val
+        return goal
diff --git a/convlab/policy/README.md b/convlab/policy/README.md
index 1990cdd6b03a38fe6f5a4f8b4eb8a9708761c590..233ea0eece10d0789791e3a97c5d1aa8311f232a 100755
--- a/convlab/policy/README.md
+++ b/convlab/policy/README.md
@@ -1,36 +1,199 @@
 # Dialog Policy
 
-In the pipeline task-oriented dialog framework, the dialog policy module
-takes as input the dialog state, and chooses the system action bases on
-it.
-
-This directory contains the interface definition of dialog policy
+In the pipeline task-oriented dialog framework, the dialogue policy module
+takes as input the dialog state, and chooses the system action based on
+it. This directory contains the interface definition of dialogue policy
 module for both system side and user simulator side, as well as some
-implementations under different sub-directories.
+implementations under different sub-directories. 
+
+An important additional module for the policy is the vectoriser which translates the dialogue state into a vectorised form that the dialogue policy network expects as input. 
+Moreover, it translates the vectorised act that the policy took back into semantic form. More information can be found in the directory /convlab/policy/vector.
+
+We currently maintain the following policies:
+
+**system policies**: GDPL, MLE, PG, PPO and VTRACE DPT
+
+**user policies**: rule, TUS, GenTUS
+
+
+## Overview
+
+Every policy directory typically has two python scripts, **1) train.py** for running an RL training and **2) a dedicated script** that implements the algorithm and loads the policy network (e.g. **ppo.py** for the ppo policy).
+Moreover, two config files define the environment and the hyper parameters for the algorithm:
+
+- config.json: defines the hyper parameters such as learning rate and algorithm related parameters
+- environment.json: defines the learning environment (MDP) for the policy. This includes the NLU, DST and NLG component for both system and user policy as well as which user policy should be used. It also defines the number of total training dialogues as well as evaluation dialogues and frequency.
+
+An example for the environment.json is the **semantic_level_config.json** in the policy subfolders.
+
+
+
+## Workflow
+
+The workflow can be generally decomposed into three steps that will be explained in more detail below:
+
+1. set up the environment configuration and policy parameters
+2. run a reinforcement learning training with the given configurations
+3. evaluate your trained models
+
+#### Set up the environment 
+
+The necessary step before starting a training is to set up the environment and policy parameters. Information about policy parameters can be found in each policy subfolder. The following example defines an environment for the policy with the rule-based dialogue state tracker, no NLU, no NLG, and the rule-based user simulator:
+
+```
+{
+	"model": {
+		"load_path": "", # specify a loading path to load a pre-trained model, omit the ending .pol.mdl
+		"use_pretrained_initialisation": false, # will download a provided ConvLab-3 model
+		"pretrained_load_path": "",
+		"seed": 0, # the seed for the experiment
+		"eval_frequency": 5, # how often evaluation should take place
+		"process_num": 4, # how many processes the evaluation should use for speed up
+		"sys_semantic_to_usr": false,
+		"num_eval_dialogues": 500 # how many dialogues should be used for evaluation
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
+			"ini_params": {
+				"use_masking": true,
+				"manually_add_entity_names": false,
+				"seed": 0
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"RuleDST": {
+			"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
+			"ini_params": {}
+		}
+	},
+	"sys_nlg": {},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"RulePolicy": {
+			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
+			"ini_params": {
+				"character": "usr"
+			}
+		}
+	},
+	"usr_nlg": {}
+}
+```
+
+#### Executing a training
+
+Once you set up your configuration, you are ready to start an experiment by executing
+
+```sh
+$ python convlab/policy/policy_subfolder/train.py --path=your_environment_config --seed=your_seed
+```
+
+You can specify the seed either in the environment config or through the argument parser. If you do not specify an environment config, it will automatically load the default config. 
+
+Once the training started, it will automatically generate an **experiment** folder and a corresponding experiment-TIMESTEP folder in it. Inside of that, there are 4 subfolders configs, logs, save and TB_summary:
+
+- **configs**: containts information about which config was used
+- **logs**: will save information created by a logger during training
+- **save**: a folder for saving model checkpoints
+- **TB_summary**: saves a tensorboard summary that will be later used for plotting graphs
+
+Once the training finished, it will move the experiment-TIMESTAMP folder into the **finished_experiments** folder.
+
+#### Evaluating your models
+
+The evaluation tools can be found in the folder convlab/policy/plot_results. Please have a look in the README for detailed instructions. 
+
+#### Running Evaluation Dialogues
+
+You can run evaluation dialogues with a trained model using 
+
+```sh
+$ python convlab/policy/evaluate.py --model_name=NAME --config_path=PATH --num_dialogues=NUM --verbose
+```
+
+- model_name: specify which model is used, i.e. MLE, PPO, PG, DDPT
+- config_path: specify the config-path that was used during RL training, for instance semantic_level_config.json
+- num_dialogues: number of evaluation dialogues
+- verbose: can be also excluded. If used, it will print the dialogues in the termain consoloe together with its goal. That helps in analysing the behaviour of the policy.
+
+## Adding a new policy
+
+If you would like to add a new policy, start by creating a subfolder for it. Then make sure that you have the four files mentioned in **Overview** section in it.
+
+#### Algorithm script
+
+Here you define your algorithm and policy network. Please ensure that you also load a vectoriser here that is inherited from the vector/vector_base.py class. 
+
+In addition, your policy module is required to have a **predict** method where the skeleton usually looks something like:
+
+    def predict(self, state):
+        """
+        Predict an system action given state.
+        Args:
+            state (dict): Dialog state. Please refer to util/state.py
+        Returns:
+            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
+        """
+        
+        # uses the vector class for vectorisation of the dialogue state and also creates an action mask
+        s, action_mask = self.vector.state_vectorize(state) 
+        s_vec = torch.Tensor(s)
+        mask_vec = torch.Tensor(action_mask)
+        
+        # predict an action using the policy network
+        a = self.policy.select_action(s_vec, mask_vec)
+
+        # map the action indices back to semantic actions using the vectoriser
+        action = self.vector.action_devectorize(a.detach().numpy())
+        return action
+
+#### train.py script
 
-## Interface
+The train.py script is responsible for several different functions. In the following we will provide some code or pointers on how to do these steps. Have a look at the train.py files as well.
 
-The interfaces for dialog policy are defined in policy.Policy:
+1. load the config and set seed
+    ```
+    environment_config = load_config_file(path)
+    conf = get_config(path, args)
+    seed = conf['model']['seed']
+    set_seed(seed)
+   save_config(vars(parser.parse_args()), environment_config, config_save_path)
+    ```
 
-- **predict** takes as input agent state (often the state tracked by DST)
-and outputs the next system action.
+2. saves additional information (through a logger and tensorboard writer)
 
-- **init_session** reset the model variables for a new dialog session.
+    ```
+    logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
+        init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
+   ```   
+   
+   
+3. load the policy module
 
-## Rule based simulator results
+    ```
+    policy_sys = PPO(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated'])
+    ```
+4. load the environment using th environment-config
+    ```
+   env, sess = env_config(conf, policy_sys)
+   ```
 
-| Model | Complete rate | Success rate | Average return | Turns | Average actions |
-|-------|---------------|--------------|----------------|-------|-----------------|
-| MLE   |               |              |                |       |                 |
-| PG    |               |              |                |       |                 |
-| GDPL  |               |              |                |       |                 |
-| PPO   |               |              |                |       |                 |
+5. collect dialogues and execute policy updates: use the update function of policy_sys and implement a create_episodes function.
+6. run evaluation during training and save policy checkpoints
 
-## Transformer based user simulator (TUS) results
+    ```
+    logging.info(f"Evaluating after Dialogues: {num_dialogues} - {time_now}" + '-' * 60)
+    eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
+    best_complete_rate, best_success_rate, best_return = \
+        save_best(policy_sys, best_complete_rate, best_success_rate, best_return,
+                  eval_dict["complete_rate"], eval_dict["success_rate_strict"],
+                  eval_dict["avg_return"], save_path)
+    policy_sys.save(save_path, "last")
+    for key in eval_dict:
+        tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz'])
+    ```
 
-| Model | Complete rate | Success rate | Average return | Turns | Average actions |
-|-------|---------------|--------------|----------------|-------|-----------------|
-| MLE   |               |              |                |       |                 |
-| PG    |               |              |                |       |                 |
-| GDPL  |               |              |                |       |                 |
-| PPO   |               |              |                |       |                 |
diff --git a/convlab/policy/evaluate.py b/convlab/policy/evaluate.py
index 78682e6ced2ccaa01e769c78c3388ad19c78c0c9..7a692261869f35e587c34a26e425d1489abdcf56 100755
--- a/convlab/policy/evaluate.py
+++ b/convlab/policy/evaluate.py
@@ -12,7 +12,8 @@ from convlab.dialog_agent.session import BiSession
 from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
 from convlab.policy.rule.multiwoz import RulePolicy
 from convlab.task.multiwoz.goal_generator import GoalGenerator
-from convlab.util.custom_util import set_seed, get_config, env_config, create_goals
+from convlab.util.custom_util import set_seed, get_config, env_config, create_goals, data_goals
+from tqdm import tqdm
 
 
 def init_logging(log_dir_path, path_suffix=None):
@@ -36,7 +37,7 @@ def init_logging(log_dir_path, path_suffix=None):
 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
-def evaluate(config_path, model_name, verbose=False):
+def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_data=False, dialogues=500):
     seed = 0
     set_seed(seed)
 
@@ -56,9 +57,15 @@ def evaluate(config_path, model_name, verbose=False):
     elif model_name == "GDPL":
         from convlab.policy.gdpl import GDPL
         policy_sys = GDPL(vectorizer=conf['vectorizer_sys_activated'])
+    elif model_name == "DDPT":
+        from convlab.policy.vtrace_DPT import VTRACE
+        policy_sys = VTRACE(is_train=False, vectorizer=conf['vectorizer_sys_activated'])
 
     try:
-        policy_sys.load(conf['model']['load_path'])
+        if model_path:
+            policy_sys.load(model_path)
+        else:
+            policy_sys.load(conf['model']['load_path'])
     except Exception as e:
         logging.info(f"Could not load a policy: {e}")
 
@@ -68,11 +75,16 @@ def evaluate(config_path, model_name, verbose=False):
     task_success = {'Complete': [], 'Success': [],
                     'Success strict': [], 'total_return': [], 'turns': []}
 
-    dialogues = 500
     goal_generator = GoalGenerator()
-    goals = create_goals(goal_generator, num_goals=dialogues, single_domains=False, allowed_domains=None)
+    if goals_from_data:
+        logging.info("read goals from dataset...")
+        goals = data_goals(dialogues, dataset="multiwoz21", dial_ids_order=0)
+    else:
+        logging.info("create goals from goal_generator...")
+        goals = create_goals(goal_generator, num_goals=dialogues,
+                             single_domains=False, allowed_domains=None)
 
-    for seed in range(1000, 1000 + dialogues):
+    for seed in tqdm(range(1000, 1000 + dialogues)):
         set_seed(seed)
         sess.init_session(goal=goals[seed-1000])
         sys_response = []
@@ -113,7 +125,10 @@ def evaluate(config_path, model_name, verbose=False):
                 task_succ = sess.evaluator.task_success()
                 task_succ = sess.evaluator.success
                 task_succ_strict = sess.evaluator.success_strict
-                complete = sess.evaluator.complete
+                if goals_from_data:
+                    complete = sess.user_agent.policy.policy.goal.task_complete()
+                else:
+                    complete = sess.evaluator.complete
                 break
 
         if verbose:
@@ -139,17 +154,29 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--model_name", type=str,
                         default="PPO", help="name of model")
-    parser.add_argument("--config_path", type=str,
-                        default='', help="path of model")
-    parser.add_argument("--verbose", action='store_true',
+    parser.add_argument("-C", "--config_path", type=str,
+                        default='', help="config path defining the environment for simulation and system pipeline")
+    parser.add_argument("--model_path", type=str,
+                        default='', help="if this is set, tries to load the model weights from this path"
+                                         ", otherwise from config")
+    parser.add_argument("-N", "--num_dialogues", type=int,
+                        default=500, help="# of evaluation dialogue")
+    parser.add_argument("-V", "--verbose", action='store_true',
                         help="whether to output utterances")
     parser.add_argument("--log_path_suffix", type=str,
                         default="", help="suffix of path of log file")
     parser.add_argument("--log_dir_path", type=str,
                         default="log", help="path of log directory")
+    parser.add_argument("-D", "--goals_from_data", action='store_true',
+                        help="load goal from the dataset")
 
     args = parser.parse_args()
 
     init_logging(log_dir_path=args.log_dir_path,
                  path_suffix=args.log_path_suffix)
-    evaluate(config_path=args.config_path, model_name=args.model_name, verbose=args.verbose)
+    evaluate(config_path=args.config_path,
+             model_name=args.model_name,
+             verbose=args.verbose,
+             model_path=args.model_path,
+             goals_from_data=args.goals_from_data,
+             dialogues=args.num_dialogues)
diff --git a/convlab/policy/evaluate_distributed.py b/convlab/policy/evaluate_distributed.py
index 2b362d865880d21a220db467e1de5e1c20eafb2e..1f7b3ffe93c040e6e18aa0ccd88b8c21fbcd178c 100644
--- a/convlab/policy/evaluate_distributed.py
+++ b/convlab/policy/evaluate_distributed.py
@@ -2,17 +2,13 @@
 
 import random
 import torch
-import sys
-import torch
-from pprint import pprint
-
-import matplotlib.pyplot as plt
 import numpy as np
-from convlab.policy.rlmodule import Memory_evaluator, Transition
+
+from convlab.policy.rlmodule import Memory_evaluator
 from torch import multiprocessing as mp
 
 
-def sampler(pid, queue, evt, sess, seed_range):
+def sampler(pid, queue, evt, sess, seed_range, goals):
     """
     This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
     processes.
@@ -31,7 +27,8 @@ def sampler(pid, queue, evt, sess, seed_range):
         torch.cuda.manual_seed(seed)
         random.seed(seed)
         np.random.seed(seed)
-        sess.init_session()
+        goal = goals.pop()
+        sess.init_session(goal=goal)
         sys_response = '' if sess.sys_agent.nlg is not None else []
         sys_response = [] if sess.sys_agent.return_semantic_acts else sys_response
         total_return_success = 0.0
@@ -46,6 +43,7 @@ def sampler(pid, queue, evt, sess, seed_range):
         request = 0
         select = 0
         offer = 0
+        recommend = 0
         task_success = {}
 
         for i in range(40):
@@ -70,6 +68,8 @@ def sampler(pid, queue, evt, sess, seed_range):
                     select += 1
                 if intent.lower() == 'offerbook':
                     offer += 1
+                if intent.lower() == 'recommend':
+                    recommend += 1
 
             if session_over is True:
                 success = sess.evaluator.task_success()
@@ -84,7 +84,7 @@ def sampler(pid, queue, evt, sess, seed_range):
             task_success[key].append(success_strict)
 
         buff.push(complete, success, success_strict, total_return_complete, total_return_success, turns, avg_actions / turns,
-                  task_success, book, inform, request, select, offer)
+                  task_success, book, inform, request, select, offer, recommend)
 
     # this is end of sampling all batchsz of items.
     # when sampling is over, push all buff data into queue
@@ -92,7 +92,7 @@ def sampler(pid, queue, evt, sess, seed_range):
     evt.wait()
 
 
-def sample(sess, seedrange, process_num):
+def sample(sess, seedrange, process_num, goals):
     """
     Given batchsz number of task, the batchsz will be splited equally to each processes
     and when processes return, it merge all data and return
@@ -112,7 +112,8 @@ def sample(sess, seedrange, process_num):
     processes = []
     for i in range(process_num):
         process_args = (
-            i, queue, evt, sess, seedrange[i * num_seeds_per_thread: (i+1) * num_seeds_per_thread])
+            i, queue, evt, sess, seedrange[i * num_seeds_per_thread: (i+1) * num_seeds_per_thread],
+            goals[i * num_seeds_per_thread: (i+1) * num_seeds_per_thread])
         processes.append(mp.Process(target=sampler, args=process_args))
     for p in processes:
         # set the process as daemon, and it will be killed once the main process is stoped.
@@ -132,13 +133,13 @@ def sample(sess, seedrange, process_num):
     return buff.get_batch()
 
 
-def evaluate_distributed(sess, seed_range, process_num):
+def evaluate_distributed(sess, seed_range, process_num, goals):
 
-    batch = sample(sess, seed_range, process_num)
-    return np.average(batch.complete), np.average(batch.success), np.average(batch.success_strict), \
-           np.average(batch.total_return_success), np.average(batch.turns), np.average(batch.avg_actions), \
-           batch.task_success, np.average(batch.book_actions), np.average(batch.inform_actions), np.average(batch.request_actions), \
-           np.average(batch.select_actions), np.average(batch.offer_actions)
+    batch = sample(sess, seed_range, process_num, goals)
+    return batch.complete, batch.success, batch.success_strict, batch.total_return_success, batch.turns, \
+           batch.avg_actions, batch.task_success, np.average(batch.book_actions), np.average(batch.inform_actions), \
+           np.average(batch.request_actions), np.average(batch.select_actions), np.average(batch.offer_actions), \
+           np.average(batch.recommend_actions)
 
 
 if __name__ == "__main__":
diff --git a/convlab/policy/gdpl/README.md b/convlab/policy/gdpl/README.md
index e9e62e96a733fa5a7e0e953df6294e97ae4499f7..7f7e4939c8fd9a2cac010885f9306c3601a4de82 100755
--- a/convlab/policy/gdpl/README.md
+++ b/convlab/policy/gdpl/README.md
@@ -1,33 +1,53 @@
-# GDPL
+# Guided Dialogue Policy Learning (GDPL)
 
-A join policy optimization and reward estimation method using adversarial inverse reinforcement learning that learns a dialog policy and builds a reward estimator simultaneously. The reward estimator evaluates the state-action pairs to guide the dialog policy at each dialog turn.
+GDPL uses the PPO algorithm to optimize the policy. The difference to vanilla PPO is that it is not using the extrinsic reward for optimization but leverages inverse reinforcement learning to train a reward estimator. This reward estimator provides the reward that should be optimized.
 
-## Train
+## Supervised pre-training
 
-Run `train.py` in the `gdpl` directory:
+If you want to obtain a supervised model for pre-training, please have a look in the MLE policy folder.
 
-```bash
-python train.py
+## RL training
+
+Starting a RL training is as easy as executing
+
+```sh
+$ python train.py --path=your_environment_config --seed=SEED
 ```
 
-For better performance, we can do immitating learning before reinforcement learning. The immitating learning is implemented in the `mle` directory.
+One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance
 
-For example, if the trained model of immitating learning is saved at FOLDER_OF_MODEL/best_mle.pol.mdl, then you can run
+- load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
+- process_num: the number of processes to use during evaluation to speed it up
+- num_eval_dialogues: how many evaluation dialogues should be used
+- epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update
+- eval_frequency: after how many epochs perform an evaluation
+- batchsz: the number of training dialogues collected before doing an update
 
-```bash
-python train.py --load_path FOLDER_OF_MODEL/best_mle
-```
+Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
+
+Parameters that are tied to the RL algorithm and the model architecture can be changed in config.json.
+
+
+## Evaluation
 
-Note that the *.pol.mdl* suffix should not appear in the --load_path argument.
+For creating evaluation plots and running evaluation dialogues, please have a look in the README of the policy folder.
 
-## Reference
+## References
 
 ```
-@inproceedings{takanobu2019guided,
-  title={Guided Dialog Policy Learning: Reward Estimation for Multi-Domain Task-Oriented Dialog},
-  author={Takanobu, Ryuichi and Zhu, Hanlin and Huang, Minlie},
-  booktitle={EMNLP-IJCNLP},
-  pages={100--110},
-  year={2019}
+@inproceedings{takanobu-etal-2019-guided,
+    title = "Guided Dialog Policy Learning: Reward Estimation for Multi-Domain Task-Oriented Dialog",
+    author = "Takanobu, Ryuichi  and
+      Zhu, Hanlin  and
+      Huang, Minlie",
+    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
+    month = nov,
+    year = "2019",
+    address = "Hong Kong, China",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/D19-1010",
+    doi = "10.18653/v1/D19-1010",
+    pages = "100--110",
+    abstract = "Dialog policy decides what and how a task-oriented dialog system will respond, and plays a vital role in delivering effective conversations. Many studies apply Reinforcement Learning to learn a dialog policy with the reward function which requires elaborate design and pre-specified user goals. With the growing needs to handle complex goals across multiple domains, such manually designed reward functions are not affordable to deal with the complexity of real-world tasks. To this end, we propose Guided Dialog Policy Learning, a novel algorithm based on Adversarial Inverse Reinforcement Learning for joint reward estimation and policy optimization in multi-domain task-oriented dialog. The proposed approach estimates the reward signal and infers the user goal in the dialog sessions. The reward estimator evaluates the state-action pairs so that it can guide the dialog policy at each dialog turn. Extensive experiments on a multi-domain dialog dataset show that the dialog policy guided by the learned reward function achieves remarkably higher task success than state-of-the-art baselines.",
 }
 ```
\ No newline at end of file
diff --git a/convlab/policy/gdpl/semantic_level_config.json b/convlab/policy/gdpl/semantic_level_config.json
index e64159c7d9eb5ff9e5c910f9cd27786cfa2d16be..0dcd46662620371c7647b05d60df0402e3737f0a 100644
--- a/convlab/policy/gdpl/semantic_level_config.json
+++ b/convlab/policy/gdpl/semantic_level_config.json
@@ -5,7 +5,7 @@
 		"pretrained_load_path": "",
 		"batchsz": 1000,
 		"seed": 0,
-		"epoch": 50,
+		"epoch": 10,
 		"eval_frequency": 5,
 		"process_num": 4,
 		"sys_semantic_to_usr": false,
diff --git a/convlab/policy/gdpl/train.py b/convlab/policy/gdpl/train.py
index 3d560b5650d5eb87e02334a42d62b4bdfd097279..15e9b7d08812a53c917b7d9f2f9e6aa60bfa97c7 100755
--- a/convlab/policy/gdpl/train.py
+++ b/convlab/policy/gdpl/train.py
@@ -186,7 +186,7 @@ if __name__ == '__main__':
     parser = ArgumentParser()
     parser.add_argument("--path", type=str, default='convlab/policy/gdpl/semantic_level_config.json',
                         help="Load path for config file")
-    parser.add_argument("--seed", type=int, default=0,
+    parser.add_argument("--seed", type=int, default=None,
                         help="Seed for the policy parameter initialization")
     parser.add_argument("--pretrain", action='store_true', help="whether to pretrain the reward estimator")
     parser.add_argument("--mode", type=str, default='info',
@@ -202,7 +202,7 @@ if __name__ == '__main__':
     logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
         init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
 
-    args = [('model', 'seed', seed)]
+    args = [('model', 'seed', seed)] if seed is not None else list()
 
     environment_config = load_config_file(path)
     save_config(vars(parser.parse_args()), environment_config, config_save_path)
@@ -265,7 +265,7 @@ if __name__ == '__main__':
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-            logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60)
+            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
 
             eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
diff --git a/convlab/policy/genTUS/README.md b/convlab/policy/genTUS/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f1a3687e6d7cbf2c2ae95666ec0300a597b91ce5
--- /dev/null
+++ b/convlab/policy/genTUS/README.md
@@ -0,0 +1,72 @@
+**GenTUS** is a data-driven user simulator with transformers, which can generate semantic actions and utterence. It is able to trasfer to a new ontology in a zero-shot fashion.
+
+## Introduction
+We propose a generative transform-based user simulator (GenTUS) in this work. GenTUS consists of an encoder-decoder structure, which can optimise both the user policy and natural language generation jointly. GenTUS generates semantic actions and natural language utterances, preserving interpretability and enhancing language variation. 
+
+The code of TUS is in `convlab/policy/genTUS`.
+
+## Usage
+### Train GenTUS from scratch
+You need to generate the input files by `build_data.py`, then train the model by `train_model.py`.
+```
+python3 convlab/policy/genTUS/unify/build_data.py --dataset $dataset --add-history --dial-ids-order $dial_ids_order --split2ratio $split2ratio
+python3 convlab/policy/genTUS/train_model.py --data-name $dataset --dial-ids-order $dial_ids_order --split2ratio $split2ratio --batch-size 8
+```
+
+`dataset` can be `multiwoz21`, `sgd`, `tm`, `sgd+tm`, or `all`.
+`dial_ids_order` can be 0, 1 or 2
+`split2ratio` can be 0.01, 0.1 or 1
+
+The `build_data.py` will generate three files, `train.json`, `validation.json`, and `test.json`, under the folder `convlab/policy/genTUS/unify/data/${dataset}_${dial_ids_order}_${split2ration}`. 
+We trained GenTUS on A100 or RTX6000.
+
+### Evaluate TUS
+```
+python3 convlab/policy/genTUS/evaluate.py --model-checkpoint $model_checkpoint --input-file $in_file --dataset $dataset --do-nlg
+```
+The `in_file` is the file generated by `build_data.py`.
+
+### Train a dialogue policy with GenTUS
+You can use it as a normal user simulator by `PipelineAgent`. For example,
+```python
+from convlab.dialog_agent import PipelineAgent
+from convlab.util.custom_util import set_seed
+
+model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21-exp'
+usr_policy = UserPolicy(model_checkpoint, mode="semantic")
+simulator = PipelineAgent(None, None, usr_policy, None, 'user')
+```
+then you can train your system with this simulator.
+
+You can also change the `mode` to `"language"`, then GenTUS will response in natural language instead of semantic actions.
+
+
+<!---citation--->
+## Citing
+
+```
+@inproceedings{lin-etal-2022-gentus,
+    title = "{G}en{TUS}: Simulating User Behaviour and Language in Task-oriented Dialogues with Generative Transformers",
+    author = "Lin, Hsien-chin  and
+      Geishauser, Christian  and
+      Feng, Shutong  and
+      Lubis, Nurul  and
+      van Niekerk, Carel  and
+      Heck, Michael  and
+      Gasic, Milica",
+    booktitle = "Proceedings of the 23rd Annual Meeting of the Special Interest Group on Discourse and Dialogue",
+    month = sep,
+    year = "2022",
+    address = "Edinburgh, UK",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/2022.sigdial-1.28",
+    pages = "270--282",
+    abstract = "User simulators (USs) are commonly used to train task-oriented dialogue systems via reinforcement learning. The interactions often take place on semantic level for efficiency, but there is still a gap from semantic actions to natural language, which causes a mismatch between training and deployment environment. Incorporating a natural language generation (NLG) module with USs during training can partly deal with this problem. However, since the policy and NLG of USs are optimised separately, these simulated user utterances may not be natural enough in a given context. In this work, we propose a generative transformer-based user simulator (GenTUS). GenTUS consists of an encoder-decoder structure, which means it can optimise both the user policy and natural language generation jointly. GenTUS generates both semantic actions and natural language utterances, preserving interpretability and enhancing language variation. In addition, by representing the inputs and outputs as word sequences and by using a large pre-trained language model we can achieve generalisability in feature representation. We evaluate GenTUS with automatic metrics and human evaluation. Our results show that GenTUS generates more natural language and is able to transfer to an unseen ontology in a zero-shot fashion. In addition, its behaviour can be further shaped with reinforcement learning opening the door to training specialised user simulators.",
+}
+
+
+```
+
+## License
+
+Apache License 2.0
diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..87de854970d2701900ba180d2bf15736071e0c1a
--- /dev/null
+++ b/convlab/policy/genTUS/evaluate.py
@@ -0,0 +1,257 @@
+import json
+import os
+import sys
+from argparse import ArgumentParser
+from pprint import pprint
+
+import torch
+from convlab.nlg.evaluate import fine_SER
+from datasets import load_metric
+
+# from convlab.policy.genTUS.pg.stepGenTUSagent import \
+#     stepGenTUSPG as UserPolicy
+from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
+from tqdm import tqdm
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    parser.add_argument("--model-checkpoint", type=str, help="the model path")
+    parser.add_argument("--model-weight", type=str,
+                        help="the model weight", default="")
+    parser.add_argument("--input-file", type=str, help="the testing input file",
+                        default="")
+    parser.add_argument("--generated-file", type=str, help="the generated results",
+                        default="")
+    parser.add_argument("--only-action", action="store_true")
+    parser.add_argument("--dataset", default="multiwoz")
+    parser.add_argument("--do-semantic", action="store_true",
+                        help="do semantic evaluation")
+    parser.add_argument("--do-nlg", action="store_true",
+                        help="do nlg generation")
+    parser.add_argument("--do-golden-nlg", action="store_true",
+                        help="do golden nlg generation")
+    return parser.parse_args()
+
+
+class Evaluator:
+    def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False):
+        self.dataset = dataset
+        self.model_checkpoint = model_checkpoint
+        self.model_weight = model_weight
+        # if model_weight:
+        #     self.usr_policy = UserPolicy(
+        #         self.model_checkpoint, only_action=only_action)
+        #     self.usr_policy.load(model_weight)
+        #     self.usr = self.usr_policy.usr
+        # else:
+        self.usr = UserActionPolicy(
+            model_checkpoint, only_action=only_action, dataset=self.dataset)
+        self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
+
+    def generate_results(self, f_eval, golden=False):
+        in_file = json.load(open(f_eval))
+        r = {
+            "input": [],
+            "golden_acts": [],
+            "golden_utts": [],
+            "gen_acts": [],
+            "gen_utts": []
+        }
+        for dialog in tqdm(in_file['dialog']):
+            inputs = dialog["in"]
+            labels = self.usr._parse_output(dialog["out"])
+            if golden:
+                usr_act = labels["action"]
+                usr_utt = self.usr.generate_text_from_give_semantic(
+                    inputs, usr_act)
+
+            else:
+                output = self.usr._parse_output(
+                    self.usr._generate_action(inputs))
+                usr_act = self.usr._remove_illegal_action(output["action"])
+                usr_utt = output["text"]
+            r["input"].append(inputs)
+            r["golden_acts"].append(labels["action"])
+            r["golden_utts"].append(labels["text"])
+            r["gen_acts"].append(usr_act)
+            r["gen_utts"].append(usr_utt)
+
+        return r
+
+    def read_generated_result(self, f_eval):
+        in_file = json.load(open(f_eval))
+        r = {
+            "input": [],
+            "golden_acts": [],
+            "golden_utts": [],
+            "gen_acts": [],
+            "gen_utts": []
+        }
+        for dialog in tqdm(in_file['dialog']):
+            for x in dialog:
+                r[x].append(dialog[x])
+
+        return r
+
+    def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
+        if input_file:
+            print("Force generation")
+            gen_r = self.generate_results(input_file, golden)
+
+        elif generated_file:
+            gen_r = self.read_generated_result(generated_file)
+        else:
+            print("You must specify the input_file or the generated_file")
+
+        nlg_eval = {
+            "golden": golden,
+            "metrics": {},
+            "dialog": []
+        }
+        for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]):
+            nlg_eval["dialog"].append({
+                "input": input,
+                "golden_acts": golden_act,
+                "golden_utts": golden_utt,
+                "gen_acts": gen_act,
+                "gen_utts": gen_utt
+            })
+
+        if golden:
+            print("Calculate BLEU")
+            bleu_metric = load_metric("sacrebleu")
+            labels = [[utt] for utt in gen_r["golden_utts"]]
+
+            bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
+                                             references=labels,
+                                             force=True)
+            print("bleu_metric", bleu_score)
+            nlg_eval["metrics"]["bleu"] = bleu_score
+
+        else:
+            print("Calculate SER")
+            missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
+                gen_r["gen_acts"], gen_r["gen_utts"])
+
+            print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
+                "genTUSNLG", missing, total, hallucinate, missing/total))
+            nlg_eval["metrics"]["SER"] = missing/total
+
+        dir_name = self.model_checkpoint
+        json.dump(nlg_eval,
+                  open(os.path.join(dir_name, "nlg_eval.json"), 'w'),
+                  indent=2)
+        return os.path.join(dir_name, "nlg_eval.json")
+
+    def evaluation(self, input_file=None, generated_file=None):
+        force_prediction = True
+        if generated_file:
+            gen_file = json.load(open(generated_file))
+            force_prediction = False
+            if gen_file["golden"]:
+                force_prediction = True
+
+        if force_prediction:
+            in_file = json.load(open(input_file))
+            dialog_result = []
+            gen_acts, golden_acts = [], []
+            # scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
+            for dialog in tqdm(in_file['dialog']):
+                inputs = dialog["in"]
+                labels = self.usr._parse_output(dialog["out"])
+                ans_action = self.usr._remove_illegal_action(labels["action"])
+                preds = self.usr._generate_action(inputs)
+                preds = self.usr._parse_output(preds)
+                usr_action = self.usr._remove_illegal_action(preds["action"])
+
+                gen_acts.append(usr_action)
+                golden_acts.append(ans_action)
+
+                d = {"input": inputs,
+                     "golden_acts": ans_action,
+                     "gen_acts": usr_action}
+                if "text" in preds:
+                    d["golden_utts"] = labels["text"]
+                    d["gen_utts"] = preds["text"]
+                    # print("pred text", preds["text"])
+
+                dialog_result.append(d)
+        else:
+            gen_acts, golden_acts = [], []
+            for dialog in gen_file['dialog']:
+                gen_acts.append(dialog["gen_acts"])
+                golden_acts.append(dialog["golden_acts"])
+            dialog_result = gen_file['dialog']
+
+        scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
+
+        for gen_act, golden_act in zip(gen_acts, golden_acts):
+            s = f1_measure(preds=gen_act, labels=golden_act)
+            for metric in scores:
+                scores[metric].append(s[metric])
+
+        result = {}
+        for metric in scores:
+            result[metric] = sum(scores[metric])/len(scores[metric])
+            print(f"{metric}: {result[metric]}")
+
+        result["dialog"] = dialog_result
+        basename = "semantic_evaluation_result"
+        json.dump(result, open(os.path.join(
+            self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
+        # if self.model_weight:
+        #     json.dump(result, open(os.path.join(
+        #         'results', f"{basename}.json"), 'w'))
+        # else:
+        #     json.dump(result, open(os.path.join(
+        #         self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
+
+
+def f1_measure(preds, labels):
+    tp = 0
+    score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
+    for p in preds:
+        if p in labels:
+            tp += 1.0
+    if preds:
+        score["precision"] = tp/len(preds)
+    if labels:
+        score["recall"] = tp/len(labels)
+    if (score["precision"] + score["recall"]) > 0:
+        score["f1"] = 2*(score["precision"]*score["recall"]) / \
+            (score["precision"]+score["recall"])
+    if tp == len(preds) and tp == len(labels):
+        score["turn_acc"] = 1
+    return score
+
+
+def main():
+    args = arg_parser()
+    eval = Evaluator(args.model_checkpoint,
+                     args.dataset,
+                     args.model_weight,
+                     args.only_action)
+    print("model checkpoint", args.model_checkpoint)
+    print("generated_file", args.generated_file)
+    print("input_file", args.input_file)
+    with torch.no_grad():
+        if args.do_semantic:
+            eval.evaluation(args.input_file)
+        if args.do_nlg:
+            nlg_result = eval.nlg_evaluation(input_file=args.input_file,
+                                             generated_file=args.generated_file,
+                                             golden=args.do_golden_nlg)
+            if args.generated_file:
+                generated_file = args.generated_file
+            else:
+                generated_file = nlg_result
+            eval.evaluation(args.input_file,
+                            generated_file)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/convlab/policy/genTUS/ppo/vector.py b/convlab/policy/genTUS/ppo/vector.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c502a46f87582008ff49219f8a14844378b9ed2
--- /dev/null
+++ b/convlab/policy/genTUS/ppo/vector.py
@@ -0,0 +1,148 @@
+import json
+
+import torch
+from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
+from convlab.policy.genTUS.token_map import tokenMap
+from convlab.policy.tus.unify.Goal import Goal
+from transformers import BartTokenizer
+
+
+class stepGenTUSVector:
+    def __init__(self, model_checkpoint, max_in_len=400, max_out_len=80, allow_general_intent=True):
+        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
+        self.vocab = len(self.tokenizer)
+        self.max_in_len = max_in_len
+        self.max_out_len = max_out_len
+        self.token_map = tokenMap(tokenizer=self.tokenizer)
+        self.token_map.default(only_action=True)
+        self.kg = KnowledgeGraph(self.tokenizer)
+        self.mentioned_domain = []
+        self.allow_general_intent = allow_general_intent
+        self.candidate_num = 5
+        if self.allow_general_intent:
+            print("---> allow_general_intent")
+
+    def init_session(self, goal: Goal):
+        self.goal = goal
+        self.mentioned_domain = []
+
+    def encode(self, raw_inputs, max_length, return_tensors="pt", truncation=True):
+        model_input = self.tokenizer(raw_inputs,
+                                     max_length=max_length,
+                                     return_tensors=return_tensors,
+                                     truncation=truncation,
+                                     padding="max_length")
+        return model_input
+
+    def decode(self, generated_so_far, skip_special_tokens=True):
+        output = self.tokenizer.decode(
+            generated_so_far, skip_special_tokens=skip_special_tokens)
+        return output
+
+    def state_vectorize(self, action, history, turn):
+        self.goal.update_user_goal(action=action)
+        inputs = json.dumps({"system": action,
+                             "goal": self.goal.get_goal_list(),
+                             "history": history,
+                             "turn": str(turn)})
+        inputs = self.encode(inputs, self.max_in_len)
+        s_vec, action_mask = inputs["input_ids"][0], inputs["attention_mask"][0]
+
+        return s_vec, action_mask
+
+    def action_vectorize(self, action, s=None):
+        # action:  [[intent, domain, slot, value], ...]
+        vec = {"vector": torch.tensor([]), "mask": torch.tensor([])}
+        if s is not None:
+            raw_inputs = self.decode(s[0])
+            self.kg.parse_input(raw_inputs)
+
+        self._append(vec, self._get_id("<s>"))
+        self._append(vec, self.token_map.get_id('start_json'))
+        self._append(vec, self.token_map.get_id('start_act'))
+
+        act_len = len(action)
+        for i, (intent, domain, slot, value) in enumerate(action):
+            if value == '?':
+                value = '<?>'
+            c_idx = {x: None for x in ["intent", "domain", "slot", "value"]}
+
+            if s is not None:
+                c_idx["intent"] = self._candidate_id(self.kg.candidate(
+                    "intent", allow_general_intent=self.allow_general_intent))
+                c_idx["domain"] = self._candidate_id(self.kg.candidate(
+                    "domain", intent=intent))
+                c_idx["slot"] = self._candidate_id(self.kg.candidate(
+                    "slot", intent=intent, domain=domain, is_mentioned=self.is_mentioned(domain)))
+                c_idx["value"] = self._candidate_id(self.kg.candidate(
+                    "value", intent=intent, domain=domain, slot=slot))
+
+            self._append(vec, self._get_id(intent), c_idx["intent"])
+            self._append(vec, self.token_map.get_id('sep_token'))
+            self._append(vec, self._get_id(domain), c_idx["domain"])
+            self._append(vec, self.token_map.get_id('sep_token'))
+            self._append(vec, self._get_id(slot), c_idx["slot"])
+            self._append(vec, self.token_map.get_id('sep_token'))
+            self._append(vec, self._get_id(value), c_idx["value"])
+
+            c_idx = [0]*self.candidate_num
+            c_idx[0] = self.token_map.get_id('end_act')[0]
+            c_idx[1] = self.token_map.get_id('sep_act')[0]
+            if i == act_len - 1:
+                x = self.token_map.get_id('end_act')
+            else:
+                x = self.token_map.get_id('sep_act')
+
+            self._append(vec, x, c_idx)
+
+        self._append(vec, self._get_id("</s>"))
+
+        # pad
+        if len(vec["vector"]) < self.max_out_len:
+            pad_len = self.max_out_len-len(vec["vector"])
+            self._append(vec, x=torch.tensor([1]*pad_len))
+        for vec_type in vec:
+            vec[vec_type] = vec[vec_type].to(torch.int64)
+
+        return vec
+
+    def _append(self, vec, x, candidate=None):
+        if type(x) is list:
+            x = torch.tensor(x)
+        mask = self._mask(x, candidate)
+        vec["vector"] = torch.cat((vec["vector"], x), dim=-1)
+        vec["mask"] = torch.cat((vec["mask"], mask), dim=0)
+
+    def _mask(self, idx, c_idx=None):
+        mask = torch.zeros(len(idx), self.candidate_num)
+        mask[:, 0] = idx
+        if c_idx is not None and len(c_idx) > 1:
+            mask[0, :] = torch.tensor(c_idx)
+
+        return mask
+
+    def _candidate_id(self, candidate):
+        if len(candidate) > self.candidate_num:
+            print(f"too many candidates. Max = {self.candidate_num}")
+        c_idx = [0]*self.candidate_num
+        for i, idx in enumerate([self._get_id(c)[0] for c in candidate[:self.candidate_num]]):
+            c_idx[i] = idx
+        return c_idx
+
+    def _get_id(self, value):
+        token_id = self.tokenizer(value, add_special_tokens=False)
+        return token_id["input_ids"]
+
+    def action_devectorize(self, action_id):
+        return self.decode(action_id)
+
+    def update_mentioned_domain(self, semantic_act):
+        for act in semantic_act:
+            domain = act[1]
+            if domain not in self.mentioned_domain:
+                self.mentioned_domain.append(domain)
+
+    def is_mentioned(self, domain):
+        if domain in self.mentioned_domain:
+            return True
+        return False
diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b5af9f3315e9db1cc98e4618e1800155d97e670
--- /dev/null
+++ b/convlab/policy/genTUS/stepGenTUS.py
@@ -0,0 +1,655 @@
+import json
+import os
+
+import torch
+from transformers import BartTokenizer
+
+from convlab.policy.genTUS.ppo.vector import stepGenTUSVector
+from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel
+from convlab.policy.genTUS.token_map import tokenMap
+from convlab.policy.genTUS.unify.Goal import Goal
+from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
+from convlab.policy.policy import Policy
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+from convlab.util.custom_util import model_downloader
+
+
+DEBUG = False
+
+
+class UserActionPolicy(Policy):
+    def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
+        self.mode = mode
+        # if mode == "semantic" and only_action:
+        #     # only generate semantic action in prediction
+        print("model_checkpoint", model_checkpoint)
+        self.only_action = only_action
+        if self.only_action:
+            print("change mode to semantic because only_action=True")
+            self.mode = "semantic"
+        self.max_in_len = 500
+        self.max_out_len = 100 if only_action else 200
+        max_act_len = kwargs.get("max_act_len", 2)
+        print("max_act_len", max_act_len)
+        self.max_action_len = max_act_len
+        if "max_act_len" in kwargs:
+            self.max_out_len = 30 * self.max_action_len
+            print("max_act_len", self.max_out_len)
+        self.max_turn = max_turn
+        if mode not in ["semantic", "language"]:
+            print("Unknown user mode")
+
+        self.reward = {"success":  self.max_turn*2,
+                       "fail": self.max_turn*-1}
+        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+        train_whole_model = kwargs.get("whole_model", True)
+        self.model = stepGenTUSmodel(
+            model_checkpoint, train_whole_model=train_whole_model)
+        self.model.eval()
+        self.model.to(self.device)
+        self.model.share_memory()
+
+        self.turn_level_reward = kwargs.get("turn_level_reward", True)
+        self.cooperative = kwargs.get("cooperative", True)
+
+        dataset = kwargs.get("dataset", "")
+        self.kg = KnowledgeGraph(
+            tokenizer=self.tokenizer,
+            dataset=dataset)
+
+        self.goal_gen = GoalGenerator()
+
+        self.vector = stepGenTUSVector(
+            model_checkpoint, self.max_in_len, self.max_out_len)
+        self.norm_reward = False
+
+        self.action_penalty = kwargs.get("action_penalty", False)
+        self.usr_act_penalize = kwargs.get("usr_act_penalize", 0)
+        self.goal_list_type = kwargs.get("goal_list_type", "normal")
+        self.update_mode = kwargs.get("update_mode", "normal")
+        self.max_history = kwargs.get("max_history", 3)
+        self.init_session()
+
+    def _update_seq(self, sub_seq: list, pos: int):
+        for x in sub_seq:
+            self.seq[0, pos] = x
+            pos += 1
+
+        return pos
+
+    def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True):
+        # TODO no duplicate
+        self.kg.parse_input(raw_inputs)
+        model_input = self.vector.encode(raw_inputs, self.max_in_len)
+        # start token
+        self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
+        pos = self._update_seq([0], 0)
+        pos = self._update_seq(self.token_map.get_id('start_json'), pos)
+        pos = self._update_seq(self.token_map.get_id('start_act'), pos)
+
+        # get semantic actions
+        for act_len in range(self.max_action_len):
+            pos = self._get_semantic_action(
+                model_input, pos, mode, allow_general_intent)
+
+            terminate, token_name = self._stop_semantic(
+                model_input, pos, act_len)
+            pos = self._update_seq(self.token_map.get_id(token_name), pos)
+
+            if terminate:
+                break
+
+        if self.only_action:
+            # return semantic action. Don't need to generate text
+            return self.vector.decode(self.seq[0, :pos])
+
+        # TODO remove illegal action here?
+
+        # get text output
+        pos = self._update_seq(self.token_map.get_id("start_text"), pos)
+
+        text = self._get_text(model_input, pos)
+
+        return text
+
+    def generate_text_from_give_semantic(self, raw_inputs, semantic_action):
+        self.kg.parse_input(raw_inputs)
+        model_input = self.vector.encode(raw_inputs, self.max_in_len)
+        self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
+        pos = self._update_seq([0], 0)
+        pos = self._update_seq(self.token_map.get_id('start_json'), pos)
+        pos = self._update_seq(self.token_map.get_id('start_act'), pos)
+
+        if len(semantic_action) == 0:
+            pos = self._update_seq(self.token_map.get_id("end_act"), pos)
+
+        for act_id, (intent, domain, slot, value) in enumerate(semantic_action):
+            pos = self._update_seq(self.kg._get_token_id(intent), pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+            pos = self._update_seq(self.kg._get_token_id(domain), pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+            pos = self._update_seq(self.kg._get_token_id(slot), pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+            pos = self._update_seq(self.kg._get_token_id(value), pos)
+
+            if act_id == len(semantic_action) - 1:
+                token_name = "end_act"
+            else:
+                token_name = "sep_act"
+            pos = self._update_seq(self.token_map.get_id(token_name), pos)
+        pos = self._update_seq(self.token_map.get_id("start_text"), pos)
+
+        raw_output = self._get_text(model_input, pos)
+        return self._parse_output(raw_output)["text"]
+
+    def _get_text(self, model_input, pos):
+        s_pos = pos
+        for i in range(s_pos, self.max_out_len):
+            next_token_logits = self.model.get_next_token_logits(
+                model_input, self.seq[:1, :pos])
+            next_token = torch.argmax(next_token_logits, dim=-1)
+
+            if self._stop_text(next_token):
+                # text = self.vector.decode(self.seq[0, s_pos:pos])
+                # text = self._norm_str(text)
+                # return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
+                break
+
+            pos = self._update_seq([next_token], pos)
+        text = self.vector.decode(self.seq[0, s_pos:pos])
+        text = self._norm_str(text)
+        return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
+        # TODO return None
+
+    def _stop_text(self, next_token):
+        if next_token == self.token_map.get_id("end_json")[0]:
+            return True
+        elif next_token == self.token_map.get_id("end_json_2")[0]:
+            return True
+
+        return False
+
+    @staticmethod
+    def _norm_str(text: str):
+        text = text.strip('"')
+        text = text.replace('"', "'")
+        text = text.replace('\\', "")
+        return text
+
+    def _stop_semantic(self, model_input, pos, act_length=0):
+
+        outputs = self.model.get_next_token_logits(
+            model_input, self.seq[:1, :pos])
+        tokens = {}
+        for token_name in ['sep_act', 'end_act']:
+            tokens[token_name] = {
+                "token_id": self.token_map.get_id(token_name)}
+            hash_id = tokens[token_name]["token_id"][0]
+            tokens[token_name]["score"] = outputs[:, hash_id].item()
+
+        if tokens['end_act']["score"] > tokens['sep_act']["score"]:
+            terminate = True
+        else:
+            terminate = False
+
+        if act_length >= self.max_action_len - 1:
+            terminate = True
+
+        token_name = "end_act" if terminate else "sep_act"
+
+        return terminate, token_name
+
+    def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True):
+
+        intent = self._get_intent(
+            model_input, self.seq[:1, :pos], mode, allow_general_intent)
+        pos = self._update_seq(intent["token_id"], pos)
+        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
+        # get domain
+        domain = self._get_domain(
+            model_input, self.seq[:1, :pos], intent["token_name"], mode)
+        pos = self._update_seq(domain["token_id"], pos)
+        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
+        # get slot
+        slot = self._get_slot(
+            model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode)
+        pos = self._update_seq(slot["token_id"], pos)
+        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
+        # get value
+
+        value = self._get_value(
+            model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode)
+        pos = self._update_seq(value["token_id"], pos)
+
+        return pos
+
+    def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+
+        return self.kg.get_intent(next_token_logits, mode, allow_general_intent)
+
+    def _get_domain(self, model_input, generated_so_far, intent, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+
+        return self.kg.get_domain(next_token_logits, intent, mode)
+
+    def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+        is_mentioned = self.vector.is_mentioned(domain)
+        return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned)
+
+    def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+
+        return self.kg.get_value(next_token_logits, intent, domain, slot, mode)
+
+    def _remove_illegal_action(self, action):
+        # Transform illegal action to legal action
+        new_action = []
+        for act in action:
+            if len(act) == 4:
+                if "<?>" in act[-1]:
+                    act = [act[0], act[1], act[2], "?"]
+                if act not in new_action:
+                    new_action.append(act)
+            else:
+                print("illegal action:", action)
+        return new_action
+
+    def _parse_output(self, in_str):
+        in_str = str(in_str)
+        in_str = in_str.replace('<s>', '').replace(
+            '<\\s>', '').replace('o"clock', "o'clock")
+        action = {"action": [], "text": ""}
+        try:
+            action = json.loads(in_str)
+        except:
+            print("invalid action:", in_str)
+            print("-"*20)
+        return action
+
+    def predict(self, sys_act, mode="max", allow_general_intent=True):
+        # raw_sys_act = sys_act
+        # sys_act = sys_act[:5]
+        # update goal
+        # TODO
+        allow_general_intent = False
+        self.model.eval()
+
+        if not self.add_sys_from_reward:
+            self.goal.update_user_goal(action=sys_act, char="sys")
+            self.sys_acts.append(sys_act)  # for terminate conversation
+
+        # update constraint
+        self.time_step += 2
+
+        history = []
+        if self.usr_acts:
+            if self.max_history == 1:
+                history = self.usr_acts[-1]
+            else:
+                history = self.usr_acts[-1*self.max_history:]
+        inputs = json.dumps({"system": sys_act,
+                             "goal": self.goal.get_goal_list(),
+                             "history": history,
+                             "turn": str(int(self.time_step/2))})
+        with torch.no_grad():
+            raw_output = self._generate_action(
+                raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
+        output = self._parse_output(raw_output)
+        self.semantic_action = self._remove_illegal_action(output["action"])
+        if not self.only_action:
+            self.utterance = output["text"]
+
+        # TODO
+        if self.is_finish():
+            self.semantic_action, self.utterance = self._good_bye()
+
+        # if self.is_finish():
+        #     print("terminated")
+
+        # if self.is_finish():
+        #     good_bye = self._good_bye()
+        #     self.goal.add_usr_da(good_bye)
+        #     return good_bye
+
+        self.goal.update_user_goal(action=self.semantic_action, char="usr")
+        self.vector.update_mentioned_domain(self.semantic_action)
+        self.usr_acts.append(self.semantic_action)
+
+        # if self._usr_terminate(usr_action):
+        #     print("terminated by user")
+        #     self.terminated = True
+
+        del inputs
+
+        if self.mode == "language":
+            # print("in", sys_act)
+            # print("out", self.utterance)
+            return self.utterance
+        else:
+            return self.semantic_action
+
+    def init_session(self, goal=None):
+        self.token_map = tokenMap(tokenizer=self.tokenizer)
+        self.token_map.default(only_action=self.only_action)
+        self.time_step = 0
+        remove_domain = "police"  # remove police domain in inference
+
+        if not goal:
+            self._new_goal(remove_domain=remove_domain)
+        else:
+            self._read_goal(goal)
+
+        self.vector.init_session(goal=self.goal)
+
+        self.terminated = False
+        self.add_sys_from_reward = False
+        self.sys_acts = []
+        self.usr_acts = []
+        self.semantic_action = []
+        self.utterance = ""
+
+    def _read_goal(self, data_goal):
+        self.goal = Goal(goal=data_goal)
+
+    def _new_goal(self, remove_domain="police", domain_len=None):
+        self.goal = Goal(goal_generator=self.goal_gen)
+        # keep_generate_goal = True
+        # # domain_len = 1
+        # while keep_generate_goal:
+        #     self.goal = Goal(goal_generator=self.goal_gen,
+        #                      goal_list_type=self.goal_list_type,
+        #                      update_mode=self.update_mode)
+        #     if (domain_len and len(self.goal.domains) != domain_len) or \
+        #             (remove_domain and remove_domain in self.goal.domains):
+        #         keep_generate_goal = True
+        #     else:
+        #         keep_generate_goal = False
+
+    def load(self, model_path):
+        self.model.load_state_dict(torch.load(
+            model_path, map_location=self.device))
+        # self.model = BartForConditionalGeneration.from_pretrained(
+        #     model_checkpoint)
+
+    def get_goal(self):
+        if self.goal.raw_goal is not None:
+            return self.goal.raw_goal
+        goal = {}
+        for domain in self.goal.domain_goals:
+            if domain not in goal:
+                goal[domain] = {}
+            for intent in self.goal.domain_goals[domain]:
+                if intent == "inform":
+                    slot_type = "info"
+                elif intent == "request":
+                    slot_type = "reqt"
+                elif intent == "book":
+                    slot_type = "book"
+                else:
+                    print("unknown slot type")
+                if slot_type not in goal[domain]:
+                    goal[domain][slot_type] = {}
+                for slot, value in self.goal.domain_goals[domain][intent].items():
+                    goal[domain][slot_type][slot] = value
+        return goal
+
+    def get_reward(self, sys_response=None):
+        self.add_sys_from_reward = False if sys_response is None else True
+
+        if self.add_sys_from_reward:
+            self.goal.update_user_goal(action=sys_response, char="sys")
+            self.goal.add_sys_da(sys_response)  # for evaluation
+            self.sys_acts.append(sys_response)  # for terminate conversation
+
+        if self.is_finish():
+            if self.is_success():
+                reward = self.reward["success"]
+                self.success = True
+            else:
+                reward = self.reward["fail"]
+                self.success = False
+
+        else:
+            reward = -1
+            if self.turn_level_reward:
+                reward += self.turn_reward()
+
+            self.success = None
+            # if self.action_penalty:
+            #     reward += self._system_action_penalty()
+
+        if self.norm_reward:
+            reward = (reward - 20)/60
+        return reward
+
+    def _system_action_penalty(self):
+        free_action_len = 3
+        if len(self.sys_acts) < 1:
+            return 0
+        # TODO only penalize the slots not in user goal
+        # else:
+        #     penlaty = 0
+        #     for i in range(len(self.sys_acts[-1])):
+        #         penlaty += -1*i
+        #     return penlaty
+        if len(self.sys_acts[-1]) > 3:
+            return -1*(len(self.sys_acts[-1])-free_action_len)
+        return 0
+
+    def turn_reward(self):
+        r = 0
+        r += self._new_act_reward()
+        r += self._reply_reward()
+        r += self._usr_act_len()
+        return r
+
+    def _usr_act_len(self):
+        last_act = self.usr_acts[-1]
+        penalty = 0
+        if len(last_act) > 2:
+            penalty = (2-len(last_act))*self.usr_act_penalize
+        return penalty
+
+    def _new_act_reward(self):
+        last_act = self.usr_acts[-1]
+        if last_act != self.semantic_action:
+            print(f"---> why? last {last_act} usr {self.semantic_action}")
+        new_act = []
+        for act in last_act:
+            if len(self.usr_acts) < 2:
+                break
+            if act[1].lower() == "general":
+                new_act.append(0)
+            elif act in self.usr_acts[-2]:
+                new_act.append(-1)
+            elif act not in self.usr_acts[-2]:
+                new_act.append(1)
+
+        return sum(new_act)
+
+    def _reply_reward(self):
+        if self.cooperative:
+            return self._cooperative_reply_reward()
+        else:
+            return self._non_cooperative_reply_reward()
+
+    def _non_cooperative_reply_reward(self):
+        r = []
+        reqts = []
+        infos = []
+        reply_len = 0
+        max_len = 1
+        for act in self.sys_acts[-1]:
+            if act[0] == "request":
+                reqts.append([act[1], act[2]])
+        for act in self.usr_acts[-1]:
+            if act[0] == "inform":
+                infos.append([act[1], act[2]])
+        for req in reqts:
+            if req in infos:
+                if reply_len < max_len:
+                    r.append(1)
+                elif reply_len == max_len:
+                    r.append(0)
+                else:
+                    r.append(-5)
+
+        if r:
+            return sum(r)
+        return 0
+
+    def _cooperative_reply_reward(self):
+        r = []
+        reqts = []
+        infos = []
+        for act in self.sys_acts[-1]:
+            if act[0] == "request":
+                reqts.append([act[1], act[2]])
+        for act in self.usr_acts[-1]:
+            if act[0] == "inform":
+                infos.append([act[1], act[2]])
+        for req in reqts:
+            if req in infos:
+                r.append(1)
+            else:
+                r.append(-1)
+        if r:
+            return sum(r)
+        return 0
+
+    def _usr_terminate(self):
+        for act in self.semantic_action:
+            if act[0] in ['thank', 'bye']:
+                return True
+        return False
+
+    def is_finish(self):
+        # stop by model generation?
+        if self._finish_conversation_rule():
+            self.terminated = True
+            return True
+        elif self._usr_terminate():
+            self.terminated = True
+            return True
+        self.terminated = False
+        return False
+
+    def is_success(self):
+        task_complete = self.goal.task_complete()
+        # goal_status = self.goal.all_mentioned()
+        # should mentioned all slots
+        if task_complete:  # and goal_status["complete"] > 0.6:
+            return True
+        return False
+
+    def _good_bye(self):
+        if self.is_success():
+            return [['thank', 'general', 'none', 'none']], "thank you. bye"
+            # if self.mode == "semantic":
+            #     return [['thank', 'general', 'none', 'none']]
+            # else:
+            #     return "bye"
+        else:
+            return [["bye", "general", "None", "None"]], "bye"
+            if self.mode == "semantic":
+                return [["bye", "general", "None", "None"]]
+            return "bye"
+
+    def _finish_conversation_rule(self):
+        if self.is_success():
+            return True
+
+        if self.time_step > self.max_turn:
+            return True
+
+        if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]):
+            return True
+        return False
+
+    def is_terminated(self):
+        # Is there any action to say?
+        self.is_finish()
+        return self.terminated
+
+
+class UserPolicy(Policy):
+    def __init__(self,
+                 model_checkpoint,
+                 mode="semantic",
+                 only_action=True,
+                 sample=False,
+                 action_penalty=False,
+                 **kwargs):
+        # self.config = config
+        if not os.path.exists(os.path.dirname(model_checkpoint)):
+            os.mkdir(os.path.dirname(model_checkpoint))
+            model_downloader(os.path.dirname(model_checkpoint),
+                             "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
+
+        self.policy = UserActionPolicy(
+            model_checkpoint,
+            mode=mode,
+            only_action=only_action,
+            action_penalty=action_penalty,
+            **kwargs)
+        self.policy.load(os.path.join(
+            model_checkpoint, "pytorch_model.bin"))
+        self.sample = sample
+
+    def predict(self, sys_act, mode="max"):
+        if self.sample:
+            mode = "sample"
+        else:
+            mode = "max"
+        response = self.policy.predict(sys_act, mode)
+        return response
+
+    def init_session(self, goal=None):
+        self.policy.init_session(goal)
+
+    def is_terminated(self):
+        return self.policy.is_terminated()
+
+    def get_reward(self, sys_response=None):
+        return self.policy.get_reward(sys_response)
+
+    def get_goal(self):
+        if hasattr(self.policy, 'get_goal'):
+            return self.policy.get_goal()
+        return None
+
+
+if __name__ == "__main__":
+    import os
+
+    from convlab.dialog_agent import PipelineAgent
+    # from convlab.nlu.jointBERT.multiwoz import BERTNLU
+    from convlab.util.custom_util import set_seed
+
+    set_seed(20220220)
+    # Test semantic level behaviour
+    model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21-exp'
+    usr_policy = UserPolicy(
+        model_checkpoint,
+        mode="semantic")
+    # usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
+    usr_nlu = None  # BERTNLU()
+    usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
+    print(usr.policy.get_goal())
+
+    print(usr.response([]))
+    print(usr.policy.policy.goal.status)
+    print(usr.response([["request", "attraction", "area", "?"]]))
+    print(usr.policy.policy.goal.status)
+    print(usr.response([["request", "attraction", "area", "?"]]))
+    print(usr.policy.policy.goal.status)
diff --git a/convlab/policy/genTUS/stepGenTUSmodel.py b/convlab/policy/genTUS/stepGenTUSmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2eaf7bc5064262808120aac4a9cbe2eb007d863
--- /dev/null
+++ b/convlab/policy/genTUS/stepGenTUSmodel.py
@@ -0,0 +1,114 @@
+
+import json
+
+import torch
+from torch.nn.functional import softmax, one_hot, cross_entropy
+
+from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
+from convlab.policy.genTUS.token_map import tokenMap
+from convlab.policy.genTUS.utils import append_tokens
+from transformers import (BartConfig, BartForConditionalGeneration,
+                          BartTokenizer)
+
+
+class stepGenTUSmodel(BartForConditionalGeneration):
+    def __init__(self, model_checkpoint, train_whole_model=True, **kwargs):
+        config = BartConfig.from_pretrained(model_checkpoint)
+        super().__init__(config, **kwargs)
+
+        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
+        self.vocab = len(self.tokenizer)
+        self.kg = KnowledgeGraph(self.tokenizer)
+        self.action_kg = KnowledgeGraph(self.tokenizer)
+        self.token_map = tokenMap(self.tokenizer)
+        # only_action doesn't matter. it is only used for get_log_prob
+        self.token_map.default(only_action=True)
+
+        if not train_whole_model:
+            for param in self.parameters():
+                param.requires_grad = False
+
+            for param in self.model.decoder.layers[-1].fc1.parameters():
+                param.requires_grad = True
+            for param in self.model.decoder.layers[-1].fc2.parameters():
+                param.requires_grad = True
+
+    def get_trainable_param(self):
+
+        return filter(
+            lambda p: p.requires_grad, self.parameters())
+
+    def get_next_token_logits(self, model_input, generated_so_far):
+        input_ids = model_input["input_ids"].to(self.device)
+        attention_mask = model_input["attention_mask"].to(self.device)
+        outputs = self.forward(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=generated_so_far,
+            return_dict=True)
+        return outputs.logits[:, -1, :]
+
+    def get_log_prob(self, s, a, action_mask, prob_mask):
+        output = self.forward(input_ids=s,
+                              attention_mask=action_mask,
+                              decoder_input_ids=a)
+        prob = self._norm_prob(a[:, 1:].long(),
+                               output.logits[:, :-1, :],
+                               prob_mask[:, 1:, :].long())
+        return prob
+
+    def _norm_prob(self, a, prob, mask):
+        prob = softmax(prob, -1)
+        base = self._base(prob, mask).to(self.device)  # [b, seq_len]
+        prob = (prob*one_hot(a, num_classes=self.vocab)).sum(-1)
+        prob = torch.log(prob / base)
+        pad_mask = a != 1
+        prob = prob*pad_mask.float()
+        return prob.sum(-1)
+
+    @staticmethod
+    def _base(prob, mask):
+        batch_size, seq_len, dim = prob.shape
+        base = torch.zeros(batch_size, seq_len)
+        for b in range(batch_size):
+            for s in range(seq_len):
+                temp = [prob[b, s, c] for c in mask[b, s, :] if c > 0]
+                base[b, s] = torch.sum(torch.tensor(temp))
+        return base
+
+
+if __name__ == "__main__":
+    import os
+    from convlab.util.custom_util import set_seed
+    from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
+    set_seed(0)
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    model_checkpoint = 'results/genTUS-22-01-31-09-21/'
+    usr = UserActionPolicy(model_checkpoint=model_checkpoint)
+    usr.model.load_state_dict(torch.load(
+        os.path.join(model_checkpoint, "pytorch_model.bin"), map_location=device))
+    usr.model.eval()
+
+    test_file = "convlab/policy/genTUS/data/goal_status_validation_v1.json"
+    data = json.load(open(test_file))
+    test_id = 20
+    inputs = usr.tokenizer(data["dialog"][test_id]["in"],
+                           max_length=400,
+                           return_tensors="pt",
+                           truncation=True)
+
+    actions = [data["dialog"][test_id]["out"],
+               data["dialog"][test_id+100]["out"]]
+
+    for action in actions:
+        action = json.loads(action)
+        vec = usr.vector.action_vectorize(
+            action["action"], s=inputs["input_ids"])
+
+        print({"action": action["action"]})
+        print("get_log_prob", usr.model.get_log_prob(
+            inputs["input_ids"],
+            torch.unsqueeze(vec["vector"], 0),
+            inputs["attention_mask"],
+            torch.unsqueeze(vec["mask"], 0)))
diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..7825c2880928c40f68284b0c3199932cd1cfc477
--- /dev/null
+++ b/convlab/policy/genTUS/token_map.py
@@ -0,0 +1,64 @@
+import json
+
+
+class tokenMap:
+    def __init__(self, tokenizer):
+        self.tokenizer = tokenizer
+        self.token_name = {}
+        self.hash_map = {}
+        self.debug = False
+        self.default()
+
+    def default(self, only_action=False):
+        self.format_tokens = {
+            'start_json': '{"action": [',   # 49643, 10845, 7862, 646
+            'start_act': '["',              # 49329
+            'sep_token': '", "',            # 1297('",'), 22
+            'sep_act': '"], ["',               # 49177
+            'end_act': '"]], "',            # 42248, 7479, 22
+            'start_text': 'text": "',       # 29015, 7862, 22
+            'end_json': '}',                 # 24303
+            'end_json_2': '"}'                 # 48805
+        }
+        if only_action:
+            self.format_tokens['end_act'] = '"]]}'
+        for token_name in self.format_tokens:
+            self.add_token(
+                token_name, self.format_tokens[token_name])
+
+    def add_token(self, token_name, value):
+        if token_name in self.token_name and self.debug:
+            print(f"---> duplicate token: {token_name}({value})!!!!!!!")
+
+        token_id = self.tokenizer(str(value), add_special_tokens=False)[
+            "input_ids"]
+        self.token_name[token_name] = {"value": value, "token_id": token_id}
+        # print(token_id)
+        hash_id = token_id[0]
+        if hash_id in self.hash_map and self.debug:
+            print(
+                f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}")
+        self.hash_map[hash_id] = {
+            "name": token_name, "value": value, "token_id": token_id}
+
+    def get_info(self, hash_id):
+        return self.hash_map[hash_id]
+
+    def get_id(self, token_name):
+        # workaround
+        # if token_name not in self.token_name[token_name]:
+        #     self.add_token(token_name, token_name)
+        return self.token_name[token_name]["token_id"]
+
+    def get_token_value(self, token_name):
+        return self.token_name[token_name]["value"]
+
+    def token_name_is_in(self, token_name):
+        if token_name in self.token_name:
+            return True
+        return False
+
+    def hash_id_is_in(self, hash_id):
+        if hash_id in self.hash_map:
+            return True
+        return False
diff --git a/convlab/policy/genTUS/train_model.py b/convlab/policy/genTUS/train_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2162417461d692514e3b27742dbfd477491fc24e
--- /dev/null
+++ b/convlab/policy/genTUS/train_model.py
@@ -0,0 +1,258 @@
+import json
+import os
+import sys
+from argparse import ArgumentParser
+from datetime import datetime
+from pprint import pprint
+import numpy as np
+import torch
+import transformers
+from datasets import Dataset, load_metric
+from tqdm import tqdm
+from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
+                          BartForConditionalGeneration, BartTokenizer,
+                          DataCollatorForSeq2Seq, Seq2SeqTrainer,
+                          Seq2SeqTrainingArguments)
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+os.environ["WANDB_DISABLED"] = "true"
+
+METRIC = load_metric("sacrebleu")
+TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base")
+TOKENIZER.add_tokens(["<?>"])
+MAX_IN_LEN = 500
+MAX_OUT_LEN = 500
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    # data_name, dial_ids_order, split2ratio
+    parser.add_argument("--model-type", type=str, default="unify",
+                        help="unify or multiwoz")
+    parser.add_argument("--data-name", type=str, default="multiwoz21",
+                        help="multiwoz21, sgd, tm1, tm2, tm3, sgd+tm, or all")
+    parser.add_argument("--dial-ids-order", type=int, default=0)
+    parser.add_argument("--split2ratio", type=float, default=1)
+    parser.add_argument("--batch-size", type=int, default=16)
+    parser.add_argument("--model-checkpoint", type=str,
+                        default="facebook/bart-base")
+    return parser.parse_args()
+
+
+def gentus_compute_metrics(eval_preds):
+    preds, labels = eval_preds
+    if isinstance(preds, tuple):
+        preds = preds[0]
+    decoded_preds = TOKENIZER.batch_decode(
+        preds, skip_special_tokens=True, max_length=MAX_OUT_LEN)
+
+    # Replace -100 in the labels as we can't decode them.
+    labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id)
+    decoded_labels = TOKENIZER.batch_decode(
+        labels, skip_special_tokens=True, max_length=MAX_OUT_LEN)
+
+    act, text = postprocess_text(decoded_preds, decoded_labels)
+
+    result = METRIC.compute(
+        # predictions=decoded_preds, references=decoded_labels)
+        predictions=text["preds"], references=text["labels"])
+    result = {"bleu": result["score"]}
+    f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"])
+    for s in f1_scores:
+        result[s] = f1_scores[s]
+
+    result = {k: round(v, 4) for k, v in result.items()}
+    return result
+
+
+def postprocess_text(preds, labels):
+    act = {"preds": [], "labels": []}
+    text = {"preds": [], "labels": []}
+
+    for pred, label in zip(preds, labels):
+        model_output = parse_output(pred.strip())
+        label_output = parse_output(label.strip())
+        if len(label_output["text"]) < 1:
+            continue
+        act["preds"].append(model_output.get("action", []))
+        text["preds"].append(model_output.get("text", pred.strip()))
+        act["labels"].append(label_output["action"])
+        text["labels"].append([label_output["text"]])
+
+    return act, text
+
+
+def parse_output(in_str):
+    in_str = in_str.replace('<s>', '').replace('<\\s>', '')
+    try:
+        output = json.loads(in_str)
+    except:
+        # print(f"invalid action {in_str}")
+        output = {"action": [], "text": ""}
+    return output
+
+
+def f1_measure(pred_acts, label_acts):
+    result = {"precision": [], "recall": [], "f1": []}
+    for pred, label in zip(pred_acts, label_acts):
+        r = tp_fn_fp(pred, label)
+        for m in result:
+            result[m].append(r[m])
+    for m in result:
+        result[m] = sum(result[m])/len(result[m])
+
+    return result
+
+
+def tp_fn_fp(pred, label):
+    tp, fn, fp = 0.0, 0.0, 0.0
+    precision, recall, f1 = 0, 0, 0
+    for p in pred:
+        if p in label:
+            tp += 1
+        else:
+            fp += 1
+    for l in label:
+        if l not in pred:
+            fn += 1
+    if (tp+fp) > 0:
+        precision = tp / (tp+fp)
+    if (tp+fn) > 0:
+        recall = tp/(tp+fn)
+    if (precision + recall) > 0:
+        f1 = (2*precision*recall)/(precision+recall)
+
+    return {"precision": precision, "recall": recall, "f1": f1}
+
+
+class TrainerHelper:
+    def __init__(self, tokenizer, max_input_length=500, max_target_length=500):
+        print("transformers version is: ", transformers.__version__)
+        self.tokenizer = tokenizer
+        self.max_input_length = max_input_length
+        self.max_target_length = max_target_length
+        self.base_name = "convlab/policy/genTUS"
+        self.dir_name = ""
+
+    def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
+        # base_name = "convlab/policy/genTUS/unify/data"
+        if model_type not in ["unify", "multiwoz"]:
+            print("Unknown model type. Currently only support unify and multiwoz")
+        self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}"
+        return os.path.join(self.base_name, model_type, 'data', self.dir_name)
+
+    def get_model_folder(self, model_type):
+        folder_name = os.path.join(
+            self.base_name, model_type, "experiments", self.dir_name)
+        if not os.path.exists(folder_name):
+            os.makedirs(folder_name)
+        return folder_name
+
+    def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
+        data_folder = self._get_data_folder(
+            model_type, data_name, dial_ids_order, split2ratio)
+
+        raw_data = {}
+        for d_type in ["train", "validation", "test"]:
+            f_name = os.path.join(data_folder, f"{d_type}.json")
+            raw_data[d_type] = json.load(open(f_name))
+
+        tokenized_datasets = {}
+        for data_type, data in raw_data.items():
+            tokenized_datasets[data_type] = Dataset.from_dict(
+                self._preprocess(data["dialog"]))
+
+        return tokenized_datasets
+
+    def _preprocess(self, examples):
+        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+        if isinstance(examples, dict):
+            examples = [examples]
+        for example in tqdm(examples):
+            inputs = self.tokenizer(example["in"],
+                                    max_length=self.max_input_length,
+                                    truncation=True)
+
+            # Setup the tokenizer for targets
+            with self.tokenizer.as_target_tokenizer():
+                labels = self.tokenizer(example["out"],
+                                        max_length=self.max_target_length,
+                                        truncation=True)
+            for key in ["input_ids", "attention_mask"]:
+                model_inputs[key].append(inputs[key])
+            model_inputs["labels"].append(labels["input_ids"])
+
+        return model_inputs
+
+
+def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"):
+    tokenizer = TOKENIZER
+
+    train_helper = TrainerHelper(
+        tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length)
+    data = train_helper.parse_data(model_type=model_type,
+                                   data_name=data_name,
+                                   dial_ids_order=dial_ids_order,
+                                   split2ratio=split2ratio)
+
+    model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
+    model.resize_token_embeddings(len(tokenizer))
+    fp16 = False
+    if torch.cuda.is_available():
+        fp16 = True
+
+    model_dir = os.path.join(
+        train_helper.get_model_folder(model_type),
+        f"{datetime.now().strftime('%y-%m-%d-%H-%M')}")
+
+    args = Seq2SeqTrainingArguments(
+        model_dir,
+        evaluation_strategy="epoch",
+        learning_rate=2e-5,
+        per_device_train_batch_size=batch_size,
+        per_device_eval_batch_size=batch_size,
+        weight_decay=0.01,
+        save_total_limit=2,
+        num_train_epochs=5,
+        predict_with_generate=True,
+        fp16=fp16,
+        push_to_hub=False,
+        generation_max_length=max_target_length,
+        logging_dir=os.path.join(model_dir, 'log')
+    )
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer, model=model, padding=True)
+
+    # customize this trainer
+    trainer = Seq2SeqTrainer(
+        model=model,
+        args=args,
+        train_dataset=data["train"],
+        eval_dataset=data["test"],
+        data_collator=data_collator,
+        tokenizer=tokenizer,
+        compute_metrics=gentus_compute_metrics)
+    print("start training...")
+    trainer.train()
+    print("saving model...")
+    trainer.save_model()
+
+
+def main():
+    args = arg_parser()
+    print("---> data_name", args.data_name)
+    train(model_type=args.model_type,
+          data_name=args.data_name,
+          dial_ids_order=args.dial_ids_order,
+          split2ratio=args.split2ratio,
+          batch_size=args.batch_size,
+          max_input_length=MAX_IN_LEN,
+          max_target_length=MAX_OUT_LEN,
+          model_checkpoint=args.model_checkpoint)
+
+
+if __name__ == "__main__":
+    main()
+    # sgd+tm: 46000
diff --git a/convlab/policy/genTUS/unify/Goal.py b/convlab/policy/genTUS/unify/Goal.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a77b090a7266d07f653e707bd4b749b6a6114bb
--- /dev/null
+++ b/convlab/policy/genTUS/unify/Goal.py
@@ -0,0 +1,233 @@
+"""
+The user goal for unify data format
+"""
+import json
+from convlab.policy.tus.unify.Goal import old_goal2list
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
+from convlab.util.custom_util import slot_mapping
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, ""]
+
+NOT_MENTIONED = "not mentioned"
+FULFILLED = "fulfilled"
+REQUESTED = "requested"
+CONFLICT = "conflict"
+
+
+class Goal:
+    """ User Goal Model Class. """
+
+    def __init__(self, goal=None, goal_generator=None):
+        """
+        create new Goal from a dialog or from goal_generator
+        Args:
+            goal: can be a list (create from a dialog), an abus goal, or none
+        """
+        self.domains = []
+        self.domain_goals = {}
+        self.status = {}
+        self.invert_slot_mapping = {v: k for k, v in slot_mapping.items()}
+        self.raw_goal = None
+
+        self._init_goal_from_data(goal, goal_generator)
+        self._init_status()
+
+    def __str__(self):
+        return '-----Goal-----\n' + \
+               json.dumps(self.domain_goals, indent=4) + \
+               '\n-----Goal-----'
+
+    def _init_goal_from_data(self, goal=None, goal_generator=None):
+        if not goal and goal_generator:
+            goal = ABUS_Goal(goal_generator)
+            self.raw_goal = goal.domain_goals
+            goal = old_goal2list(goal.domain_goals)
+
+        elif isinstance(goal, dict):
+            self.raw_goal = goal
+            goal = old_goal2list(goal)
+
+        elif isinstance(goal, ABUS_Goal):
+            self.raw_goal = goal.domain_goals
+            goal = old_goal2list(goal.domain_goals)
+
+        # else:
+        #     print("unknow goal")
+
+        # be careful of this order
+        for domain, intent, slot, value in goal:
+            if domain == "none":
+                continue
+            if domain not in self.domains:
+                self.domains.append(domain)
+                self.domain_goals[domain] = {}
+            if intent not in self.domain_goals[domain]:
+                self.domain_goals[domain][intent] = {}
+
+            if not value:
+                if intent == "request":
+                    self.domain_goals[domain][intent][slot] = DEF_VAL_UNK
+                else:
+                    print(
+                        f"unknown no value intent {domain}, {intent}, {slot}")
+            else:
+                self.domain_goals[domain][intent][slot] = value
+
+    def _init_status(self):
+        for domain, domain_goal in self.domain_goals.items():
+            if domain not in self.status:
+                self.status[domain] = {}
+            for slot_type, sub_domain_goal in domain_goal.items():
+                if slot_type not in self.status[domain]:
+                    self.status[domain][slot_type] = {}
+                for slot in sub_domain_goal:
+                    if slot not in self.status[domain][slot_type]:
+                        self.status[domain][slot_type][slot] = {}
+                    self.status[domain][slot_type][slot] = {
+                        "value": str(sub_domain_goal[slot]),
+                        "status": NOT_MENTIONED}
+
+    def get_goal_list(self, data_goal=None):
+        goal_list = []
+        if data_goal:
+            # make sure the order!!!
+            for domain, intent, slot, _ in data_goal:
+                status = self._get_status(domain, intent, slot)
+                value = self.domain_goals[domain][intent][slot]
+                goal_list.append([intent, domain, slot, value, status])
+            return goal_list
+        else:
+            for domain, domain_goal in self.domain_goals.items():
+                for intent, sub_goal in domain_goal.items():
+                    for slot, value in sub_goal.items():
+                        status = self._get_status(domain, intent, slot)
+                        goal_list.append([intent, domain, slot, value, status])
+
+        return goal_list
+
+    def _get_status(self, domain, intent, slot):
+        if domain not in self.status:
+            return NOT_MENTIONED
+        if intent not in self.status[domain]:
+            return NOT_MENTIONED
+        if slot not in self.status[domain][intent]:
+            return NOT_MENTIONED
+        return self.status[domain][intent][slot]["status"]
+
+    def task_complete(self):
+        """
+        Check that all requests have been met
+        Returns:
+            (boolean): True to accomplish.
+        """
+        for domain, domain_goal in self.status.items():
+            if domain not in self.domain_goals:
+                continue
+            for slot_type, sub_domain_goal in domain_goal.items():
+                if slot_type not in self.domain_goals[domain]:
+                    continue
+                for slot, status in sub_domain_goal.items():
+                    if slot not in self.domain_goals[domain][slot_type]:
+                        continue
+                    # for strict success, turn this on
+                    if status["status"] in [NOT_MENTIONED, CONFLICT]:
+                        if status["status"] == CONFLICT and slot in ["arrive by", "leave at"]:
+                            continue
+                        return False
+                    if "?" in status["value"]:
+                        return False
+
+        return True
+
+    # TODO change to update()?
+    def update_user_goal(self, action, char="usr"):
+        # update request and booked
+        if char == "usr":
+            self._user_action_update(action)
+        elif char == "sys":
+            self._system_action_update(action)
+        else:
+            print("!!!UNKNOWN CHAR!!!")
+
+    def _user_action_update(self, action):
+        # no need to update user goal
+        for intent, domain, slot, _ in action:
+            goal_intent = self._check_slot_and_intent(domain, slot)
+            if not goal_intent:
+                continue
+            # fulfilled by user
+            if is_inform(intent):
+                self._set_status(goal_intent, domain, slot, FULFILLED)
+            # requested by user
+            if is_request(intent):
+                self._set_status(goal_intent, domain, slot, REQUESTED)
+
+    def _system_action_update(self, action):
+        for intent, domain, slot, value in action:
+            goal_intent = self._check_slot_and_intent(domain, slot)
+            if not goal_intent:
+                continue
+            # fulfill request by system
+            if is_inform(intent) and is_request(goal_intent):
+                self._set_status(goal_intent, domain, slot, FULFILLED)
+                self._set_goal(goal_intent, domain, slot, value)
+
+            if is_inform(intent) and is_inform(goal_intent):
+                # fulfill infrom by system
+                if value == self.domain_goals[domain][goal_intent][slot]:
+                    self._set_status(goal_intent, domain, slot, FULFILLED)
+                # conflict system inform
+                else:
+                    self._set_status(goal_intent, domain, slot, CONFLICT)
+            # requested by system
+            if is_request(intent) and is_inform(goal_intent):
+                self._set_status(goal_intent, domain, slot, REQUESTED)
+
+    def _set_status(self, intent, domain, slot, status):
+        self.status[domain][intent][slot]["status"] = status
+
+    def _set_goal(self, intent, domain, slot, value):
+        # old_value = self.domain_goals[domain][intent][slot]
+        self.domain_goals[domain][intent][slot] = value
+        self.status[domain][intent][slot]["value"] = value
+        # print(
+        #     f"updating user goal {intent}-{domain}-{slot} {old_value}-> {value}")
+
+    def _check_slot_and_intent(self, domain, slot):
+        not_found = ""
+        if domain not in self.domain_goals:
+            return not_found
+        for intent in self.domain_goals[domain]:
+            if slot in self.domain_goals[domain][intent]:
+                return intent
+        return not_found
+
+
+def is_inform(intent):
+    if "inform" in intent:
+        return True
+    return False
+
+
+def is_request(intent):
+    if "request" in intent:
+        return True
+    return False
+
+
+def transform_data_act(data_action):
+    action_list = []
+    for _, dialog_act in data_action.items():
+        for act in dialog_act:
+            value = act.get("value", "")
+            if not value:
+                if "request" in act["intent"]:
+                    value = "?"
+                else:
+                    value = "none"
+            action_list.append(
+                [act["intent"], act["domain"], act["slot"], value])
+    return action_list
diff --git a/convlab/policy/genTUS/unify/build_data.py b/convlab/policy/genTUS/unify/build_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..50873a1d4b6ffaa8ee49c84a4b088ae56ad13554
--- /dev/null
+++ b/convlab/policy/genTUS/unify/build_data.py
@@ -0,0 +1,211 @@
+import json
+import os
+import sys
+from argparse import ArgumentParser
+
+from tqdm import tqdm
+
+from convlab.policy.genTUS.unify.Goal import Goal, transform_data_act
+from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
+
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    parser.add_argument("--dataset", type=str, default="multiwoz21",
+                        help="the dataset, such as multiwoz21, sgd, tm1, tm2, and tm3.")
+    parser.add_argument("--dial-ids-order", type=int, default=0)
+    parser.add_argument("--split2ratio", type=float, default=1)
+    parser.add_argument("--random-order", action="store_true")
+    parser.add_argument("--no-status", action="store_true")
+    parser.add_argument("--add-history",  action="store_true")
+    parser.add_argument("--remove-domain", type=str, default="")
+
+    return parser.parse_args()
+
+class DataBuilder:
+    def __init__(self, dataset='multiwoz21'):
+        self.dataset = dataset
+
+    def setup_data(self,
+                   raw_data,
+                   random_order=False,
+                   no_status=False,
+                   add_history=False,
+                   remove_domain=None):
+        examples = {data_split: {"dialog": []} for data_split in raw_data}
+
+        for data_split, dialogs in raw_data.items():
+            for dialog in tqdm(dialogs, ascii=True):
+                example = self._one_dialog(dialog=dialog,
+                                           add_history=add_history,
+                                           random_order=random_order,
+                                           no_status=no_status)
+                examples[data_split]["dialog"] += example
+
+        return examples
+
+    def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False):
+        example = []
+        history = []
+
+        data_goal = self.norm_domain_goal(create_goal(dialog))
+        if not data_goal:
+            return example
+        user_goal = Goal(goal=data_goal)
+
+        for turn_id in range(0, len(dialog["turns"]), 2):
+            sys_act = self._get_sys_act(dialog, turn_id)
+
+            user_goal.update_user_goal(action=sys_act, char="sys")
+            usr_goal_str = self._user_goal_str(user_goal, data_goal, random_order, no_status)
+
+            usr_act = self.norm_domain(transform_data_act(
+                dialog["turns"][turn_id]["dialogue_acts"]))
+            user_goal.update_user_goal(action=usr_act, char="usr")
+
+            # change value "?" to "<?>"
+            usr_act = self._modify_act(usr_act)
+
+            in_str = self._dump_in_str(sys_act, usr_goal_str, history, turn_id, add_history)
+            out_str = self._dump_out_str(usr_act, dialog["turns"][turn_id]["utterance"])
+
+            history.append(usr_act)
+            if usr_act:
+                example.append({"in": in_str, "out": out_str})
+
+        return example
+
+    def _get_sys_act(self, dialog, turn_id):
+        sys_act = []
+        if turn_id > 0:
+            sys_act = self.norm_domain(transform_data_act(
+                dialog["turns"][turn_id - 1]["dialogue_acts"]))
+        return sys_act
+
+    def _user_goal_str(self, user_goal, data_goal, random_order, no_status):
+        if random_order:
+            usr_goal_str = user_goal.get_goal_list()
+        else:
+            usr_goal_str = user_goal.get_goal_list(data_goal=data_goal)
+
+        if no_status:
+            usr_goal_str = self._remove_status(usr_goal_str)
+        return usr_goal_str
+
+    def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history):
+        in_str = {}
+        in_str["system"] = self._modify_act(sys_act)
+        in_str["goal"] = usr_goal_str
+        if add_history:
+            h = []
+            if history:
+                h = history[-3:]
+            in_str["history"] = h
+            in_str["turn"] = str(int(turn_id/2))
+
+        return json.dumps(in_str)
+
+    def _dump_out_str(self, usr_act, text):
+        out_str = {"action": usr_act, "text": text}
+        return json.dumps(out_str)
+
+    @staticmethod
+    def _norm_intent(intent):
+        if intent in ["inform_intent", "negate_intent", "affirm_intent", "request_alts"]:
+            return f"_{intent}"
+        return intent
+
+    def norm_domain(self, x):
+        if not x:
+            return x
+        norm_result = []
+        # print(x)
+        for intent, domain, slot, value in x:
+            if "_" in domain:
+                domain = domain.split('_')[0]
+            if not domain:
+                domain = "none"
+            if not slot:
+                slot = "none"
+            if not value:
+                if intent == "request":
+                    value = "<?>"
+                else:
+                    value = "none"
+            norm_result.append([self._norm_intent(intent), domain, slot, value])
+        return norm_result
+
+    def norm_domain_goal(self, x):
+        if not x:
+            return x
+        norm_result = []
+        # take care of the order!
+        for domain, intent, slot, value in x:
+            if "_" in domain:
+                domain = domain.split('_')[0]
+            if not domain:
+                domain = "none"
+            if not slot:
+                slot = "none"
+            if not value:
+                if intent == "request":
+                    value = "<?>"
+                else:
+                    value = "none"
+            norm_result.append([domain, self._norm_intent(intent), slot, value])
+        return norm_result
+
+    @staticmethod
+    def _remove_status(goal_list):
+        new_list = [[goal[0], goal[1], goal[2], goal[3]]
+                    for goal in goal_list]
+        return new_list
+
+    @staticmethod
+    def _modify_act(act):
+        new_act = []
+        for i, d, s, value in act:
+            if value == "?":
+                new_act.append([i, d, s, "<?>"])
+            else:
+                new_act.append([i, d, s, value])
+        return new_act
+
+
+if __name__ == "__main__":
+    args = arg_parser()
+
+    base_name = "convlab/policy/genTUS/unify/data"
+    dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}"
+    folder_name = os.path.join(base_name, dir_name)
+    remove_domain = args.remove_domain
+
+    if not os.path.exists(folder_name):
+        os.makedirs(folder_name)
+
+    dataset = load_experiment_dataset(
+        data_name=args.dataset,
+        dial_ids_order=args.dial_ids_order,
+        split2ratio=args.split2ratio)
+    data_builder = DataBuilder(dataset=args.dataset)
+    data = data_builder.setup_data(
+        raw_data=dataset,
+        random_order=args.random_order,
+        no_status=args.no_status,
+        add_history=args.add_history,
+        remove_domain=remove_domain)
+
+    for data_type in data:
+        if remove_domain:
+            file_name = os.path.join(
+                folder_name,
+                f"no{remove_domain}_{data_type}.json")
+        else:
+            file_name = os.path.join(
+                folder_name,
+                f"{data_type}.json")
+        json.dump(data[data_type], open(file_name, 'w'), indent=2)
diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..68af13e481fe4799dfc2a6f3763b526611eabd9c
--- /dev/null
+++ b/convlab/policy/genTUS/unify/knowledge_graph.py
@@ -0,0 +1,252 @@
+import json
+from random import choices
+
+from convlab.policy.genTUS.token_map import tokenMap
+
+from transformers import BartTokenizer
+
+DEBUG = False
+DATASET = "unify"
+
+
+class KnowledgeGraph:
+    def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"):
+        print("dataset", dataset)
+        self.debug = DEBUG
+        self.tokenizer = tokenizer
+
+        if "multiwoz" in dataset:
+            self.domain_intent = ["inform", "request"]
+            self.general_intent = ["thank", "bye"]
+        # use sgd dataset intents as default
+        else:
+            self.domain_intent = ["_inform_intent",
+                                  "_negate_intent",
+                                  "_affirm_intent",
+                                  "inform",
+                                  "request",
+                                  "affirm",
+                                  "negate",
+                                  "select",
+                                  "_request_alts"]
+            self.general_intent = ["thank_you", "goodbye"]
+
+        self.general_domain = "none"
+        self.kg_map = {"intent": tokenMap(tokenizer=self.tokenizer)}
+
+        for intent in self.domain_intent + self.general_intent:
+            self.kg_map["intent"].add_token(intent, intent)
+
+        self.init()
+
+    def init(self):
+        for map_type in ["domain", "slot", "value"]:
+            self.kg_map[map_type] = tokenMap(tokenizer=self.tokenizer)
+        self.add_token("<?>", "value")
+
+    def parse_input(self, in_str):
+        self.init()
+        inputs = json.loads(in_str)
+        self.sys_act = inputs["system"]
+        self.user_goal = {}
+        self._add_none_domain()
+        for intent, domain, slot, value, _ in inputs["goal"]:
+            self._update_user_goal(intent, domain, slot, value, source="goal")
+
+        for intent, domain, slot, value in self.sys_act:
+            self._update_user_goal(intent, domain, slot, value, source="sys")
+
+    def _add_none_domain(self):
+        self.user_goal["none"] = {"none": "none"}
+        # add slot
+        self.add_token("none", "domain")
+        self.add_token("none", "slot")
+        self.add_token("none", "value")
+
+    def _update_user_goal(self, intent, domain, slot, value, source="goal"):
+
+        if value == "?":
+            value = "<?>"
+
+        if intent == "request" and source == "sys":
+            value = "dontcare"  # user can "dontcare" system request
+
+        if source == "sys" and intent != "request":
+            return
+
+        if domain not in self.user_goal:
+            self.user_goal[domain] = {}
+            self.user_goal[domain]["none"] = ["none"]
+            self.add_token(domain, "domain")
+            self.add_token("none", "slot")
+            self.add_token("none", "value")
+
+        if slot not in self.user_goal[domain]:
+            self.user_goal[domain][slot] = []
+            self.add_token(domain, "slot")
+
+        if value not in self.user_goal[domain][slot]:
+            value = json.dumps(str(value))[1:-1]
+            self.user_goal[domain][slot].append(value)
+            value = value.replace('"', "'")
+            self.add_token(value, "value")
+
+    def add_token(self, token_name, map_type):
+        if map_type == "value":
+            token_name = token_name.replace('"', "'")
+        if not self.kg_map[map_type].token_name_is_in(token_name):
+            self.kg_map[map_type].add_token(token_name, token_name)
+
+    def _get_max_score(self, outputs, candidate_list, map_type):
+        score = {}
+        if not candidate_list:
+            print(f"ERROR: empty candidate list for {map_type}")
+            score[1] = {"token_id": self._get_token_id(
+                "none"), "token_name": "none"}
+
+        for x in candidate_list:
+            hash_id = self._get_token_id(x)[0]
+            s = outputs[:, hash_id].item()
+            score[s] = {"token_id": self._get_token_id(x),
+                        "token_name": x}
+        return score
+
+    def _select(self, score, mode="max"):
+        probs = [s for s in score]
+        if mode == "max":
+            s = max(probs)
+        elif mode == "sample":
+            s = choices(probs, weights=probs, k=1)
+            s = s[0]
+
+        else:
+            print("unknown select mode")
+
+        return s
+
+    def _get_max_domain_token(self, outputs, candidates, map_type, mode="max"):
+        score = self._get_max_score(outputs, candidates, map_type)
+        s = self._select(score, mode)
+        token_id = score[s]["token_id"]
+        token_name = score[s]["token_name"]
+
+        return {"token_id": token_id, "token_name": token_name}
+
+    def candidate(self, candidate_type, **kwargs):
+        if "intent" in kwargs:
+            intent = kwargs["intent"]
+        if candidate_type == "intent":
+            allow_general_intent = kwargs.get("allow_general_intent", True)
+            if allow_general_intent:
+                return self.domain_intent + self.general_intent
+            else:
+                return self.domain_intent
+        elif candidate_type == "domain":
+            if intent in self.general_intent:
+                return [self.general_domain]
+            else:
+                return [d for d in self.user_goal]
+        elif candidate_type == "slot":
+            if intent in self.general_intent:
+                return ["none"]
+            else:
+                return self._filter_slot(intent, kwargs["domain"], kwargs["is_mentioned"])
+        else:
+            if intent in self.general_intent:
+                return ["none"]
+            elif intent.lower() == "request":
+                return ["<?>"]
+            else:
+                return self._filter_value(intent, kwargs["domain"], kwargs["slot"])
+
+    def get_intent(self, outputs, mode="max", allow_general_intent=True):
+        # return intent, token_id_list
+        # TODO request?
+        canidate_list = self.candidate(
+            "intent", allow_general_intent=allow_general_intent)
+        score = self._get_max_score(outputs, canidate_list, "intent")
+        s = self._select(score, mode)
+
+        return score[s]
+
+    def get_domain(self, outputs, intent, mode="max"):
+        if intent in self.general_intent:
+            token_name = self.general_domain
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent in self.domain_intent:
+            # [d for d in self.user_goal]
+            domain_list = self.candidate("domain", intent=intent)
+            token_map = self._get_max_domain_token(
+                outputs=outputs, candidates=domain_list, map_type="domain", mode=mode)
+        else:
+            if self.debug:
+                print("unknown intent", intent)
+
+        return token_map
+
+    def get_slot(self, outputs, intent, domain, mode="max", is_mentioned=False):
+        if intent in self.general_intent:
+            token_name = "none"
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent in self.domain_intent:
+            slot_list = self.candidate(
+                candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned)
+            token_map = self._get_max_domain_token(
+                outputs=outputs, candidates=slot_list, map_type="slot", mode=mode)
+
+        return token_map
+
+    def get_value(self, outputs, intent, domain, slot, mode="max"):
+        if intent in self.general_intent or slot.lower() == "none":
+            token_name = "none"
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent.lower() == "request":
+            token_name = "<?>"
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent in self.domain_intent:
+            # TODO should not none ?
+            # value_list = [v for v in self.user_goal[domain][slot]]
+            value_list = self.candidate(
+                candidate_type="value", intent=intent, domain=domain, slot=slot)
+
+            token_map = self._get_max_domain_token(
+                outputs=outputs, candidates=value_list, map_type="value", mode=mode)
+
+        return token_map
+
+    def _filter_slot(self, intent, domain, is_mentioned=True):
+        slot_list = []
+        for slot in self.user_goal[domain]:
+            value_list = self._filter_value(intent, domain, slot)
+            if len(value_list) > 0:
+                slot_list.append(slot)
+        if not is_mentioned and intent.lower() != "request":
+            slot_list.append("none")
+        return slot_list
+
+    def _filter_value(self, intent, domain, slot):
+        value_list = [v for v in self.user_goal[domain][slot]]
+        if "none" in value_list:
+            value_list.remove("none")
+        if intent.lower() != "request":
+            if "?" in value_list:
+                value_list.remove("?")
+            if "<?>" in value_list:
+                value_list.remove("<?>")
+        # print(f"{intent}-{domain}-{slot}= {value_list}")
+        return value_list
+
+    def _get_token_id(self, token):
+        return self.tokenizer(token, add_special_tokens=False)["input_ids"]
diff --git a/convlab/policy/genTUS/utils.py b/convlab/policy/genTUS/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c822dd35f985790d53a2834e8b6fe437864f24
--- /dev/null
+++ b/convlab/policy/genTUS/utils.py
@@ -0,0 +1,5 @@
+import torch
+
+
+def append_tokens(tokens, new_token, device):
+    return torch.cat((tokens, torch.tensor([new_token]).to(device)), dim=1)
diff --git a/convlab/policy/mle/README.md b/convlab/policy/mle/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c13140497508d79ec6faedfc588fa4f6af7043f7
--- /dev/null
+++ b/convlab/policy/mle/README.md
@@ -0,0 +1,22 @@
+# Maximum Likelihood Estimator (MLE)
+
+MLE learns a MLP model in a supervised way using a provided dataset. The trained model can be used as intialisation point for running RL trainings with PPO or GDPL for instance.
+
+## Supervised Training
+
+Starting a training is as easy as executing
+
+```sh
+$ python train.py --dataset_name=DATASET_NAME --seed=SEED --eval_freq=FREQ
+```
+
+The dataset name can be "multiwoz21" or "sgd" for instance. The first time you run that command, it will take longer as the dataset needs to be pre-processed. The evaluation frequency decides after how many epochs should be evaluated.
+
+Other hyperparameters such as learning rate or number of epochs can be set in the config.json file.
+
+We provide a model trained on multiwoz21 on hugging-face: https://huggingface.co/ConvLab/mle-policy-multiwoz21
+
+
+## Evaluation
+
+Evaluation on the validation data set takes place during training.
\ No newline at end of file
diff --git a/convlab/policy/mle/loader.py b/convlab/policy/mle/loader.py
index ebc01a0149dbeabf872da4ceb4ba184bfb54fd22..bb898ab4a30ad957db632b4eb74fca581f05f05b 100755
--- a/convlab/policy/mle/loader.py
+++ b/convlab/policy/mle/loader.py
@@ -2,6 +2,9 @@ import os
 import pickle
 import torch
 import torch.utils.data as data
+from copy import deepcopy
+
+from tqdm import tqdm
 
 from convlab.policy.vector.vector_binary import VectorBinary
 from convlab.util import load_policy_data, load_dataset
@@ -12,18 +15,20 @@ from convlab.policy.vector.dataset import ActDataset
 
 class PolicyDataVectorizer:
     
-    def __init__(self, dataset_name='multiwoz21', vector=None):
+    def __init__(self, dataset_name='multiwoz21', vector=None, dst=None):
         self.dataset_name = dataset_name
         if vector is None:
             self.vector = VectorBinary(dataset_name)
         else:
             self.vector = vector
+        self.dst = dst
         self.process_data()
 
     def process_data(self):
-
-        processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
-                                     f'processed_data/{self.dataset_name}_{type(self.vector).__name__}')
+        name = f"{self.dataset_name}_"
+        name += f"{type(self.dst).__name__}_" if self.dst is not None else ""
+        name += f"{type(self.vector).__name__}"
+        processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), name)
         if os.path.exists(processed_dir):
             print('Load processed data file')
             self._load_data(processed_dir)
@@ -42,15 +47,27 @@ class PolicyDataVectorizer:
             self.data[split] = []
             raw_data = data_split[split]
 
-            for data_point in raw_data:
-                state = default_state()
+            if self.dst is not None:
+                self.dst.init_session()
+
+            for data_point in tqdm(raw_data):
+                if self.dst is None:
+                    state = default_state()
+
+                    state['belief_state'] = data_point['context'][-1]['state']
+                    state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
+                else:
+                    last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else ''
+                    self.dst.state['history'].append(['sys', last_system_utt])
 
-                state['belief_state'] = data_point['context'][-1]['state']
-                state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
-                last_system_act = data_point['context'][-2]['dialogue_acts'] \
-                    if len(data_point['context']) > 1 else {}
+                    usr_utt = data_point['context'][-1]['utterance']
+                    state = deepcopy(self.dst.update(usr_utt))
+                    self.dst.state['history'].append(['usr', usr_utt])
+                last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {}
                 state['system_action'] = flatten_acts(last_system_act)
                 state['terminated'] = data_point['terminated']
+                if self.dst is not None and state['terminated']:
+                    self.dst.init_session()
                 state['booked'] = data_point['booked']
                 dialogue_act = flatten_acts(data_point['dialogue_acts'])
 
diff --git a/convlab/policy/mle/train.py b/convlab/policy/mle/train.py
index 2b82a476e5db119fe80b44517fd1b0d5e21fcfa6..c2477760c8a189cbacd978151755fa790996a421 100755
--- a/convlab/policy/mle/train.py
+++ b/convlab/policy/mle/train.py
@@ -137,15 +137,6 @@ class MLE_Trainer(MLE_Trainer_Abstract):
     def __init__(self, manager, vector, cfg):
         self._init_data(manager, cfg)
 
-        try:
-            self.use_entropy = manager.use_entropy
-            self.use_mutual_info = manager.use_mutual_info
-            self.use_confidence_scores = manager.use_confidence_scores
-        except:
-            self.use_entropy = False
-            self.use_mutual_info = False
-            self.use_confidence_scores = False
-
         # override the loss defined in the MLE_Trainer_Abstract to support pos_weight
         pos_weight = cfg['pos_weight'] * torch.ones(vector.da_dim).to(device=DEVICE)
         self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
@@ -161,6 +152,10 @@ def arg_parser():
     parser.add_argument("--seed", type=int, default=0)
     parser.add_argument("--eval_freq", type=int, default=1)
     parser.add_argument("--dataset_name", type=str, default="multiwoz21")
+    parser.add_argument("--use_masking", action='store_true')
+
+    parser.add_argument("--dst", type=str, default=None)
+    parser.add_argument("--dst_args", type=str, default=None)
 
     args = parser.parse_args()
     return args
@@ -181,8 +176,28 @@ if __name__ == '__main__':
     set_seed(args.seed)
     logging.info(f"Seed used: {args.seed}")
 
-    vector = VectorBinary(dataset_name=args.dataset_name, use_masking=False)
-    manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector)
+    if args.dst is None:
+        vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
+        dst = None
+    elif args.dst == "setsumbt":
+        dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ')
+                    if '=' in arg] if args.dst_args is not None else []
+        dst_args = {key: eval(value) for key, value in dst_args}
+        from convlab.dst.setsumbt import SetSUMBTTracker
+        dst = SetSUMBTTracker(**dst_args)
+        if dst.return_confidence_scores:
+            from convlab.policy.vector.vector_uncertainty import VectorUncertainty
+            vector = VectorUncertainty(dataset_name=args.dataset_name, use_masking=args.use_masking,
+                                       manually_add_entity_names=False,
+                                       use_confidence_scores=dst.return_confidence_scores,
+                                       confidence_thresholds=dst.confidence_thresholds,
+                                       use_state_total_uncertainty=dst.return_belief_state_entropy,
+                                       use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info)
+        else:
+            vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
+    else:
+        raise NameError(f"Tracker: {args.tracker} not implemented.")
+    manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst)
     agent = MLE_Trainer(manager, vector, cfg)
 
     logging.info('Start training')
diff --git a/convlab/policy/pg/README.md b/convlab/policy/pg/README.md
index d61365276fe2d8e932ea41f32ad663742a17ca15..23032f827237d2364f28933f44631a6e981d7a60 100755
--- a/convlab/policy/pg/README.md
+++ b/convlab/policy/pg/README.md
@@ -1,36 +1,49 @@
-# REINFORCE
+# Policy Gradient (PG)
 
-A simple stochastic gradient algorithm for policy gradient reinforcement learning. We adapt REINFORCE to the dialog policy.
+PG is an on-policy reinforcement learning algorithm that uses the policy gradient theorem to perform policy updates, using directly the return as value estimation
+. 
+## Supervised pre-training
 
-## Train
+If you want to obtain a supervised model for pre-training, please have a look in the MLE policy folder.
 
-Run `train.py` in the `pg` directory:
+## RL training
 
-```bash
-python train.py
+Starting a RL training is as easy as executing
+
+```sh
+$ python train.py --path=your_environment_config --seed=SEED
 ```
 
-For better performance, we can do immitating learning before reinforcement learning. The immitating learning is implemented in the `mle` directory.
+One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance
 
-For example, if the trained model of immitating learning is saved at FOLDER_OF_MODEL/best_mle.pol.mdl, then you can run
+- load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
+- process_num: the number of processes to use during evaluation to speed it up
+- num_eval_dialogues: how many evaluation dialogues should be used
+- epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update
+- eval_frequency: after how many epochs perform an evaluation
+- batchsz: the number of training dialogues collected before doing an update
 
-```bash
-python train.py --load_path FOLDER_OF_MODEL/best_mle
-```
+Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
+
+Parameters that are tied to the RL algorithm and the model architecture can be changed in config.json.
+
+
+## Evaluation
 
-Note that the *.pol.mdl* suffix should not appear in the --load_path argument.
+For creating evaluation plots and running evaluation dialogues, please have a look in the README of the policy folder.
 
-## Reference
+## References
 
 ```
-@article{williams1992simple,
-  title={Simple statistical gradient-following algorithms for connectionist reinforcement learning},
-  author={Williams, Ronald J},
-  journal={Machine learning},
-  volume={8},
-  number={3-4},
-  pages={229--256},
-  year={1992},
-  publisher={Springer}
+@inproceedings{NIPS1999_464d828b,
+ author = {Sutton, Richard S and McAllester, David and Singh, Satinder and Mansour, Yishay},
+ booktitle = {Advances in Neural Information Processing Systems},
+ editor = {S. Solla and T. Leen and K. M\"{u}ller},
+ pages = {},
+ publisher = {MIT Press},
+ title = {Policy Gradient Methods for Reinforcement Learning with Function Approximation},
+ url = {https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf},
+ volume = {12},
+ year = {1999}
 }
 ```
\ No newline at end of file
diff --git a/convlab/policy/pg/train.py b/convlab/policy/pg/train.py
index 05d51015b38072ffc00b72ae11f9d1ba0ee0ade0..72a52e26bf54f558ca5fc24e0fa7fbbf8f25608e 100755
--- a/convlab/policy/pg/train.py
+++ b/convlab/policy/pg/train.py
@@ -184,7 +184,7 @@ if __name__ == '__main__':
     parser = ArgumentParser()
     parser.add_argument("--path", type=str, default='convlab/policy/pg/semantic_level_config.json',
                         help="Load path for config file")
-    parser.add_argument("--seed", type=int, default=0,
+    parser.add_argument("--seed", type=int, default=None,
                         help="Seed for the policy parameter initialization")
     parser.add_argument("--mode", type=str, default='info',
                         help="Set level for logger")
@@ -199,7 +199,7 @@ if __name__ == '__main__':
     logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
         init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
 
-    args = [('model', 'seed', seed)]
+    args = [('model', 'seed', seed)] if seed is not None else list()
 
     environment_config = load_config_file(path)
     save_config(vars(parser.parse_args()), environment_config, config_save_path)
@@ -261,7 +261,7 @@ if __name__ == '__main__':
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-            logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60)
+            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
 
             eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
diff --git a/convlab/policy/plot_results/README.md b/convlab/policy/plot_results/README.md
index 724dda6dfdeac7ca4cc103c445fde531b062df06..216d3a2df783ad647625134599c9c0df1fcc26cf 100644
--- a/convlab/policy/plot_results/README.md
+++ b/convlab/policy/plot_results/README.md
@@ -38,3 +38,5 @@ The file structure of the exp_dir is like this:
                     └── events.* 
 
 If you want to truncate the figure to a certain number of training dialogues on the x-axis, use the argument `--max-dialogues`.
+
+This script will automatically generate plots in the folder **--out-file** showing several performance metrics such as success rate and return, but also additional information such as the action distributions.
\ No newline at end of file
diff --git a/convlab/policy/plot_results/example_map.json b/convlab/policy/plot_results/example_map.json
index 72b5c6cc721b487cd313d2af7731b39607230a1a..7f5cd6242d93a70dafed36262a375c90e3702511 100644
--- a/convlab/policy/plot_results/example_map.json
+++ b/convlab/policy/plot_results/example_map.json
@@ -1,10 +1,18 @@
 [
   {
-    "dir": "ppo",
-    "legend": "PPO"
+    "dir": "scratch",
+    "legend": "scratch"
+  },
+    {
+    "dir": "sgd",
+    "legend": "SGD"
   },
   {
-    "dir": "pg",
-    "legend": "PG"
+    "dir": "mwoz",
+    "legend": "1%MWOZ"
+  },
+    {
+    "dir": "sgd_mw",
+    "legend": "SGD->1%MWOZ"
   }
 ]
\ No newline at end of file
diff --git a/convlab/policy/plot_results/plot.py b/convlab/policy/plot_results/plot.py
index 8dca2b84b79ff2c150ec4cb7389bc6e6ae43d5f4..c4032da663ec7a27af455c498b5243fb70500208 100644
--- a/convlab/policy/plot_results/plot.py
+++ b/convlab/policy/plot_results/plot.py
@@ -7,9 +7,12 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import seaborn as sns
+import sys
 from tensorboard.backend.event_processing import event_accumulator
 from tqdm import tqdm
 
+from convlab.policy.plot_results.plot_action_distributions import plot_distributions
+
 
 def get_args():
     parser = argparse.ArgumentParser(description='Export tensorboard data')
@@ -22,8 +25,14 @@ def get_args():
     parser.add_argument("--max-dialogues", type=int, default=0)
     parser.add_argument("--fill-between", type=float, default=0.3,
                         help="the transparency of the std err area")
+    parser.add_argument("--fontsize", type=int, default=18)
+    parser.add_argument("--font", type=str, default="Times New Roman")
+    parser.add_argument("--figure-size", type=str, help="Format 'width,height', eg '6,5'", default='6,5')
+    parser.add_argument("--figure-face-color", type=str, default='#E6E6E6')
 
     args = parser.parse_args()
+    args.figure_size = eval(args.figure_size)
+    plt.rcParams["font.family"] = args.font
     return args
 
 
@@ -56,53 +65,60 @@ def read_tb_data(in_path):
     return df
 
 
-def plot(data, out_file, plot_type="complete_rate", show_image=False, fill_between=0.3, max_dialogues=0, y_label=''):
+def plot(data, out_file, plot_type="complete_rate", show_image=False, fill_between=0.3, max_dialogues=0, y_label='',
+         fontsize=16, figsize=(12, 8), facecolor='#E6E6E6'):
 
     legends = [alg for alg in data]
     clrs = sns.color_palette("husl", len(legends))
-    plt.figure(plot_type)
-
-    with sns.axes_style("darkgrid"):
-        for i, alg in enumerate(legends):
-
-            max_step = min([len(d[plot_type]) for d in data[alg]])
-            if max_dialogues > 0:
-                max_length = min([len([s for s in d['steps'] if s <= max_dialogues]) for d in data[alg]])
-                max_step = min([max_length, max_step])
-            print("max_step: ", max_step)
-
-            value = np.array([d[plot_type][:max_step] for d in data[alg]])
-            step = np.array([d['steps'][:max_step] for d in data[alg]][0])
-            mean, err = np.mean(value, axis=0), np.std(value, axis=0)
-            plt.plot(
-                step, mean, c=clrs[i], label=alg)
-
-            plt.fill_between(
-                step, mean - err,
-                mean + err, alpha=fill_between, facecolor=clrs[i])
-        # locs, labels = plt.xticks()
-        # plt.xticks(locs, labels)
-        #plt.yticks(np.arange(10) / 10)
-        #plt.yticks([0.5, 0.6, 0.7])
-        plt.xlabel('Training dialogues')
-        if len(y_label) > 0:
-            plt.ylabel(y_label)
-        else:
-            plt.ylabel(plot_type)
-        plt.legend(fancybox=True, shadow=False, ncol=1, loc='lower left')
-        plt.savefig(out_file, bbox_inches='tight')
-
-        if show_image:
-            plt.show()
+    plt.figure(plot_type, figsize=figsize)
+    plt.gca().patch.set_facecolor(facecolor)
+    plt.grid(color='w', linestyle='solid', alpha=0.5)
+
+    largest_max = -sys.maxsize
+    smallest_min = sys.maxsize
+    for i, alg in enumerate(legends):
+
+        max_step = min([len(d[plot_type]) for d in data[alg]])
+        if max_dialogues > 0:
+            max_length = min([len([s for s in d['steps'] if s <= max_dialogues]) for d in data[alg]])
+            max_step = min([max_length, max_step])
+
+        value = np.array([d[plot_type][:max_step] for d in data[alg]])
+        step = np.array([d['steps'][:max_step] for d in data[alg]][0])
+        seeds_used = value.shape[0]
+        mean, err = np.mean(value, axis=0), np.std(value, axis=0)
+        err = err / np.sqrt(seeds_used)
+        plt.plot(
+            step, mean, c=clrs[i], label=alg)
+        plt.fill_between(
+            step, mean - err,
+            mean + err, alpha=fill_between, facecolor=clrs[i])
+        largest_max = mean.max() if mean.max() > largest_max else largest_max
+        smallest_min = mean.min() if mean.min() < smallest_min else smallest_min
+
+    plt.xlabel('Training Dialogues', fontsize=fontsize)
+    #plt.gca().yaxis.set_major_locator(plt.MultipleLocator(round((largest_max - smallest_min) / 10.0, 2)))
+    if len(y_label) > 0:
+        plt.ylabel(y_label.title(), fontsize=fontsize)
+    else:
+        plt.ylabel(plot_type.title(), fontsize=fontsize)
+    plt.xticks(fontsize=fontsize-4)
+    plt.yticks(fontsize=fontsize-4)
+    plt.legend(fancybox=True, shadow=False, ncol=1, loc='best', fontsize=fontsize)
+    plt.savefig(out_file + ".pdf", bbox_inches='tight', dpi=400, pad_inches=0)
+
+    if show_image:
+        plt.show()
 
 
 if __name__ == "__main__":
     args = get_args()
 
-    y_label_dict = {"complete_rate": 'Complete rate', "success_rate": 'Success rate', 'turns': 'Average turns',
-                    'avg_return': 'Average Return'}
+    y_label_dict = {"complete_rate": 'Complete Rate', "success_rate": 'Success Rate', 'turns': 'Average Turns',
+                    'avg_return': 'Average Return', "success_rate_strict": 'Strict Success Rate',
+                    "avg_actions": "Average Actions"}
 
-    for plot_type in ["complete_rate", "success_rate", 'turns', 'avg_return']:
+    for plot_type in ["complete_rate", "success_rate", "success_rate_strict", 'turns', 'avg_return', 'avg_actions']:
         file_name, file_extension = os.path.splitext(args.out_file)
         os.makedirs(file_name, exist_ok=True)
         fig_name = f"{file_name}_{plot_type}{file_extension}"
@@ -114,4 +130,11 @@ if __name__ == "__main__":
              plot_type=plot_type,
              fill_between=args.fill_between,
              max_dialogues=args.max_dialogues,
-             y_label=y_label_dict[plot_type])
+             y_label=y_label_dict[plot_type],
+             fontsize=args.fontsize,
+             figsize=args.figure_size,
+             facecolor=args.figure_face_color)
+
+    plot_distributions(args.dir, json.load(open(args.map_file)), args.out_file, fontsize=args.fontsize, font=args.font,
+                       figsize=args.figure_size, facecolor=args.figure_face_color)
+
diff --git a/convlab/policy/plot_results/plot_action_distributions.py b/convlab/policy/plot_results/plot_action_distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a8f923a594718f011f3ba742d7b964a0da11029
--- /dev/null
+++ b/convlab/policy/plot_results/plot_action_distributions.py
@@ -0,0 +1,158 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import seaborn as sns
+import pandas as pd
+
+
+def extract_action_distributions_across_seeds(algorithm_dir_path):
+    '''
+    We extract the information directly from the train_INFO.log file. An evaluation step has either of the two forms:
+
+    Evaluating at start - 2022-11-03-08-53-38------------------------------------------------------------
+    Complete: 0.636+-0.02, Success: 0.51+-0.02, Success strict: 0.432+-0.02, Average......
+    **OR**
+    Evaluating after Dialogues: 1000 - 2022-11-03-09-18-42------------------------------------------------------------
+    Complete: 0.786+-0.02, Success: 0.686+-0.02, Success strict: 0.634+-0.02, Average Return: 24.42.......
+    '''
+
+    seed_dir_paths = [f.path for f in os.scandir(
+        algorithm_dir_path) if f.is_dir()]
+    seed_dir_names = [f.name for f in os.scandir(
+        algorithm_dir_path) if f.is_dir()]
+
+    # dict below will have the form {0: {book: [], inform: [], ..}, 1000: {book: [], inform: [], ..}, ...}
+    # where 0 and 1000 are evaluation steps and the list will be as long as the number of seeds used
+    distribution_per_step_dict = {}
+
+    for seed_dir_name, seed_dir_path in zip(seed_dir_names, seed_dir_paths):
+
+        with open(os.path.join(seed_dir_path, 'logs', 'train_INFO.log'), "r") as f:
+            evaluation_found = False
+            num_dialogues = 0
+            for line in f:
+                line = line.strip()
+                if "evaluating at" in line.lower() or "evaluating after" in line.lower():
+                    evaluation_found = True
+                    if not "at start" in line.lower():
+                        num_dialogues = int(line.split(" ")[3])
+                    continue
+                if evaluation_found:
+                    # extracts the strings "book action: 0.3", "inform action: 0.4" ....
+                    action_distribution_string = [a for a in line.split(", ")
+                                                  if "actions" in a.lower() and "average actions" not in a.lower()]
+
+                    if num_dialogues in distribution_per_step_dict:
+                        for action_string in action_distribution_string:
+                            action = action_string.lower().split(" ")[0]
+                            distribution = float(
+                                action_string.lower().split(": ")[-1])
+                            if action in distribution_per_step_dict[num_dialogues]:
+                                distribution_per_step_dict[num_dialogues][action].append(
+                                    distribution)
+                            else:
+                                distribution_per_step_dict[num_dialogues][action] = [
+                                    distribution]
+                    else:
+                        distribution_per_step_dict[num_dialogues] = {}
+                        for action_string in action_distribution_string:
+                            action = action_string.lower().split(" ")[0]
+                            distribution = float(
+                                action_string.lower().split(": ")[-1])
+                            distribution_per_step_dict[num_dialogues][action] = [
+                                distribution]
+
+                    evaluation_found = False
+
+    return distribution_per_step_dict
+
+
+def plot_distributions(dir_path, alg_maps, output_dir, fill_between=0.3, fontsize=16, font="Times New Roman",
+                       figsize=(12, 8), facecolor='#E6E6E6'):
+    plt.rcParams["font.family"] = font
+    clrs = sns.color_palette("husl", len(alg_maps))
+
+    alg_paths = [os.path.join(dir_path, alg_map['dir'])
+                 for alg_map in alg_maps]
+    action_distributions = [
+        extract_action_distributions_across_seeds(path) for path in alg_paths]
+    possible_actions = action_distributions[0][0].keys()
+
+    create_bar_plots(action_distributions, alg_maps,
+                     possible_actions, output_dir,
+                     fontsize, figsize, facecolor)
+
+    for action in possible_actions:
+        plt.clf()
+        plt.figure(figsize=figsize)
+        plt.gca().patch.set_facecolor(facecolor)
+        plt.grid(color='w', linestyle='solid', alpha=0.5)
+
+        largest_max = 0
+        smallest_min = 1
+        for i, alg_distribution in enumerate(action_distributions):
+            steps = alg_distribution.keys()
+            try:
+                distributions = np.array(
+                    [alg_distribution[step][action] for step in steps])
+                # length = num_Evaluations * num_seeds
+                mean, std_dev = np.mean(distributions, axis=1), np.std(
+                    distributions, axis=1)
+                seeds_used = distributions.shape[1]
+                std_error = std_dev / np.sqrt(seeds_used)
+
+                # with sns.axes_style("darkgrid"):
+                plt.plot(steps, mean, c=clrs[i],
+                         label=f"{alg_maps[i]['legend']}")
+                plt.fill_between(
+                    steps, mean - std_error,
+                    mean + std_error, alpha=fill_between, facecolor=clrs[i])
+
+                largest_max = mean.max() if mean.max() > largest_max else largest_max
+                smallest_min = mean.min() if mean.min() < smallest_min else smallest_min
+
+            except Exception as e:
+                # catch if an algorithm does not have a specific action
+                print(e)
+        print(action)
+        if round((largest_max - smallest_min) / 10.0, 2) > 0:
+            plt.gca().yaxis.set_major_locator(plt.MultipleLocator(
+                round((largest_max - smallest_min) / 10.0, 2)))
+        plt.xticks(fontsize=fontsize-4, rotation=0)
+        plt.yticks(fontsize=fontsize-4)
+        plt.xlabel('Training Dialogues', fontsize=fontsize)
+        plt.ylabel(f"{action.title()} Intent Probability", fontsize=fontsize)
+        plt.legend(fancybox=True, shadow=False, ncol=1, loc='best')
+        plt.savefig(
+            output_dir + f'/{action}_probability.pdf', bbox_inches='tight',
+            dpi=400, pad_inches=0)
+
+
+def create_bar_plots(action_distributions, alg_maps, possible_actions, output_dir, fontsize, figsize, facecolor):
+
+    max_step = max(action_distributions[0].keys())
+    final_distributions = [distribution[max_step]
+                           for distribution in action_distributions]
+
+    df_list = []
+    for action in possible_actions:
+        action_list = [action.title()]
+        for distribution in final_distributions:
+            action_list.append(np.mean(distribution[action]))
+        df_list.append(action_list)
+
+    df = pd.DataFrame(df_list, columns=[
+                      'Probabilities'] + [alg_map["legend"] for alg_map in alg_maps])
+    plt.figure(figsize = figsize)
+    plt.rcParams.update({'font.size': fontsize})
+    fig = df.plot(x='Probabilities', kind='bar', stacked=False,
+                  rot=0, grid=True, color=sns.color_palette("husl", len(alg_maps)),
+                  fontsize=fontsize, figsize=figsize).get_figure()
+    plt.gca().patch.set_facecolor(facecolor)
+    plt.grid(color='w', linestyle='solid', alpha=0.5)
+    plt.yticks(np.arange(0, 1, 0.1), fontsize=fontsize-4)
+    plt.xticks(fontsize=fontsize-4)
+    plt.xlabel('Intents', fontsize=fontsize)
+    plt.ylabel('Probability', fontsize=fontsize)
+    fig.savefig(os.path.join(output_dir, "final_action_probabilities.pdf"),
+                dpi=400, bbox_inches='tight', pad_inches=0)
diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md
index 6bf87252ffbba511d32e275a079fc7cba2e954a9..c762253ce4fe769bb2c540b39d39df713881a7f3 100755
--- a/convlab/policy/ppo/README.md
+++ b/convlab/policy/ppo/README.md
@@ -1,34 +1,55 @@
-# PPO
+# Proximal Policy Optimization (PPO)
 
-A policy optimization method in policy based reinforcement learning that uses
-multiple epochs of stochastic gradient ascent and a constant
-clipping mechanism as the soft constraint to perform each policy update. We adapt PPO to the dialog policy.
+Proximal Policy Optimization (Schulmann et. al. 2017) is an on-policy reinforcement learning algorithm. The architecture used is a simple MLP and thus not transferable to new ontologies.
 
-## Train
+## Supervised pre-training
 
-Run `train.py` in the `ppo` directory:
+If you want to obtain a supervised model for pre-training, please have a look in the MLE policy folder.
 
-```bash
-python train.py
+## RL training
+
+Starting a RL training is as easy as executing
+
+```sh
+$ python train.py --path=your_environment_config --seed=SEED
 ```
 
-For better performance, we can do immitating learning before reinforcement learning. The immitating learning is implemented in the `mle` directory.
+One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance
 
-For example, if the trained model of immitating learning is saved at FOLDER_OF_MODEL/best_mle.pol.mdl, then you can run
+- load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
+- process_num: the number of processes to use during evaluation to speed it up
+- num_eval_dialogues: how many evaluation dialogues should be used
+- epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update
+- eval_frequency: after how many epochs perform an evaluation
+- batchsz: the number of training dialogues collected before doing an update
 
-```bash
-python train.py --load_path FOLDER_OF_MODEL/best_mle
-```
+Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
+
+Parameters that are tied to the RL algorithm and the model architecture can be changed in config.json.
+
+
+## Evaluation
 
-Note that the *.pol.mdl* suffix should not appear in the --load_path argument.
+For creating evaluation plots and running evaluation dialogues, please have a look in the README of the policy folder.
 
-## Reference
+## References
 
 ```
-@article{schulman2017proximal,
-  title={Proximal policy optimization algorithms},
-  author={Schulman, John and Wolski, Filip and Dhariwal, Prafulla and Radford, Alec and Klimov, Oleg},
-  journal={arXiv preprint arXiv:1707.06347},
-  year={2017}
+@article{DBLP:journals/corr/SchulmanWDRK17,
+  author    = {John Schulman and
+               Filip Wolski and
+               Prafulla Dhariwal and
+               Alec Radford and
+               Oleg Klimov},
+  title     = {Proximal Policy Optimization Algorithms},
+  journal   = {CoRR},
+  volume    = {abs/1707.06347},
+  year      = {2017},
+  url       = {http://arxiv.org/abs/1707.06347},
+  eprinttype = {arXiv},
+  eprint    = {1707.06347},
+  timestamp = {Mon, 13 Aug 2018 16:47:34 +0200},
+  biburl    = {https://dblp.org/rec/journals/corr/SchulmanWDRK17.bib},
+  bibsource = {dblp computer science bibliography, https://dblp.org}
 }
 ```
\ No newline at end of file
diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
new file mode 100644
index 0000000000000000000000000000000000000000..0e8774e20898c2855f368127f8e14e193ac2c21d
--- /dev/null
+++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
@@ -0,0 +1,44 @@
+{
+	"model": {
+		"load_path": "convlab/policy/ppo/pretrained_models/mle",
+		"pretrained_load_path": "",
+		"use_pretrained_initialisation": false,
+		"batchsz": 500,
+		"seed": 0,
+		"epoch": 50,
+		"eval_frequency": 5,
+		"process_num": 1,
+		"num_eval_dialogues": 500,
+		"sys_semantic_to_usr": false
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
+			"ini_params": {
+				"use_masking": true,
+				"manually_add_entity_names": true,
+				"seed": 0
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"RuleDST": {
+			"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
+			"ini_params": {}
+		}
+	},
+	"sys_nlg": {},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"RulePolicy": {
+			"class_path": "convlab.policy.genTUS.stepGenTUS.UserPolicy",
+			"ini_params": {
+				"model_checkpoint": "convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0",
+				"character": "usr"
+			}
+		}
+	},
+	"usr_nlg": {}
+}
diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json
index 381acac72bd2c28e870e091e7c36d499c1bbd76c..b9908c9cb7717515775221227f3fba19636d20dc 100644
--- a/convlab/policy/ppo/semantic_level_config.json
+++ b/convlab/policy/ppo/semantic_level_config.json
@@ -1,12 +1,12 @@
 {
 	"model": {
-		"load_path": "convlab/policy/mle/experiments/experiment_2022-05-23-14-08-43/save/supervised",
+		"load_path": "",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
 		"batchsz": 1000,
 		"seed": 0,
-		"epoch": 50,
-		"eval_frequency": 5,
+		"epoch": 10,
+		"eval_frequency": 1,
 		"process_num": 4,
 		"sys_semantic_to_usr": false,
 		"num_eval_dialogues": 500
@@ -40,4 +40,4 @@
 		}
 	},
 	"usr_nlg": {}
-}
\ No newline at end of file
+}
diff --git a/convlab/policy/ppo/setsumbt_end_baseline_config.json b/convlab/policy/ppo/setsumbt_config.json
similarity index 53%
rename from convlab/policy/ppo/setsumbt_end_baseline_config.json
rename to convlab/policy/ppo/setsumbt_config.json
index ea84dd7670a3a27dcc267e476c894bf939c9ba10..5a13ee82fcbf24c0b13112106d2b97f115966e1a 100644
--- a/convlab/policy/ppo/setsumbt_end_baseline_config.json
+++ b/convlab/policy/ppo/setsumbt_config.json
@@ -1,22 +1,22 @@
 {
 	"model": {
-		"load_path": "supervised",
+		"load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-11-13-12-56-34/save/supervised",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
 		"batchsz": 1000,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
-		"process_num": 4,
+		"process_num": 2,
 		"num_eval_dialogues": 500,
-		"sys_semantic_to_usr": false
+		"sys_semantic_to_usr": true
 	},
 	"vectorizer_sys": {
 		"uncertainty_vector_mul": {
-			"class_path": "convlab.policy.vector.vector_multiwoz_uncertainty.MultiWozVector",
+			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
 			"ini_params": {
 				"use_masking": false,
-				"manually_add_entity_names": false,
+				"manually_add_entity_names": true,
 				"seed": 0
 			}
 		}
@@ -24,12 +24,9 @@
 	"nlu_sys": {},
 	"dst_sys": {
 		"setsumbt-mul": {
-			"class_path": "convlab.dst.setsumbt.multiwoz.Tracker.SetSUMBTTracker",
+			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
 			"ini_params": {
-				"model_path": "https://zenodo.org/record/5497808/files/setsumbt_end.zip",
-				"get_confidence_scores": true,
-				"return_mutual_info": false,
-				"return_entropy": true
+				"model_path": "/gpfs/project/niekerk/models/setsumbt_models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0-30-08-22-15-00"
 			}
 		}
 	},
@@ -41,16 +38,7 @@
 			}
 		}
 	},
-	"nlu_usr": {
-		"BERTNLU": {
-			"class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU",
-			"ini_params": {
-				"mode": "sys",
-				"config_file": "multiwoz_sys_context.json",
-				"model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip"
-			}
-		}
-	},
+	"nlu_usr": {},
 	"dst_usr": {},
 	"policy_usr": {
 		"RulePolicy": {
@@ -65,7 +53,7 @@
 			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
 			"ini_params": {
 				"is_user": true,
-				"label_noise": 0.0,
+				"label_noise": 0.05,
 				"text_noise": 0.0
 			}
 		}
diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/setsumbt_unc_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..6b7d115aafa53a2bfc5c58672e67086f04d5884d
--- /dev/null
+++ b/convlab/policy/ppo/setsumbt_unc_config.json
@@ -0,0 +1,65 @@
+{
+	"model": {
+		"load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-11-10-10-37-30/save/supervised",
+		"pretrained_load_path": "",
+		"use_pretrained_initialisation": false,
+		"batchsz": 1000,
+		"seed": 0,
+		"epoch": 50,
+		"eval_frequency": 5,
+		"process_num": 2,
+		"num_eval_dialogues": 500,
+		"sys_semantic_to_usr": true
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_uncertainty.VectorUncertainty",
+			"ini_params": {
+				"use_masking": false,
+				"manually_add_entity_names": true,
+				"seed": 0,
+				"use_confidence_scores": true,
+				"use_state_knowledge_uncertainty": true
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"setsumbt-mul": {
+			"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
+			"ini_params": {
+				"model_path": "/gpfs/project/niekerk/models/setsumbt_models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0-30-08-22-15-00",
+				"return_confidence_scores": true,
+				"return_belief_state_mutual_info": true
+			}
+		}
+	},
+	"sys_nlg": {
+		"TemplateNLG": {
+			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
+			"ini_params": {
+				"is_user": false
+			}
+		}
+	},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"RulePolicy": {
+			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
+			"ini_params": {
+				"character": "usr"
+			}
+		}
+	},
+	"usr_nlg": {
+		"TemplateNLG": {
+			"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
+			"ini_params": {
+				"is_user": true,
+				"label_noise": 0.05,
+				"text_noise": 0.0
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py
index 899e91525ba67de7015b1b7fafeab516e4b02f2b..703a55005b8c07578b85765a626d9871deebf26e 100755
--- a/convlab/policy/ppo/train.py
+++ b/convlab/policy/ppo/train.py
@@ -184,7 +184,7 @@ if __name__ == '__main__':
     parser = ArgumentParser()
     parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json',
                         help="Load path for config file")
-    parser.add_argument("--seed", type=int, default=0,
+    parser.add_argument("--seed", type=int, default=None,
                         help="Seed for the policy parameter initialization")
     parser.add_argument("--mode", type=str, default='info',
                         help="Set level for logger")
@@ -199,7 +199,7 @@ if __name__ == '__main__':
     logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
         init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
 
-    args = [('model', 'seed', seed)]
+    args = [('model', 'seed', seed)] if seed is not None else list()
 
     environment_config = load_config_file(path)
     save_config(vars(parser.parse_args()), environment_config, config_save_path)
@@ -228,13 +228,6 @@ if __name__ == '__main__':
 
     env, sess = env_config(conf, policy_sys)
 
-    # Setup uncertainty thresholding
-    if env.sys_dst:
-        try:
-            if env.sys_dst.use_confidence_scores:
-                policy_sys.vector.setup_uncertain_query(env.sys_dst.thresholds)
-        except:
-            logging.info('Uncertainty threshold not set.')
 
     policy_sys.current_time = current_time
     policy_sys.log_dir = config_save_path.replace('configs', 'logs')
@@ -261,7 +254,7 @@ if __name__ == '__main__':
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-            logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60)
+            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
 
             eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..9d56646cc857de85b47a1f9925a0e4bf89d8b524
--- /dev/null
+++ b/convlab/policy/ppo/tus_semantic_level_config.json
@@ -0,0 +1,43 @@
+{
+	"model": {
+		"load_path": "convlab/policy/ppo/pretrained_models/mle",
+		"use_pretrained_initialisation": false,
+		"pretrained_load_path": "",
+		"batchsz": 1000,
+		"seed": 0,
+		"epoch": 50,
+		"eval_frequency": 5,
+		"process_num": 4,
+		"sys_semantic_to_usr": false,
+		"num_eval_dialogues": 500
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
+			"ini_params": {
+				"use_masking": true,
+				"manually_add_entity_names": false,
+				"seed": 0
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"RuleDST": {
+			"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
+			"ini_params": {}
+		}
+	},
+	"sys_nlg": {},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"TUSPolicy": {
+			"class_path": "convlab.policy.tus.unify.TUS.UserPolicy",
+			"ini_params": {
+				"config": "convlab/policy/tus/unify/exp/multiwoz.json"
+			}
+		}
+	},
+	"usr_nlg": {}
+}
diff --git a/convlab/policy/rlmodule.py b/convlab/policy/rlmodule.py
index db46026656d908b2453a9143b3f482ce7378e382..18844ca375967cb9524dacd3a81a4f85a2cf9041 100755
--- a/convlab/policy/rlmodule.py
+++ b/convlab/policy/rlmodule.py
@@ -287,7 +287,7 @@ class Value(nn.Module):
 Transition_evaluator = namedtuple('Transition_evaluator',
                                   ('complete', 'success', 'success_strict', 'total_return_complete', 'total_return_success', 'turns',
                                    'avg_actions', 'task_success', 'book_actions', 'inform_actions', 'request_actions', 'select_actions',
-                                   'offer_actions'))
+                                   'offer_actions', 'recommend_actions'))
 
 
 class Memory_evaluator(object):
diff --git a/convlab/policy/tus/README.md b/convlab/policy/tus/README.md
index 2976e17bcbd5662ff4802eef9efb2219172f95af..9424edcf3795220a582809e1455c942f300c7740 100644
--- a/convlab/policy/tus/README.md
+++ b/convlab/policy/tus/README.md
@@ -1,30 +1,48 @@
-**TUS** is a domain-independent user simulator with transformers for task-oriented dialogue systems. It is based on the [ConvLab-2](https://github.com/thu-coai/ConvLab-2) framework. Therefore, you should follow their instruction to install the package.
+**TUS** is a domain-independent user simulator with transformers for task-oriented dialogue systems.
 
 ## Introduction
-Our model is a domain-independent user simulator, which means it is not based on any domain-dependent freatures and the output representation is also domain-independent. Therefore, it can easily adapt to a new domain, without additional feature engineering and model retraining.
+Our model is a domain-independent user simulator, which means its input and output representations are domain agnostic. Therefore, it can easily adapt to a new domain, without additional feature engineering and model retraining.
 
-The code of TUS is in `convlab/policy/tus` and a rule-based DST of user is also created in `convlab/dst/rule/multiwoz/dst.py` based on the rule-based DST in `convlab/dst/rule/multiwoz/dst.py`.
+The code of TUS is in `convlab/policy/tus`.
 
-## How to run the model
-### Train the user simulator
-`python3 convlab/policy/tus/multiwoz/train.py --user_config convlab/policy/tus/multiwoz/exp/default.json`
+## Usage
+### Train TUS from scratch
 
-One default configuration is placed in `convlab/policy/tus/multiwoz/exp/default.json`. They can be modified based on your requirements. For example, the output directory can be specified in the configuration (`model_dir`).
+```
+python3 convlab/policy/tus/unify/train.py --dataset $dataset --dial-ids-order $dial_ids_order --split2ratio $split2ratio --user-config $config
+```
+
+`dataset` can be `multiwoz21`, `sgd`, `tm`, `sgd+tm`, or `all`.
+`dial_ids_order` can be 0, 1 or 2
+`split2ratio` can be 0.01, 0.1 or 1
+Default configurations are placed in `convlab/policy/tus/unify/exp`. They can be modified based on your requirements. 
+
+For example, you can train TUS for multiwoz21 by 
+`python3 convlab/policy/tus/unify/train.py --dataset multiwoz21 --dial-ids-order 0 --split2ratio 1 --user-config "convlab/policy/tus/unify/exp/multiwoz.json"`
+
+### Evaluate TUS
 
 ### Train a dialogue policy with TUS
 You can use it as a normal user simulator by `PipelineAgent`. For example,
 ```python
 import json
 from convlab.dialog_agent.agent import PipelineAgent
-from convlab.dst.rule.multiwoz.usr_dst import UserRuleDST
-from convlab.policy.tus.multiwoz.TUS import UserPolicy
+from convlab.policy.tus.unify.TUS import UserPolicy
 
-user_config_file = "convlab/policy/tus/multiwoz/exp/default.json"
-dst_usr = UserRuleDST()
+user_config_file = "convlab/policy/tus/unify/exp/multiwoz.json"
 user_config = json.load(open(user_config_file))
 policy_usr = UserPolicy(user_config)
-simulator = PipelineAgent(None, dst_usr, policy_usr, None, 'user')
+simulator = PipelineAgent(None, None, policy_usr, None, 'user')
+```
+then you can train your system with this simulator.
+
+There is an example config, which trains a PPO policy with TUS in semantic level, in `convlab/policy/ppo/tus_semantic_level_config.json`.
+You can train a PPO policy as following, 
+```
+config="convlab/policy/ppo/tus_semantic_level_config.json"
+python3 convlab/policy/ppo/train.py --path $config
 ```
+notice: You should name your pretrained policy as `convlab/policy/ppo/pretrained_models/mle` or modify the `load_path` of `model` in the config `convlab/policy/ppo/tus_semantic_level_config.json`.
 
 
 <!---citation--->
diff --git a/convlab/policy/tus/multiwoz/Goal.py b/convlab/policy/tus/multiwoz/Goal.py
index e337a1bd3b5668589f130ae2e2ec5a1615fdd985..c26e942723fd9e76901572dd2a85fd7433c4b275 100644
--- a/convlab/policy/tus/multiwoz/Goal.py
+++ b/convlab/policy/tus/multiwoz/Goal.py
@@ -44,10 +44,10 @@ class Goal(object):
         elif goal_generator is None and goal is not None:
             self.domains = []
             self.domain_goals = {}
-            for domain in goal:
-                if domain in SysDa2Goal and goal[domain]:  # TODO check order
+            for domain in goal.domains:
+                if domain in SysDa2Goal and goal.domain_goals[domain]:  # TODO check order
                     self.domains.append(domain)
-                    self.domain_goals[domain] = goal[domain]
+                    self.domain_goals[domain] = goal.domain_goals[domain]
         else:
             print("Warning!!! One of goal_generator or goal should not be None!!!")
 
diff --git a/convlab/policy/tus/multiwoz/TUS.py b/convlab/policy/tus/multiwoz/TUS.py
index 725098d9162d6900746aa7398c14a8aafe49d786..f8f70ab3e918073079dec850dac9ad78ff79f36a 100644
--- a/convlab/policy/tus/multiwoz/TUS.py
+++ b/convlab/policy/tus/multiwoz/TUS.py
@@ -18,7 +18,8 @@ from convlab.util.custom_util import model_downloader
 from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple
 from convlab.util import relative_import_module_from_unified_datasets
 
-reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value'])
+reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets(
+    'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value'])
 
 
 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
diff --git a/convlab/policy/tus/unify/Goal.py b/convlab/policy/tus/unify/Goal.py
new file mode 100644
index 0000000000000000000000000000000000000000..610469bb4246cf6c16ef4271c89827cab6421437
--- /dev/null
+++ b/convlab/policy/tus/unify/Goal.py
@@ -0,0 +1,302 @@
+import time
+import json
+from convlab.policy.tus.unify.util import split_slot_name, slot_name_map
+from convlab.util.custom_util import slot_mapping
+
+from random import sample, shuffle
+from pprint import pprint
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""]
+
+
+# only support user goal from dataset
+
+
+def is_time(goal, status):
+    if isTimeFormat(goal) and isTimeFormat(status):
+        return True
+    return False
+
+
+def isTimeFormat(input):
+    try:
+        time.strptime(input, '%H:%M')
+        return True
+    except ValueError:
+        return False
+
+
+def old_goal2list(goal: dict, reorder=False) -> list:
+    goal_list = []
+    for domain in goal:
+        for slot_type in ['info', 'book', 'reqt']:
+            if slot_type not in goal[domain]:
+                continue
+            temp = []
+            for slot in goal[domain][slot_type]:
+                s = slot
+                if slot in slot_name_map:
+                    s = slot_name_map[slot]
+                elif slot in slot_name_map[domain]:
+                    s = slot_name_map[domain][slot]
+                # domain, intent, slot, value
+                if slot_type in ['info', 'book']:
+                    i = "inform"
+                    v = goal[domain][slot_type][slot]
+                else:
+                    i = "request"
+                    v = DEF_VAL_UNK
+                s = slot_mapping.get(s, s)
+                temp.append([domain, i, s, v])
+            shuffle(temp)
+            goal_list = goal_list + temp
+    # shuffle_goal = goal_list[:1] + sample(goal_list[1:], len(goal_list)-1)
+    # return shuffle_goal
+    return goal_list
+
+
+class Goal(object):
+    """ User Goal Model Class. """
+
+    def __init__(self, goal: list = None):
+        """
+        create new Goal by random
+        Args:
+            goal (list): user goal built from user history
+            ontology (dict): domains, slots, values
+        """
+        self.goal = goal
+        self.max_domain_len = 6
+        self.max_slot_len = 20
+        self.local_id = {}
+
+        self.domains = []
+        # goal: {domain: {"info": {slot: value}, "reqt": {slot:?}}, ...}
+        self.domain_goals = {}
+        # status: {domain: {slot: value}}
+        self.status = {}
+        self.user_history = {}
+        self.init_goal_status(goal)
+        self.init_local_id()
+
+    def __str__(self):
+        return '-----Goal-----\n' + \
+               json.dumps(self.domain_goals, indent=4) + \
+               '\n-----Goal-----'
+
+    def init_goal_status(self, goal):
+        for domain, intent, slot, value in goal:  # check this order
+            if domain not in self.domains:
+                self.domains.append(domain)
+                self.domain_goals[domain] = {}
+
+            # "book" domain is not clear for unify data format
+
+            if "request" in intent.lower():
+                if "reqt" not in self.domain_goals[domain]:
+                    self.domain_goals[domain]["reqt"] = {}
+                self.domain_goals[domain]["reqt"][slot] = DEF_VAL_UNK
+
+            elif "info" in intent.lower():
+                if "info" not in self.domain_goals[domain]:
+                    self.domain_goals[domain]["info"] = {}
+                self.domain_goals[domain]["info"][slot] = value
+
+            self.user_history[f"{domain}-{slot}"] = value
+
+    def task_complete(self):
+        """
+        Check that all requests have been met
+        Returns:
+            (boolean): True to accomplish.
+        """
+        for domain in self.domain_goals:
+            if domain not in self.status:
+                # print(f"{domain} is not mentioned")
+                return False
+            if "info" in self.domain_goals[domain]:
+                for slot in self.domain_goals[domain]["info"]:
+                    if slot not in self.status[domain]:
+                        # print(f"{slot} is not mentioned")
+                        return False
+                    goal = self.domain_goals[domain]["info"][slot].lower()
+                    status = self.status[domain][slot].lower()
+                    if goal != status and not is_time(goal, status):
+                        # print(f"conflict slot {slot}: {goal} <-> {status}")
+                        return False
+            if "reqt" in self.domain_goals[domain]:
+                for slot in self.domain_goals[domain]["reqt"]:
+                    if self.domain_goals[domain]["reqt"][slot] == DEF_VAL_UNK:
+                        # print(f"not fulfilled request{domain}-{slot}")
+                        return False
+
+        return True
+
+    def init_local_id(self):
+        # local_id = {
+        #     "domain 1": {
+        #         "ID": [1, 0, 0],
+        #         "SLOT": {
+        #             "slot 1": [1, 0, 0],
+        #             "slot 2": [0, 1, 0]}}}
+
+        for domain_id, domain in enumerate(self.domains):
+            self._init_domain_id(domain)
+            self._update_domain_id(domain, domain_id)
+            slot_id = 0
+            for slot_type in ["info", "book", "reqt"]:
+                for slot in self.domain_goals[domain].get(slot_type, {}):
+                    self._init_slot_id(domain, slot)
+                    self._update_slot_id(domain, slot, slot_id)
+                    slot_id += 1
+
+    def insert_local_id(self, new_slot_name):
+        # domain, slot = new_slot_name.split('-')
+        domain, slot = split_slot_name(new_slot_name)
+        if domain not in self.local_id:
+            self._init_domain_id(domain)
+            domain_id = len(self.domains) + 1
+            self._update_domain_id(domain, domain_id)
+            self._init_slot_id(domain, slot)
+            # the first slot for a new domain
+            self._update_slot_id(domain, slot, 0)
+
+        else:
+            slot_id = len(self.local_id[domain]["SLOT"]) + 1
+            self._init_slot_id(domain, slot)
+            self._update_slot_id(domain, slot, slot_id)
+
+    def get_slot_id(self, slot_name):
+        # print(slot_name)
+        # domain, slot = slot_name.split('-')
+        domain, slot = split_slot_name(slot_name)
+        if domain in self.local_id and slot in self.local_id[domain]["SLOT"]:
+            return self.local_id[domain]["ID"], self.local_id[domain]["SLOT"][slot]
+        else:  # a slot not in original user goal
+            self.insert_local_id(slot_name)
+            domain_id, slot_id = self.get_slot_id(slot_name)
+            return domain_id, slot_id
+
+    def action_list(self, sys_act=None):
+        priority_action = [x for x in self.user_history]
+
+        if sys_act:
+            for _, domain, slot, _ in sys_act:
+                slot_name = f"{domain}-{slot}"
+                if slot_name and slot_name not in priority_action:
+                    priority_action.insert(0, slot_name)
+
+        return priority_action
+
+    def update(self, action: list = None, char: str = "system"):
+        # update request and booked
+        if char not in ["user", "system"]:
+            print(f"unknown role: {char}")
+        self._update_status(action, char)
+        self._update_goal(action, char)
+        return self.status
+
+    def _update_status(self, action: list, char: str):
+        for intent, domain, slot, value in action:
+            if slot == "arrive by":
+                slot = "arriveBy"
+            elif slot == "leave at":
+                slot = "leaveAt"
+            if domain not in self.status:
+                self.status[domain] = {}
+            # update info
+            if "info" in intent:
+                self.status[domain][slot] = value
+            elif "request" in intent:
+                self.status[domain][slot] = DEF_VAL_UNK
+
+    def _update_goal(self, action: list, char: str):
+        # update requt slots in goal
+        for intent, domain, slot, value in action:
+            if slot == "arrive by":
+                slot = "arriveBy"
+            elif slot == "leave at":
+                slot = "leaveAt"
+            if "info" not in intent:
+                continue
+            if self._check_update_request(domain, slot) and value != "?":
+                self.domain_goals[domain]['reqt'][slot] = value
+                # print(f"update reqt {slot} = {value} from system action")
+
+    def _update_slot(self, domain, slot, value):
+        self.domain_goals[domain]['reqt'][slot] = value
+
+    def _check_update_request(self, domain, slot):
+        # check whether one slot is a request slot
+        if domain not in self.domain_goals:
+            return False
+        if 'reqt' not in self.domain_goals[domain]:
+            return False
+        if slot not in self.domain_goals[domain]['reqt']:
+            return False
+        return True
+
+    def _check_value(self, value=None):
+        if not value:
+            return False
+        if value in NOT_SURE_VALS:
+            return False
+        return True
+
+    def _init_domain_id(self, domain):
+        self.local_id[domain] = {"ID": [0] * self.max_domain_len, "SLOT": {}}
+
+    def _init_slot_id(self, domain, slot):
+        self.local_id[domain]["SLOT"][slot] = [0] * self.max_slot_len
+
+    def _update_domain_id(self, domain, domain_id):
+        if domain_id < self.max_domain_len:
+            self.local_id[domain]["ID"][domain_id] = 1
+        else:
+            print(
+                f"too many doamins: {domain_id} > {self.max_domain_len}")
+
+    def _update_slot_id(self, domain, slot, slot_id):
+        if slot_id < self.max_slot_len:
+            self.local_id[domain]["SLOT"][slot][slot_id] = 1
+        else:
+            print(
+                f"too many slots, {slot_id} > {self.max_slot_len}")
+
+
+if __name__ == "__main__":
+    data_goal = [["restaurant", "inform", "cuisine", "punjabi"],
+                 ["restaurant", "inform", "city", "milpitas"],
+                 ["restaurant", "request", "price_range", ""],
+                 ["restaurant", "request", "street_address", ""]]
+    goal = Goal(data_goal)
+    print(goal)
+    action = {"char": "system",
+              "action": [["request", "restaurant", "cuisine", "?"], ["request", "restaurant", "city", "?"]]}
+    goal.update(action["action"], action["char"])
+    print(goal.status)
+    print("complete:", goal.task_complete())
+    action = {"char": "user",
+              "action": [["inform", "restaurant", "cuisine", "punjabi"], ["inform", "restaurant", "city", "milpitas"]]}
+    goal.update(action["action"], action["char"])
+    print(goal.status)
+    print("complete:", goal.task_complete())
+    action = {"char": "system",
+              "action": [["inform", "restaurant", "price_range", "cheap"]]}
+    goal.update(action["action"], action["char"])
+    print(goal.status)
+    print("complete:", goal.task_complete())
+    action = {"char": "user",
+              "action": [["request", "restaurant", "street_address", ""]]}
+    goal.update(action["action"], action["char"])
+    print(goal.status)
+    print("complete:", goal.task_complete())
+    action = {"char": "system",
+              "action": [["inform", "restaurant", "street_address", "ABCD"]]}
+    goal.update(action["action"], action["char"])
+    print(goal.status)
+    print("complete:", goal.task_complete())
diff --git a/convlab/policy/tus/unify/TUS.py b/convlab/policy/tus/unify/TUS.py
new file mode 100644
index 0000000000000000000000000000000000000000..be931eca16b03164fea094d55c08b0c85acdbd64
--- /dev/null
+++ b/convlab/policy/tus/unify/TUS.py
@@ -0,0 +1,456 @@
+import json
+import os
+import random
+from copy import deepcopy
+
+import torch
+from convlab.policy.policy import Policy
+from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction
+from convlab.policy.tus.unify.Goal import Goal
+from convlab.policy.tus.unify.usermanager import BinaryFeature
+from convlab.policy.tus.unify.util import create_goal, split_slot_name
+from convlab.util import (load_dataset,
+                          relative_import_module_from_unified_datasets)
+from convlab.util.custom_util import model_downloader
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+from convlab.policy.tus.unify.Goal import old_goal2list
+from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
+
+
+reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets(
+    'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value'])
+
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+Inform = "inform"
+Request = "request"
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""]
+
+
+# TODO not ready for unify dataformat now
+class UserActionPolicy(Policy):
+    def __init__(self, config, pretrain=True, dataset="multiwoz21"):
+        Policy.__init__(self)
+        self.dataset = dataset
+        if isinstance(config, str):
+            self.config = json.load(open(config))
+        else:
+            self.config = config
+
+        feat_type = self.config.get("feat_type", "binary")
+        # print("feat_type", feat_type)
+        self.feat_handler = BinaryFeature(self.config)
+
+        self.config["num_token"] = config["num_token"]
+        self.user = TransformerActionPrediction(self.config).to(device=DEVICE)
+        if pretrain:
+            model_path = os.path.join(
+                self.config["model_dir"], "model-non-zero")  # self.config["model_name"])
+            print(f"loading model from {model_path}...")
+            self.load(model_path)
+        self.user.eval()
+        self.use_domain_mask = self.config.get("domain_mask", False)
+        self.max_turn = 40
+        self.mentioned_domain = []
+        self.reward = {"success": 40,
+                       "fail": -20}
+        self.sys_acts = []
+        self.goal_gen = GoalGenerator()
+        self.raw_goal = None
+
+    def _no_offer(self, system_in):
+        for intent, domain, slot, value in system_in:
+            if intent.lower() == "nooffer":
+                self.terminated = True
+                return True
+            else:
+                return False
+
+    def predict(self, sys_dialog_act, mode="max"):
+        # update goal
+        self.predict_action_list = self.goal.action_list(sys_dialog_act)
+        cur_state = self.goal.update(action=sys_dialog_act, char="system")
+        self.sys_acts.append(sys_dialog_act)
+
+        # need better way to handle this
+        if self._no_offer(sys_dialog_act):
+            return [["bye", "general", "none", "none"]]
+
+        # update constraint
+        self.time_step += 2
+
+        feature, mask = self.feat_handler.get_feature(
+            all_slot=self.predict_action_list,
+            user_goal=self.goal,
+            cur_state=cur_state,
+            pre_state=self.sys_history_state,
+            sys_action=sys_dialog_act,
+            usr_action=self.pre_usr_act)
+        feature = torch.tensor([feature], dtype=torch.float).to(DEVICE)
+        mask = torch.tensor([mask], dtype=torch.bool).to(DEVICE)
+
+        self.sys_history_state = cur_state
+
+        usr_output = self.user.forward(feature, mask)
+        usr_action = self.transform_usr_act(
+            usr_output, self.predict_action_list, mode)
+        domains = [act[1] for act in usr_action]
+        none_slot_acts = self._add_none_slot_act(domains)
+        usr_action = none_slot_acts + usr_action
+
+        self.pre_usr_act = deepcopy(usr_action)
+
+        if len(usr_action) < 1:
+            print("EMPTY ACTION")
+
+        # convert user action to unify data format
+        norm_usr_action = []
+        for intent, domain, slot, value in usr_action:
+            intent = intent
+            # domain, slot, value = normalize_domain_slot_value(
+            #     domain, slot, value)
+            norm_usr_action.append([intent, domain, slot, value])
+
+        cur_state = self.goal.update(action=norm_usr_action, char="user")
+
+        return norm_usr_action
+
+        # return usr_action
+
+    def init_session(self, goal=None):
+        self.mentioned_domain = []
+        self.time_step = 0
+        self.topic = 'NONE'
+        remove_domain = "police"  # remove police domain in inference
+
+        if type(goal) == ABUS_Goal:
+            self.raw_goal = goal.domain_goals
+            goal_list = old_goal2list(goal.domain_goals)
+            goal = Goal(goal_list)
+        elif type(goal) == Goal:
+            self.raw_goal = goal.domain_goals
+        else:
+            goal = ABUS_Goal(self.goal_gen)
+            self.raw_goal = goal.domain_goals
+            goal_list = old_goal2list(goal.domain_goals)
+            goal = Goal(goal_list)
+
+        self.read_goal(goal)
+        self.feat_handler.initFeatureHandeler(self.goal)
+
+        # print(self.goal)
+        if self.config.get("reorder", False):
+            self.predict_action_list = self.goal.action_list()
+        else:
+            self.predict_action_list = self.action_list
+        self.sys_history_state = None  # to save sys history
+        self.terminated = False
+
+        self.pre_usr_act = None
+        self.sys_acts = []
+
+    def read_goal(self, data_goal):
+        if type(data_goal) == Goal:
+            self.goal = data_goal
+        else:
+            self.goal = Goal(goal=data_goal)
+
+    # def new_goal(self, remove_domain="police", domain_len=None):
+    #     keep_generate_goal = True
+    #     while keep_generate_goal:
+    #         self.goal = Goal(goal_generator=self.goal_gen)
+    #         if (domain_len and len(self.goal.domains) != domain_len) or \
+    #                 (remove_domain and remove_domain in self.goal.domains):
+    #             keep_generate_goal = True
+    #         else:
+    #             keep_generate_goal = False
+
+    def load(self, model_path=None):
+        self.user.load_state_dict(torch.load(model_path, map_location=DEVICE))
+
+    def load_state_dict(self, model=None):
+        self.user.load_state_dict(model)
+
+    def _get_goal(self):
+        # internal usage
+        return self.goal.domain_goals
+
+    def get_goal(self):
+        # for outside usage, e.g. evaluator
+        return self.raw_goal
+
+    def get_reward(self):
+        if self.goal.task_complete():
+            # reward = 2 * self.max_turn
+            reward = self.reward["success"]
+            # reward = 1
+
+        elif self.time_step >= self.max_turn:
+            # reward = -1 * self.max_turn
+            reward = self.reward["fail"]
+            # reward = -1
+
+        else:
+            # reward = -1.0
+            reward = 0
+        return reward
+
+    def _add_none_slot_act(self, domains):
+        actions = []
+        for domain in domains:
+            domain = domain.lower()
+            if domain not in self.mentioned_domain and domain != 'general':
+                actions.append([Inform, domain, "none", "none"])
+                self.mentioned_domain.append(domain)
+        return actions
+
+    def _finish_conversation(self):
+
+        if self.goal.task_complete():
+            return True, [['thank', 'general', 'none', 'none']]
+
+        if self.time_step > self.max_turn:
+            return True, [["bye", "general", "none", "none"]]
+
+        if len(self.sys_acts) >= 3:
+            if self.sys_acts[-1] == self.sys_acts[-2] and self.sys_acts[-2] == self.sys_acts[-3]:
+                return True, [["bye", "general", "none", "none"]]
+
+        return False, [[]]
+
+    def transform_usr_act(self, usr_output, action_list, mode="max"):
+        is_finish, usr_action = self._finish_conversation()
+        if is_finish:
+            self.terminated = True
+            # if "bye" == usr_action[0][0]:
+            #     print("fail")
+            #     pprint(self.goal.domain_goals)
+            #     pprint(self.goal.status)
+            return usr_action
+
+        usr_action = self._get_acts(
+            usr_output, action_list, mode)
+
+        # if usr_action is empty, sample at least one
+        while not usr_action:
+            usr_action = self._get_acts(
+                usr_output, action_list, mode="pick-one")
+
+        if self.use_domain_mask:
+            domain_mask = self._get_prediction_domain(torch.round(
+                torch.sigmoid(usr_output[0, 0, :])).tolist())
+            usr_action = self._mask_user_action(usr_action, domain_mask)
+
+        return usr_action
+
+    def _get_acts(self, usr_output, action_list, mode="max"):
+        score = {}
+        for index, slot_name in enumerate(action_list):
+            weights = self.user.softmax(usr_output[0, index + 1, :])
+            if mode == "max":
+                o = torch.argmax(usr_output[0, index + 1, :]).item()
+
+            elif mode == "sample" or mode == "pick-one":
+                o = random.choices(
+                    range(self.config["out_dim"]),
+                    weights=weights,
+                    k=1)
+                o = o[0]
+            else:
+                print("(BUG) unknown mode")
+            v = weights[o]
+            score[slot_name] = {"output": o, "weight": v}
+
+        usr_action = self._append_actions(action_list, score)
+
+        if mode == "sample" and len(usr_action) > 3:
+            slot_names = []
+            outputs = []
+            scores = []
+            for index, slot_name in enumerate(action_list):
+                weights = self.user.softmax(usr_output[0, index + 1, :])
+                o = torch.argmax(usr_output[0, index + 1, 1:]).item() + 1
+                slot_names.append(slot_name)
+                outputs.append(o)
+                scores.append(weights[o].item())
+            slot_name = random.choices(
+                slot_names,
+                weights=scores,
+                k=3)
+            slot_name = slot_name[0]
+            score[slot_name]["output"] = outputs[slot_names.index(slot_name)]
+            score[slot_name]["weight"] = scores[slot_names.index(slot_name)]
+            # print(score)
+            usr_action = self._append_actions(action_list, score)
+
+        if mode == "pick-one" and not usr_action:
+            # print("pick-one")
+            slot_names = []
+            outputs = []
+            scores = []
+            for index, slot_name in enumerate(action_list):
+                weights = self.user.softmax(usr_output[0, index + 1, :])
+                o = torch.argmax(usr_output[0, index + 1, 1:]).item() + 1
+                slot_names.append(slot_name)
+                outputs.append(o)
+                scores.append(weights[o].item())
+            slot_name = random.choices(
+                slot_names,
+                weights=scores,
+                k=1)
+            slot_name = slot_name[0]
+            score[slot_name]["output"] = outputs[slot_names.index(slot_name)]
+            score[slot_name]["weight"] = scores[slot_names.index(slot_name)]
+            # print(score)
+            usr_action = self._append_actions(action_list, score)
+
+        return usr_action
+
+    def _append_actions(self, action_list, score):
+        usr_action = []
+        for index, slot_name in enumerate(action_list):
+            domain, slot = split_slot_name(slot_name)
+            is_action, act = self._add_user_action(
+                output=score[slot_name]["output"],
+                domain=domain,
+                slot=slot)
+            if is_action:
+                usr_action += act
+        return usr_action
+
+    def _mask_user_action(self, usr_action, mask):
+        mask_action = []
+        for intent, domain, slot, value in usr_action:
+            if domain.lower() in mask:
+                mask_action += [[intent, domain, slot, value]]
+        return mask_action
+
+    def _get_prediction_domain(self, domain_output):
+        predict_domain = []
+        if domain_output[0] > 0:
+            predict_domain.append('general')
+        for index, value in enumerate(domain_output[1:]):
+            if value > 0 and index < len(self.goal.domains):
+                predict_domain.append(self.goal.domains[index])
+        return predict_domain
+
+    def _add_user_action(self, output, domain, slot):
+        goal = self._get_goal()
+        is_action = False
+        act = [[]]
+        value = None
+
+        # get intent
+        if output == 1:
+            intent = Request
+        else:
+            intent = Inform
+
+        # "?"
+        if output == 1:  # "?"
+            value = DEF_VAL_UNK
+
+        # "dontcare"
+        elif output == 2:
+            value = DEF_VAL_DNC
+
+        # system
+        elif output == 3 and domain in self.sys_history_state:
+            value = self.sys_history_state[domain].get(
+                slot, "")
+
+        elif output == 4 and domain in goal:  # usr
+            for slot_type in ["info"]:
+                if slot_type in goal[domain] and slot in goal[domain][slot_type]:
+                    value = goal[domain][slot_type][slot]
+
+        # elif output == 5 and domain.lower() in goal:
+        #     if domain.lower() not in self.all_values["all_value"]:
+        #         value = None
+        #     elif slot.lower() not in self.all_values["all_value"][domain.lower()]:
+        #         value = None
+        #     else:
+        #         value = random.choice(
+        #             list(self.all_values["all_value"][domain.lower()][slot.lower()].keys()))
+
+        if value:
+            is_action, act = self._form_action(
+                intent, domain, slot, value)
+
+        return is_action, act
+
+    def _get_action_slot(self, domain, slot):
+        return slot
+
+    def _form_action(self, intent, domain, slot, value):
+        action_slot = self._get_action_slot(domain, slot)
+        if action_slot:
+            return True, [[intent, domain, action_slot, value]]
+        return False, [[]]
+
+    def is_terminated(self):
+        # Is there any action to say?
+        return self.terminated
+
+    def _slot_type(self, domain, slot):
+        slot_type = ""
+        if slot in self.sys_history_state[domain]["book"]:
+            slot_type = "book"
+        elif slot in self.sys_history_state[domain]["semi"]:
+            slot_type = "semi"
+
+        return slot_type
+
+
+class UserPolicy(Policy):
+    def __init__(self, config, dial_ids_order=0):
+        if isinstance(config, str):
+            self.config = json.load(open(config))
+        else:
+            self.config = config
+        self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}'
+        print("model_dir", self.config['model_dir'])
+        if not os.path.exists(self.config["model_dir"]):
+            # os.mkdir(self.config["model_dir"])
+            model_downloader(os.path.dirname(self.config["model_dir"]),
+                             "https://zenodo.org/record/7369429/files/multiwoz_0.zip")
+        self.slot2dbattr = {
+            'open hours': 'openhours',
+            'price range': 'pricerange',
+            'arrive by': 'arriveBy',
+            'leave at': 'leaveAt',
+            'train id': 'trainID'
+        }
+        self.dbattr2slot = {}
+        for k, v in self.slot2dbattr.items():
+            self.dbattr2slot[v] = k
+
+        self.policy = UserActionPolicy(self.config)
+
+    def predict(self, state):
+        raw_act = self.policy.predict(state)
+        act = []
+        for intent, domain, slot, value in raw_act:
+            if slot in self.dbattr2slot:
+                slot = self.dbattr2slot[slot]
+            act.append([intent, domain, slot, value])
+        return act
+
+    def init_session(self, goal=None):
+        self.policy.init_session(goal)
+
+    def is_terminated(self):
+        return self.policy.is_terminated()
+
+    def get_reward(self):
+        return self.policy.get_reward()
+
+    def get_goal(self):
+        if hasattr(self.policy, 'get_goal'):
+            return self.policy.get_goal()
+        return None
diff --git a/convlab/policy/tus/unify/analysis.py b/convlab/policy/tus/unify/analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2461e3a024b0c83209521f3627ea66cf3a0ffd3
--- /dev/null
+++ b/convlab/policy/tus/unify/analysis.py
@@ -0,0 +1,260 @@
+import argparse
+import json
+import os
+
+import pandas as pd
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from convlab.policy.rule.multiwoz import RulePolicy
+from convlab.policy.tus.unify.Goal import Goal
+from convlab.policy.tus.unify.TUS import UserPolicy
+from convlab.policy.tus.unify.usermanager import TUSDataManager
+from convlab.policy.tus.unify.util import create_goal, parse_dialogue_act
+from convlab.util import load_dataset
+
+
+def check_device():
+    if torch.cuda.is_available():
+        print("using GPU")
+        return torch.device('cuda')
+    else:
+        print("using CPU")
+        return torch.device('cpu')
+
+
+class Analysis:
+    def __init__(self, config, analysis_dir='user-analysis-result', show_dialog=False, save_dialog=True):
+        if not os.path.exists(analysis_dir):
+            os.makedirs(analysis_dir)
+        self.dialog_dir = os.path.join(analysis_dir, 'dialog')
+        if not os.path.exists(self.dialog_dir):
+            os.makedirs(self.dialog_dir)
+        self.dir = analysis_dir
+        self.config = config
+        self.device = check_device()
+        self.show_dialog = show_dialog
+        self.save_dialog = save_dialog
+        self.max_turn = 40
+
+    def get_usr(self, usr="tus", load_path=None):
+        # if using "tus", we read config
+        # for the other user simulators, we read load_path
+        usr = usr.lower()
+        if usr == "tus":
+            policy_usr = UserPolicy(self.config)
+        else:
+            print(f"Unsupport user type: {usr}")
+        # TODO VHUS
+
+        return policy_usr
+
+    def data_interact_test(self, test_data, usr="tus", user_mode=None, load_path=None):
+        if user_mode:
+            # origin_model_name = "-".join(self.config["model_name"].split('-')[:-1])
+            self.config["model_name"] = f"model-{user_mode}"
+
+        result = []
+        label = []
+        policy_usr = self.get_usr(usr=usr, load_path=load_path)
+
+        for dialog in tqdm(test_data):
+            if self.show_dialog:
+                print(f"dialog_id: {dialog['dialog_id']}")
+            goal = Goal(create_goal(dialog))
+
+            sys_act = []
+            policy_usr.init_session(goal=goal)
+            if not policy_usr.get_goal():
+                continue
+            turn_num = len(dialog["turns"])
+            start = 0
+            if dialog["turns"][0]["speaker"] == "system":
+                start = 1
+            for turn_id in range(start, turn_num, 2):
+                if turn_id > 0:
+                    sys_act = parse_dialogue_act(
+                        dialog["turns"][turn_id - 1]["dialogue_acts"])
+                usr_act = policy_usr.predict(sys_act)
+                golden_usr = parse_dialogue_act(
+                    dialog["turns"][turn_id]["dialogue_acts"])
+                result.append(usr_act)
+                label.append(golden_usr)
+
+        for domain in [None]:
+
+            statistic = self._data_f1(result, label, domain)
+            ana_result = {}
+            for stat_type in statistic:
+                s = statistic[stat_type]["success"] / \
+                    statistic[stat_type]["count"]
+                ana_result[stat_type] = s
+            ana_result["f1"] = 2/((1/ana_result["precision"])
+                                  * (1/ana_result["recall"]))
+
+            print(user_mode)
+            for stat_type in ana_result:
+                print(f'{stat_type}: {ana_result[stat_type]}')
+            col = [c for c in ana_result]
+            df_f1 = pd.DataFrame([ana_result[c] for c in col], col)
+            print(df_f1)
+            if domain:
+                df_f1.to_csv(os.path.join(
+                    self.dir, f'{domain}-{user_mode}_data_scores.csv'))
+            else:
+                df_f1.to_csv(os.path.join(
+                    self.config["model_dir"], f'{user_mode}_data_scores.csv'))
+
+    def _extract_domain_related_actions(self, actions, select_domain):
+        #
+        domain_related_acts = []
+        for act in actions:
+            domain = act[1].lower()
+            if domain == select_domain:
+                domain_related_acts.append(act)
+        return domain_related_acts
+
+    def _data_f1(self, result, label, domain=None):
+        #
+        statistic = {}
+        for stat_type in ["precision", "recall", "turn_acc"]:
+            statistic[stat_type] = {"success": 0, "count": 0}
+
+        for r, l in zip(result, label):
+            if domain:
+                r = self._extract_domain_related_actions(r, domain)
+                l = self._extract_domain_related_actions(l, domain)
+
+            if self._skip(l, r, domain):
+                continue
+            # check turn accuracy
+            turn_acc, tp, fp, fn = self._check(r, l)
+            if self.show_dialog:
+                print(r, l)
+                print(turn_acc, tp, fp, fn)
+            if turn_acc:
+                statistic["turn_acc"]["success"] += 1
+            statistic["turn_acc"]["count"] += 1
+            statistic["precision"]["success"] += tp
+            statistic["precision"]["count"] += tp + fp
+            statistic["recall"]["success"] += tp
+            statistic["recall"]["count"] += tp + fn
+
+        return statistic
+
+    @staticmethod
+    def _skip(label, result, domain=None):
+        #
+        ignore = False
+        if domain:
+            if not label and not result:
+                ignore = True
+        else:
+            if not label:
+                ignore = True
+            for intent, domain, slot, value in label:
+                if intent.lower() in ["thank", "bye"]:
+                    ignore = True
+
+        return ignore
+
+    def _check(self, r, l):
+        #
+        # TODO domain check
+        # [['Inform', 'Attraction', 'Addr', 'dontcare']] [['thank', 'general', 'none', 'none']]
+        # skip this one
+        turn_acc = True
+        tp = 0
+        fp = 0
+        fn = 0
+        for a in r:
+            is_none_slot, is_in = self._is_in(a, l)
+            if is_none_slot:
+                continue
+
+            if is_in:
+                tp += 1
+            else:
+                fp += 1
+                turn_acc = False
+
+        for a in l:
+            is_none_slot, is_in = self._is_in(a, r)
+            if is_none_slot:
+                continue
+
+            if is_in:
+                tp += 1
+            else:
+                fn += 1
+                turn_acc = False
+
+        return turn_acc, tp/2, fp, fn
+
+    @staticmethod
+    def _is_in(a, acts):
+        #
+        is_none_slot = False
+        intent, domain, slot, value = a
+        if slot.lower() == "none" or domain.lower() == "general":
+            is_none_slot = True
+            return is_none_slot, True
+        if a in acts:
+            return is_none_slot, True
+        else:
+            for i, d, s, v in acts:
+                if i == intent and d == domain and s == slot:
+                    return is_none_slot, True
+            return is_none_slot, False
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--analysis_dir", type=str,
+                        default="user-analysis-result")
+    parser.add_argument("--user_config", type=str,
+                        default="convlab/policy/tus/multiwoz/exp/multiwoz.json")
+    parser.add_argument("--user_mode", type=str, default="")
+    parser.add_argument("--do_data", action="store_true")
+    parser.add_argument("--usr", type=str, default="tus")
+    parser.add_argument("--domain", type=str, default="",
+                        help="the user goal must contain a specific domain")
+    parser.add_argument("--load_path", type=str, default="",
+                        help="load path for certain models")
+    parser.add_argument("--dataset", type=str, default="multiwoz21",
+                        help="data type")
+    parser.add_argument("--dial_ids_order", type=int, default=0)
+
+    args = parser.parse_args()
+
+    analysis_dir = os.path.join(f"{args.analysis_dir}-{args.usr}")
+
+    if not os.path.exists(os.path.join(analysis_dir)):
+        os.makedirs(analysis_dir)
+
+    config = json.load(open(args.user_config))
+    if args.user_mode:
+        config["model_name"] = config["model_name"] + '-' + args.user_mode
+
+    # config["model_dir"] = f'{config["model_dir"]}_{args.dial_ids_order}'
+    # with open(config["all_slot"]) as f:
+    #     action_list = [line.strip() for line in f]
+    # config["num_token"] = len(action_list)
+
+    ana = Analysis(config, analysis_dir=analysis_dir)
+
+    if args.usr == "tus" and args.do_data:
+        test_data = load_dataset(args.dataset,
+                                 dial_ids_order=args.dial_ids_order)["test"]
+        if args.user_mode:
+            ana.data_interact_test(test_data=test_data,
+                                   usr=args.usr,
+                                   user_mode=args.user_mode,
+                                   load_path=args.load_path)
+        else:
+            for user_mode in ["loss", "total", "turn", "non-zero"]:
+                ana.data_interact_test(test_data=test_data,
+                                       usr=args.usr,
+                                       user_mode=user_mode,
+                                       load_path=args.load_path)
diff --git a/convlab/policy/tus/unify/exp/all.json b/convlab/policy/tus/unify/exp/all.json
new file mode 100644
index 0000000000000000000000000000000000000000..b45b1b80725eeb8eb82cf582794cebfee09b3cd2
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/all.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/all",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/default.json b/convlab/policy/tus/unify/exp/default.json
new file mode 100644
index 0000000000000000000000000000000000000000..5d681849f8b478f8d6b3efe4e86d937e90277f56
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/default.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/multiwoz/default",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/multiwoz.json b/convlab/policy/tus/unify/exp/multiwoz.json
new file mode 100644
index 0000000000000000000000000000000000000000..46d90c8b4e76d43f74491b4fd9259cab849c2299
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/multiwoz.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/multiwoz",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/multiwoz_001.json b/convlab/policy/tus/unify/exp/multiwoz_001.json
new file mode 100644
index 0000000000000000000000000000000000000000..5aa3ec7665b82102ef1979e857f8318e5ad09e9a
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/multiwoz_001.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/multiwoz001",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/multiwoz_01.json b/convlab/policy/tus/unify/exp/multiwoz_01.json
new file mode 100644
index 0000000000000000000000000000000000000000..34fc3587ed0e037c11e910ee42511b583bd6adf8
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/multiwoz_01.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/multiwoz01",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/pretrain-multiwoz.json b/convlab/policy/tus/unify/exp/pretrain-multiwoz.json
new file mode 100644
index 0000000000000000000000000000000000000000..a101658d85b23dc5d1634cc454c9de36f9769ae3
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/pretrain-multiwoz.json
@@ -0,0 +1,33 @@
+{
+    "model_dir": "convlab/policy/tus/unify/pretrain-multiwoz",
+    "pretrain": "convlab/policy/tus/unify/sgd+tm",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/pretrain-multiwoz_001.json b/convlab/policy/tus/unify/exp/pretrain-multiwoz_001.json
new file mode 100644
index 0000000000000000000000000000000000000000..cca2bef73572c6932cf74a4fc540079bb60327ac
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/pretrain-multiwoz_001.json
@@ -0,0 +1,33 @@
+{
+    "model_dir": "convlab/policy/tus/unify/pretrain-multiwoz001",
+    "pretrain": "convlab/policy/tus/unify/sgd+tm",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/pretrain-multiwoz_01.json b/convlab/policy/tus/unify/exp/pretrain-multiwoz_01.json
new file mode 100644
index 0000000000000000000000000000000000000000..a662dfc5df459819076e58f680388abeb919137c
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/pretrain-multiwoz_01.json
@@ -0,0 +1,33 @@
+{
+    "model_dir": "convlab/policy/tus/unify/multiwoz01",
+    "pretrain": "convlab/policy/tus/unify/sgd+tm",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/sgd+tm.json b/convlab/policy/tus/unify/exp/sgd+tm.json
new file mode 100644
index 0000000000000000000000000000000000000000..d16e23aba23a1342f0ef5bca4e07933fa998a065
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/sgd+tm.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/sgd+tm",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/sgd.json b/convlab/policy/tus/unify/exp/sgd.json
new file mode 100644
index 0000000000000000000000000000000000000000..8961ee35bcbf9e90ed4ba80761e1548e9be970bc
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/sgd.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/sgd",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/tm1.json b/convlab/policy/tus/unify/exp/tm1.json
new file mode 100644
index 0000000000000000000000000000000000000000..7ab472b7fdc5795b0452f3d53709842d8a0a7df5
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/tm1.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/tm1",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/tm2.json b/convlab/policy/tus/unify/exp/tm2.json
new file mode 100644
index 0000000000000000000000000000000000000000..47430090f570afffa5e6b3f07d7383986efe29d8
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/tm2.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/tm2",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/exp/tm3.json b/convlab/policy/tus/unify/exp/tm3.json
new file mode 100644
index 0000000000000000000000000000000000000000..fb1857cacbae8006c14a9b616d58aed5312aa303
--- /dev/null
+++ b/convlab/policy/tus/unify/exp/tm3.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab/policy/tus/unify/tm3",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 0.0001,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 79,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/train.py b/convlab/policy/tus/unify/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbaac0ab6772eca7cc2d0481ebb5c88afbe74197
--- /dev/null
+++ b/convlab/policy/tus/unify/train.py
@@ -0,0 +1,285 @@
+import json
+import os
+
+import torch
+import torch.optim as optim
+from tqdm import tqdm
+from convlab.policy.tus.multiwoz.analysis import Analysis
+from convlab.util import load_dataset, load_ontology
+
+
+def check_device():
+    if torch.cuda.is_available():
+        print("using GPU")
+        return torch.device('cuda')
+    else:
+        print("using CPU")
+        return torch.device('cpu')
+
+
+class Trainer:
+    def __init__(self, model, config):
+        self.model = model
+        self.config = config
+        self.num_epoch = self.config["num_epoch"]
+        self.batch_size = self.config["batch_size"]
+        self.device = check_device()
+        print(self.device)
+        self.optimizer = optim.Adam(
+            model.parameters(), lr=self.config["learning_rate"])
+
+        self.ana = Analysis(config)
+
+    def training(self, train_data, test_data=None):
+
+        self.model = self.model.to(self.device)
+        if not os.path.exists(self.config["model_dir"]):
+            os.makedirs(self.config["model_dir"])
+
+        save_path = os.path.join(
+            self.config["model_dir"], self.config["model_name"])
+
+        # best = [0, 0, 0]
+        best = {"loss": 100}
+        lowest_loss = 100
+        for epoch in range(self.num_epoch):
+            print("epoch {}".format(epoch))
+            total_loss = self.train_epoch(train_data)
+            print("loss: {}".format(total_loss))
+            if test_data is not None:
+                acc = self.eval(test_data)
+
+            if total_loss < lowest_loss:
+                best["loss"] = total_loss
+                print(f"save model in {save_path}-loss")
+                torch.save(self.model.state_dict(), f"{save_path}-loss")
+
+            for acc_type in acc:
+                if acc_type not in best:
+                    best[acc_type] = 0
+                temp = acc[acc_type]["correct"] / acc[acc_type]["total"]
+                if best[acc_type] < temp:
+                    best[acc_type] = temp
+                    print(f"save model in {save_path}-{acc_type}")
+                    torch.save(self.model.state_dict(),
+                               f"{save_path}-{acc_type}")
+            if epoch < 10 and epoch > 5:
+                print(f"save model in {save_path}-{epoch}")
+                torch.save(self.model.state_dict(),
+                           f"{save_path}-{epoch}")
+            print(f"save latest model in {save_path}")
+            torch.save(self.model.state_dict(), save_path)
+
+    def train_epoch(self, data_loader):
+        self.model.train()
+        total_loss = 0
+        result = {}
+        # result = {"id": {"slot": {"prediction": [],"label": []}}}
+        count = 0
+        for i, data in enumerate(tqdm(data_loader, ascii=True, desc="Training"), 0):
+            input_feature = data["input"].to(self.device)
+            mask = data["mask"].to(self.device)
+            label = data["label"].to(self.device)
+            if self.config.get("domain_traget", True):
+                domain = data["domain"].to(self.device)
+            else:
+                domain = None
+            self.optimizer.zero_grad()
+
+            loss, output = self.model(input_feature, mask, label, domain)
+
+            loss.backward()
+            self.optimizer.step()
+            total_loss += float(loss)
+            count += 1
+
+        return total_loss / count
+
+    def eval(self, test_data):
+        self.model.zero_grad()
+        self.model.eval()
+
+        result = {}
+
+        with torch.no_grad():
+            correct, total, non_zero_correct, non_zero_total = 0, 0, 0, 0
+            for i, data in enumerate(tqdm(test_data, ascii=True, desc="Evaluation"), 0):
+                input_feature = data["input"].to(self.device)
+                mask = data["mask"].to(self.device)
+                label = data["label"].to(self.device)
+                output = self.model(input_feature, mask)
+                r = parse_result(output, label)
+                for r_type in r:
+                    if r_type not in result:
+                        result[r_type] = {"correct": 0, "total": 0}
+                    for n in result[r_type]:
+                        result[r_type][n] += float(r[r_type][n])
+
+        for r_type in result:
+            temp = result[r_type]['correct'] / result[r_type]['total']
+            print(f"{r_type}: {temp}")
+
+        return result
+
+
+def parse_result(prediction, label):
+    # result = {"id": {"slot": {"prediction": [],"label": []}}}
+    # dialog_index = ["dialog-id"_"slot-name", "dialog-id"_"slot-name", ...]
+    # prdiction = [0, 1, 0, ...] # after max
+
+    _, arg_prediction = torch.max(prediction.data, -1)
+    batch_size, token_num = label.shape
+    result = {
+        "non-zero": {"correct": 0, "total": 0},
+        "total": {"correct": 0, "total": 0},
+        "turn": {"correct": 0, "total": 0}
+    }
+
+    for batch_num in range(batch_size):
+        turn_acc = True
+        for element in range(token_num):
+            result["total"]["total"] += 1
+            if label[batch_num][element] > 0:
+                result["non-zero"]["total"] += 1
+
+            if arg_prediction[batch_num][element + 1] == label[batch_num][element]:
+                if label[batch_num][element] > 0:
+                    result["non-zero"]["correct"] += 1
+                result["total"]["correct"] += 1
+
+            elif arg_prediction[batch_num][element + 1] == 0 and label[batch_num][element] < 0:
+                result["total"]["correct"] += 1
+
+            else:
+                if label[batch_num][element] >= 0:
+                    turn_acc = False
+
+        result["turn"]["total"] += 1
+        if turn_acc:
+            result["turn"]["correct"] += 1
+
+    return result
+
+
+def f1(target, result):
+    target_len = 0
+    result_len = 0
+    tp = 0
+    for t, r in zip(target, result):
+        if t:
+            target_len += 1
+        if r:
+            result_len += 1
+        if r == t and t:
+            tp += 1
+    precision = 0
+    recall = 0
+    if result_len:
+        precision = tp / result_len
+    if target_len:
+        recall = tp / target_len
+    f1_score = 2 / (1 / precision + 1 / recall)
+    return f1_score, precision, recall
+
+
+def save_data(data, file_name, file_dir):
+    if not os.path.exists(file_dir):
+        os.makedirs(file_dir)
+    f_name = os.path.join(file_dir, file_name)
+    torch.save(data, f_name)
+    # with open(f_name, 'wb') as data_obj:
+    #     pickle.dump(data, data_obj, pickle.HIGHEST_PROTOCOL)
+    print(f"save data to {f_name}")
+
+
+if __name__ == "__main__":
+    import argparse
+    import os
+    from convlab.policy.tus.multiwoz.transformer import \
+        TransformerActionPrediction
+    from convlab.policy.tus.unify.usermanager import \
+        TUSDataManager
+    from torch.utils.data import DataLoader
+    from convlab.policy.tus.unify.util import update_config_file
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--user-config", type=str,
+                        default="convlab/policy/tus/unify/exp/default.json")
+    parser.add_argument("--force-read-data", '-f', action='store_true',
+                        help="Force to read data from scratch")
+    parser.add_argument("--dataset", type=str, default="multiwoz21")
+    parser.add_argument("--dial-ids-order", type=int, default=0)
+    parser.add_argument("--split2ratio", type=float, default=1)
+
+    args = parser.parse_args()
+    config_file = open(args.user_config)
+    config = json.load(config_file)
+    config_file.close()
+    if args.dataset == "all":
+        print("merge all datasets...")
+        all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"]
+        datasets = {}
+        for dataset in all_dataset:
+            datasets[dataset] = load_dataset(dataset,
+                                             dial_ids_order=args.dial_ids_order,
+                                             split2ratio={'train': args.split2ratio})
+        # merge dataset
+        raw_data = {}
+        for data_type in ["train", "test"]:
+            raw_data[data_type] = []
+            for dataset in all_dataset:
+                raw_data[data_type] += datasets[dataset][data_type]
+
+    elif args.dataset == "sgd+tm":
+        print("merge multiple datasets...")
+        all_dataset = ["sgd", "tm1", "tm2", "tm3"]
+        datasets = {}
+        for dataset in all_dataset:
+            datasets[dataset] = load_dataset(dataset,
+                                             dial_ids_order=args.dial_ids_order,
+                                             split2ratio={'train': args.split2ratio})
+        # merge dataset
+        raw_data = {}
+        for data_type in ["train", "test"]:
+            raw_data[data_type] = []
+            for dataset in all_dataset:
+                raw_data[data_type] += datasets[dataset][data_type]
+
+    else:
+        print(f"load single dataset {args.dataset}/{args.split2ratio}")
+        raw_data = load_dataset(args.dataset,
+                                dial_ids_order=args.dial_ids_order,
+                                split2ratio={'train': args.split2ratio})
+
+    batch_size = config["batch_size"]
+
+    # load data with "load_data"
+
+    # check train/test data
+    data = {"train": {}, "test": {}}
+    for data_type in data:
+        data[data_type]["data"] = TUSDataManager(
+            config, raw_data[data_type])
+
+    # check the embed_dim and update it
+    embed_dim = data["train"]["data"].features["input"].shape[-1]
+    if embed_dim != config["embed_dim"]:
+        config["embed_dim"] = embed_dim
+        update_config_file(file_name=args.user_config,
+                           attribute="embed_dim", value=embed_dim)
+
+    train_data = DataLoader(data["train"]["data"],
+                            batch_size=batch_size, shuffle=True)
+    test_data = DataLoader(data["test"]["data"],
+                           batch_size=batch_size, shuffle=True)
+
+    model = TransformerActionPrediction(config)
+
+    if "pretrain" in config:
+        pretrain_weight = os.path.join(
+            f'{config["pretrain"]}_{args.dial_ids_order}', f"model-loss")
+        print(f"fine tune based on {pretrain_weight}...")
+        model.load_state_dict(torch.load(
+            pretrain_weight, map_location=check_device()))
+    trainer = Trainer(model, config)
+    trainer.training(train_data, test_data)
diff --git a/convlab/policy/tus/unify/usermanager.py b/convlab/policy/tus/unify/usermanager.py
new file mode 100644
index 0000000000000000000000000000000000000000..3192d3d578ca3698424b16f47933d6863208f77c
--- /dev/null
+++ b/convlab/policy/tus/unify/usermanager.py
@@ -0,0 +1,515 @@
+import json
+import os
+from collections import Counter
+
+import torch
+from convlab.policy.tus.unify.Goal import Goal
+from convlab.policy.tus.unify.util import parse_dialogue_act, parse_user_goal, metadata2state, int2onehot, create_goal, split_slot_name
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+NOT_MENTIONED = "not mentioned"
+
+
+def dic2list(da2goal):
+    action_list = []
+    for domain in da2goal:
+        for slot in da2goal[domain]:
+            if da2goal[domain][slot] is None:
+                continue
+            act = f"{domain}-{da2goal[domain][slot]}"
+            if act not in action_list:
+                action_list.append(act)
+    return action_list
+
+
+class TUSDataManager(Dataset):
+    def __init__(self,
+                 config,
+                 data,
+                 max_turns=12):
+
+        self.config = config
+        self.feature_handler = BinaryFeature(self.config)
+        self.features = self.process(data, max_turns)
+
+    def __getitem__(self, index):
+        return {label: self.features[label][index] if self.features[label] is not None else None
+                for label in self.features}
+
+    def __len__(self):
+        return self.features['input'].size(0)
+
+    def resample(self, size=None):
+        n_dialogues = self.__len__()
+        if not size:
+            size = n_dialogues
+
+        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
+        self.features = {
+            label: self.features[label][dialogues] for label in self.features}
+
+    def to(self, device):
+        self.device = device
+        self.features = {label: self.features[label].to(
+            device) for label in self.features}
+
+    def process(self, data, max_turns):
+
+        feature = {"id": [], "input": [],
+                   "label": [], "mask": [], "domain": []}
+        # TODO remove dst. trace in user goal
+        for dialog in tqdm(data, ascii=True, desc="Processing"):
+            # TODO build user goal from history
+            user_goal = Goal(create_goal(dialog))
+            if not user_goal.domain_goals:
+                continue
+
+            # if one domain is removed, we skip all data related to this domain
+            # remove police at default
+            if "police" in user_goal.domains:
+                continue
+
+            turn_num = len(dialog["turns"])
+            pre_state = {}
+            sys_act = []
+            self.feature_handler.initFeatureHandeler(user_goal)
+
+            start = 0
+            if dialog["turns"][0]["speaker"] == "system":
+                start = 1
+
+            for turn_id in range(start, turn_num, 2):
+                # dialog start from user
+                action_list = user_goal.action_list(sys_act)
+                if turn_id > 0:
+                    # cur_state = data[dialog_id]["log"][turn_id-1]["metadata"]
+                    sys_act = parse_dialogue_act(
+                        dialog["turns"][turn_id - 1]["dialogue_acts"])
+                cur_state = user_goal.update(action=sys_act, char="system")
+
+                usr_act = parse_dialogue_act(
+                    dialog["turns"][turn_id]["dialogue_acts"])
+
+                input_feature, mask = self.feature_handler.get_feature(
+                    action_list, user_goal, cur_state, pre_state, sys_act)  # TODO why
+                label = self.feature_handler.generate_label(
+                    action_list, user_goal, cur_state, usr_act)
+                domain_label = self.feature_handler.domain_label(
+                    user_goal, usr_act)
+                # pre_state = user_goal.update(action=usr_act, char="user") # trick?
+                feature["id"].append(dialog["dialogue_id"])
+                feature["input"].append(input_feature)
+                feature["mask"].append(mask)
+                feature["label"].append(label)
+                feature["domain"].append(domain_label)
+
+        print("label distribution")
+        label_distribution = Counter()
+        for label in feature["label"]:
+            label_distribution += Counter(label)
+        print(label_distribution)
+        feature["input"] = torch.tensor(feature["input"], dtype=torch.float)
+        feature["label"] = torch.tensor(feature["label"], dtype=torch.long)
+        feature["mask"] = torch.tensor(feature["mask"], dtype=torch.bool)
+        feature["domain"] = torch.tensor(feature["domain"], dtype=torch.float)
+        for feat_type in ["input", "label", "mask", "domain"]:
+            print("{}: {}".format(feat_type, feature[feat_type].shape))
+        return feature
+
+
+class Feature:
+    def __init__(self, config):
+        self.config = config
+        self.intents = ["inform", "request", "recommend", "select",
+                        "book", "nobook", "offerbook", "offerbooked", "nooffer"]
+        self.general_intent = ["reqmore", "bye", "thank", "welcome", "greet"]
+        self.default_values = ["none", "?", "dontcare"]
+        # path = os.path.dirname(
+        #     os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+        # path = os.path.join(path, 'data/multiwoz/all_value.json')
+        # self.all_values = json.load(open(path))
+
+    def initFeatureHandeler(self, goal: Goal):
+        self.goal = goal
+        self.domain_list = goal.domains
+        usr = parse_user_goal(goal)
+        self.constrains = {}  # slot: fulfill
+        self.requirements = {}  # slot: fulfill
+        self.pre_usr = []
+        self.all_slot = None
+        self.user_feat_hist = []
+        for slot in usr:
+            if usr[slot] != "?":
+                self.constrains[slot] = NOT_MENTIONED
+
+    def get_feature(self, all_slot, user_goal, cur_state, pre_state=None, sys_action=None, usr_action=None, state_vectorize=False):
+        """ 
+        given current dialog information and return the input feature 
+        user_goal: Goal()
+        cur_state: dict, {domain: "semi": {slot: value}, "book": {slot: value, "booked": []}}("metadata" in the data set)
+        sys_action: [[intent, domain, slot, value]]
+        """
+
+        feature = []
+        usr = parse_user_goal(user_goal)
+
+        if sys_action and not state_vectorize:
+            self.update_constrain(sys_action)
+
+        cur = metadata2state(cur_state)
+        pre = {}
+        if pre_state != None:
+            pre = metadata2state(pre_state)
+        if not self.pre_usr and not state_vectorize:
+            self.pre_usr = [0] * len(all_slot)
+
+        usr_act_feat = self.get_user_action_feat(
+            all_slot, user_goal, usr_action)
+
+        for slot in all_slot:
+            feat = self.slot_feature(
+                slot, usr, cur, pre, sys_action, usr_act_feat)
+            feature.append(feat)
+
+        if not state_vectorize:
+            self.user_feat_hist.append(feature)
+
+        feature, mask = self.pad_feature(max_memory=self.config["window"])
+
+        return feature, mask
+
+    def slot_feature(self, slot, user_goal, current_state, previous_state, sys_action, usr_action):
+        pass
+
+    def pad_feature(self, max_memory=5):
+        feature = []
+        num_feat = len(self.user_feat_hist)
+
+        feat_dim = len(self.user_feat_hist[0][0])
+
+        for feat_index in range(num_feat - 1, max(num_feat - 1 - max_memory, -1), -1):
+            if feat_index == num_feat - 1:
+                special_token = self.slot_feature(
+                    "CLS", {}, {}, {}, [], [])
+            else:
+                special_token = self.slot_feature(
+                    "SEP", {}, {}, {}, [], [])
+            feature += [special_token]
+            feature += self.user_feat_hist[feat_index]
+
+        max_len = max_memory * self.config["num_token"]
+        if len(feature) < max_len:
+            padding = [[0] * feat_dim] * (max_len - len(feature))
+            feature += padding
+            mask = [False] * len(feature) + [True] * (max_len - len(feature))
+        else:
+            mask = [False] * max_len
+
+        return feature[:max_len], mask[:max_len]
+
+    def domain_label(self, user_goal, dialog_act):
+        labels = [0] * self.config["out_dim"]
+        goal_domains = user_goal.domains
+        no_domain = True
+
+        for intent, domain, slot, value in dialog_act:
+            # domain = domain.lower()
+            if domain in goal_domains:
+                index = goal_domains.index(domain)
+                labels[index + 1] = 1
+                no_domain = False
+        if no_domain:
+            labels[0] = 1
+        return labels
+
+    def generate_label(self, action_list: list, user_goal, cur_state, dialog_act):
+        # label = "none", "?", "dontcare", "system", "user", "change"
+
+        labels = [-1] * self.config["num_token"]
+
+        usr = parse_user_goal(user_goal)
+        cur = metadata2state(cur_state)
+        # print("usr", usr)
+        # print("cur", cur)
+        # print(action_list)
+        for intent, domain, slot, value in dialog_act:
+            # domain = domain.lower()
+            # value = value.lower()
+            # slot = slot.lower()
+            # name = util.act2slot(intent, domain, slot, value, self.all_values)
+            name = f"{domain}-{slot}"
+
+            if name not in action_list:
+                # print(f"Not handle name {name} in getting label")
+                continue
+            name_id = action_list.index(name)
+            if name_id >= self.config["num_token"]:
+                continue
+            if value == "?":
+                labels[name_id] = 1
+            elif value == "dontcare":
+                labels[name_id] = 2
+            elif name in cur and value == cur[name]:
+                labels[name_id] = 3
+            elif name in usr and value == usr[name]:
+                labels[name_id] = 4
+            elif (name in cur or name in usr) and value not in [cur.get(name), usr.get(name)]:
+                labels[name_id] = 5
+
+        for name in action_list:
+            domain = name.split('-')[0]
+            name_id = action_list.index(name)
+            if name_id < len(labels):
+                if labels[name_id] < 0 and domain in self.domain_list:
+                    labels[name_id] = 0
+
+        self.pre_usr = labels
+
+        return labels
+
+    def get_user_action_feat(self, all_slot, user_goal, usr_act):
+        pass
+
+    def update_constrain(self, action):
+        """ 
+        update constrain status by system actions
+        action = [[intent, domain, slot, name]]
+        """
+        for intent, domain, slot, value in action:
+            # domain = domain.lower()
+            # slot = slot.lower()
+            if domain in self.domain_list:
+                # slot = SysDa2Goal[domain].get(slot, "none")
+                slot_name = f"{domain}-{slot}"
+            elif domain == "booking":
+                if slot.lower() == "ref":
+                    continue
+                # slot = SysDa2Goal[domain].get(slot, "none")
+                # domain = util.get_booking_domain(
+                #     slot, value, self.all_values, self.domain_list)
+                if not domain:
+                    continue  # work around
+                slot_name = f"{domain}-{slot}"
+
+            else:
+                continue
+            if value != "?":
+                self.constrains[slot_name] = value
+
+    @staticmethod
+    def concatenate_subvectors(vec_list):
+        vec = []
+        for sub_vec in vec_list:
+            vec += sub_vec
+        return vec
+
+
+class BinaryFeature(Feature):
+    def __init__(self, config):
+        super().__init__(config)
+
+    def slot_feature(self, slot, user_goal, current_state, previous_state, sys_action, usr_action):
+        feat = []
+        feat += self._special_token(slot)
+        feat += self._value_representation(
+            slot, current_state.get(slot, NOT_MENTIONED))
+        feat += self._value_representation(
+            slot, user_goal.get(slot, NOT_MENTIONED))
+        feat += self._is_constrain_request(slot, user_goal)
+        feat += self._is_fulfill(slot, user_goal)
+        if self.config.get("conflict", True):
+            feat += self._conflict_check(user_goal, current_state, slot)
+        if self.config.get("domain_feat", False):
+            # feat += self.domain_feat(slot)
+            if slot in ["CLS", "SEP"]:
+                feat += [0] * (self.goal.max_domain_len +
+                               self.goal.max_slot_len)
+            else:
+                domain_feat, slot_feat = self.goal.get_slot_id(slot)
+                feat += domain_feat + slot_feat
+        feat += self._first_mention_detection(
+            previous_state, current_state, slot)
+        feat += self._just_mention(slot, sys_action)
+        feat += self._action_representation(slot, sys_action)
+        # need change from 0 to domain predictor
+        if slot in ["CLS", "SEP"]:
+            feat += [0] * self.config["out_dim"]
+        else:
+            feat += usr_action[slot]
+        return feat
+
+    def get_user_action_feat(self, all_slot, user_goal, usr_act):
+        if usr_act:
+            usr_label = self.generate_label(
+                all_slot, user_goal, {}, usr_act)
+            self.pre_usr = usr_label
+        usr_act_feat = {}
+        for index, slot in zip(self.pre_usr, all_slot):
+            usr_act_feat[slot] = int2onehot(index, self.config["out_dim"])
+        return usr_act_feat
+
+    def _special_token(self, slot):
+        special_token = ["CLS", "SEP"]
+        feat = [0] * len(special_token)
+        if slot in special_token:
+            feat[special_token.index(slot)] = 1
+        return feat
+
+    def _is_constrain_request(self, feature_slot, user_goal):
+        if feature_slot in ["CLS", "SEP"]:
+            return [0, 0]
+        # [is_constrain, is_request]
+        value = user_goal.get(feature_slot, NOT_MENTIONED)
+        if value == "?":
+            return [0, 1]
+        elif value == NOT_MENTIONED:
+            return [0, 0]
+        else:
+            return [1, 0]
+
+    def _is_fulfill(self, feature_slot, user_goal):
+        if feature_slot in ["CLS", "SEP"]:
+            return [0]
+
+        if feature_slot in user_goal and user_goal.get(feature_slot) == self.constrains.get(feature_slot):
+            return [1]
+        return [0]
+
+    def _just_mention(self, feature_slot, sys_action):
+        """
+        the system action just mentioned this slot
+        """
+        if feature_slot in ["CLS", "SEP"]:
+            return [0]
+        if not sys_action:
+            return [0]
+        sys_action_slot = []
+        for intent, domain, slot, value in sys_action:
+            # domain = domain.lower()
+            # slot = slot.lower()
+            # value = value.lower()
+            # if domain == "booking":
+            #     domain = util.get_booking_domain(
+            #         slot, value, self.all_values, self.domain_list)
+            if domain in sys_action:
+                action = f"{domain}-{slot}"
+                sys_action_slot.append(action)
+        if feature_slot in sys_action_slot:
+            return [1]
+        return [0]
+
+    def _action_representation(self, feature_slot, action):
+
+        gen_vec = [0] * len(self.general_intent)
+        # ["none", "?", other]
+        intent2act = {intent: [0] * 3 for intent in self.intents}
+
+        if action is None or feature_slot in ["CLS", "SEP"]:
+            return self._concatenate_action_vector(intent2act, gen_vec)
+        for intent, domain, slot, value in action:
+            # domain = domain.lower()
+            # slot = slot.lower()
+            # value = value.lower()
+
+            # general
+            if domain == "general":
+                self._update_general_action(gen_vec, intent)
+            else:
+                # if domain == "booking":
+                #     domain = util.get_booking_domain(
+                #         slot, value, self.all_values, self.domain_list)
+                self._update_intent2act(
+                    feature_slot, intent2act,
+                    domain, intent, slot, value)
+                # TODO special slots, "choice, ref, none"
+
+        return self._concatenate_action_vector(intent2act, gen_vec)
+
+    def _update_general_action(self, vec, intent):
+        if intent in self.general_intent:
+            vec[self.general_intent.index(intent)] = 1
+
+    def _update_intent2act(self, feature_slot, intent2act, domain, intent, slot, value):
+        feature_domain, feature_slot = split_slot_name(
+            feature_slot)  # .split('-')
+        # intent = intent.lower()
+        # slot = slot.lower()
+        # value = value.lower()
+        if slot == "none" and feature_domain == domain:  # None slot
+            intent2act[intent][2] = 1
+        elif feature_domain == domain and slot == feature_slot and intent in intent2act:
+            if value == "none":
+                intent2act[intent][0] = 1
+            elif value == "?":
+                intent2act[intent][1] = 1
+            else:
+                intent2act[intent][2] = 1
+
+    def _concatenate_action_vector(self, intent2act, general):
+        feat = []
+        for intent in intent2act:
+            feat += intent2act[intent]
+        feat += general
+        return feat
+
+    def _value_representation(self, slot, value):
+        if slot in ["CLS", "SEP"]:
+            return [0, 0, 0, 0]
+        if value == NOT_MENTIONED:
+            return [1, 0, 0, 0]
+        else:
+            temp_vector = [0] * (len(self.default_values) + 1)
+            if value in self.default_values:
+                temp_vector[self.default_values.index(value)] = 1
+            else:
+                temp_vector[-1] = 1
+
+            return temp_vector
+
+    def _conflict_check(self, user_goal, system_state, slot):
+        # conflict = [1] else [0]
+        if slot in ["CLS", "SEP"]:
+            return [0]
+        usr = user_goal.get(slot, NOT_MENTIONED)
+        sys = system_state.get(slot, NOT_MENTIONED)
+        if usr in [NOT_MENTIONED, "none", ""] and sys in [NOT_MENTIONED, "none", ""]:
+            return [0]
+
+        if usr != sys or (usr == "?" and sys == "?"):
+            # print(f"{slot} has different value: {usr} and {sys}.")
+            # conflict = uniform(0.2, 1)
+            conflict = 1
+            return [conflict]
+        return [0]
+
+    def _first_mention_detection(self, pre_state, cur_state, slot):
+        if slot in ["CLS", "SEP"]:
+            return [0]
+
+        first_mention = [1]
+        not_first_mention = [0]
+        cur = cur_state.get(slot, NOT_MENTIONED)
+        if pre_state is None:
+            if cur not in [NOT_MENTIONED, "none"]:
+                return first_mention
+            else:
+                return not_first_mention
+
+        pre = pre_state.get(slot, NOT_MENTIONED)
+
+        if pre in [NOT_MENTIONED, "none"] and cur not in [NOT_MENTIONED, "none"]:
+            return first_mention
+
+        return not_first_mention  # hasn't been mentioned
+
+
+if __name__ == "__main__":
+    from convlab.util import load_dataset, load_ontology
+    data = load_dataset("multiwoz21")
+    config = json.load(open("convlab/policy/tus/unify/exp/default.json"))
+    # test = TUSDataManager(config, data["test"])
+    train = TUSDataManager(config, data["train"])
+    # data = load_dataset("sgd")
+    # config = json.load(open("convlab/policy/tus/unify/exp/default.json"))
+    # data_manager = TUSDataManager(config, data["test"])
diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d65f72a06e181e66bfe0d7ac0c60f0c03a56ad43
--- /dev/null
+++ b/convlab/policy/tus/unify/util.py
@@ -0,0 +1,245 @@
+from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
+from convlab.util import load_dataset
+
+import json
+
+NOT_MENTIONED = "not mentioned"
+
+
+def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1):
+    ratio = {'train': split2ratio, 'validation': split2ratio}
+    if data_name == "all" or data_name == "sgd+tm" or data_name == "tm":
+        print("merge all datasets...")
+        if data_name == "all":
+            all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"]
+        if data_name == "sgd+tm":
+            all_dataset = ["sgd", "tm1", "tm2", "tm3"]
+        if data_name == "tm":
+            all_dataset = ["tm1", "tm2", "tm3"]
+
+        datasets = {}
+        for name in all_dataset:
+            datasets[name] = load_dataset(
+                name,
+                dial_ids_order=dial_ids_order,
+                split2ratio=ratio)
+        raw_data = merge_dataset(datasets, all_dataset[0])
+
+    else:
+        print(f"load single dataset {data_name}/{split2ratio}")
+        raw_data = load_dataset(data_name,
+                                dial_ids_order=dial_ids_order,
+                                split2ratio=ratio)
+    return raw_data
+
+
+def merge_dataset(datasets, data_name):
+    data_split = [x for x in datasets[data_name]]
+    raw_data = {}
+    for data_type in data_split:
+        raw_data[data_type] = []
+        for dataname, dataset in datasets.items():
+            print(f"merge {dataname}...")
+            raw_data[data_type] += dataset[data_type]
+    return raw_data
+
+
+def int2onehot(index, output_dim=6, remove_zero=False):
+    one_hot = [0] * output_dim
+    if remove_zero:
+        if index == 0:
+            one_hot[index] = 1
+    else:
+        if index >= 0:
+            one_hot[index] = 1
+
+    return one_hot
+
+
+def parse_user_goal(raw_goal):
+    """flatten user goal structure"""
+    goal = raw_goal.domain_goals
+    user_goal = {}
+    for domain in goal:
+        # if domain not in UsrDa2Goal:
+        #     continue
+        for slot_type in goal[domain]:
+            if slot_type in ["fail_info", "fail_book", "booked"]:
+                continue  # TODO [fail_info] fix in the future
+            if slot_type in ["info", "book", "reqt"]:
+                for slot in goal[domain][slot_type]:
+                    slot_name = f"{domain}-{slot}"
+                    user_goal[slot_name] = goal[domain][slot_type][slot]
+
+    return user_goal
+
+
+def parse_dialogue_act(dialogue_act):
+    """ transfer action from dict to list """
+    actions = []
+    for action_type in dialogue_act:
+        for act in dialogue_act[action_type]:
+            domain = act["domain"]
+            if "value" in act:
+                actions.append(
+                    [act["intent"], domain, act["slot"], act["value"]])
+            else:
+                if act["intent"] == "request":
+                    actions.append(
+                        [act["intent"], domain, act["slot"], "?"])
+                else:
+                    slot = act.get("slot", "none")
+                    value = act.get("value", "none")
+                    actions.append(
+                        [act["intent"], domain, slot, value])
+
+    return actions
+
+
+def metadata2state(metadata):
+    """
+    parse metadata in the data set or dst
+    """
+    slot_value = {}
+
+    for domain in metadata:
+        for slot in metadata[domain]:
+            slot_name = f"{domain}-{slot}"
+            value = metadata[domain][slot]
+            if not value or value == NOT_MENTIONED:
+                value = "none"
+            slot_value[slot_name] = value
+
+    return slot_value
+
+
+def get_booking_domain(slot, value, all_values, domain_list):
+    """ 
+    find the domain for domain booking, excluding slot "ref"
+    """
+    found = ""
+    if not slot:
+        return found
+    slot = slot.lower()
+    value = value.lower()
+    for domain in domain_list:
+        if slot in all_values["all_value"][domain] \
+                and value in all_values["all_value"][domain][slot]:
+            found = domain
+    return found
+
+
+def update_config_file(file_name, attribute, value):
+    with open(file_name, 'r') as config_file:
+        config = json.load(config_file)
+
+    config[attribute] = value
+    print(config)
+    with open(file_name, 'w') as config_file:
+        json.dump(config, config_file)
+    print(f"update {attribute} = {value}")
+
+
+def create_goal(dialog) -> list:
+    # a list of {'intent': ..., 'domain': ..., 'slot': ..., 'value': ...}
+    dicts = []
+    for turn in dialog['turns']:
+        # print(turn['speaker'])
+        # assert (i % 2 == 0) == (turn['speaker'] == 'user')
+        # if i % 2 == 0:
+        if turn['speaker'] == 'user':
+            dicts += turn['dialogue_acts']['categorical']
+            dicts += turn['dialogue_acts']['binary']
+            dicts += turn['dialogue_acts']['non-categorical']
+    tuples = []
+    for d in dicts:
+        if "value" not in d:
+            if d['intent'] == "request":
+                value = "?"
+            else:
+                value = "none"
+        else:
+            value = d["value"]
+        slot = d.get("slot", "none")
+        domain = d['domain']  # .split('_')[0]
+        tuples.append(
+            (d['domain'], d['intent'], slot, value)
+        )
+
+    user_goal = []  # a list of (domain, intent, slot, value)
+    for domain, intent, slot, value in tuples:
+        # if slot == "intent":
+        #     continue
+        if not slot:
+            continue
+        if intent == "inform" and value == "":
+            continue
+        # if intent == "request" and value != "":
+        #     intent = "inform"
+        user_goal.append((domain, intent, slot, value))
+    user_goal = unique_list(user_goal)
+    inform_slots = {(domain, slot) for (domain, intent, slot,
+                                        value) in user_goal if intent == "inform"}
+    user_goal = [(domain, intent, slot, value) for (domain, intent, slot, value)
+                 in user_goal if not (intent == "request" and (domain, slot) in inform_slots)]
+    return user_goal
+
+
+def unique_list(list_):
+    r = []
+    for x in list_:
+        if x not in r:
+            r.append(x)
+    return r
+
+
+def split_slot_name(slot_name):
+    tokens = slot_name.split('-')
+    if len(tokens) == 2:
+        return tokens[0], tokens[1]
+    else:
+        return tokens[0], '-'.join(tokens[1:])
+
+
+# copy from data.unified_datasets.multiwoz21
+slot_name_map = {
+    'addr': "address",
+    'post': "postcode",
+    'pricerange': "price range",
+    'arrive': "arrive by",
+    'arriveby': "arrive by",
+    'leave': "leave at",
+    'leaveat': "leave at",
+    'depart': "departure",
+    'dest': "destination",
+    'fee': "entrance fee",
+    'open': 'open hours',
+    'car': "type",
+    'car type': "type",
+    'ticket': 'price',
+    'trainid': 'train id',
+    'id': 'train id',
+    'people': 'book people',
+    'stay': 'book stay',
+    'none': '',
+    'attraction': {
+        'price': 'entrance fee'
+    },
+    'hospital': {},
+    'hotel': {
+        'day': 'book day', 'price': "price range"
+    },
+    'restaurant': {
+        'day': 'book day', 'time': 'book time', 'price': "price range"
+    },
+    'taxi': {},
+    'train': {
+        'day': 'day', 'time': "duration"
+    },
+    'police': {},
+    'booking': {}
+}
+
+if __name__ == "__main__":
+    print(split_slot_name("restaurant-search-location"))
+    print(split_slot_name("sports-day.match"))
diff --git a/convlab/policy/vector/README.md b/convlab/policy/vector/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..107ba733f2075204ed47c098a046d2f7805fa7b3
--- /dev/null
+++ b/convlab/policy/vector/README.md
@@ -0,0 +1,18 @@
+# Vectoriser
+
+The vectoriser is a module used by the policy network and has several functionalities
+
+1. it translates the semantic dialogue act into a vector representation usable for the policy network
+2. it translates the policy network output back into a lexicalized semantic act
+3. it creates an action masking that the policy can use to forbid illegal actions
+
+There is a **vector_base** class that has many functionalities already implemented. All other vector classes are inherited from the base class.
+
+If you build a new vectoriser, you need at least the following method:
+
+    
+    def state_vectorize(self, state):
+        # translates the semantic dialogue state into vector representation
+        # will be used by the policy module
+    
+See the implemented vector classes for examples.
\ No newline at end of file
diff --git a/convlab/policy/vector/dataset.py b/convlab/policy/vector/dataset.py
index 0aa1b7ad879f2d814cece3a98bb457b71ad99033..5b233e6659abc18da69a3efe6bb2d52185aa30fd 100755
--- a/convlab/policy/vector/dataset.py
+++ b/convlab/policy/vector/dataset.py
@@ -18,6 +18,26 @@ class ActDataset(data.Dataset):
         return self.num_total
 
 
+class ActDatasetKG(data.Dataset):
+    def __init__(self, action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch):
+        self.action_batch = action_batch
+        self.a_masks = a_masks
+        self.current_domain_mask_batch = current_domain_mask_batch
+        self.non_current_domain_mask_batch = non_current_domain_mask_batch
+        self.num_total = len(action_batch)
+
+    def __getitem__(self, index):
+        action = self.action_batch[index]
+        action_mask = self.a_masks[index]
+        current_domain_mask = self.current_domain_mask_batch[index]
+        non_current_domain_mask = self.non_current_domain_mask_batch[index]
+
+        return action, action_mask, current_domain_mask, non_current_domain_mask, index
+
+    def __len__(self):
+        return self.num_total
+
+
 class ActStateDataset(data.Dataset):
     def __init__(self, s_s, a_s, next_s):
         self.s_s = s_s
@@ -32,4 +52,4 @@ class ActStateDataset(data.Dataset):
         return s, a, next_s
     
     def __len__(self):
-        return self.num_total
\ No newline at end of file
+        return self.num_total
diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py
index 8b7d8ff0ddafed41efc91b249003ae55c525bc93..8f72144ce37a970fe4855a19d5bc8002fc2b4034 100644
--- a/convlab/policy/vector/vector_base.py
+++ b/convlab/policy/vector/vector_base.py
@@ -2,10 +2,11 @@
 import os
 import sys
 import numpy as np
+import logging
 
 from copy import deepcopy
 from convlab.policy.vec import Vector
-from convlab.util.custom_util import flatten_acts
+from convlab.util.custom_util import flatten_acts, timeout
 from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
 from convlab.util import load_ontology, load_database, load_dataset
 
@@ -22,18 +23,20 @@ class VectorBase(Vector):
 
         super().__init__()
 
+        logging.info(f"Vectorizer: Data set used is {dataset_name}")
         self.set_seed(seed)
         self.ontology = load_ontology(dataset_name)
         try:
             # execute to make sure that the database exists or is downloaded otherwise
-            load_database(dataset_name)
+            if dataset_name == "multiwoz21":
+                load_database(dataset_name)
             # the following two lines are needed for pickling correctly during multi-processing
             exec(f'from data.unified_datasets.{dataset_name}.database import Database')
             self.db = eval('Database()')
             self.db_domains = self.db.domains
         except Exception as e:
             self.db = None
-            self.db_domains = None
+            self.db_domains = []
             print(f"VectorBase: {e}")
 
         self.dataset_name = dataset_name
@@ -272,6 +275,8 @@ class VectorBase(Vector):
         2. If there is an entity available, can not say NoOffer or NoBook
         '''
         mask_list = np.zeros(self.da_dim)
+        if number_entities_dict is None:
+            return mask_list
         for i in range(self.da_dim):
             action = self.vec2act[i]
             domain, intent, slot, value = action.split('-')
diff --git a/convlab/policy/vector/vector_binary.py b/convlab/policy/vector/vector_binary.py
index 9efde57562c30d2f41aecf2331d23f4669f4ae14..e780dc645043f4775b208479abf022dccce649a5 100755
--- a/convlab/policy/vector/vector_binary.py
+++ b/convlab/policy/vector/vector_binary.py
@@ -8,7 +8,7 @@ from .vector_base import VectorBase
 class VectorBinary(VectorBase):
 
     def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True,
-                 seed=0):
+                 seed=0, **kwargs):
 
         super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
 
@@ -65,7 +65,7 @@ class VectorBinary(VectorBase):
         return state_vec, mask
 
     def get_mask(self, domain_active_dict, number_entities_dict):
-        domain_mask = self.compute_domain_mask(domain_active_dict)
+        #domain_mask = self.compute_domain_mask(domain_active_dict)
         entity_mask = self.compute_entity_mask(number_entities_dict)
         general_mask = self.compute_general_mask()
         mask = entity_mask + general_mask
diff --git a/convlab/policy/vector/vector_multiwoz_uncertainty.py b/convlab/policy/vector/vector_multiwoz_uncertainty.py
deleted file mode 100644
index 6a0850f4ce7c791f7da99c9c8b9e632eac543ba2..0000000000000000000000000000000000000000
--- a/convlab/policy/vector/vector_multiwoz_uncertainty.py
+++ /dev/null
@@ -1,238 +0,0 @@
-# -*- coding: utf-8 -*-
-import sys
-import os
-import numpy as np
-import logging
-from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
-from convlab.util.multiwoz.state import default_state
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
-from .vector_binary import VectorBinary as VectorBase
-
-DEFAULT_INTENT_FILEPATH = os.path.join(
-    os.path.dirname(os.path.dirname(os.path.dirname(
-        os.path.dirname(os.path.abspath(__file__))))),
-    'data/multiwoz/trackable_intent.json'
-)
-
-
-SLOT_MAP = {'taxi_types': 'car type'}
-
-
-class MultiWozVector(VectorBase):
-
-    def __init__(self, voc_file=None, voc_opp_file=None, character='sys',
-                 intent_file=DEFAULT_INTENT_FILEPATH,
-                 use_confidence_scores=False,
-                 use_entropy=False,
-                 use_mutual_info=False,
-                 use_masking=False,
-                 manually_add_entity_names=False,
-                 seed=0,
-                 shrink=False):
-
-        self.use_confidence_scores = use_confidence_scores
-        self.use_entropy = use_entropy
-        self.use_mutual_info = use_mutual_info
-        self.thresholds = None
-
-        super().__init__(voc_file, voc_opp_file, character, intent_file, use_masking, manually_add_entity_names, seed)
-
-    def get_state_dim(self):
-        self.belief_state_dim = 0
-        for domain in self.belief_domains:
-            for slot in default_state()['belief_state'][domain.lower()]['semi']:
-                # Dim 1 - indicator/confidence score
-                # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc)
-                slot_dim = 1 if not self.use_entropy else 2
-                slot_dim += 1 if self.use_mutual_info else 0
-                self.belief_state_dim += slot_dim
-
-        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
-            len(self.db_domains) + 6 * len(self.db_domains) + 1
-
-    def dbquery_domain(self, domain):
-        """
-        query entities of specified domain
-        Args:
-            domain string:
-                domain to query
-        Returns:
-            entities list:
-                list of entities of the specified domain
-        """
-        # Get all user constraints
-        constraint = self.state[domain.lower()]['semi']
-        constraint = {k: i for k, i in constraint.items() if i and i not in ['dontcare', "do n't care", "do not care"]}
-
-        # Remove constraints for which the uncertainty is high
-        if self.confidence_scores is not None and self.use_confidence_scores and self.thresholds != None:
-            # Collect threshold values for each domain-slot pair
-            thres = self.thresholds.get(domain.lower(), {})
-            thres = {k: thres.get(k, 0.05) for k in constraint}
-            # Get confidence scores for each constraint
-            probs = self.confidence_scores.get(domain.lower(), {})
-            probs = {k: probs.get(k, {}).get('inform', 1.0)
-                     for k in constraint}
-
-            # Filter out constraints for which confidence is lower than threshold
-            constraint = {k: i for k, i in constraint.items()
-                          if probs[k] >= thres[k]}
-
-        return self.db.query(domain.lower(), constraint.items())
-
-    # Add thresholds for db_queries
-    def setup_uncertain_query(self, thresholds):
-        self.use_confidence_scores = True
-        self.thresholds = thresholds
-        logging.info('DB Search uncertainty activated.')
-
-    def vectorize_user_act_confidence_scores(self, state, opp_action):
-        """Return confidence scores for the user actions"""
-        opp_act_vec = np.zeros(self.da_opp_dim)
-        for da in self.opp2vec:
-            domain, intent, slot, value = da.split('-')
-            if domain.lower() in state['belief_state_probs']:
-                # Map slot name to match user actions
-                slot = REF_SYS_DA[domain].get(
-                    slot, slot) if domain in REF_SYS_DA else slot
-                slot = slot if slot else 'none'
-                slot = SLOT_MAP.get(slot, slot)
-                domain = domain.lower()
-
-                if slot in state['belief_state_probs'][domain]:
-                    prob = state['belief_state_probs'][domain][slot]
-                elif slot.lower() in state['belief_state_probs'][domain]:
-                    prob = state['belief_state_probs'][domain][slot.lower()]
-                else:
-                    prob = {}
-
-                intent = intent.lower()
-                if intent in prob:
-                    prob = float(prob[intent])
-                elif da in opp_action:
-                    prob = 1.0
-                else:
-                    prob = 0.0
-            elif da in opp_action:
-                prob = 1.0
-            else:
-                prob = 0.0
-            opp_act_vec[self.opp2vec[da]] = prob
-
-        return opp_act_vec
-
-    def state_vectorize(self, state):
-        """vectorize a state
-
-        Args:
-            state (dict):
-                Dialog state
-            action (tuple):
-                Dialog act
-        Returns:
-            state_vec (np.array):
-                Dialog state vector
-        """
-        self.state = state['belief_state']
-        self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
-        domain_active_dict = {}
-        for domain in self.belief_domains:
-            domain_active_dict[domain] = False
-
-        # when character is sys, to help query database when da is booking-book
-        # update current domain according to user action
-        if self.character == 'sys':
-            action = state['user_action']
-            for intent, domain, slot, value in action:
-                domain_active_dict[domain] = True
-
-        action = state['user_action'] if self.character == 'sys' else state['system_action']
-        opp_action = delexicalize_da(action, self.requestable)
-        opp_action = flat_da(opp_action)
-        if 'belief_state_probs' in state and self.use_confidence_scores:
-            opp_act_vec = self.vectorize_user_act_confidence_scores(
-                state, opp_action)
-        else:
-            opp_act_vec = np.zeros(self.da_opp_dim)
-            for da in opp_action:
-                if da in self.opp2vec:
-                    prob = 1.0
-                    opp_act_vec[self.opp2vec[da]] = prob
-
-        action = state['system_action'] if self.character == 'sys' else state['user_action']
-        action = delexicalize_da(action, self.requestable)
-        action = flat_da(action)
-        last_act_vec = np.zeros(self.da_dim)
-        for da in action:
-            if da in self.act2vec:
-                last_act_vec[self.act2vec[da]] = 1.
-
-        belief_state = np.zeros(self.belief_state_dim)
-        i = 0
-        for domain in self.belief_domains:
-            if self.use_confidence_scores and 'belief_state_probs' in state:
-                for slot in state['belief_state'][domain.lower()]['semi']:
-                    if slot in state['belief_state_probs'][domain.lower()]:
-                        prob = state['belief_state_probs'][domain.lower()
-                                                           ][slot]
-                        prob = prob['inform'] if 'inform' in prob else None
-                    if prob:
-                        belief_state[i] = float(prob)
-                    i += 1
-            else:
-                for slot, value in state['belief_state'][domain.lower()]['semi'].items():
-                    if value and value != 'not mentioned':
-                        belief_state[i] = 1.
-                    i += 1
-            if 'active_domains' in state:
-                domain_active = state['active_domains'][domain.lower()]
-                domain_active_dict[domain] = domain_active
-            else:
-                if [slot for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value]:
-                    domain_active_dict[domain] = True
-
-        # Add knowledge and/or total uncertainty to the belief state
-        if self.use_entropy and 'entropy' in state:
-            for domain in self.belief_domains:
-                for slot in state['belief_state'][domain.lower()]['semi']:
-                    if slot in state['entropy'][domain.lower()]:
-                        belief_state[i] = float(
-                            state['entropy'][domain.lower()][slot])
-                    i += 1
-
-        if self.use_mutual_info and 'mutual_information' in state:
-            for domain in self.belief_domains:
-                for slot in state['belief_state'][domain.lower()]['semi']:
-                    if slot in state['mutual_information'][domain.lower()]:
-                        belief_state[i] = float(
-                            state['mutual_information'][domain.lower()][slot])
-                    i += 1
-
-        book = np.zeros(len(self.db_domains))
-        for i, domain in enumerate(self.db_domains):
-            if state['belief_state'][domain.lower()]['book']['booked']:
-                book[i] = 1.
-
-        degree, number_entities_dict = self.pointer()
-
-        final = 1. if state['terminated'] else 0.
-
-        state_vec = np.r_[opp_act_vec, last_act_vec,
-                          belief_state, book, degree, final]
-        assert len(state_vec) == self.state_dim
-
-        if self.use_mask is not None:
-            # None covers the case for policies that don't use masking at all, so do not expect an output "state_vec, mask"
-            if self.use_mask:
-                domain_mask = self.compute_domain_mask(domain_active_dict)
-                entity_mask = self.compute_entity_mask(number_entities_dict)
-                general_mask = self.compute_general_mask()
-                mask = domain_mask + entity_mask + general_mask
-                for i in range(self.da_dim):
-                    mask[i] = -int(bool(mask[i])) * sys.maxsize
-            else:
-                mask = np.zeros(self.da_dim)
-
-            return state_vec, mask
-        else:
-            return state_vec
diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e073669effc518cee4efd1f03d25bbd501b65af
--- /dev/null
+++ b/convlab/policy/vector/vector_nodes.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+import sys
+import numpy as np
+import logging
+
+from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
+from .vector_base import VectorBase
+
+
+class VectorNodes(VectorBase):
+
+    def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True,
+                 seed=0, filter_state=True):
+
+        super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
+        self.filter_state = filter_state
+        logging.info(f"We filter state by active domains: {self.filter_state}")
+
+    def get_state_dim(self):
+        self.belief_state_dim = 0
+
+        for domain in self.ontology['state']:
+            for slot in self.ontology['state'][domain]:
+                self.belief_state_dim += 1
+
+        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
+            len(self.db_domains) + 6 * len(self.db_domains) + 1
+
+    def init_kg_graph(self):
+        self.kg_info = []
+
+    def add_graph_node(self, domain, node_type, description, value):
+
+        node = {"domain": domain, "node_type": node_type, "description": description, "value": value}
+        self.kg_info.append(node)
+
+    def state_vectorize(self, state):
+        """vectorize a state
+
+        Args:
+            state (dict):
+                Dialog state
+            action (tuple):
+                Dialog act
+        Returns:
+            state_vec (np.array):
+                Dialog state vector
+        """
+        self.state = state['belief_state']
+        domain_active_dict = self.init_domain_active_dict()
+        self.init_kg_graph()
+
+        # when character is sys, to help query database when da is booking-book
+        # update current domain according to user action
+        if self.character == 'sys':
+            action = state['user_action']
+            for intent, domain, slot, value in action:
+                domain_active_dict[domain] = True
+
+        self.get_user_act_feature(state)
+        self.get_sys_act_feature(state)
+        domain_active_dict = self.get_user_goal_feature(state, domain_active_dict)
+        self.get_general_features(state, domain_active_dict)
+
+        if self.db is not None:
+            number_entities_dict = self.get_db_features()
+        else:
+            number_entities_dict = None
+
+        if self.filter_state:
+            self.kg_info = self.filter_inactive_domains(domain_active_dict)
+
+        if self.use_mask:
+            mask = self.get_mask(domain_active_dict, number_entities_dict)
+            for i in range(self.da_dim):
+                mask[i] = -int(bool(mask[i])) * sys.maxsize
+        else:
+            mask = np.zeros(self.da_dim)
+
+        return np.zeros(1), mask
+
+    def get_mask(self, domain_active_dict, number_entities_dict):
+        #domain_mask = self.compute_domain_mask(domain_active_dict)
+        entity_mask = self.compute_entity_mask(number_entities_dict)
+        general_mask = self.compute_general_mask()
+        mask = entity_mask + general_mask
+        return mask
+
+    def get_db_features(self):
+
+        degree, number_entities_dict = self.pointer()
+        feature_type = 'db'
+        for domain, num_entities in number_entities_dict.items():
+            description = f"db-{domain}-entities".lower()
+            # self.add_graph_node(domain, feature_type, description, int(num_entities > 0))
+            self.add_graph_node(domain, feature_type, description, min(num_entities, 5) / 5)
+        return number_entities_dict
+
+    def get_user_goal_feature(self, state, domain_active_dict):
+
+        feature_type = 'user goal'
+        for domain in self.belief_domains:
+            # the if case is needed because SGD only saves the dialogue state info for active domains
+            if domain in state['belief_state']:
+                for slot, value in state['belief_state'][domain].items():
+                    description = f"user goal-{domain}-{slot}".lower()
+                    value = 1.0 if (value and value != "not mentioned") else 0.0
+                    self.add_graph_node(domain, feature_type, description, value)
+
+                if [slot for slot, value in state['belief_state'][domain].items() if value]:
+                    domain_active_dict[domain] = True
+        return domain_active_dict
+
+    def get_sys_act_feature(self, state):
+
+        feature_type = 'last system act'
+        action = state['system_action'] if self.character == 'sys' else state['user_action']
+        action = delexicalize_da(action, self.requestable)
+        action = flat_da(action)
+        for da in action:
+            if da in self.act2vec:
+                domain = da.split('-')[0]
+                description = "system-" + da
+                value = 1.0
+                self.add_graph_node(domain, feature_type, description.lower(), value)
+
+    def get_user_act_feature(self, state):
+        # user-act feature
+        feature_type = 'user act'
+        action = state['user_action'] if self.character == 'sys' else state['system_action']
+        opp_action = delexicalize_da(action, self.requestable)
+        opp_action = flat_da(opp_action)
+
+        for da in opp_action:
+            if da in self.opp2vec:
+                domain = da.split('-')[0]
+                description = "user-" + da
+                value = 1.0
+                self.add_graph_node(domain, feature_type, description.lower(), value)
+
+    def get_general_features(self, state, domain_active_dict):
+
+        feature_type = 'general'
+        if 'booked' in state:
+            for i, domain in enumerate(self.db_domains):
+                if domain in state['booked']:
+                    description = f"general-{domain}-booked".lower()
+                    value = 1.0 if state['booked'][domain] else 0.0
+                    self.add_graph_node(domain, feature_type, description, value)
+
+        for domain in self.domains:
+            if domain == 'general':
+                continue
+            value = 1.0 if domain_active_dict[domain] else 0
+            description = f"general-{domain}".lower()
+            self.add_graph_node(domain, feature_type, description, value)
+
+    def filter_inactive_domains(self, domain_active_dict):
+
+        kg_filtered = []
+        for node in self.kg_info:
+            domain = node['domain']
+            if domain in domain_active_dict:
+                if domain_active_dict[domain]:
+                    kg_filtered.append(node)
+            else:
+                kg_filtered.append(node)
+
+        return kg_filtered
+
diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py
new file mode 100644
index 0000000000000000000000000000000000000000..7da05449dbe247cacf01bac4970ee669cd670c44
--- /dev/null
+++ b/convlab/policy/vector/vector_uncertainty.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+import sys
+import numpy as np
+import logging
+
+from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
+from convlab.policy.vector.vector_binary import VectorBinary
+
+
+class VectorUncertainty(VectorBinary):
+    """Vectorise state and state uncertainty predictions"""
+
+    def __init__(self,
+                 dataset_name: str = 'multiwoz21',
+                 character: str = 'sys',
+                 use_masking: bool = False,
+                 manually_add_entity_names: bool = True,
+                 seed: str = 0,
+                 use_confidence_scores: bool = True,
+                 confidence_thresholds: dict = None,
+                 use_state_total_uncertainty: bool = False,
+                 use_state_knowledge_uncertainty: bool = False):
+        """
+        Args:
+            dataset_name: Name of environment dataset
+            character: Character of the agent (sys/usr)
+            use_masking: If true certain actions are masked during devectorisation
+            manually_add_entity_names: If true inform entity name actions are manually added
+            seed: Seed
+            use_confidence_scores: If true confidence scores are used in state vectorisation
+            confidence_thresholds: If true confidence thresholds are used in database querying
+            use_state_total_uncertainty: If true state entropy is added to the state vector
+            use_state_knowledge_uncertainty: If true state mutual information is added to the state vector
+        """
+
+        self.use_confidence_scores = use_confidence_scores
+        self.use_state_total_uncertainty = use_state_total_uncertainty
+        self.use_state_knowledge_uncertainty = use_state_knowledge_uncertainty
+        if confidence_thresholds is not None:
+            self.setup_uncertain_query(confidence_thresholds)
+
+        super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
+
+    def get_state_dim(self):
+        self.belief_state_dim = 0
+
+        for domain in self.ontology['state']:
+            for slot in self.ontology['state'][domain]:
+                # Dim 1 - indicator/confidence score
+                # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc)
+                slot_dim = 1 if not self.use_state_total_uncertainty else 2
+                slot_dim += 1 if self.use_state_knowledge_uncertainty else 0
+                self.belief_state_dim += slot_dim
+
+        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
+            len(self.db_domains) + 6 * len(self.db_domains) + 1
+
+    # Add thresholds for db_queries
+    def setup_uncertain_query(self, confidence_thresholds):
+        self.use_confidence_scores = True
+        self.confidence_thresholds = confidence_thresholds
+        logging.info('DB Search uncertainty activated.')
+
+    def dbquery_domain(self, domain):
+        """
+        query entities of specified domain
+        Args:
+            domain string:
+                domain to query
+        Returns:
+            entities list:
+                list of entities of the specified domain
+        """
+        # Get all user constraints
+        constraints = {slot: value for slot, value in self.state[domain].items()
+                       if slot and value not in ['dontcare',
+                                                 "do n't care", "do not care"]} if domain in self.state else dict()
+
+        # Remove constraints for which the uncertainty is high
+        if self.confidence_scores is not None and self.use_confidence_scores and self.confidence_thresholds is not None:
+            # Collect threshold values for each domain-slot pair
+            threshold = self.confidence_thresholds.get(domain, dict())
+            threshold = {slot: threshold.get(slot, 0.05) for slot in constraints}
+            # Get confidence scores for each constraint
+            probs = self.confidence_scores.get(domain, dict())
+            probs = {slot: probs.get(slot, {}).get('inform', 1.0) for slot in constraints}
+
+            # Filter out constraints for which confidence is lower than threshold
+            constraints = {slot: value for slot, value in constraints.items() if probs[slot] >= threshold[slot]}
+
+        return self.db.query(domain, constraints.items(), topk=10)
+
+    def vectorize_user_act(self, state):
+        """Return confidence scores for the user actions"""
+        self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
+        action = state['user_action'] if self.character == 'sys' else state['system_action']
+        opp_action = delexicalize_da(action, self.requestable)
+        opp_action = flat_da(opp_action)
+        opp_act_vec = np.zeros(self.da_opp_dim)
+        for da in opp_action:
+            if da in self.opp2vec:
+                if 'belief_state_probs' in state and self.use_confidence_scores:
+                    domain, intent, slot, value = da.split('-')
+                    if domain in state['belief_state_probs']:
+                        slot = slot if slot else 'none'
+                        if slot in state['belief_state_probs'][domain]:
+                            prob = state['belief_state_probs'][domain][slot]
+                        elif slot.lower() in state['belief_state_probs'][domain]:
+                            prob = state['belief_state_probs'][domain][slot.lower()]
+                        else:
+                            prob = dict()
+
+                        if intent in prob:
+                            prob = float(prob[intent])
+                        else:
+                            prob = 1.0
+                    else:
+                        prob = 1.0
+                else:
+                    prob = 1.0
+                opp_act_vec[self.opp2vec[da]] = prob
+
+        return opp_act_vec
+
+    def vectorize_belief_state(self, state, domain_active_dict):
+        belief_state = np.zeros(self.belief_state_dim)
+        i = 0
+        for domain in self.belief_domains:
+            if self.use_confidence_scores and 'belief_state_probs' in state:
+                for slot in state['belief_state'][domain]:
+                    prob = None
+                    if slot in state['belief_state_probs'][domain]:
+                        prob = state['belief_state_probs'][domain][slot]
+                        prob = prob['inform'] if 'inform' in prob else None
+                    if prob:
+                        belief_state[i] = float(prob)
+                    i += 1
+            else:
+                for slot, value in state['belief_state'][domain].items():
+                    if value and value != 'not mentioned':
+                        belief_state[i] = 1.
+                    i += 1
+
+            if 'active_domains' in state:
+                domain_active = state['active_domains'][domain]
+                domain_active_dict[domain] = domain_active
+            else:
+                if [slot for slot, value in state['belief_state'][domain].items() if value]:
+                    domain_active_dict[domain] = True
+
+        # Add knowledge and/or total uncertainty to the belief state
+        if self.use_state_total_uncertainty and 'entropy' in state:
+            for domain in self.belief_domains:
+                for slot in state['belief_state'][domain]:
+                    if slot in state['entropy'][domain]:
+                        belief_state[i] = float(state['entropy'][domain][slot])
+                    i += 1
+
+        if self.use_state_knowledge_uncertainty and 'mutual_information' in state:
+            for domain in self.belief_domains:
+                for slot in state['belief_state'][domain]:
+                    if slot in state['mutual_information'][domain]:
+                        belief_state[i] = float(state['mutual_information'][domain][slot])
+                    i += 1
+
+        return belief_state, domain_active_dict
diff --git a/convlab/policy/vtrace_DPT/README.md b/convlab/policy/vtrace_DPT/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..002a8a050cc8bf573761a1b5ba2276d844a6db7d
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/README.md
@@ -0,0 +1,105 @@
+# Dynamic Dialogue Policy Transformer (DDPT)
+
+The dynamic dialogue policy transformer (Geishauser et. al. 2022) is a model built for continual reinforcement learning. It uses a pre-trained RoBERTa language model to construct embeddings for each state information and domain, slot and value in the action set. As a consequence, it can be used for different ontologies and is able to deal with new state information as well as actions. The backbone architecture is a transformer encoder-decoder.
+
+It uses the CLEAR algorithm (Rolnick et. al. 2019) for continual reinforcement learning that builds on top of VTRACE (Espheholt et. al. 2018). The current folder supports only training in a stationary environment and no continual learning, which uses VTRACE as algorithm.
+
+## Supervised pre-training
+
+If you want to pre-train the model on a dataset, use the command
+
+```sh
+$ python supervised/train_supervised.py --dataset_name=DATASET_NAME --seed=SEED --model_path=""
+```
+
+The first time you run that command, it will take longer as the dataset needs to be pre-processed.
+
+This will create a corresponding experiments folder under supervised/experiments, where the model is saved in /save.
+
+You can specify the dataset that you would like to use, e.g. "multiwoz21" or "sgd". You can also specify a model_path if you have already a pre-trained model, for instance when you first train on SGD before you fine-tune on multiwoz21 data.
+
+You can specify hyperparamters such as epoch, supervised_lr and data_percentage (how much of the data you want to use) in the config.json file.
+
+We provide several supervised trained models on hugging-face to reproduce the results:
+
+- pre-trained on SGD: https://huggingface.co/ConvLab/ddpt-policy-sgd
+- pre-trained on 1% multiwoz21: https://huggingface.co/ConvLab/ddpt-policy-0.01multiwoz21
+- pre-trained on SGD and afterwards on 1% multiwoz21: https://huggingface.co/ConvLab/ddpt-policy-sgd_0.01multiwoz21
+
+## RL training
+
+Starting a RL training is as easy as executing
+
+```sh
+$ python train.py --path=your_environment_config --seed=SEED
+```
+
+One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance
+
+- load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
+- process_num: the number of processes to use during evaluation to speed it up
+- num_eval_dialogues: how many evaluation dialogues should be used
+- eval_frequency: after how many training dialogues an evaluation should be performed
+- total_dialogues: how many training dialogues should be done in total
+- new_dialogues: how many new dialogues should be collected before a policy update
+
+Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
+
+Parameters that are tied to the RL algorithm and the model architecture can be changed in config.json.
+
+
+## Evaluation
+
+For creating evaluation plots and running evaluation dialogues, please have a look in the README of the policy folder.
+
+## References
+
+```
+@inproceedings{geishauser-etal-2022-dynamic,
+    title = "Dynamic Dialogue Policy for Continual Reinforcement Learning",
+    author = "Geishauser, Christian  and
+      van Niekerk, Carel  and
+      Lin, Hsien-chin  and
+      Lubis, Nurul  and
+      Heck, Michael  and
+      Feng, Shutong  and
+      Ga{\v{s}}i{\'c}, Milica",
+    booktitle = "Proceedings of the 29th International Conference on Computational Linguistics",
+    month = oct,
+    year = "2022",
+    address = "Gyeongju, Republic of Korea",
+    publisher = "International Committee on Computational Linguistics",
+    url = "https://aclanthology.org/2022.coling-1.21",
+    pages = "266--284",
+    abstract = "Continual learning is one of the key components of human learning and a necessary requirement of artificial intelligence. As dialogue can potentially span infinitely many topics and tasks, a task-oriented dialogue system must have the capability to continually learn, dynamically adapting to new challenges while preserving the knowledge it already acquired. Despite the importance, continual reinforcement learning of the dialogue policy has remained largely unaddressed. The lack of a framework with training protocols, baseline models and suitable metrics, has so far hindered research in this direction. In this work we fill precisely this gap, enabling research in dialogue policy optimisation to go from static to dynamic learning. We provide a continual learning algorithm, baseline architectures and metrics for assessing continual learning models. Moreover, we propose the dynamic dialogue policy transformer (DDPT), a novel dynamic architecture that can integrate new knowledge seamlessly, is capable of handling large state spaces and obtains significant zero-shot performance when being exposed to unseen domains, without any growth in network parameter size. We validate the strengths of DDPT in simulation with two user simulators as well as with humans.",
+}
+
+@inproceedings{NEURIPS2019_fa7cdfad,
+ author = {Rolnick, David and Ahuja, Arun and Schwarz, Jonathan and Lillicrap, Timothy and Wayne, Gregory},
+ booktitle = {Advances in Neural Information Processing Systems},
+ editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
+ pages = {},
+ publisher = {Curran Associates, Inc.},
+ title = {Experience Replay for Continual Learning},
+ url = {https://proceedings.neurips.cc/paper/2019/file/fa7cdfad1a5aaf8370ebeda47a1ff1c3-Paper.pdf},
+ volume = {32},
+ year = {2019}
+}
+
+@InProceedings{pmlr-v80-espeholt18a,
+  title = 	 {{IMPALA}: Scalable Distributed Deep-{RL} with Importance Weighted Actor-Learner Architectures},
+  author =       {Espeholt, Lasse and Soyer, Hubert and Munos, Remi and Simonyan, Karen and Mnih, Vlad and Ward, Tom and Doron, Yotam and Firoiu, Vlad and Harley, Tim and Dunning, Iain and Legg, Shane and Kavukcuoglu, Koray},
+  booktitle = 	 {Proceedings of the 35th International Conference on Machine Learning},
+  pages = 	 {1407--1416},
+  year = 	 {2018},
+  editor = 	 {Dy, Jennifer and Krause, Andreas},
+  volume = 	 {80},
+  series = 	 {Proceedings of Machine Learning Research},
+  month = 	 {10--15 Jul},
+  publisher =    {PMLR},
+  pdf = 	 {http://proceedings.mlr.press/v80/espeholt18a/espeholt18a.pdf},
+  url = 	 {https://proceedings.mlr.press/v80/espeholt18a.html},
+  abstract = 	 {In this work we aim to solve a large collection of tasks using a single reinforcement learning agent with a single set of parameters. A key challenge is to handle the increased amount of data and extended training time. We have developed a new distributed agent IMPALA (Importance Weighted Actor-Learner Architecture) that not only uses resources more efficiently in single-machine training but also scales to thousands of machines without sacrificing data efficiency or resource utilisation. We achieve stable learning at high throughput by combining decoupled acting and learning with a novel off-policy correction method called V-trace. We demonstrate the effectiveness of IMPALA for multi-task reinforcement learning on DMLab-30 (a set of 30 tasks from the DeepMind Lab environment (Beattie et al., 2016)) and Atari57 (all available Atari games in Arcade Learning Environment (Bellemare et al., 2013a)). Our results show that IMPALA is able to achieve better performance than previous agents with less data, and crucially exhibits positive transfer between tasks as a result of its multi-task approach.}
+}
+
+```
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/__init__.py b/convlab/policy/vtrace_DPT/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13469e5bc59921487499878f90bc146bc0a43e4c
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/__init__.py
@@ -0,0 +1 @@
+from convlab.policy.vtrace_DPT.vtrace import VTRACE
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/config.json b/convlab/policy/vtrace_DPT/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..13362e74f11fd3705820028366b145c7d66ea4d9
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/config.json
@@ -0,0 +1,76 @@
+{
+	"batchsz": 64,
+	"epoch": 40,
+	"gamma": 0.99,
+	"policy_lr": 5e-06,
+	"supervised_lr": 1e-05,
+	"entropy_weight": 0.01,
+	"value_lr": 0.0001,
+	"save_dir": "save",
+	"log_dir": "log",
+	"save_per_epoch": 5000,
+	"hidden_size": 256,
+	"load": "save/best",
+	"logging_mode": "INFO",
+	"use_cer": true,
+	"memory_size": 5000,
+	"behaviour_cloning_weight": 0.1,
+	"supervised_weight": 0.0,
+	"online_offline_ratio": 0.20,
+	"smoothed_value_function": false,
+	"use_reservoir_sampling": false,
+	"seed": 0,
+	"lambda": 1,
+	"tau": 0.001,
+	"policy_freq": 1,
+	"print_per_batch": 400,
+	"c": 1.0,
+	"rho_bar": 1,
+	"max_length": 10,
+	"noisy_linear": false,
+	"dataset_name": "multiwoz21",
+	"data_percentage": 1.0,
+	"dialogue_order": 0,
+	"multiwoz_like": false,
+	"regularization_weight": 0.0,
+
+	"enc_input_dim": 128,
+	"enc_nhead": 2,
+	"enc_d_hid": 128,
+	"enc_nlayers": 4,
+	"enc_dropout": 0.1,
+
+	"dec_input_dim": 128,
+	"dec_nhead": 2,
+	"dec_d_hid": 128,
+	"dec_nlayers": 2,
+	"dec_dropout": 0.0,
+
+	"action_embedding_dim": 128,
+	"domain_embedding_dim": 64,
+	"value_embedding_dim": 12,
+	"node_embedding_dim": 128,
+	"roberta_path": "",
+	"node_attention": true,
+	"semantic_descriptions": true,
+	"freeze_roberta": true,
+	"use_pooled": false,
+	"mean": true,
+	"roberta_actions": true,
+	"independent_descriptions": true,
+	"random_matrix": false,
+	"distance_metric": false,
+
+	"verbose": false,
+	"ignore_features": [],
+	"domains_removed": ["hospital", "police", "train", "hotel", "attraction", "taxi"],
+	"only_active_values": false,
+	"permuted_data": false,
+	"need_weights": false,
+
+	"cls_dim": 128,
+	"independent": true,
+	"old_critic": false,
+	"pos_weight": 5,
+	"weight_decay": 0.00001
+}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/create_descriptions.py b/convlab/policy/vtrace_DPT/create_descriptions.py
new file mode 100644
index 0000000000000000000000000000000000000000..5357fafafcf4972aa8eebaecf8c56e9fba79afe7
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/create_descriptions.py
@@ -0,0 +1,66 @@
+import os
+import json
+
+from convlab.policy.vector.vector_binary import VectorBinary
+from convlab.util import load_ontology, load_database
+from convlab.util.custom_util import timeout
+
+
+def create_description_dicts(name='multiwoz21'):
+
+    vector = VectorBinary(name)
+    ontology = load_ontology(name)
+    default_state = ontology['state']
+    domains = list(ontology['domains'].keys())
+
+    if name == "multiwoz21":
+        db = load_database(name)
+        db_domains = db.domains
+    else:
+        db = None
+        db_domains = []
+
+    root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+    voc_file = os.path.join(root_dir, f'vector/action_dicts/{name}_VectorBinary/sys_da_voc.txt')
+    voc_opp_file = os.path.join(root_dir, f'vector/action_dicts/{name}_VectorBinary/user_da_voc.txt')
+
+    with open(voc_file) as f:
+        da_voc = f.read().splitlines()
+    with open(voc_opp_file) as f:
+        da_voc_opp = f.read().splitlines()
+
+    description_dict_semantic = {}
+
+    for domain in default_state:
+        for slot in default_state[domain]:
+            domain = domain.lower()
+            description_dict_semantic[f"user goal-{domain}-{slot.lower()}"] = f"user goal {domain} {slot}"
+
+    if db_domains:
+        for domain in db_domains:
+            domain = domain.lower()
+            description_dict_semantic[f"db-{domain}-entities"] = f"data base {domain} number of entities"
+            description_dict_semantic[f"general-{domain}-booked"] = f"general {domain} booked"
+
+    for domain in domains:
+        domain = domain.lower()
+        description_dict_semantic[f"general-{domain}"] = f"domain {domain}"
+
+    for act in da_voc:
+        domain, intent, slot, value = act.split("-")
+        domain = domain.lower()
+        description_dict_semantic["system-"+act.lower()] = f"last system act {domain} {intent} {slot} {value}"
+
+    for act in da_voc_opp:
+        domain, intent, slot, value = [item.lower() for item in act.split("-")]
+        domain = domain.lower()
+        description_dict_semantic["user-"+act.lower()] = f"user act {domain} {intent} {slot} {value}"
+
+    root_dir = os.path.dirname(os.path.abspath(__file__))
+    os.makedirs(os.path.join(root_dir, "descriptions"), exist_ok=True)
+    with open(os.path.join(root_dir, 'descriptions', f'semantic_information_descriptions_{name}.json'), "w") as f:
+        json.dump(description_dict_semantic, f)
+
+
+if __name__ == '__main__':
+    create_description_dicts()
diff --git a/convlab/policy/vtrace_DPT/descriptions/semantic_information_descriptions_multiwoz21.json b/convlab/policy/vtrace_DPT/descriptions/semantic_information_descriptions_multiwoz21.json
new file mode 100644
index 0000000000000000000000000000000000000000..f0757293c3f05fbd7b77dbf356f3e7c4fc578371
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/descriptions/semantic_information_descriptions_multiwoz21.json
@@ -0,0 +1 @@
+{"user goal-attraction-type": "user goal attraction type", "user goal-attraction-name": "user goal attraction name", "user goal-attraction-area": "user goal attraction area", "user goal-hotel-name": "user goal hotel name", "user goal-hotel-area": "user goal hotel area", "user goal-hotel-parking": "user goal hotel parking", "user goal-hotel-price range": "user goal hotel price range", "user goal-hotel-stars": "user goal hotel stars", "user goal-hotel-internet": "user goal hotel internet", "user goal-hotel-type": "user goal hotel type", "user goal-hotel-book stay": "user goal hotel book stay", "user goal-hotel-book day": "user goal hotel book day", "user goal-hotel-book people": "user goal hotel book people", "user goal-restaurant-food": "user goal restaurant food", "user goal-restaurant-price range": "user goal restaurant price range", "user goal-restaurant-name": "user goal restaurant name", "user goal-restaurant-area": "user goal restaurant area", "user goal-restaurant-book time": "user goal restaurant book time", "user goal-restaurant-book day": "user goal restaurant book day", "user goal-restaurant-book people": "user goal restaurant book people", "user goal-taxi-leave at": "user goal taxi leave at", "user goal-taxi-destination": "user goal taxi destination", "user goal-taxi-departure": "user goal taxi departure", "user goal-taxi-arrive by": "user goal taxi arrive by", "user goal-train-leave at": "user goal train leave at", "user goal-train-destination": "user goal train destination", "user goal-train-day": "user goal train day", "user goal-train-arrive by": "user goal train arrive by", "user goal-train-departure": "user goal train departure", "user goal-train-book people": "user goal train book people", "user goal-hospital-department": "user goal hospital department", "db-restaurant-entities": "data base restaurant number of entities", "general-restaurant-booked": "general restaurant booked", "db-hotel-entities": "data base hotel number of entities", "general-hotel-booked": "general hotel booked", "db-attraction-entities": "data base attraction number of entities", "general-attraction-booked": "general attraction booked", "db-train-entities": "data base train number of entities", "general-train-booked": "general train booked", "db-hospital-entities": "data base hospital number of entities", "general-hospital-booked": "general hospital booked", "db-police-entities": "data base police number of entities", "general-police-booked": "general police booked", "general-attraction": "domain attraction", "general-hotel": "domain hotel", "general-taxi": "domain taxi", "general-restaurant": "domain restaurant", "general-train": "domain train", "general-police": "domain police", "general-hospital": "domain hospital", "general-general": "domain general", "system-attraction-inform-address-1": "last system act attraction inform address 1", "system-attraction-inform-address-2": "last system act attraction inform address 2", "system-attraction-inform-address-3": "last system act attraction inform address 3", "system-attraction-inform-area-1": "last system act attraction inform area 1", "system-attraction-inform-area-2": "last system act attraction inform area 2", "system-attraction-inform-area-3": "last system act attraction inform area 3", "system-attraction-inform-choice-1": "last system act attraction inform choice 1", "system-attraction-inform-choice-2": "last system act attraction inform choice 2", "system-attraction-inform-choice-3": "last system act attraction inform choice 3", "system-attraction-inform-entrance fee-1": "last system act attraction inform entrance fee 1", "system-attraction-inform-entrance fee-2": "last system act attraction inform entrance fee 2", "system-attraction-inform-name-1": "last system act attraction inform name 1", "system-attraction-inform-name-2": "last system act attraction inform name 2", "system-attraction-inform-name-3": "last system act attraction inform name 3", "system-attraction-inform-name-4": "last system act attraction inform name 4", "system-attraction-inform-phone-1": "last system act attraction inform phone 1", "system-attraction-inform-postcode-1": "last system act attraction inform postcode 1", "system-attraction-inform-type-1": "last system act attraction inform type 1", "system-attraction-inform-type-2": "last system act attraction inform type 2", "system-attraction-inform-type-3": "last system act attraction inform type 3", "system-attraction-inform-type-4": "last system act attraction inform type 4", "system-attraction-inform-type-5": "last system act attraction inform type 5", "system-attraction-nooffer-area-1": "last system act attraction nooffer area 1", "system-attraction-nooffer-none-none": "last system act attraction nooffer none none", "system-attraction-nooffer-type-1": "last system act attraction nooffer type 1", "system-attraction-recommend-address-1": "last system act attraction recommend address 1", "system-attraction-recommend-address-2": "last system act attraction recommend address 2", "system-attraction-recommend-area-1": "last system act attraction recommend area 1", "system-attraction-recommend-entrance fee-1": "last system act attraction recommend entrance fee 1", "system-attraction-recommend-name-1": "last system act attraction recommend name 1", "system-attraction-recommend-phone-1": "last system act attraction recommend phone 1", "system-attraction-recommend-postcode-1": "last system act attraction recommend postcode 1", "system-attraction-recommend-type-1": "last system act attraction recommend type 1", "system-attraction-request-area-?": "last system act attraction request area ?", "system-attraction-request-entrance fee-?": "last system act attraction request entrance fee ?", "system-attraction-request-name-?": "last system act attraction request name ?", "system-attraction-request-type-?": "last system act attraction request type ?", "system-attraction-select-none-none": "last system act attraction select none none", "system-attraction-select-type-1": "last system act attraction select type 1", "system-attraction-select-type-2": "last system act attraction select type 2", "system-attraction-select-type-3": "last system act attraction select type 3", "system-general-bye-none-none": "last system act general bye none none", "system-general-greet-none-none": "last system act general greet none none", "system-general-reqmore-none-none": "last system act general reqmore none none", "system-general-welcome-none-none": "last system act general welcome none none", "system-hospital-inform-address-1": "last system act hospital inform address 1", "system-hospital-inform-department-1": "last system act hospital inform department 1", "system-hospital-inform-phone-1": "last system act hospital inform phone 1", "system-hospital-inform-postcode-1": "last system act hospital inform postcode 1", "system-hospital-request-department-?": "last system act hospital request department ?", "system-hotel-book-none-none": "last system act hotel book none none", "system-hotel-inform-address-1": "last system act hotel inform address 1", "system-hotel-inform-address-2": "last system act hotel inform address 2", "system-hotel-inform-area-1": "last system act hotel inform area 1", "system-hotel-inform-area-2": "last system act hotel inform area 2", "system-hotel-inform-book day-1": "last system act hotel inform book day 1", "system-hotel-inform-book people-1": "last system act hotel inform book people 1", "system-hotel-inform-book stay-1": "last system act hotel inform book stay 1", "system-hotel-inform-choice-1": "last system act hotel inform choice 1", "system-hotel-inform-choice-2": "last system act hotel inform choice 2", "system-hotel-inform-choice-3": "last system act hotel inform choice 3", "system-hotel-inform-internet-1": "last system act hotel inform internet 1", "system-hotel-inform-name-1": "last system act hotel inform name 1", "system-hotel-inform-name-2": "last system act hotel inform name 2", "system-hotel-inform-name-3": "last system act hotel inform name 3", "system-hotel-inform-parking-1": "last system act hotel inform parking 1", "system-hotel-inform-phone-1": "last system act hotel inform phone 1", "system-hotel-inform-postcode-1": "last system act hotel inform postcode 1", "system-hotel-inform-price range-1": "last system act hotel inform price range 1", "system-hotel-inform-price range-2": "last system act hotel inform price range 2", "system-hotel-inform-ref-1": "last system act hotel inform ref 1", "system-hotel-inform-stars-1": "last system act hotel inform stars 1", "system-hotel-inform-stars-2": "last system act hotel inform stars 2", "system-hotel-inform-type-1": "last system act hotel inform type 1", "system-hotel-inform-type-2": "last system act hotel inform type 2", "system-hotel-nooffer-area-1": "last system act hotel nooffer area 1", "system-hotel-nooffer-none-none": "last system act hotel nooffer none none", "system-hotel-nooffer-price range-1": "last system act hotel nooffer price range 1", "system-hotel-nooffer-stars-1": "last system act hotel nooffer stars 1", "system-hotel-nooffer-type-1": "last system act hotel nooffer type 1", "system-hotel-offerbook-name-1": "last system act hotel offerbook name 1", "system-hotel-recommend-address-1": "last system act hotel recommend address 1", "system-hotel-recommend-area-1": "last system act hotel recommend area 1", "system-hotel-recommend-internet-1": "last system act hotel recommend internet 1", "system-hotel-recommend-name-1": "last system act hotel recommend name 1", "system-hotel-recommend-parking-1": "last system act hotel recommend parking 1", "system-hotel-recommend-price range-1": "last system act hotel recommend price range 1", "system-hotel-recommend-stars-1": "last system act hotel recommend stars 1", "system-hotel-recommend-type-1": "last system act hotel recommend type 1", "system-hotel-request-area-?": "last system act hotel request area ?", "system-hotel-request-book day-?": "last system act hotel request book day ?", "system-hotel-request-book people-?": "last system act hotel request book people ?", "system-hotel-request-book stay-?": "last system act hotel request book stay ?", "system-hotel-request-internet-?": "last system act hotel request internet ?", "system-hotel-request-name-?": "last system act hotel request name ?", "system-hotel-request-parking-?": "last system act hotel request parking ?", "system-hotel-request-price range-?": "last system act hotel request price range ?", "system-hotel-request-stars-?": "last system act hotel request stars ?", "system-hotel-request-type-?": "last system act hotel request type ?", "system-hotel-select-area-1": "last system act hotel select area 1", "system-hotel-select-area-2": "last system act hotel select area 2", "system-hotel-select-name-1": "last system act hotel select name 1", "system-hotel-select-none-none": "last system act hotel select none none", "system-hotel-select-price range-1": "last system act hotel select price range 1", "system-hotel-select-price range-2": "last system act hotel select price range 2", "system-hotel-select-stars-1": "last system act hotel select stars 1", "system-hotel-select-type-1": "last system act hotel select type 1", "system-hotel-select-type-2": "last system act hotel select type 2", "system-police-inform-address-1": "last system act police inform address 1", "system-police-inform-name-1": "last system act police inform name 1", "system-police-inform-phone-1": "last system act police inform phone 1", "system-police-inform-postcode-1": "last system act police inform postcode 1", "system-restaurant-book-none-none": "last system act restaurant book none none", "system-restaurant-inform-address-1": "last system act restaurant inform address 1", "system-restaurant-inform-address-2": "last system act restaurant inform address 2", "system-restaurant-inform-area-1": "last system act restaurant inform area 1", "system-restaurant-inform-area-2": "last system act restaurant inform area 2", "system-restaurant-inform-book day-1": "last system act restaurant inform book day 1", "system-restaurant-inform-book people-1": "last system act restaurant inform book people 1", "system-restaurant-inform-book time-1": "last system act restaurant inform book time 1", "system-restaurant-inform-choice-1": "last system act restaurant inform choice 1", "system-restaurant-inform-choice-2": "last system act restaurant inform choice 2", "system-restaurant-inform-choice-3": "last system act restaurant inform choice 3", "system-restaurant-inform-food-1": "last system act restaurant inform food 1", "system-restaurant-inform-food-2": "last system act restaurant inform food 2", "system-restaurant-inform-food-3": "last system act restaurant inform food 3", "system-restaurant-inform-food-4": "last system act restaurant inform food 4", "system-restaurant-inform-name-1": "last system act restaurant inform name 1", "system-restaurant-inform-name-2": "last system act restaurant inform name 2", "system-restaurant-inform-name-3": "last system act restaurant inform name 3", "system-restaurant-inform-name-4": "last system act restaurant inform name 4", "system-restaurant-inform-phone-1": "last system act restaurant inform phone 1", "system-restaurant-inform-postcode-1": "last system act restaurant inform postcode 1", "system-restaurant-inform-postcode-2": "last system act restaurant inform postcode 2", "system-restaurant-inform-price range-1": "last system act restaurant inform price range 1", "system-restaurant-inform-price range-2": "last system act restaurant inform price range 2", "system-restaurant-inform-ref-1": "last system act restaurant inform ref 1", "system-restaurant-nobook-book time-1": "last system act restaurant nobook book time 1", "system-restaurant-nooffer-area-1": "last system act restaurant nooffer area 1", "system-restaurant-nooffer-food-1": "last system act restaurant nooffer food 1", "system-restaurant-nooffer-none-none": "last system act restaurant nooffer none none", "system-restaurant-nooffer-price range-1": "last system act restaurant nooffer price range 1", "system-restaurant-offerbook-name-1": "last system act restaurant offerbook name 1", "system-restaurant-recommend-address-1": "last system act restaurant recommend address 1", "system-restaurant-recommend-area-1": "last system act restaurant recommend area 1", "system-restaurant-recommend-food-1": "last system act restaurant recommend food 1", "system-restaurant-recommend-name-1": "last system act restaurant recommend name 1", "system-restaurant-recommend-phone-1": "last system act restaurant recommend phone 1", "system-restaurant-recommend-postcode-1": "last system act restaurant recommend postcode 1", "system-restaurant-recommend-price range-1": "last system act restaurant recommend price range 1", "system-restaurant-request-area-?": "last system act restaurant request area ?", "system-restaurant-request-book day-?": "last system act restaurant request book day ?", "system-restaurant-request-book people-?": "last system act restaurant request book people ?", "system-restaurant-request-book time-?": "last system act restaurant request book time ?", "system-restaurant-request-food-?": "last system act restaurant request food ?", "system-restaurant-request-name-?": "last system act restaurant request name ?", "system-restaurant-request-price range-?": "last system act restaurant request price range ?", "system-restaurant-select-area-1": "last system act restaurant select area 1", "system-restaurant-select-area-2": "last system act restaurant select area 2", "system-restaurant-select-food-1": "last system act restaurant select food 1", "system-restaurant-select-food-2": "last system act restaurant select food 2", "system-restaurant-select-food-3": "last system act restaurant select food 3", "system-restaurant-select-name-1": "last system act restaurant select name 1", "system-restaurant-select-none-none": "last system act restaurant select none none", "system-restaurant-select-price range-1": "last system act restaurant select price range 1", "system-restaurant-select-price range-2": "last system act restaurant select price range 2", "system-taxi-book-none-none": "last system act taxi book none none", "system-taxi-inform-arrive by-1": "last system act taxi inform arrive by 1", "system-taxi-inform-departure-1": "last system act taxi inform departure 1", "system-taxi-inform-destination-1": "last system act taxi inform destination 1", "system-taxi-inform-leave at-1": "last system act taxi inform leave at 1", "system-taxi-inform-none-none": "last system act taxi inform none none", "system-taxi-inform-phone-1": "last system act taxi inform phone 1", "system-taxi-inform-type-1": "last system act taxi inform type 1", "system-taxi-request-arrive by-?": "last system act taxi request arrive by ?", "system-taxi-request-departure-?": "last system act taxi request departure ?", "system-taxi-request-destination-?": "last system act taxi request destination ?", "system-taxi-request-leave at-?": "last system act taxi request leave at ?", "system-train-book-none-none": "last system act train book none none", "system-train-inform-arrive by-1": "last system act train inform arrive by 1", "system-train-inform-arrive by-2": "last system act train inform arrive by 2", "system-train-inform-book people-1": "last system act train inform book people 1", "system-train-inform-choice-1": "last system act train inform choice 1", "system-train-inform-choice-2": "last system act train inform choice 2", "system-train-inform-day-1": "last system act train inform day 1", "system-train-inform-departure-1": "last system act train inform departure 1", "system-train-inform-destination-1": "last system act train inform destination 1", "system-train-inform-duration-1": "last system act train inform duration 1", "system-train-inform-leave at-1": "last system act train inform leave at 1", "system-train-inform-leave at-2": "last system act train inform leave at 2", "system-train-inform-leave at-3": "last system act train inform leave at 3", "system-train-inform-none-none": "last system act train inform none none", "system-train-inform-price-1": "last system act train inform price 1", "system-train-inform-ref-1": "last system act train inform ref 1", "system-train-inform-train id-1": "last system act train inform train id 1", "system-train-offerbook-arrive by-1": "last system act train offerbook arrive by 1", "system-train-offerbook-destination-1": "last system act train offerbook destination 1", "system-train-offerbook-leave at-1": "last system act train offerbook leave at 1", "system-train-offerbook-none-none": "last system act train offerbook none none", "system-train-offerbook-train id-1": "last system act train offerbook train id 1", "system-train-request-arrive by-?": "last system act train request arrive by ?", "system-train-request-book people-?": "last system act train request book people ?", "system-train-request-day-?": "last system act train request day ?", "system-train-request-departure-?": "last system act train request departure ?", "system-train-request-destination-?": "last system act train request destination ?", "system-train-request-leave at-?": "last system act train request leave at ?", "system-train-select-leave at-1": "last system act train select leave at 1", "system-train-select-none-none": "last system act train select none none", "user-attraction-inform-area-1": "user act attraction inform area 1", "user-attraction-inform-name-1": "user act attraction inform name 1", "user-attraction-inform-none-none": "user act attraction inform none none", "user-attraction-inform-type-1": "user act attraction inform type 1", "user-attraction-request-address-?": "user act attraction request address ?", "user-attraction-request-area-?": "user act attraction request area ?", "user-attraction-request-entrance fee-?": "user act attraction request entrance fee ?", "user-attraction-request-phone-?": "user act attraction request phone ?", "user-attraction-request-postcode-?": "user act attraction request postcode ?", "user-attraction-request-type-?": "user act attraction request type ?", "user-general-bye-none-none": "user act general bye none none", "user-general-greet-none-none": "user act general greet none none", "user-general-thank-none-none": "user act general thank none none", "user-hospital-inform-department-1": "user act hospital inform department 1", "user-hospital-inform-none-none": "user act hospital inform none none", "user-hospital-request-address-?": "user act hospital request address ?", "user-hospital-request-phone-?": "user act hospital request phone ?", "user-hospital-request-postcode-?": "user act hospital request postcode ?", "user-hotel-inform-area-1": "user act hotel inform area 1", "user-hotel-inform-book day-1": "user act hotel inform book day 1", "user-hotel-inform-book people-1": "user act hotel inform book people 1", "user-hotel-inform-book stay-1": "user act hotel inform book stay 1", "user-hotel-inform-internet-1": "user act hotel inform internet 1", "user-hotel-inform-name-1": "user act hotel inform name 1", "user-hotel-inform-none-none": "user act hotel inform none none", "user-hotel-inform-parking-1": "user act hotel inform parking 1", "user-hotel-inform-price range-1": "user act hotel inform price range 1", "user-hotel-inform-stars-1": "user act hotel inform stars 1", "user-hotel-inform-type-1": "user act hotel inform type 1", "user-hotel-request-address-?": "user act hotel request address ?", "user-hotel-request-area-?": "user act hotel request area ?", "user-hotel-request-internet-?": "user act hotel request internet ?", "user-hotel-request-parking-?": "user act hotel request parking ?", "user-hotel-request-phone-?": "user act hotel request phone ?", "user-hotel-request-postcode-?": "user act hotel request postcode ?", "user-hotel-request-price range-?": "user act hotel request price range ?", "user-hotel-request-ref-?": "user act hotel request ref ?", "user-hotel-request-stars-?": "user act hotel request stars ?", "user-hotel-request-type-?": "user act hotel request type ?", "user-police-inform-name-1": "user act police inform name 1", "user-police-inform-none-none": "user act police inform none none", "user-police-request-address-?": "user act police request address ?", "user-police-request-phone-?": "user act police request phone ?", "user-police-request-postcode-?": "user act police request postcode ?", "user-restaurant-inform-area-1": "user act restaurant inform area 1", "user-restaurant-inform-book day-1": "user act restaurant inform book day 1", "user-restaurant-inform-book people-1": "user act restaurant inform book people 1", "user-restaurant-inform-book time-1": "user act restaurant inform book time 1", "user-restaurant-inform-food-1": "user act restaurant inform food 1", "user-restaurant-inform-name-1": "user act restaurant inform name 1", "user-restaurant-inform-none-none": "user act restaurant inform none none", "user-restaurant-inform-price range-1": "user act restaurant inform price range 1", "user-restaurant-request-address-?": "user act restaurant request address ?", "user-restaurant-request-area-?": "user act restaurant request area ?", "user-restaurant-request-food-?": "user act restaurant request food ?", "user-restaurant-request-phone-?": "user act restaurant request phone ?", "user-restaurant-request-postcode-?": "user act restaurant request postcode ?", "user-restaurant-request-price range-?": "user act restaurant request price range ?", "user-restaurant-request-ref-?": "user act restaurant request ref ?", "user-taxi-inform-arrive by-1": "user act taxi inform arrive by 1", "user-taxi-inform-departure-1": "user act taxi inform departure 1", "user-taxi-inform-destination-1": "user act taxi inform destination 1", "user-taxi-inform-leave at-1": "user act taxi inform leave at 1", "user-taxi-inform-none-none": "user act taxi inform none none", "user-taxi-request-phone-?": "user act taxi request phone ?", "user-taxi-request-type-?": "user act taxi request type ?", "user-train-inform-arrive by-1": "user act train inform arrive by 1", "user-train-inform-book people-1": "user act train inform book people 1", "user-train-inform-day-1": "user act train inform day 1", "user-train-inform-departure-1": "user act train inform departure 1", "user-train-inform-destination-1": "user act train inform destination 1", "user-train-inform-leave at-1": "user act train inform leave at 1", "user-train-inform-none-none": "user act train inform none none", "user-train-request-arrive by-?": "user act train request arrive by ?", "user-train-request-duration-?": "user act train request duration ?", "user-train-request-leave at-?": "user act train request leave at ?", "user-train-request-price-?": "user act train request price ?", "user-train-request-ref-?": "user act train request ref ?", "user-train-request-train id-?": "user act train request train id ?"}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/descriptions/semantic_information_descriptions_sgd.json b/convlab/policy/vtrace_DPT/descriptions/semantic_information_descriptions_sgd.json
new file mode 100644
index 0000000000000000000000000000000000000000..2539067dfd07e25e35359218ed4487a6da057910
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/descriptions/semantic_information_descriptions_sgd.json
@@ -0,0 +1 @@
+{"user goal-banks_1-account_type": "user goal banks_1 account_type", "user goal-banks_1-recipient_account_type": "user goal banks_1 recipient_account_type", "user goal-banks_1-balance": "user goal banks_1 balance", "user goal-banks_1-amount": "user goal banks_1 amount", "user goal-banks_1-recipient_account_name": "user goal banks_1 recipient_account_name", "user goal-buses_1-from_location": "user goal buses_1 from_location", "user goal-buses_1-to_location": "user goal buses_1 to_location", "user goal-buses_1-from_station": "user goal buses_1 from_station", "user goal-buses_1-to_station": "user goal buses_1 to_station", "user goal-buses_1-leaving_date": "user goal buses_1 leaving_date", "user goal-buses_1-leaving_time": "user goal buses_1 leaving_time", "user goal-buses_1-fare": "user goal buses_1 fare", "user goal-buses_1-travelers": "user goal buses_1 travelers", "user goal-buses_1-transfers": "user goal buses_1 transfers", "user goal-buses_2-origin": "user goal buses_2 origin", "user goal-buses_2-destination": "user goal buses_2 destination", "user goal-buses_2-origin_station_name": "user goal buses_2 origin_station_name", "user goal-buses_2-destination_station_name": "user goal buses_2 destination_station_name", "user goal-buses_2-departure_date": "user goal buses_2 departure_date", "user goal-buses_2-price": "user goal buses_2 price", "user goal-buses_2-departure_time": "user goal buses_2 departure_time", "user goal-buses_2-group_size": "user goal buses_2 group_size", "user goal-buses_2-fare_type": "user goal buses_2 fare_type", "user goal-calendar_1-event_date": "user goal calendar_1 event_date", "user goal-calendar_1-event_time": "user goal calendar_1 event_time", "user goal-calendar_1-event_location": "user goal calendar_1 event_location", "user goal-calendar_1-event_name": "user goal calendar_1 event_name", "user goal-calendar_1-available_start_time": "user goal calendar_1 available_start_time", "user goal-calendar_1-available_end_time": "user goal calendar_1 available_end_time", "user goal-events_1-category": "user goal events_1 category", "user goal-events_1-subcategory": "user goal events_1 subcategory", "user goal-events_1-event_name": "user goal events_1 event_name", "user goal-events_1-date": "user goal events_1 date", "user goal-events_1-time": "user goal events_1 time", "user goal-events_1-number_of_seats": "user goal events_1 number_of_seats", "user goal-events_1-city_of_event": "user goal events_1 city_of_event", "user goal-events_1-event_location": "user goal events_1 event_location", "user goal-events_1-address_of_location": "user goal events_1 address_of_location", "user goal-events_2-event_type": "user goal events_2 event_type", "user goal-events_2-category": "user goal events_2 category", "user goal-events_2-event_name": "user goal events_2 event_name", "user goal-events_2-date": "user goal events_2 date", "user goal-events_2-time": "user goal events_2 time", "user goal-events_2-number_of_tickets": "user goal events_2 number_of_tickets", "user goal-events_2-city": "user goal events_2 city", "user goal-events_2-venue": "user goal events_2 venue", "user goal-events_2-venue_address": "user goal events_2 venue_address", "user goal-flights_1-passengers": "user goal flights_1 passengers", "user goal-flights_1-seating_class": "user goal flights_1 seating_class", "user goal-flights_1-origin_city": "user goal flights_1 origin_city", "user goal-flights_1-destination_city": "user goal flights_1 destination_city", "user goal-flights_1-origin_airport": "user goal flights_1 origin_airport", "user goal-flights_1-destination_airport": "user goal flights_1 destination_airport", "user goal-flights_1-departure_date": "user goal flights_1 departure_date", "user goal-flights_1-return_date": "user goal flights_1 return_date", "user goal-flights_1-number_stops": "user goal flights_1 number_stops", "user goal-flights_1-outbound_departure_time": "user goal flights_1 outbound_departure_time", "user goal-flights_1-outbound_arrival_time": "user goal flights_1 outbound_arrival_time", "user goal-flights_1-inbound_arrival_time": "user goal flights_1 inbound_arrival_time", "user goal-flights_1-inbound_departure_time": "user goal flights_1 inbound_departure_time", "user goal-flights_1-price": "user goal flights_1 price", "user goal-flights_1-refundable": "user goal flights_1 refundable", "user goal-flights_1-airlines": "user goal flights_1 airlines", "user goal-flights_2-passengers": "user goal flights_2 passengers", "user goal-flights_2-seating_class": "user goal flights_2 seating_class", "user goal-flights_2-origin": "user goal flights_2 origin", "user goal-flights_2-destination": "user goal flights_2 destination", "user goal-flights_2-origin_airport": "user goal flights_2 origin_airport", "user goal-flights_2-destination_airport": "user goal flights_2 destination_airport", "user goal-flights_2-departure_date": "user goal flights_2 departure_date", "user goal-flights_2-return_date": "user goal flights_2 return_date", "user goal-flights_2-number_stops": "user goal flights_2 number_stops", "user goal-flights_2-outbound_departure_time": "user goal flights_2 outbound_departure_time", "user goal-flights_2-outbound_arrival_time": "user goal flights_2 outbound_arrival_time", "user goal-flights_2-inbound_arrival_time": "user goal flights_2 inbound_arrival_time", "user goal-flights_2-inbound_departure_time": "user goal flights_2 inbound_departure_time", "user goal-flights_2-fare": "user goal flights_2 fare", "user goal-flights_2-is_redeye": "user goal flights_2 is_redeye", "user goal-flights_2-airlines": "user goal flights_2 airlines", "user goal-homes_1-area": "user goal homes_1 area", "user goal-homes_1-address": "user goal homes_1 address", "user goal-homes_1-property_name": "user goal homes_1 property_name", "user goal-homes_1-phone_number": "user goal homes_1 phone_number", "user goal-homes_1-furnished": "user goal homes_1 furnished", "user goal-homes_1-pets_allowed": "user goal homes_1 pets_allowed", "user goal-homes_1-rent": "user goal homes_1 rent", "user goal-homes_1-visit_date": "user goal homes_1 visit_date", "user goal-homes_1-number_of_beds": "user goal homes_1 number_of_beds", "user goal-homes_1-number_of_baths": "user goal homes_1 number_of_baths", "user goal-hotels_1-destination": "user goal hotels_1 destination", "user goal-hotels_1-number_of_rooms": "user goal hotels_1 number_of_rooms", "user goal-hotels_1-check_in_date": "user goal hotels_1 check_in_date", "user goal-hotels_1-number_of_days": "user goal hotels_1 number_of_days", "user goal-hotels_1-star_rating": "user goal hotels_1 star_rating", "user goal-hotels_1-hotel_name": "user goal hotels_1 hotel_name", "user goal-hotels_1-street_address": "user goal hotels_1 street_address", "user goal-hotels_1-phone_number": "user goal hotels_1 phone_number", "user goal-hotels_1-price_per_night": "user goal hotels_1 price_per_night", "user goal-hotels_1-has_wifi": "user goal hotels_1 has_wifi", "user goal-hotels_2-where_to": "user goal hotels_2 where_to", "user goal-hotels_2-number_of_adults": "user goal hotels_2 number_of_adults", "user goal-hotels_2-check_in_date": "user goal hotels_2 check_in_date", "user goal-hotels_2-check_out_date": "user goal hotels_2 check_out_date", "user goal-hotels_2-rating": "user goal hotels_2 rating", "user goal-hotels_2-address": "user goal hotels_2 address", "user goal-hotels_2-phone_number": "user goal hotels_2 phone_number", "user goal-hotels_2-total_price": "user goal hotels_2 total_price", "user goal-hotels_2-has_laundry_service": "user goal hotels_2 has_laundry_service", "user goal-hotels_3-location": "user goal hotels_3 location", "user goal-hotels_3-number_of_rooms": "user goal hotels_3 number_of_rooms", "user goal-hotels_3-check_in_date": "user goal hotels_3 check_in_date", "user goal-hotels_3-check_out_date": "user goal hotels_3 check_out_date", "user goal-hotels_3-average_rating": "user goal hotels_3 average_rating", "user goal-hotels_3-hotel_name": "user goal hotels_3 hotel_name", "user goal-hotels_3-street_address": "user goal hotels_3 street_address", "user goal-hotels_3-phone_number": "user goal hotels_3 phone_number", "user goal-hotels_3-price": "user goal hotels_3 price", "user goal-hotels_3-pets_welcome": "user goal hotels_3 pets_welcome", "user goal-media_1-title": "user goal media_1 title", "user goal-media_1-genre": "user goal media_1 genre", "user goal-media_1-subtitles": "user goal media_1 subtitles", "user goal-media_1-directed_by": "user goal media_1 directed_by", "user goal-movies_1-price": "user goal movies_1 price", "user goal-movies_1-number_of_tickets": "user goal movies_1 number_of_tickets", "user goal-movies_1-show_type": "user goal movies_1 show_type", "user goal-movies_1-theater_name": "user goal movies_1 theater_name", "user goal-movies_1-show_time": "user goal movies_1 show_time", "user goal-movies_1-show_date": "user goal movies_1 show_date", "user goal-movies_1-genre": "user goal movies_1 genre", "user goal-movies_1-street_address": "user goal movies_1 street_address", "user goal-movies_1-location": "user goal movies_1 location", "user goal-movies_1-movie_name": "user goal movies_1 movie_name", "user goal-music_1-song_name": "user goal music_1 song_name", "user goal-music_1-artist": "user goal music_1 artist", "user goal-music_1-album": "user goal music_1 album", "user goal-music_1-genre": "user goal music_1 genre", "user goal-music_1-year": "user goal music_1 year", "user goal-music_1-playback_device": "user goal music_1 playback_device", "user goal-music_2-song_name": "user goal music_2 song_name", "user goal-music_2-artist": "user goal music_2 artist", "user goal-music_2-album": "user goal music_2 album", "user goal-music_2-genre": "user goal music_2 genre", "user goal-music_2-playback_device": "user goal music_2 playback_device", "user goal-rentalcars_1-type": "user goal rentalcars_1 type", "user goal-rentalcars_1-car_name": "user goal rentalcars_1 car_name", "user goal-rentalcars_1-pickup_location": "user goal rentalcars_1 pickup_location", "user goal-rentalcars_1-pickup_date": "user goal rentalcars_1 pickup_date", "user goal-rentalcars_1-pickup_time": "user goal rentalcars_1 pickup_time", "user goal-rentalcars_1-pickup_city": "user goal rentalcars_1 pickup_city", "user goal-rentalcars_1-dropoff_date": "user goal rentalcars_1 dropoff_date", "user goal-rentalcars_1-total_price": "user goal rentalcars_1 total_price", "user goal-rentalcars_2-car_type": "user goal rentalcars_2 car_type", "user goal-rentalcars_2-car_name": "user goal rentalcars_2 car_name", "user goal-rentalcars_2-pickup_location": "user goal rentalcars_2 pickup_location", "user goal-rentalcars_2-pickup_date": "user goal rentalcars_2 pickup_date", "user goal-rentalcars_2-pickup_time": "user goal rentalcars_2 pickup_time", "user goal-rentalcars_2-pickup_city": "user goal rentalcars_2 pickup_city", "user goal-rentalcars_2-dropoff_date": "user goal rentalcars_2 dropoff_date", "user goal-rentalcars_2-total_price": "user goal rentalcars_2 total_price", "user goal-restaurants_1-restaurant_name": "user goal restaurants_1 restaurant_name", "user goal-restaurants_1-date": "user goal restaurants_1 date", "user goal-restaurants_1-time": "user goal restaurants_1 time", "user goal-restaurants_1-serves_alcohol": "user goal restaurants_1 serves_alcohol", "user goal-restaurants_1-has_live_music": "user goal restaurants_1 has_live_music", "user goal-restaurants_1-phone_number": "user goal restaurants_1 phone_number", "user goal-restaurants_1-street_address": "user goal restaurants_1 street_address", "user goal-restaurants_1-party_size": "user goal restaurants_1 party_size", "user goal-restaurants_1-price_range": "user goal restaurants_1 price_range", "user goal-restaurants_1-city": "user goal restaurants_1 city", "user goal-restaurants_1-cuisine": "user goal restaurants_1 cuisine", "user goal-ridesharing_1-destination": "user goal ridesharing_1 destination", "user goal-ridesharing_1-shared_ride": "user goal ridesharing_1 shared_ride", "user goal-ridesharing_1-ride_fare": "user goal ridesharing_1 ride_fare", "user goal-ridesharing_1-approximate_ride_duration": "user goal ridesharing_1 approximate_ride_duration", "user goal-ridesharing_1-number_of_riders": "user goal ridesharing_1 number_of_riders", "user goal-ridesharing_2-destination": "user goal ridesharing_2 destination", "user goal-ridesharing_2-ride_type": "user goal ridesharing_2 ride_type", "user goal-ridesharing_2-ride_fare": "user goal ridesharing_2 ride_fare", "user goal-ridesharing_2-wait_time": "user goal ridesharing_2 wait_time", "user goal-ridesharing_2-number_of_seats": "user goal ridesharing_2 number_of_seats", "user goal-services_1-stylist_name": "user goal services_1 stylist_name", "user goal-services_1-phone_number": "user goal services_1 phone_number", "user goal-services_1-average_rating": "user goal services_1 average_rating", "user goal-services_1-is_unisex": "user goal services_1 is_unisex", "user goal-services_1-street_address": "user goal services_1 street_address", "user goal-services_1-city": "user goal services_1 city", "user goal-services_1-appointment_date": "user goal services_1 appointment_date", "user goal-services_1-appointment_time": "user goal services_1 appointment_time", "user goal-services_2-dentist_name": "user goal services_2 dentist_name", "user goal-services_2-phone_number": "user goal services_2 phone_number", "user goal-services_2-address": "user goal services_2 address", "user goal-services_2-city": "user goal services_2 city", "user goal-services_2-appointment_date": "user goal services_2 appointment_date", "user goal-services_2-appointment_time": "user goal services_2 appointment_time", "user goal-services_2-offers_cosmetic_services": "user goal services_2 offers_cosmetic_services", "user goal-services_3-doctor_name": "user goal services_3 doctor_name", "user goal-services_3-phone_number": "user goal services_3 phone_number", "user goal-services_3-average_rating": "user goal services_3 average_rating", "user goal-services_3-street_address": "user goal services_3 street_address", "user goal-services_3-city": "user goal services_3 city", "user goal-services_3-appointment_date": "user goal services_3 appointment_date", "user goal-services_3-appointment_time": "user goal services_3 appointment_time", "user goal-services_3-type": "user goal services_3 type", "user goal-travel_1-location": "user goal travel_1 location", "user goal-travel_1-attraction_name": "user goal travel_1 attraction_name", "user goal-travel_1-category": "user goal travel_1 category", "user goal-travel_1-phone_number": "user goal travel_1 phone_number", "user goal-travel_1-free_entry": "user goal travel_1 free_entry", "user goal-travel_1-good_for_kids": "user goal travel_1 good_for_kids", "user goal-weather_1-precipitation": "user goal weather_1 precipitation", "user goal-weather_1-humidity": "user goal weather_1 humidity", "user goal-weather_1-wind": "user goal weather_1 wind", "user goal-weather_1-temperature": "user goal weather_1 temperature", "user goal-weather_1-city": "user goal weather_1 city", "user goal-weather_1-date": "user goal weather_1 date", "user goal-alarm_1-alarm_time": "user goal alarm_1 alarm_time", "user goal-alarm_1-alarm_name": "user goal alarm_1 alarm_name", "user goal-alarm_1-new_alarm_time": "user goal alarm_1 new_alarm_time", "user goal-alarm_1-new_alarm_name": "user goal alarm_1 new_alarm_name", "user goal-banks_2-account_type": "user goal banks_2 account_type", "user goal-banks_2-recipient_account_type": "user goal banks_2 recipient_account_type", "user goal-banks_2-account_balance": "user goal banks_2 account_balance", "user goal-banks_2-transfer_amount": "user goal banks_2 transfer_amount", "user goal-banks_2-recipient_name": "user goal banks_2 recipient_name", "user goal-banks_2-transfer_time": "user goal banks_2 transfer_time", "user goal-flights_3-passengers": "user goal flights_3 passengers", "user goal-flights_3-flight_class": "user goal flights_3 flight_class", "user goal-flights_3-origin_city": "user goal flights_3 origin_city", "user goal-flights_3-destination_city": "user goal flights_3 destination_city", "user goal-flights_3-origin_airport_name": "user goal flights_3 origin_airport_name", "user goal-flights_3-destination_airport_name": "user goal flights_3 destination_airport_name", "user goal-flights_3-departure_date": "user goal flights_3 departure_date", "user goal-flights_3-return_date": "user goal flights_3 return_date", "user goal-flights_3-number_stops": "user goal flights_3 number_stops", "user goal-flights_3-outbound_departure_time": "user goal flights_3 outbound_departure_time", "user goal-flights_3-outbound_arrival_time": "user goal flights_3 outbound_arrival_time", "user goal-flights_3-inbound_arrival_time": "user goal flights_3 inbound_arrival_time", "user goal-flights_3-inbound_departure_time": "user goal flights_3 inbound_departure_time", "user goal-flights_3-price": "user goal flights_3 price", "user goal-flights_3-number_checked_bags": "user goal flights_3 number_checked_bags", "user goal-flights_3-airlines": "user goal flights_3 airlines", "user goal-flights_3-arrives_next_day": "user goal flights_3 arrives_next_day", "user goal-hotels_4-location": "user goal hotels_4 location", "user goal-hotels_4-number_of_rooms": "user goal hotels_4 number_of_rooms", "user goal-hotels_4-check_in_date": "user goal hotels_4 check_in_date", "user goal-hotels_4-stay_length": "user goal hotels_4 stay_length", "user goal-hotels_4-star_rating": "user goal hotels_4 star_rating", "user goal-hotels_4-place_name": "user goal hotels_4 place_name", "user goal-hotels_4-street_address": "user goal hotels_4 street_address", "user goal-hotels_4-phone_number": "user goal hotels_4 phone_number", "user goal-hotels_4-price_per_night": "user goal hotels_4 price_per_night", "user goal-hotels_4-smoking_allowed": "user goal hotels_4 smoking_allowed", "user goal-media_2-movie_name": "user goal media_2 movie_name", "user goal-media_2-genre": "user goal media_2 genre", "user goal-media_2-subtitle_language": "user goal media_2 subtitle_language", "user goal-media_2-director": "user goal media_2 director", "user goal-media_2-actors": "user goal media_2 actors", "user goal-media_2-price": "user goal media_2 price", "user goal-movies_2-title": "user goal movies_2 title", "user goal-movies_2-genre": "user goal movies_2 genre", "user goal-movies_2-aggregate_rating": "user goal movies_2 aggregate_rating", "user goal-movies_2-starring": "user goal movies_2 starring", "user goal-movies_2-director": "user goal movies_2 director", "user goal-restaurants_2-restaurant_name": "user goal restaurants_2 restaurant_name", "user goal-restaurants_2-date": "user goal restaurants_2 date", "user goal-restaurants_2-time": "user goal restaurants_2 time", "user goal-restaurants_2-has_seating_outdoors": "user goal restaurants_2 has_seating_outdoors", "user goal-restaurants_2-has_vegetarian_options": "user goal restaurants_2 has_vegetarian_options", "user goal-restaurants_2-phone_number": "user goal restaurants_2 phone_number", "user goal-restaurants_2-rating": "user goal restaurants_2 rating", "user goal-restaurants_2-address": "user goal restaurants_2 address", "user goal-restaurants_2-number_of_seats": "user goal restaurants_2 number_of_seats", "user goal-restaurants_2-price_range": "user goal restaurants_2 price_range", "user goal-restaurants_2-location": "user goal restaurants_2 location", "user goal-restaurants_2-category": "user goal restaurants_2 category", "user goal-services_4-therapist_name": "user goal services_4 therapist_name", "user goal-services_4-phone_number": "user goal services_4 phone_number", "user goal-services_4-address": "user goal services_4 address", "user goal-services_4-city": "user goal services_4 city", "user goal-services_4-appointment_date": "user goal services_4 appointment_date", "user goal-services_4-appointment_time": "user goal services_4 appointment_time", "user goal-services_4-type": "user goal services_4 type", "user goal-buses_3-from_city": "user goal buses_3 from_city", "user goal-buses_3-to_city": "user goal buses_3 to_city", "user goal-buses_3-from_station": "user goal buses_3 from_station", "user goal-buses_3-to_station": "user goal buses_3 to_station", "user goal-buses_3-departure_date": "user goal buses_3 departure_date", "user goal-buses_3-departure_time": "user goal buses_3 departure_time", "user goal-buses_3-price": "user goal buses_3 price", "user goal-buses_3-additional_luggage": "user goal buses_3 additional_luggage", "user goal-buses_3-num_passengers": "user goal buses_3 num_passengers", "user goal-buses_3-category": "user goal buses_3 category", "user goal-events_3-event_type": "user goal events_3 event_type", "user goal-events_3-event_name": "user goal events_3 event_name", "user goal-events_3-date": "user goal events_3 date", "user goal-events_3-time": "user goal events_3 time", "user goal-events_3-number_of_tickets": "user goal events_3 number_of_tickets", "user goal-events_3-price_per_ticket": "user goal events_3 price_per_ticket", "user goal-events_3-city": "user goal events_3 city", "user goal-events_3-venue": "user goal events_3 venue", "user goal-events_3-venue_address": "user goal events_3 venue_address", "user goal-flights_4-number_of_tickets": "user goal flights_4 number_of_tickets", "user goal-flights_4-seating_class": "user goal flights_4 seating_class", "user goal-flights_4-origin_airport": "user goal flights_4 origin_airport", "user goal-flights_4-destination_airport": "user goal flights_4 destination_airport", "user goal-flights_4-departure_date": "user goal flights_4 departure_date", "user goal-flights_4-return_date": "user goal flights_4 return_date", "user goal-flights_4-is_nonstop": "user goal flights_4 is_nonstop", "user goal-flights_4-outbound_departure_time": "user goal flights_4 outbound_departure_time", "user goal-flights_4-outbound_arrival_time": "user goal flights_4 outbound_arrival_time", "user goal-flights_4-inbound_arrival_time": "user goal flights_4 inbound_arrival_time", "user goal-flights_4-inbound_departure_time": "user goal flights_4 inbound_departure_time", "user goal-flights_4-price": "user goal flights_4 price", "user goal-flights_4-airlines": "user goal flights_4 airlines", "user goal-homes_2-intent": "user goal homes_2 intent", "user goal-homes_2-area": "user goal homes_2 area", "user goal-homes_2-address": "user goal homes_2 address", "user goal-homes_2-property_name": "user goal homes_2 property_name", "user goal-homes_2-phone_number": "user goal homes_2 phone_number", "user goal-homes_2-has_garage": "user goal homes_2 has_garage", "user goal-homes_2-in_unit_laundry": "user goal homes_2 in_unit_laundry", "user goal-homes_2-price": "user goal homes_2 price", "user goal-homes_2-visit_date": "user goal homes_2 visit_date", "user goal-homes_2-number_of_beds": "user goal homes_2 number_of_beds", "user goal-homes_2-number_of_baths": "user goal homes_2 number_of_baths", "user goal-media_3-title": "user goal media_3 title", "user goal-media_3-genre": "user goal media_3 genre", "user goal-media_3-subtitle_language": "user goal media_3 subtitle_language", "user goal-media_3-starring": "user goal media_3 starring", "user goal-messaging_1-location": "user goal messaging_1 location", "user goal-messaging_1-contact_name": "user goal messaging_1 contact_name", "user goal-movies_3-movie_title": "user goal movies_3 movie_title", "user goal-movies_3-genre": "user goal movies_3 genre", "user goal-movies_3-percent_rating": "user goal movies_3 percent_rating", "user goal-movies_3-cast": "user goal movies_3 cast", "user goal-movies_3-directed_by": "user goal movies_3 directed_by", "user goal-music_3-track": "user goal music_3 track", "user goal-music_3-artist": "user goal music_3 artist", "user goal-music_3-album": "user goal music_3 album", "user goal-music_3-genre": "user goal music_3 genre", "user goal-music_3-year": "user goal music_3 year", "user goal-music_3-device": "user goal music_3 device", "user goal-payment_1-payment_method": "user goal payment_1 payment_method", "user goal-payment_1-amount": "user goal payment_1 amount", "user goal-payment_1-receiver": "user goal payment_1 receiver", "user goal-payment_1-private_visibility": "user goal payment_1 private_visibility", "user goal-rentalcars_3-car_type": "user goal rentalcars_3 car_type", "user goal-rentalcars_3-car_name": "user goal rentalcars_3 car_name", "user goal-rentalcars_3-pickup_location": "user goal rentalcars_3 pickup_location", "user goal-rentalcars_3-start_date": "user goal rentalcars_3 start_date", "user goal-rentalcars_3-pickup_time": "user goal rentalcars_3 pickup_time", "user goal-rentalcars_3-city": "user goal rentalcars_3 city", "user goal-rentalcars_3-end_date": "user goal rentalcars_3 end_date", "user goal-rentalcars_3-price_per_day": "user goal rentalcars_3 price_per_day", "user goal-rentalcars_3-add_insurance": "user goal rentalcars_3 add_insurance", "user goal-trains_1-from": "user goal trains_1 from", "user goal-trains_1-to": "user goal trains_1 to", "user goal-trains_1-from_station": "user goal trains_1 from_station", "user goal-trains_1-to_station": "user goal trains_1 to_station", "user goal-trains_1-date_of_journey": "user goal trains_1 date_of_journey", "user goal-trains_1-journey_start_time": "user goal trains_1 journey_start_time", "user goal-trains_1-total": "user goal trains_1 total", "user goal-trains_1-number_of_adults": "user goal trains_1 number_of_adults", "user goal-trains_1-class": "user goal trains_1 class", "user goal-trains_1-trip_protection": "user goal trains_1 trip_protection", "general-banks_1": "domain banks_1", "general-buses_1": "domain buses_1", "general-buses_2": "domain buses_2", "general-calendar_1": "domain calendar_1", "general-events_1": "domain events_1", "general-events_2": "domain events_2", "general-flights_1": "domain flights_1", "general-flights_2": "domain flights_2", "general-homes_1": "domain homes_1", "general-hotels_1": "domain hotels_1", "general-hotels_2": "domain hotels_2", "general-hotels_3": "domain hotels_3", "general-media_1": "domain media_1", "general-movies_1": "domain movies_1", "general-music_1": "domain music_1", "general-music_2": "domain music_2", "general-rentalcars_1": "domain rentalcars_1", "general-rentalcars_2": "domain rentalcars_2", "general-restaurants_1": "domain restaurants_1", "general-ridesharing_1": "domain ridesharing_1", "general-ridesharing_2": "domain ridesharing_2", "general-services_1": "domain services_1", "general-services_2": "domain services_2", "general-services_3": "domain services_3", "general-travel_1": "domain travel_1", "general-weather_1": "domain weather_1", "general-alarm_1": "domain alarm_1", "general-banks_2": "domain banks_2", "general-flights_3": "domain flights_3", "general-hotels_4": "domain hotels_4", "general-media_2": "domain media_2", "general-movies_2": "domain movies_2", "general-restaurants_2": "domain restaurants_2", "general-services_4": "domain services_4", "general-buses_3": "domain buses_3", "general-events_3": "domain events_3", "general-flights_4": "domain flights_4", "general-homes_2": "domain homes_2", "general-media_3": "domain media_3", "general-messaging_1": "domain messaging_1", "general-movies_3": "domain movies_3", "general-music_3": "domain music_3", "general-payment_1": "domain payment_1", "general-rentalcars_3": "domain rentalcars_3", "general-trains_1": "domain trains_1", "system--goodbye-none-none": "last system act  goodbye none none", "system--req_more-none-none": "last system act  req_more none none", "system-alarm_1-confirm-new_alarm_name-1": "last system act alarm_1 confirm new_alarm_name 1", "system-alarm_1-confirm-new_alarm_time-1": "last system act alarm_1 confirm new_alarm_time 1", "system-alarm_1-inform_count-count-1": "last system act alarm_1 inform_count count 1", "system-alarm_1-notify_success-none-none": "last system act alarm_1 notify_success none none", "system-alarm_1-offer-alarm_name-1": "last system act alarm_1 offer alarm_name 1", "system-alarm_1-offer-alarm_time-1": "last system act alarm_1 offer alarm_time 1", "system-alarm_1-offer_intent-addalarm-1": "last system act alarm_1 offer_intent AddAlarm 1", "system-alarm_1-request-new_alarm_time-?": "last system act alarm_1 request new_alarm_time ?", "system-banks_1-confirm-account_type-1": "last system act banks_1 confirm account_type 1", "system-banks_1-confirm-amount-1": "last system act banks_1 confirm amount 1", "system-banks_1-confirm-recipient_account_name-1": "last system act banks_1 confirm recipient_account_name 1", "system-banks_1-confirm-recipient_account_type-1": "last system act banks_1 confirm recipient_account_type 1", "system-banks_1-notify_success-none-none": "last system act banks_1 notify_success none none", "system-banks_1-offer-account_type-1": "last system act banks_1 offer account_type 1", "system-banks_1-offer-balance-1": "last system act banks_1 offer balance 1", "system-banks_1-offer_intent-transfermoney-1": "last system act banks_1 offer_intent TransferMoney 1", "system-banks_1-request-account_type-?": "last system act banks_1 request account_type ?", "system-banks_1-request-amount-?": "last system act banks_1 request amount ?", "system-banks_1-request-recipient_account_name-?": "last system act banks_1 request recipient_account_name ?", "system-banks_2-confirm-account_type-1": "last system act banks_2 confirm account_type 1", "system-banks_2-confirm-recipient_account_type-1": "last system act banks_2 confirm recipient_account_type 1", "system-banks_2-confirm-recipient_name-1": "last system act banks_2 confirm recipient_name 1", "system-banks_2-confirm-transfer_amount-1": "last system act banks_2 confirm transfer_amount 1", "system-banks_2-inform-transfer_time-1": "last system act banks_2 inform transfer_time 1", "system-banks_2-notify_success-none-none": "last system act banks_2 notify_success none none", "system-banks_2-offer-account_balance-1": "last system act banks_2 offer account_balance 1", "system-banks_2-offer-account_type-1": "last system act banks_2 offer account_type 1", "system-banks_2-offer_intent-transfermoney-1": "last system act banks_2 offer_intent TransferMoney 1", "system-banks_2-request-account_type-?": "last system act banks_2 request account_type ?", "system-banks_2-request-recipient_name-?": "last system act banks_2 request recipient_name ?", "system-banks_2-request-transfer_amount-?": "last system act banks_2 request transfer_amount ?", "system-buses_1-confirm-from_location-1": "last system act buses_1 confirm from_location 1", "system-buses_1-confirm-leaving_date-1": "last system act buses_1 confirm leaving_date 1", "system-buses_1-confirm-leaving_time-1": "last system act buses_1 confirm leaving_time 1", "system-buses_1-confirm-to_location-1": "last system act buses_1 confirm to_location 1", "system-buses_1-confirm-travelers-1": "last system act buses_1 confirm travelers 1", "system-buses_1-inform-from_station-1": "last system act buses_1 inform from_station 1", "system-buses_1-inform-to_station-1": "last system act buses_1 inform to_station 1", "system-buses_1-inform-transfers-1": "last system act buses_1 inform transfers 1", "system-buses_1-inform_count-count-1": "last system act buses_1 inform_count count 1", "system-buses_1-notify_failure-none-none": "last system act buses_1 notify_failure none none", "system-buses_1-notify_success-none-none": "last system act buses_1 notify_success none none", "system-buses_1-offer-fare-1": "last system act buses_1 offer fare 1", "system-buses_1-offer-leaving_time-1": "last system act buses_1 offer leaving_time 1", "system-buses_1-offer-transfers-1": "last system act buses_1 offer transfers 1", "system-buses_1-offer_intent-buybusticket-1": "last system act buses_1 offer_intent BuyBusTicket 1", "system-buses_1-request-from_location-?": "last system act buses_1 request from_location ?", "system-buses_1-request-leaving_date-?": "last system act buses_1 request leaving_date ?", "system-buses_1-request-leaving_time-?": "last system act buses_1 request leaving_time ?", "system-buses_1-request-to_location-?": "last system act buses_1 request to_location ?", "system-buses_1-request-travelers-?": "last system act buses_1 request travelers ?", "system-buses_2-confirm-departure_date-1": "last system act buses_2 confirm departure_date 1", "system-buses_2-confirm-departure_time-1": "last system act buses_2 confirm departure_time 1", "system-buses_2-confirm-destination-1": "last system act buses_2 confirm destination 1", "system-buses_2-confirm-fare_type-1": "last system act buses_2 confirm fare_type 1", "system-buses_2-confirm-group_size-1": "last system act buses_2 confirm group_size 1", "system-buses_2-confirm-origin-1": "last system act buses_2 confirm origin 1", "system-buses_2-inform-destination_station_name-1": "last system act buses_2 inform destination_station_name 1", "system-buses_2-inform-origin_station_name-1": "last system act buses_2 inform origin_station_name 1", "system-buses_2-inform_count-count-1": "last system act buses_2 inform_count count 1", "system-buses_2-notify_failure-none-none": "last system act buses_2 notify_failure none none", "system-buses_2-notify_success-none-none": "last system act buses_2 notify_success none none", "system-buses_2-offer-departure_time-1": "last system act buses_2 offer departure_time 1", "system-buses_2-offer-fare_type-1": "last system act buses_2 offer fare_type 1", "system-buses_2-offer-price-1": "last system act buses_2 offer price 1", "system-buses_2-offer_intent-buybusticket-1": "last system act buses_2 offer_intent BuyBusTicket 1", "system-buses_2-request-departure_date-?": "last system act buses_2 request departure_date ?", "system-buses_2-request-departure_time-?": "last system act buses_2 request departure_time ?", "system-buses_2-request-destination-?": "last system act buses_2 request destination ?", "system-buses_2-request-group_size-?": "last system act buses_2 request group_size ?", "system-buses_2-request-origin-?": "last system act buses_2 request origin ?", "system-buses_3-confirm-additional_luggage-1": "last system act buses_3 confirm additional_luggage 1", "system-buses_3-confirm-departure_date-1": "last system act buses_3 confirm departure_date 1", "system-buses_3-confirm-departure_time-1": "last system act buses_3 confirm departure_time 1", "system-buses_3-confirm-from_city-1": "last system act buses_3 confirm from_city 1", "system-buses_3-confirm-num_passengers-1": "last system act buses_3 confirm num_passengers 1", "system-buses_3-confirm-to_city-1": "last system act buses_3 confirm to_city 1", "system-buses_3-inform-category-1": "last system act buses_3 inform category 1", "system-buses_3-inform-from_station-1": "last system act buses_3 inform from_station 1", "system-buses_3-inform-to_station-1": "last system act buses_3 inform to_station 1", "system-buses_3-inform_count-count-1": "last system act buses_3 inform_count count 1", "system-buses_3-notify_failure-none-none": "last system act buses_3 notify_failure none none", "system-buses_3-notify_success-none-none": "last system act buses_3 notify_success none none", "system-buses_3-offer-departure_time-1": "last system act buses_3 offer departure_time 1", "system-buses_3-offer-price-1": "last system act buses_3 offer price 1", "system-buses_3-offer_intent-buybusticket-1": "last system act buses_3 offer_intent BuyBusTicket 1", "system-buses_3-request-departure_date-?": "last system act buses_3 request departure_date ?", "system-buses_3-request-departure_time-?": "last system act buses_3 request departure_time ?", "system-buses_3-request-from_city-?": "last system act buses_3 request from_city ?", "system-buses_3-request-num_passengers-?": "last system act buses_3 request num_passengers ?", "system-buses_3-request-to_city-?": "last system act buses_3 request to_city ?", "system-calendar_1-confirm-event_date-1": "last system act calendar_1 confirm event_date 1", "system-calendar_1-confirm-event_location-1": "last system act calendar_1 confirm event_location 1", "system-calendar_1-confirm-event_name-1": "last system act calendar_1 confirm event_name 1", "system-calendar_1-confirm-event_time-1": "last system act calendar_1 confirm event_time 1", "system-calendar_1-inform_count-count-1": "last system act calendar_1 inform_count count 1", "system-calendar_1-notify_success-none-none": "last system act calendar_1 notify_success none none", "system-calendar_1-offer-available_end_time-1": "last system act calendar_1 offer available_end_time 1", "system-calendar_1-offer-available_start_time-1": "last system act calendar_1 offer available_start_time 1", "system-calendar_1-offer-event_date-1": "last system act calendar_1 offer event_date 1", "system-calendar_1-offer-event_name-1": "last system act calendar_1 offer event_name 1", "system-calendar_1-offer-event_time-1": "last system act calendar_1 offer event_time 1", "system-calendar_1-offer_intent-addevent-1": "last system act calendar_1 offer_intent AddEvent 1", "system-calendar_1-request-event_date-?": "last system act calendar_1 request event_date ?", "system-calendar_1-request-event_location-?": "last system act calendar_1 request event_location ?", "system-calendar_1-request-event_name-?": "last system act calendar_1 request event_name ?", "system-calendar_1-request-event_time-?": "last system act calendar_1 request event_time ?", "system-events_1-confirm-city_of_event-1": "last system act events_1 confirm city_of_event 1", "system-events_1-confirm-date-1": "last system act events_1 confirm date 1", "system-events_1-confirm-event_name-1": "last system act events_1 confirm event_name 1", "system-events_1-confirm-number_of_seats-1": "last system act events_1 confirm number_of_seats 1", "system-events_1-inform-address_of_location-1": "last system act events_1 inform address_of_location 1", "system-events_1-inform-event_location-1": "last system act events_1 inform event_location 1", "system-events_1-inform-subcategory-1": "last system act events_1 inform subcategory 1", "system-events_1-inform-time-1": "last system act events_1 inform time 1", "system-events_1-inform_count-count-1": "last system act events_1 inform_count count 1", "system-events_1-notify_failure-none-none": "last system act events_1 notify_failure none none", "system-events_1-notify_success-none-none": "last system act events_1 notify_success none none", "system-events_1-offer-date-1": "last system act events_1 offer date 1", "system-events_1-offer-event_location-1": "last system act events_1 offer event_location 1", "system-events_1-offer-event_name-1": "last system act events_1 offer event_name 1", "system-events_1-offer-time-1": "last system act events_1 offer time 1", "system-events_1-offer_intent-buyeventtickets-1": "last system act events_1 offer_intent BuyEventTickets 1", "system-events_1-request-category-?": "last system act events_1 request category ?", "system-events_1-request-city_of_event-?": "last system act events_1 request city_of_event ?", "system-events_1-request-date-?": "last system act events_1 request date ?", "system-events_1-request-event_name-?": "last system act events_1 request event_name ?", "system-events_1-request-number_of_seats-?": "last system act events_1 request number_of_seats ?", "system-events_2-confirm-city-1": "last system act events_2 confirm city 1", "system-events_2-confirm-date-1": "last system act events_2 confirm date 1", "system-events_2-confirm-event_name-1": "last system act events_2 confirm event_name 1", "system-events_2-confirm-number_of_tickets-1": "last system act events_2 confirm number_of_tickets 1", "system-events_2-inform-category-1": "last system act events_2 inform category 1", "system-events_2-inform-time-1": "last system act events_2 inform time 1", "system-events_2-inform-venue-1": "last system act events_2 inform venue 1", "system-events_2-inform-venue_address-1": "last system act events_2 inform venue_address 1", "system-events_2-inform_count-count-1": "last system act events_2 inform_count count 1", "system-events_2-notify_success-none-none": "last system act events_2 notify_success none none", "system-events_2-offer-date-1": "last system act events_2 offer date 1", "system-events_2-offer-event_name-1": "last system act events_2 offer event_name 1", "system-events_2-offer-venue-1": "last system act events_2 offer venue 1", "system-events_2-offer_intent-buyeventtickets-1": "last system act events_2 offer_intent BuyEventTickets 1", "system-events_2-request-city-?": "last system act events_2 request city ?", "system-events_2-request-date-?": "last system act events_2 request date ?", "system-events_2-request-event_name-?": "last system act events_2 request event_name ?", "system-events_2-request-event_type-?": "last system act events_2 request event_type ?", "system-events_2-request-number_of_tickets-?": "last system act events_2 request number_of_tickets ?", "system-events_3-confirm-city-1": "last system act events_3 confirm city 1", "system-events_3-confirm-date-1": "last system act events_3 confirm date 1", "system-events_3-confirm-event_name-1": "last system act events_3 confirm event_name 1", "system-events_3-confirm-number_of_tickets-1": "last system act events_3 confirm number_of_tickets 1", "system-events_3-inform-price_per_ticket-1": "last system act events_3 inform price_per_ticket 1", "system-events_3-inform-venue_address-1": "last system act events_3 inform venue_address 1", "system-events_3-inform_count-count-1": "last system act events_3 inform_count count 1", "system-events_3-notify_success-none-none": "last system act events_3 notify_success none none", "system-events_3-offer-date-1": "last system act events_3 offer date 1", "system-events_3-offer-event_name-1": "last system act events_3 offer event_name 1", "system-events_3-offer-time-1": "last system act events_3 offer time 1", "system-events_3-offer-venue-1": "last system act events_3 offer venue 1", "system-events_3-offer_intent-buyeventtickets-1": "last system act events_3 offer_intent BuyEventTickets 1", "system-events_3-request-city-?": "last system act events_3 request city ?", "system-events_3-request-date-?": "last system act events_3 request date ?", "system-events_3-request-event_name-?": "last system act events_3 request event_name ?", "system-events_3-request-event_type-?": "last system act events_3 request event_type ?", "system-events_3-request-number_of_tickets-?": "last system act events_3 request number_of_tickets ?", "system-flights_1-confirm-airlines-1": "last system act flights_1 confirm airlines 1", "system-flights_1-confirm-departure_date-1": "last system act flights_1 confirm departure_date 1", "system-flights_1-confirm-destination_city-1": "last system act flights_1 confirm destination_city 1", "system-flights_1-confirm-inbound_departure_time-1": "last system act flights_1 confirm inbound_departure_time 1", "system-flights_1-confirm-origin_city-1": "last system act flights_1 confirm origin_city 1", "system-flights_1-confirm-outbound_departure_time-1": "last system act flights_1 confirm outbound_departure_time 1", "system-flights_1-confirm-passengers-1": "last system act flights_1 confirm passengers 1", "system-flights_1-confirm-return_date-1": "last system act flights_1 confirm return_date 1", "system-flights_1-confirm-seating_class-1": "last system act flights_1 confirm seating_class 1", "system-flights_1-inform-destination_airport-1": "last system act flights_1 inform destination_airport 1", "system-flights_1-inform-inbound_arrival_time-1": "last system act flights_1 inform inbound_arrival_time 1", "system-flights_1-inform-number_stops-1": "last system act flights_1 inform number_stops 1", "system-flights_1-inform-origin_airport-1": "last system act flights_1 inform origin_airport 1", "system-flights_1-inform-outbound_arrival_time-1": "last system act flights_1 inform outbound_arrival_time 1", "system-flights_1-inform-refundable-1": "last system act flights_1 inform refundable 1", "system-flights_1-inform_count-count-1": "last system act flights_1 inform_count count 1", "system-flights_1-notify_failure-none-none": "last system act flights_1 notify_failure none none", "system-flights_1-notify_success-none-none": "last system act flights_1 notify_success none none", "system-flights_1-offer-airlines-1": "last system act flights_1 offer airlines 1", "system-flights_1-offer-inbound_departure_time-1": "last system act flights_1 offer inbound_departure_time 1", "system-flights_1-offer-number_stops-1": "last system act flights_1 offer number_stops 1", "system-flights_1-offer-outbound_departure_time-1": "last system act flights_1 offer outbound_departure_time 1", "system-flights_1-offer-price-1": "last system act flights_1 offer price 1", "system-flights_1-offer_intent-reserveonewayflight-1": "last system act flights_1 offer_intent ReserveOnewayFlight 1", "system-flights_1-offer_intent-reserveroundtripflights-1": "last system act flights_1 offer_intent ReserveRoundtripFlights 1", "system-flights_1-request-airlines-?": "last system act flights_1 request airlines ?", "system-flights_1-request-departure_date-?": "last system act flights_1 request departure_date ?", "system-flights_1-request-destination_city-?": "last system act flights_1 request destination_city ?", "system-flights_1-request-inbound_departure_time-?": "last system act flights_1 request inbound_departure_time ?", "system-flights_1-request-origin_city-?": "last system act flights_1 request origin_city ?", "system-flights_1-request-outbound_departure_time-?": "last system act flights_1 request outbound_departure_time ?", "system-flights_1-request-return_date-?": "last system act flights_1 request return_date ?", "system-flights_2-inform-destination_airport-1": "last system act flights_2 inform destination_airport 1", "system-flights_2-inform-is_redeye-1": "last system act flights_2 inform is_redeye 1", "system-flights_2-inform-origin_airport-1": "last system act flights_2 inform origin_airport 1", "system-flights_2-inform-outbound_arrival_time-1": "last system act flights_2 inform outbound_arrival_time 1", "system-flights_2-inform_count-count-1": "last system act flights_2 inform_count count 1", "system-flights_2-offer-airlines-1": "last system act flights_2 offer airlines 1", "system-flights_2-offer-fare-1": "last system act flights_2 offer fare 1", "system-flights_2-offer-inbound_departure_time-1": "last system act flights_2 offer inbound_departure_time 1", "system-flights_2-offer-number_stops-1": "last system act flights_2 offer number_stops 1", "system-flights_2-offer-outbound_departure_time-1": "last system act flights_2 offer outbound_departure_time 1", "system-flights_2-request-departure_date-?": "last system act flights_2 request departure_date ?", "system-flights_2-request-destination-?": "last system act flights_2 request destination ?", "system-flights_2-request-origin-?": "last system act flights_2 request origin ?", "system-flights_2-request-return_date-?": "last system act flights_2 request return_date ?", "system-flights_3-inform-arrives_next_day-1": "last system act flights_3 inform arrives_next_day 1", "system-flights_3-inform-destination_airport_name-1": "last system act flights_3 inform destination_airport_name 1", "system-flights_3-inform-origin_airport_name-1": "last system act flights_3 inform origin_airport_name 1", "system-flights_3-inform-outbound_arrival_time-1": "last system act flights_3 inform outbound_arrival_time 1", "system-flights_3-inform_count-count-1": "last system act flights_3 inform_count count 1", "system-flights_3-offer-airlines-1": "last system act flights_3 offer airlines 1", "system-flights_3-offer-inbound_departure_time-1": "last system act flights_3 offer inbound_departure_time 1", "system-flights_3-offer-number_stops-1": "last system act flights_3 offer number_stops 1", "system-flights_3-offer-outbound_departure_time-1": "last system act flights_3 offer outbound_departure_time 1", "system-flights_3-offer-price-1": "last system act flights_3 offer price 1", "system-flights_3-request-departure_date-?": "last system act flights_3 request departure_date ?", "system-flights_3-request-destination_city-?": "last system act flights_3 request destination_city ?", "system-flights_3-request-origin_city-?": "last system act flights_3 request origin_city ?", "system-flights_3-request-return_date-?": "last system act flights_3 request return_date ?", "system-flights_4-inform-inbound_arrival_time-1": "last system act flights_4 inform inbound_arrival_time 1", "system-flights_4-inform-number_of_tickets-1": "last system act flights_4 inform number_of_tickets 1", "system-flights_4-inform-outbound_arrival_time-1": "last system act flights_4 inform outbound_arrival_time 1", "system-flights_4-inform-seating_class-1": "last system act flights_4 inform seating_class 1", "system-flights_4-inform_count-count-1": "last system act flights_4 inform_count count 1", "system-flights_4-offer-airlines-1": "last system act flights_4 offer airlines 1", "system-flights_4-offer-inbound_departure_time-1": "last system act flights_4 offer inbound_departure_time 1", "system-flights_4-offer-is_nonstop-1": "last system act flights_4 offer is_nonstop 1", "system-flights_4-offer-outbound_departure_time-1": "last system act flights_4 offer outbound_departure_time 1", "system-flights_4-offer-price-1": "last system act flights_4 offer price 1", "system-flights_4-request-departure_date-?": "last system act flights_4 request departure_date ?", "system-flights_4-request-destination_airport-?": "last system act flights_4 request destination_airport ?", "system-flights_4-request-origin_airport-?": "last system act flights_4 request origin_airport ?", "system-flights_4-request-return_date-?": "last system act flights_4 request return_date ?", "system-homes_1-confirm-property_name-1": "last system act homes_1 confirm property_name 1", "system-homes_1-confirm-visit_date-1": "last system act homes_1 confirm visit_date 1", "system-homes_1-inform-furnished-1": "last system act homes_1 inform furnished 1", "system-homes_1-inform-pets_allowed-1": "last system act homes_1 inform pets_allowed 1", "system-homes_1-inform-phone_number-1": "last system act homes_1 inform phone_number 1", "system-homes_1-inform_count-count-1": "last system act homes_1 inform_count count 1", "system-homes_1-notify_failure-none-none": "last system act homes_1 notify_failure none none", "system-homes_1-notify_success-none-none": "last system act homes_1 notify_success none none", "system-homes_1-offer-address-1": "last system act homes_1 offer address 1", "system-homes_1-offer-number_of_baths-1": "last system act homes_1 offer number_of_baths 1", "system-homes_1-offer-number_of_beds-1": "last system act homes_1 offer number_of_beds 1", "system-homes_1-offer-property_name-1": "last system act homes_1 offer property_name 1", "system-homes_1-offer-rent-1": "last system act homes_1 offer rent 1", "system-homes_1-offer_intent-schedulevisit-1": "last system act homes_1 offer_intent ScheduleVisit 1", "system-homes_1-request-area-?": "last system act homes_1 request area ?", "system-homes_1-request-number_of_beds-?": "last system act homes_1 request number_of_beds ?", "system-homes_1-request-visit_date-?": "last system act homes_1 request visit_date ?", "system-homes_2-confirm-property_name-1": "last system act homes_2 confirm property_name 1", "system-homes_2-confirm-visit_date-1": "last system act homes_2 confirm visit_date 1", "system-homes_2-inform-has_garage-1": "last system act homes_2 inform has_garage 1", "system-homes_2-inform-in_unit_laundry-1": "last system act homes_2 inform in_unit_laundry 1", "system-homes_2-inform-phone_number-1": "last system act homes_2 inform phone_number 1", "system-homes_2-inform_count-count-1": "last system act homes_2 inform_count count 1", "system-homes_2-notify_success-none-none": "last system act homes_2 notify_success none none", "system-homes_2-offer-address-1": "last system act homes_2 offer address 1", "system-homes_2-offer-price-1": "last system act homes_2 offer price 1", "system-homes_2-offer-property_name-1": "last system act homes_2 offer property_name 1", "system-homes_2-offer_intent-schedulevisit-1": "last system act homes_2 offer_intent ScheduleVisit 1", "system-homes_2-request-area-?": "last system act homes_2 request area ?", "system-homes_2-request-intent-?": "last system act homes_2 request intent ?", "system-homes_2-request-number_of_baths-?": "last system act homes_2 request number_of_baths ?", "system-homes_2-request-number_of_beds-?": "last system act homes_2 request number_of_beds ?", "system-homes_2-request-visit_date-?": "last system act homes_2 request visit_date ?", "system-hotels_1-confirm-check_in_date-1": "last system act hotels_1 confirm check_in_date 1", "system-hotels_1-confirm-destination-1": "last system act hotels_1 confirm destination 1", "system-hotels_1-confirm-hotel_name-1": "last system act hotels_1 confirm hotel_name 1", "system-hotels_1-confirm-number_of_days-1": "last system act hotels_1 confirm number_of_days 1", "system-hotels_1-confirm-number_of_rooms-1": "last system act hotels_1 confirm number_of_rooms 1", "system-hotels_1-inform-has_wifi-1": "last system act hotels_1 inform has_wifi 1", "system-hotels_1-inform-phone_number-1": "last system act hotels_1 inform phone_number 1", "system-hotels_1-inform-price_per_night-1": "last system act hotels_1 inform price_per_night 1", "system-hotels_1-inform-street_address-1": "last system act hotels_1 inform street_address 1", "system-hotels_1-inform_count-count-1": "last system act hotels_1 inform_count count 1", "system-hotels_1-notify_success-none-none": "last system act hotels_1 notify_success none none", "system-hotels_1-offer-hotel_name-1": "last system act hotels_1 offer hotel_name 1", "system-hotels_1-offer-star_rating-1": "last system act hotels_1 offer star_rating 1", "system-hotels_1-offer_intent-reservehotel-1": "last system act hotels_1 offer_intent ReserveHotel 1", "system-hotels_1-request-check_in_date-?": "last system act hotels_1 request check_in_date ?", "system-hotels_1-request-destination-?": "last system act hotels_1 request destination ?", "system-hotels_1-request-hotel_name-?": "last system act hotels_1 request hotel_name ?", "system-hotels_1-request-number_of_days-?": "last system act hotels_1 request number_of_days ?", "system-hotels_2-confirm-check_in_date-1": "last system act hotels_2 confirm check_in_date 1", "system-hotels_2-confirm-check_out_date-1": "last system act hotels_2 confirm check_out_date 1", "system-hotels_2-confirm-number_of_adults-1": "last system act hotels_2 confirm number_of_adults 1", "system-hotels_2-confirm-where_to-1": "last system act hotels_2 confirm where_to 1", "system-hotels_2-inform-has_laundry_service-1": "last system act hotels_2 inform has_laundry_service 1", "system-hotels_2-inform-phone_number-1": "last system act hotels_2 inform phone_number 1", "system-hotels_2-inform-total_price-1": "last system act hotels_2 inform total_price 1", "system-hotels_2-inform_count-count-1": "last system act hotels_2 inform_count count 1", "system-hotels_2-notify_success-none-none": "last system act hotels_2 notify_success none none", "system-hotels_2-offer-address-1": "last system act hotels_2 offer address 1", "system-hotels_2-offer-rating-1": "last system act hotels_2 offer rating 1", "system-hotels_2-offer_intent-bookhouse-1": "last system act hotels_2 offer_intent BookHouse 1", "system-hotels_2-request-check_in_date-?": "last system act hotels_2 request check_in_date ?", "system-hotels_2-request-check_out_date-?": "last system act hotels_2 request check_out_date ?", "system-hotels_2-request-number_of_adults-?": "last system act hotels_2 request number_of_adults ?", "system-hotels_2-request-where_to-?": "last system act hotels_2 request where_to ?", "system-hotels_3-confirm-check_in_date-1": "last system act hotels_3 confirm check_in_date 1", "system-hotels_3-confirm-check_out_date-1": "last system act hotels_3 confirm check_out_date 1", "system-hotels_3-confirm-hotel_name-1": "last system act hotels_3 confirm hotel_name 1", "system-hotels_3-confirm-location-1": "last system act hotels_3 confirm location 1", "system-hotels_3-confirm-number_of_rooms-1": "last system act hotels_3 confirm number_of_rooms 1", "system-hotels_3-inform-pets_welcome-1": "last system act hotels_3 inform pets_welcome 1", "system-hotels_3-inform-phone_number-1": "last system act hotels_3 inform phone_number 1", "system-hotels_3-inform-price-1": "last system act hotels_3 inform price 1", "system-hotels_3-inform-street_address-1": "last system act hotels_3 inform street_address 1", "system-hotels_3-inform_count-count-1": "last system act hotels_3 inform_count count 1", "system-hotels_3-notify_success-none-none": "last system act hotels_3 notify_success none none", "system-hotels_3-offer-average_rating-1": "last system act hotels_3 offer average_rating 1", "system-hotels_3-offer-hotel_name-1": "last system act hotels_3 offer hotel_name 1", "system-hotels_3-offer_intent-reservehotel-1": "last system act hotels_3 offer_intent ReserveHotel 1", "system-hotels_3-request-check_in_date-?": "last system act hotels_3 request check_in_date ?", "system-hotels_3-request-check_out_date-?": "last system act hotels_3 request check_out_date ?", "system-hotels_3-request-hotel_name-?": "last system act hotels_3 request hotel_name ?", "system-hotels_3-request-location-?": "last system act hotels_3 request location ?", "system-hotels_4-confirm-check_in_date-1": "last system act hotels_4 confirm check_in_date 1", "system-hotels_4-confirm-location-1": "last system act hotels_4 confirm location 1", "system-hotels_4-confirm-number_of_rooms-1": "last system act hotels_4 confirm number_of_rooms 1", "system-hotels_4-confirm-place_name-1": "last system act hotels_4 confirm place_name 1", "system-hotels_4-confirm-stay_length-1": "last system act hotels_4 confirm stay_length 1", "system-hotels_4-inform-phone_number-1": "last system act hotels_4 inform phone_number 1", "system-hotels_4-inform-price_per_night-1": "last system act hotels_4 inform price_per_night 1", "system-hotels_4-inform-smoking_allowed-1": "last system act hotels_4 inform smoking_allowed 1", "system-hotels_4-inform-street_address-1": "last system act hotels_4 inform street_address 1", "system-hotels_4-inform_count-count-1": "last system act hotels_4 inform_count count 1", "system-hotels_4-notify_success-none-none": "last system act hotels_4 notify_success none none", "system-hotels_4-offer-place_name-1": "last system act hotels_4 offer place_name 1", "system-hotels_4-offer-star_rating-1": "last system act hotels_4 offer star_rating 1", "system-hotels_4-offer_intent-reservehotel-1": "last system act hotels_4 offer_intent ReserveHotel 1", "system-hotels_4-request-check_in_date-?": "last system act hotels_4 request check_in_date ?", "system-hotels_4-request-location-?": "last system act hotels_4 request location ?", "system-hotels_4-request-stay_length-?": "last system act hotels_4 request stay_length ?", "system-media_1-confirm-subtitles-1": "last system act media_1 confirm subtitles 1", "system-media_1-confirm-title-1": "last system act media_1 confirm title 1", "system-media_1-inform-directed_by-1": "last system act media_1 inform directed_by 1", "system-media_1-inform-genre-1": "last system act media_1 inform genre 1", "system-media_1-inform_count-count-1": "last system act media_1 inform_count count 1", "system-media_1-notify_failure-none-none": "last system act media_1 notify_failure none none", "system-media_1-notify_success-none-none": "last system act media_1 notify_success none none", "system-media_1-offer-title-1": "last system act media_1 offer title 1", "system-media_1-offer-title-2": "last system act media_1 offer title 2", "system-media_1-offer-title-3": "last system act media_1 offer title 3", "system-media_1-offer_intent-playmovie-1": "last system act media_1 offer_intent PlayMovie 1", "system-media_1-request-genre-?": "last system act media_1 request genre ?", "system-media_1-request-title-?": "last system act media_1 request title ?", "system-media_2-confirm-movie_name-1": "last system act media_2 confirm movie_name 1", "system-media_2-confirm-subtitle_language-1": "last system act media_2 confirm subtitle_language 1", "system-media_2-inform-price-1": "last system act media_2 inform price 1", "system-media_2-inform_count-count-1": "last system act media_2 inform_count count 1", "system-media_2-notify_success-none-none": "last system act media_2 notify_success none none", "system-media_2-offer-movie_name-1": "last system act media_2 offer movie_name 1", "system-media_2-offer-movie_name-2": "last system act media_2 offer movie_name 2", "system-media_2-offer-movie_name-3": "last system act media_2 offer movie_name 3", "system-media_2-offer_intent-rentmovie-1": "last system act media_2 offer_intent RentMovie 1", "system-media_2-request-genre-?": "last system act media_2 request genre ?", "system-media_3-confirm-subtitle_language-1": "last system act media_3 confirm subtitle_language 1", "system-media_3-confirm-title-1": "last system act media_3 confirm title 1", "system-media_3-inform-starring-1": "last system act media_3 inform starring 1", "system-media_3-inform_count-count-1": "last system act media_3 inform_count count 1", "system-media_3-notify_success-none-none": "last system act media_3 notify_success none none", "system-media_3-offer-title-1": "last system act media_3 offer title 1", "system-media_3-offer-title-2": "last system act media_3 offer title 2", "system-media_3-offer-title-3": "last system act media_3 offer title 3", "system-media_3-offer_intent-playmovie-1": "last system act media_3 offer_intent PlayMovie 1", "system-media_3-request-genre-?": "last system act media_3 request genre ?", "system-messaging_1-confirm-contact_name-1": "last system act messaging_1 confirm contact_name 1", "system-messaging_1-confirm-location-1": "last system act messaging_1 confirm location 1", "system-messaging_1-notify_success-none-none": "last system act messaging_1 notify_success none none", "system-messaging_1-request-contact_name-?": "last system act messaging_1 request contact_name ?", "system-messaging_1-request-location-?": "last system act messaging_1 request location ?", "system-movies_1-confirm-location-1": "last system act movies_1 confirm location 1", "system-movies_1-confirm-movie_name-1": "last system act movies_1 confirm movie_name 1", "system-movies_1-confirm-number_of_tickets-1": "last system act movies_1 confirm number_of_tickets 1", "system-movies_1-confirm-show_date-1": "last system act movies_1 confirm show_date 1", "system-movies_1-confirm-show_time-1": "last system act movies_1 confirm show_time 1", "system-movies_1-confirm-show_type-1": "last system act movies_1 confirm show_type 1", "system-movies_1-inform-genre-1": "last system act movies_1 inform genre 1", "system-movies_1-inform-price-1": "last system act movies_1 inform price 1", "system-movies_1-inform-street_address-1": "last system act movies_1 inform street_address 1", "system-movies_1-inform_count-count-1": "last system act movies_1 inform_count count 1", "system-movies_1-notify_failure-none-none": "last system act movies_1 notify_failure none none", "system-movies_1-notify_success-none-none": "last system act movies_1 notify_success none none", "system-movies_1-offer-movie_name-1": "last system act movies_1 offer movie_name 1", "system-movies_1-offer-movie_name-2": "last system act movies_1 offer movie_name 2", "system-movies_1-offer-movie_name-3": "last system act movies_1 offer movie_name 3", "system-movies_1-offer-show_time-1": "last system act movies_1 offer show_time 1", "system-movies_1-offer-theater_name-1": "last system act movies_1 offer theater_name 1", "system-movies_1-offer_intent-buymovietickets-1": "last system act movies_1 offer_intent BuyMovieTickets 1", "system-movies_1-request-location-?": "last system act movies_1 request location ?", "system-movies_1-request-movie_name-?": "last system act movies_1 request movie_name ?", "system-movies_1-request-number_of_tickets-?": "last system act movies_1 request number_of_tickets ?", "system-movies_1-request-show_date-?": "last system act movies_1 request show_date ?", "system-movies_1-request-show_time-?": "last system act movies_1 request show_time ?", "system-movies_1-request-show_type-?": "last system act movies_1 request show_type ?", "system-movies_2-inform_count-count-1": "last system act movies_2 inform_count count 1", "system-movies_2-offer-aggregate_rating-1": "last system act movies_2 offer aggregate_rating 1", "system-movies_2-offer-title-1": "last system act movies_2 offer title 1", "system-movies_3-inform-cast-1": "last system act movies_3 inform cast 1", "system-movies_3-inform-directed_by-1": "last system act movies_3 inform directed_by 1", "system-movies_3-inform-genre-1": "last system act movies_3 inform genre 1", "system-movies_3-inform_count-count-1": "last system act movies_3 inform_count count 1", "system-movies_3-offer-movie_title-1": "last system act movies_3 offer movie_title 1", "system-movies_3-offer-percent_rating-1": "last system act movies_3 offer percent_rating 1", "system-music_1-confirm-playback_device-1": "last system act music_1 confirm playback_device 1", "system-music_1-confirm-song_name-1": "last system act music_1 confirm song_name 1", "system-music_1-inform-album-1": "last system act music_1 inform album 1", "system-music_1-inform-genre-1": "last system act music_1 inform genre 1", "system-music_1-inform-year-1": "last system act music_1 inform year 1", "system-music_1-inform_count-count-1": "last system act music_1 inform_count count 1", "system-music_1-notify_success-none-none": "last system act music_1 notify_success none none", "system-music_1-offer-album-1": "last system act music_1 offer album 1", "system-music_1-offer-artist-1": "last system act music_1 offer artist 1", "system-music_1-offer-song_name-1": "last system act music_1 offer song_name 1", "system-music_1-offer_intent-playsong-1": "last system act music_1 offer_intent PlaySong 1", "system-music_1-request-song_name-?": "last system act music_1 request song_name ?", "system-music_2-confirm-playback_device-1": "last system act music_2 confirm playback_device 1", "system-music_2-confirm-song_name-1": "last system act music_2 confirm song_name 1", "system-music_2-inform-genre-1": "last system act music_2 inform genre 1", "system-music_2-inform_count-count-1": "last system act music_2 inform_count count 1", "system-music_2-notify_success-none-none": "last system act music_2 notify_success none none", "system-music_2-offer-album-1": "last system act music_2 offer album 1", "system-music_2-offer-artist-1": "last system act music_2 offer artist 1", "system-music_2-offer-song_name-1": "last system act music_2 offer song_name 1", "system-music_2-offer_intent-playmedia-1": "last system act music_2 offer_intent PlayMedia 1", "system-music_2-request-song_name-?": "last system act music_2 request song_name ?", "system-music_3-confirm-device-1": "last system act music_3 confirm device 1", "system-music_3-confirm-track-1": "last system act music_3 confirm track 1", "system-music_3-inform-genre-1": "last system act music_3 inform genre 1", "system-music_3-inform-year-1": "last system act music_3 inform year 1", "system-music_3-inform_count-count-1": "last system act music_3 inform_count count 1", "system-music_3-notify_success-none-none": "last system act music_3 notify_success none none", "system-music_3-offer-album-1": "last system act music_3 offer album 1", "system-music_3-offer-artist-1": "last system act music_3 offer artist 1", "system-music_3-offer-track-1": "last system act music_3 offer track 1", "system-music_3-offer_intent-playmedia-1": "last system act music_3 offer_intent PlayMedia 1", "system-payment_1-confirm-amount-1": "last system act payment_1 confirm amount 1", "system-payment_1-confirm-payment_method-1": "last system act payment_1 confirm payment_method 1", "system-payment_1-confirm-private_visibility-1": "last system act payment_1 confirm private_visibility 1", "system-payment_1-confirm-receiver-1": "last system act payment_1 confirm receiver 1", "system-payment_1-notify_success-none-none": "last system act payment_1 notify_success none none", "system-payment_1-request-amount-?": "last system act payment_1 request amount ?", "system-payment_1-request-payment_method-?": "last system act payment_1 request payment_method ?", "system-payment_1-request-receiver-?": "last system act payment_1 request receiver ?", "system-rentalcars_1-confirm-dropoff_date-1": "last system act rentalcars_1 confirm dropoff_date 1", "system-rentalcars_1-confirm-pickup_date-1": "last system act rentalcars_1 confirm pickup_date 1", "system-rentalcars_1-confirm-pickup_location-1": "last system act rentalcars_1 confirm pickup_location 1", "system-rentalcars_1-confirm-pickup_time-1": "last system act rentalcars_1 confirm pickup_time 1", "system-rentalcars_1-confirm-type-1": "last system act rentalcars_1 confirm type 1", "system-rentalcars_1-inform-car_name-1": "last system act rentalcars_1 inform car_name 1", "system-rentalcars_1-inform-total_price-1": "last system act rentalcars_1 inform total_price 1", "system-rentalcars_1-inform_count-count-1": "last system act rentalcars_1 inform_count count 1", "system-rentalcars_1-notify_success-none-none": "last system act rentalcars_1 notify_success none none", "system-rentalcars_1-offer-car_name-1": "last system act rentalcars_1 offer car_name 1", "system-rentalcars_1-offer-pickup_date-1": "last system act rentalcars_1 offer pickup_date 1", "system-rentalcars_1-offer-pickup_location-1": "last system act rentalcars_1 offer pickup_location 1", "system-rentalcars_1-offer-type-1": "last system act rentalcars_1 offer type 1", "system-rentalcars_1-offer_intent-reservecar-1": "last system act rentalcars_1 offer_intent ReserveCar 1", "system-rentalcars_1-request-dropoff_date-?": "last system act rentalcars_1 request dropoff_date ?", "system-rentalcars_1-request-pickup_city-?": "last system act rentalcars_1 request pickup_city ?", "system-rentalcars_1-request-pickup_date-?": "last system act rentalcars_1 request pickup_date ?", "system-rentalcars_1-request-pickup_location-?": "last system act rentalcars_1 request pickup_location ?", "system-rentalcars_1-request-pickup_time-?": "last system act rentalcars_1 request pickup_time ?", "system-rentalcars_1-request-type-?": "last system act rentalcars_1 request type ?", "system-rentalcars_2-confirm-car_type-1": "last system act rentalcars_2 confirm car_type 1", "system-rentalcars_2-confirm-dropoff_date-1": "last system act rentalcars_2 confirm dropoff_date 1", "system-rentalcars_2-confirm-pickup_date-1": "last system act rentalcars_2 confirm pickup_date 1", "system-rentalcars_2-confirm-pickup_location-1": "last system act rentalcars_2 confirm pickup_location 1", "system-rentalcars_2-confirm-pickup_time-1": "last system act rentalcars_2 confirm pickup_time 1", "system-rentalcars_2-inform-car_name-1": "last system act rentalcars_2 inform car_name 1", "system-rentalcars_2-inform-total_price-1": "last system act rentalcars_2 inform total_price 1", "system-rentalcars_2-inform_count-count-1": "last system act rentalcars_2 inform_count count 1", "system-rentalcars_2-notify_success-none-none": "last system act rentalcars_2 notify_success none none", "system-rentalcars_2-offer-car_name-1": "last system act rentalcars_2 offer car_name 1", "system-rentalcars_2-offer-car_type-1": "last system act rentalcars_2 offer car_type 1", "system-rentalcars_2-offer-pickup_date-1": "last system act rentalcars_2 offer pickup_date 1", "system-rentalcars_2-offer-pickup_location-1": "last system act rentalcars_2 offer pickup_location 1", "system-rentalcars_2-offer_intent-reservecar-1": "last system act rentalcars_2 offer_intent ReserveCar 1", "system-rentalcars_2-request-car_type-?": "last system act rentalcars_2 request car_type ?", "system-rentalcars_2-request-dropoff_date-?": "last system act rentalcars_2 request dropoff_date ?", "system-rentalcars_2-request-pickup_city-?": "last system act rentalcars_2 request pickup_city ?", "system-rentalcars_2-request-pickup_date-?": "last system act rentalcars_2 request pickup_date ?", "system-rentalcars_2-request-pickup_location-?": "last system act rentalcars_2 request pickup_location ?", "system-rentalcars_2-request-pickup_time-?": "last system act rentalcars_2 request pickup_time ?", "system-rentalcars_3-confirm-add_insurance-1": "last system act rentalcars_3 confirm add_insurance 1", "system-rentalcars_3-confirm-car_type-1": "last system act rentalcars_3 confirm car_type 1", "system-rentalcars_3-confirm-end_date-1": "last system act rentalcars_3 confirm end_date 1", "system-rentalcars_3-confirm-pickup_location-1": "last system act rentalcars_3 confirm pickup_location 1", "system-rentalcars_3-confirm-pickup_time-1": "last system act rentalcars_3 confirm pickup_time 1", "system-rentalcars_3-confirm-start_date-1": "last system act rentalcars_3 confirm start_date 1", "system-rentalcars_3-inform-car_name-1": "last system act rentalcars_3 inform car_name 1", "system-rentalcars_3-inform-price_per_day-1": "last system act rentalcars_3 inform price_per_day 1", "system-rentalcars_3-inform_count-count-1": "last system act rentalcars_3 inform_count count 1", "system-rentalcars_3-notify_success-none-none": "last system act rentalcars_3 notify_success none none", "system-rentalcars_3-offer-car_name-1": "last system act rentalcars_3 offer car_name 1", "system-rentalcars_3-offer-car_type-1": "last system act rentalcars_3 offer car_type 1", "system-rentalcars_3-offer-pickup_location-1": "last system act rentalcars_3 offer pickup_location 1", "system-rentalcars_3-offer_intent-reservecar-1": "last system act rentalcars_3 offer_intent ReserveCar 1", "system-rentalcars_3-request-add_insurance-?": "last system act rentalcars_3 request add_insurance ?", "system-rentalcars_3-request-car_type-?": "last system act rentalcars_3 request car_type ?", "system-rentalcars_3-request-city-?": "last system act rentalcars_3 request city ?", "system-rentalcars_3-request-end_date-?": "last system act rentalcars_3 request end_date ?", "system-rentalcars_3-request-pickup_location-?": "last system act rentalcars_3 request pickup_location ?", "system-rentalcars_3-request-pickup_time-?": "last system act rentalcars_3 request pickup_time ?", "system-rentalcars_3-request-start_date-?": "last system act rentalcars_3 request start_date ?", "system-restaurants_1-confirm-city-1": "last system act restaurants_1 confirm city 1", "system-restaurants_1-confirm-date-1": "last system act restaurants_1 confirm date 1", "system-restaurants_1-confirm-party_size-1": "last system act restaurants_1 confirm party_size 1", "system-restaurants_1-confirm-restaurant_name-1": "last system act restaurants_1 confirm restaurant_name 1", "system-restaurants_1-confirm-time-1": "last system act restaurants_1 confirm time 1", "system-restaurants_1-inform-cuisine-1": "last system act restaurants_1 inform cuisine 1", "system-restaurants_1-inform-has_live_music-1": "last system act restaurants_1 inform has_live_music 1", "system-restaurants_1-inform-phone_number-1": "last system act restaurants_1 inform phone_number 1", "system-restaurants_1-inform-price_range-1": "last system act restaurants_1 inform price_range 1", "system-restaurants_1-inform-serves_alcohol-1": "last system act restaurants_1 inform serves_alcohol 1", "system-restaurants_1-inform-street_address-1": "last system act restaurants_1 inform street_address 1", "system-restaurants_1-inform_count-count-1": "last system act restaurants_1 inform_count count 1", "system-restaurants_1-notify_failure-none-none": "last system act restaurants_1 notify_failure none none", "system-restaurants_1-notify_success-none-none": "last system act restaurants_1 notify_success none none", "system-restaurants_1-offer-city-1": "last system act restaurants_1 offer city 1", "system-restaurants_1-offer-date-1": "last system act restaurants_1 offer date 1", "system-restaurants_1-offer-party_size-1": "last system act restaurants_1 offer party_size 1", "system-restaurants_1-offer-restaurant_name-1": "last system act restaurants_1 offer restaurant_name 1", "system-restaurants_1-offer-time-1": "last system act restaurants_1 offer time 1", "system-restaurants_1-offer_intent-reserverestaurant-1": "last system act restaurants_1 offer_intent ReserveRestaurant 1", "system-restaurants_1-request-city-?": "last system act restaurants_1 request city ?", "system-restaurants_1-request-cuisine-?": "last system act restaurants_1 request cuisine ?", "system-restaurants_1-request-restaurant_name-?": "last system act restaurants_1 request restaurant_name ?", "system-restaurants_1-request-time-?": "last system act restaurants_1 request time ?", "system-restaurants_2-confirm-date-1": "last system act restaurants_2 confirm date 1", "system-restaurants_2-confirm-location-1": "last system act restaurants_2 confirm location 1", "system-restaurants_2-confirm-number_of_seats-1": "last system act restaurants_2 confirm number_of_seats 1", "system-restaurants_2-confirm-restaurant_name-1": "last system act restaurants_2 confirm restaurant_name 1", "system-restaurants_2-confirm-time-1": "last system act restaurants_2 confirm time 1", "system-restaurants_2-inform-address-1": "last system act restaurants_2 inform address 1", "system-restaurants_2-inform-has_seating_outdoors-1": "last system act restaurants_2 inform has_seating_outdoors 1", "system-restaurants_2-inform-has_vegetarian_options-1": "last system act restaurants_2 inform has_vegetarian_options 1", "system-restaurants_2-inform-phone_number-1": "last system act restaurants_2 inform phone_number 1", "system-restaurants_2-inform-price_range-1": "last system act restaurants_2 inform price_range 1", "system-restaurants_2-inform-rating-1": "last system act restaurants_2 inform rating 1", "system-restaurants_2-inform_count-count-1": "last system act restaurants_2 inform_count count 1", "system-restaurants_2-notify_failure-none-none": "last system act restaurants_2 notify_failure none none", "system-restaurants_2-notify_success-none-none": "last system act restaurants_2 notify_success none none", "system-restaurants_2-offer-date-1": "last system act restaurants_2 offer date 1", "system-restaurants_2-offer-location-1": "last system act restaurants_2 offer location 1", "system-restaurants_2-offer-number_of_seats-1": "last system act restaurants_2 offer number_of_seats 1", "system-restaurants_2-offer-restaurant_name-1": "last system act restaurants_2 offer restaurant_name 1", "system-restaurants_2-offer-time-1": "last system act restaurants_2 offer time 1", "system-restaurants_2-offer_intent-reserverestaurant-1": "last system act restaurants_2 offer_intent ReserveRestaurant 1", "system-restaurants_2-request-category-?": "last system act restaurants_2 request category ?", "system-restaurants_2-request-location-?": "last system act restaurants_2 request location ?", "system-restaurants_2-request-restaurant_name-?": "last system act restaurants_2 request restaurant_name ?", "system-restaurants_2-request-time-?": "last system act restaurants_2 request time ?", "system-ridesharing_1-confirm-destination-1": "last system act ridesharing_1 confirm destination 1", "system-ridesharing_1-confirm-number_of_riders-1": "last system act ridesharing_1 confirm number_of_riders 1", "system-ridesharing_1-confirm-shared_ride-1": "last system act ridesharing_1 confirm shared_ride 1", "system-ridesharing_1-inform-approximate_ride_duration-1": "last system act ridesharing_1 inform approximate_ride_duration 1", "system-ridesharing_1-inform-ride_fare-1": "last system act ridesharing_1 inform ride_fare 1", "system-ridesharing_1-notify_success-none-none": "last system act ridesharing_1 notify_success none none", "system-ridesharing_1-request-destination-?": "last system act ridesharing_1 request destination ?", "system-ridesharing_1-request-number_of_riders-?": "last system act ridesharing_1 request number_of_riders ?", "system-ridesharing_1-request-shared_ride-?": "last system act ridesharing_1 request shared_ride ?", "system-ridesharing_2-confirm-destination-1": "last system act ridesharing_2 confirm destination 1", "system-ridesharing_2-confirm-number_of_seats-1": "last system act ridesharing_2 confirm number_of_seats 1", "system-ridesharing_2-confirm-ride_type-1": "last system act ridesharing_2 confirm ride_type 1", "system-ridesharing_2-inform-ride_fare-1": "last system act ridesharing_2 inform ride_fare 1", "system-ridesharing_2-inform-wait_time-1": "last system act ridesharing_2 inform wait_time 1", "system-ridesharing_2-notify_success-none-none": "last system act ridesharing_2 notify_success none none", "system-ridesharing_2-request-destination-?": "last system act ridesharing_2 request destination ?", "system-ridesharing_2-request-number_of_seats-?": "last system act ridesharing_2 request number_of_seats ?", "system-ridesharing_2-request-ride_type-?": "last system act ridesharing_2 request ride_type ?", "system-services_1-confirm-appointment_date-1": "last system act services_1 confirm appointment_date 1", "system-services_1-confirm-appointment_time-1": "last system act services_1 confirm appointment_time 1", "system-services_1-confirm-stylist_name-1": "last system act services_1 confirm stylist_name 1", "system-services_1-inform-average_rating-1": "last system act services_1 inform average_rating 1", "system-services_1-inform-is_unisex-1": "last system act services_1 inform is_unisex 1", "system-services_1-inform-phone_number-1": "last system act services_1 inform phone_number 1", "system-services_1-inform-street_address-1": "last system act services_1 inform street_address 1", "system-services_1-inform_count-count-1": "last system act services_1 inform_count count 1", "system-services_1-notify_failure-none-none": "last system act services_1 notify_failure none none", "system-services_1-notify_success-none-none": "last system act services_1 notify_success none none", "system-services_1-offer-appointment_date-1": "last system act services_1 offer appointment_date 1", "system-services_1-offer-appointment_time-1": "last system act services_1 offer appointment_time 1", "system-services_1-offer-city-1": "last system act services_1 offer city 1", "system-services_1-offer-stylist_name-1": "last system act services_1 offer stylist_name 1", "system-services_1-offer_intent-bookappointment-1": "last system act services_1 offer_intent BookAppointment 1", "system-services_1-request-appointment_date-?": "last system act services_1 request appointment_date ?", "system-services_1-request-appointment_time-?": "last system act services_1 request appointment_time ?", "system-services_1-request-city-?": "last system act services_1 request city ?", "system-services_2-confirm-appointment_date-1": "last system act services_2 confirm appointment_date 1", "system-services_2-confirm-appointment_time-1": "last system act services_2 confirm appointment_time 1", "system-services_2-confirm-dentist_name-1": "last system act services_2 confirm dentist_name 1", "system-services_2-inform-address-1": "last system act services_2 inform address 1", "system-services_2-inform-offers_cosmetic_services-1": "last system act services_2 inform offers_cosmetic_services 1", "system-services_2-inform-phone_number-1": "last system act services_2 inform phone_number 1", "system-services_2-inform_count-count-1": "last system act services_2 inform_count count 1", "system-services_2-notify_failure-none-none": "last system act services_2 notify_failure none none", "system-services_2-notify_success-none-none": "last system act services_2 notify_success none none", "system-services_2-offer-appointment_date-1": "last system act services_2 offer appointment_date 1", "system-services_2-offer-appointment_time-1": "last system act services_2 offer appointment_time 1", "system-services_2-offer-city-1": "last system act services_2 offer city 1", "system-services_2-offer-dentist_name-1": "last system act services_2 offer dentist_name 1", "system-services_2-offer_intent-bookappointment-1": "last system act services_2 offer_intent BookAppointment 1", "system-services_2-request-appointment_date-?": "last system act services_2 request appointment_date ?", "system-services_2-request-appointment_time-?": "last system act services_2 request appointment_time ?", "system-services_2-request-city-?": "last system act services_2 request city ?", "system-services_3-confirm-appointment_date-1": "last system act services_3 confirm appointment_date 1", "system-services_3-confirm-appointment_time-1": "last system act services_3 confirm appointment_time 1", "system-services_3-confirm-doctor_name-1": "last system act services_3 confirm doctor_name 1", "system-services_3-inform-average_rating-1": "last system act services_3 inform average_rating 1", "system-services_3-inform-phone_number-1": "last system act services_3 inform phone_number 1", "system-services_3-inform-street_address-1": "last system act services_3 inform street_address 1", "system-services_3-inform_count-count-1": "last system act services_3 inform_count count 1", "system-services_3-notify_failure-none-none": "last system act services_3 notify_failure none none", "system-services_3-notify_success-none-none": "last system act services_3 notify_success none none", "system-services_3-offer-appointment_date-1": "last system act services_3 offer appointment_date 1", "system-services_3-offer-appointment_time-1": "last system act services_3 offer appointment_time 1", "system-services_3-offer-city-1": "last system act services_3 offer city 1", "system-services_3-offer-doctor_name-1": "last system act services_3 offer doctor_name 1", "system-services_3-offer-type-1": "last system act services_3 offer type 1", "system-services_3-offer_intent-bookappointment-1": "last system act services_3 offer_intent BookAppointment 1", "system-services_3-request-appointment_date-?": "last system act services_3 request appointment_date ?", "system-services_3-request-appointment_time-?": "last system act services_3 request appointment_time ?", "system-services_3-request-city-?": "last system act services_3 request city ?", "system-services_3-request-type-?": "last system act services_3 request type ?", "system-services_4-confirm-appointment_date-1": "last system act services_4 confirm appointment_date 1", "system-services_4-confirm-appointment_time-1": "last system act services_4 confirm appointment_time 1", "system-services_4-confirm-therapist_name-1": "last system act services_4 confirm therapist_name 1", "system-services_4-inform-address-1": "last system act services_4 inform address 1", "system-services_4-inform-phone_number-1": "last system act services_4 inform phone_number 1", "system-services_4-inform_count-count-1": "last system act services_4 inform_count count 1", "system-services_4-notify_failure-none-none": "last system act services_4 notify_failure none none", "system-services_4-notify_success-none-none": "last system act services_4 notify_success none none", "system-services_4-offer-appointment_date-1": "last system act services_4 offer appointment_date 1", "system-services_4-offer-appointment_time-1": "last system act services_4 offer appointment_time 1", "system-services_4-offer-city-1": "last system act services_4 offer city 1", "system-services_4-offer-therapist_name-1": "last system act services_4 offer therapist_name 1", "system-services_4-offer-type-1": "last system act services_4 offer type 1", "system-services_4-offer_intent-bookappointment-1": "last system act services_4 offer_intent BookAppointment 1", "system-services_4-request-appointment_date-?": "last system act services_4 request appointment_date ?", "system-services_4-request-appointment_time-?": "last system act services_4 request appointment_time ?", "system-services_4-request-city-?": "last system act services_4 request city ?", "system-services_4-request-type-?": "last system act services_4 request type ?", "system-trains_1-confirm-class-1": "last system act trains_1 confirm class 1", "system-trains_1-confirm-date_of_journey-1": "last system act trains_1 confirm date_of_journey 1", "system-trains_1-confirm-from-1": "last system act trains_1 confirm from 1", "system-trains_1-confirm-journey_start_time-1": "last system act trains_1 confirm journey_start_time 1", "system-trains_1-confirm-number_of_adults-1": "last system act trains_1 confirm number_of_adults 1", "system-trains_1-confirm-to-1": "last system act trains_1 confirm to 1", "system-trains_1-confirm-trip_protection-1": "last system act trains_1 confirm trip_protection 1", "system-trains_1-inform-from_station-1": "last system act trains_1 inform from_station 1", "system-trains_1-inform-to_station-1": "last system act trains_1 inform to_station 1", "system-trains_1-inform_count-count-1": "last system act trains_1 inform_count count 1", "system-trains_1-notify_success-none-none": "last system act trains_1 notify_success none none", "system-trains_1-offer-journey_start_time-1": "last system act trains_1 offer journey_start_time 1", "system-trains_1-offer-total-1": "last system act trains_1 offer total 1", "system-trains_1-offer_intent-gettraintickets-1": "last system act trains_1 offer_intent GetTrainTickets 1", "system-trains_1-request-date_of_journey-?": "last system act trains_1 request date_of_journey ?", "system-trains_1-request-from-?": "last system act trains_1 request from ?", "system-trains_1-request-number_of_adults-?": "last system act trains_1 request number_of_adults ?", "system-trains_1-request-to-?": "last system act trains_1 request to ?", "system-trains_1-request-trip_protection-?": "last system act trains_1 request trip_protection ?", "system-travel_1-inform-free_entry-1": "last system act travel_1 inform free_entry 1", "system-travel_1-inform-good_for_kids-1": "last system act travel_1 inform good_for_kids 1", "system-travel_1-inform-phone_number-1": "last system act travel_1 inform phone_number 1", "system-travel_1-inform_count-count-1": "last system act travel_1 inform_count count 1", "system-travel_1-offer-attraction_name-1": "last system act travel_1 offer attraction_name 1", "system-travel_1-offer-category-1": "last system act travel_1 offer category 1", "system-travel_1-request-location-?": "last system act travel_1 request location ?", "system-weather_1-inform-humidity-1": "last system act weather_1 inform humidity 1", "system-weather_1-inform-wind-1": "last system act weather_1 inform wind 1", "system-weather_1-offer-precipitation-1": "last system act weather_1 offer precipitation 1", "system-weather_1-offer-temperature-1": "last system act weather_1 offer temperature 1", "system-weather_1-request-city-?": "last system act weather_1 request city ?", "user--affirm-none-none": "user act  affirm none none", "user--goodbye-none-none": "user act  goodbye none none", "user--negate-none-none": "user act  negate none none", "user--thank_you-none-none": "user act  thank_you none none", "user-alarm_1-affirm_intent-none-none": "user act alarm_1 affirm_intent none none", "user-alarm_1-inform-new_alarm_name-1": "user act alarm_1 inform new_alarm_name 1", "user-alarm_1-inform-new_alarm_time-1": "user act alarm_1 inform new_alarm_time 1", "user-alarm_1-inform_intent-addalarm-1": "user act alarm_1 inform_intent addalarm 1", "user-alarm_1-inform_intent-getalarms-1": "user act alarm_1 inform_intent getalarms 1", "user-alarm_1-negate_intent-none-none": "user act alarm_1 negate_intent none none", "user-alarm_1-select-none-none": "user act alarm_1 select none none", "user-banks_1-inform-account_type-1": "user act banks_1 inform account_type 1", "user-banks_1-inform-amount-1": "user act banks_1 inform amount 1", "user-banks_1-inform-recipient_account_name-1": "user act banks_1 inform recipient_account_name 1", "user-banks_1-inform-recipient_account_type-1": "user act banks_1 inform recipient_account_type 1", "user-banks_1-inform_intent-checkbalance-1": "user act banks_1 inform_intent checkbalance 1", "user-banks_1-inform_intent-transfermoney-1": "user act banks_1 inform_intent transfermoney 1", "user-banks_1-negate_intent-none-none": "user act banks_1 negate_intent none none", "user-banks_1-request_alts-none-none": "user act banks_1 request_alts none none", "user-banks_1-select-none-none": "user act banks_1 select none none", "user-banks_2-affirm_intent-none-none": "user act banks_2 affirm_intent none none", "user-banks_2-inform-account_type-1": "user act banks_2 inform account_type 1", "user-banks_2-inform-recipient_account_type-1": "user act banks_2 inform recipient_account_type 1", "user-banks_2-inform-recipient_name-1": "user act banks_2 inform recipient_name 1", "user-banks_2-inform-transfer_amount-1": "user act banks_2 inform transfer_amount 1", "user-banks_2-inform_intent-checkbalance-1": "user act banks_2 inform_intent checkbalance 1", "user-banks_2-inform_intent-transfermoney-1": "user act banks_2 inform_intent transfermoney 1", "user-banks_2-negate_intent-none-none": "user act banks_2 negate_intent none none", "user-banks_2-request-transfer_time-?": "user act banks_2 request transfer_time ?", "user-banks_2-request_alts-none-none": "user act banks_2 request_alts none none", "user-banks_2-select-none-none": "user act banks_2 select none none", "user-buses_1-affirm_intent-none-none": "user act buses_1 affirm_intent none none", "user-buses_1-inform-from_location-1": "user act buses_1 inform from_location 1", "user-buses_1-inform-leaving_date-1": "user act buses_1 inform leaving_date 1", "user-buses_1-inform-leaving_time-1": "user act buses_1 inform leaving_time 1", "user-buses_1-inform-to_location-1": "user act buses_1 inform to_location 1", "user-buses_1-inform-travelers-1": "user act buses_1 inform travelers 1", "user-buses_1-inform_intent-buybusticket-1": "user act buses_1 inform_intent buybusticket 1", "user-buses_1-inform_intent-findbus-1": "user act buses_1 inform_intent findbus 1", "user-buses_1-negate_intent-none-none": "user act buses_1 negate_intent none none", "user-buses_1-request-from_station-?": "user act buses_1 request from_station ?", "user-buses_1-request-to_station-?": "user act buses_1 request to_station ?", "user-buses_1-request-transfers-?": "user act buses_1 request transfers ?", "user-buses_1-request_alts-none-none": "user act buses_1 request_alts none none", "user-buses_1-select-none-none": "user act buses_1 select none none", "user-buses_2-affirm_intent-none-none": "user act buses_2 affirm_intent none none", "user-buses_2-inform-departure_date-1": "user act buses_2 inform departure_date 1", "user-buses_2-inform-departure_time-1": "user act buses_2 inform departure_time 1", "user-buses_2-inform-destination-1": "user act buses_2 inform destination 1", "user-buses_2-inform-fare_type-1": "user act buses_2 inform fare_type 1", "user-buses_2-inform-group_size-1": "user act buses_2 inform group_size 1", "user-buses_2-inform-origin-1": "user act buses_2 inform origin 1", "user-buses_2-inform_intent-buybusticket-1": "user act buses_2 inform_intent buybusticket 1", "user-buses_2-inform_intent-findbus-1": "user act buses_2 inform_intent findbus 1", "user-buses_2-negate_intent-none-none": "user act buses_2 negate_intent none none", "user-buses_2-request-destination_station_name-?": "user act buses_2 request destination_station_name ?", "user-buses_2-request-origin_station_name-?": "user act buses_2 request origin_station_name ?", "user-buses_2-request-price-?": "user act buses_2 request price ?", "user-buses_2-request_alts-none-none": "user act buses_2 request_alts none none", "user-buses_2-select-none-none": "user act buses_2 select none none", "user-buses_3-inform-additional_luggage-1": "user act buses_3 inform additional_luggage 1", "user-buses_3-inform-category-1": "user act buses_3 inform category 1", "user-buses_3-inform-departure_date-1": "user act buses_3 inform departure_date 1", "user-buses_3-inform-departure_time-1": "user act buses_3 inform departure_time 1", "user-buses_3-inform-from_city-1": "user act buses_3 inform from_city 1", "user-buses_3-inform-num_passengers-1": "user act buses_3 inform num_passengers 1", "user-buses_3-inform-to_city-1": "user act buses_3 inform to_city 1", "user-buses_3-inform_intent-buybusticket-1": "user act buses_3 inform_intent buybusticket 1", "user-buses_3-inform_intent-findbus-1": "user act buses_3 inform_intent findbus 1", "user-buses_3-negate_intent-none-none": "user act buses_3 negate_intent none none", "user-buses_3-request-category-?": "user act buses_3 request category ?", "user-buses_3-request-from_station-?": "user act buses_3 request from_station ?", "user-buses_3-request-to_station-?": "user act buses_3 request to_station ?", "user-buses_3-request_alts-none-none": "user act buses_3 request_alts none none", "user-buses_3-select-none-none": "user act buses_3 select none none", "user-calendar_1-inform-event_date-1": "user act calendar_1 inform event_date 1", "user-calendar_1-inform-event_location-1": "user act calendar_1 inform event_location 1", "user-calendar_1-inform-event_name-1": "user act calendar_1 inform event_name 1", "user-calendar_1-inform-event_time-1": "user act calendar_1 inform event_time 1", "user-calendar_1-inform_intent-addevent-1": "user act calendar_1 inform_intent addevent 1", "user-calendar_1-inform_intent-getavailabletime-1": "user act calendar_1 inform_intent getavailabletime 1", "user-calendar_1-inform_intent-getevents-1": "user act calendar_1 inform_intent getevents 1", "user-calendar_1-negate_intent-none-none": "user act calendar_1 negate_intent none none", "user-calendar_1-request_alts-none-none": "user act calendar_1 request_alts none none", "user-calendar_1-select-none-none": "user act calendar_1 select none none", "user-events_1-affirm_intent-none-none": "user act events_1 affirm_intent none none", "user-events_1-inform-category-1": "user act events_1 inform category 1", "user-events_1-inform-city_of_event-1": "user act events_1 inform city_of_event 1", "user-events_1-inform-date-1": "user act events_1 inform date 1", "user-events_1-inform-event_name-1": "user act events_1 inform event_name 1", "user-events_1-inform-number_of_seats-1": "user act events_1 inform number_of_seats 1", "user-events_1-inform-subcategory-1": "user act events_1 inform subcategory 1", "user-events_1-inform_intent-buyeventtickets-1": "user act events_1 inform_intent buyeventtickets 1", "user-events_1-inform_intent-findevents-1": "user act events_1 inform_intent findevents 1", "user-events_1-negate_intent-none-none": "user act events_1 negate_intent none none", "user-events_1-request-address_of_location-?": "user act events_1 request address_of_location ?", "user-events_1-request-event_location-?": "user act events_1 request event_location ?", "user-events_1-request-subcategory-?": "user act events_1 request subcategory ?", "user-events_1-request-time-?": "user act events_1 request time ?", "user-events_1-request_alts-none-none": "user act events_1 request_alts none none", "user-events_1-select-none-none": "user act events_1 select none none", "user-events_2-affirm_intent-none-none": "user act events_2 affirm_intent none none", "user-events_2-inform-category-1": "user act events_2 inform category 1", "user-events_2-inform-city-1": "user act events_2 inform city 1", "user-events_2-inform-date-1": "user act events_2 inform date 1", "user-events_2-inform-event_name-1": "user act events_2 inform event_name 1", "user-events_2-inform-event_type-1": "user act events_2 inform event_type 1", "user-events_2-inform-number_of_tickets-1": "user act events_2 inform number_of_tickets 1", "user-events_2-inform_intent-buyeventtickets-1": "user act events_2 inform_intent buyeventtickets 1", "user-events_2-inform_intent-findevents-1": "user act events_2 inform_intent findevents 1", "user-events_2-inform_intent-geteventdates-1": "user act events_2 inform_intent geteventdates 1", "user-events_2-negate_intent-none-none": "user act events_2 negate_intent none none", "user-events_2-request-category-?": "user act events_2 request category ?", "user-events_2-request-time-?": "user act events_2 request time ?", "user-events_2-request-venue-?": "user act events_2 request venue ?", "user-events_2-request-venue_address-?": "user act events_2 request venue_address ?", "user-events_2-request_alts-none-none": "user act events_2 request_alts none none", "user-events_2-select-none-none": "user act events_2 select none none", "user-events_3-inform-city-1": "user act events_3 inform city 1", "user-events_3-inform-date-1": "user act events_3 inform date 1", "user-events_3-inform-event_name-1": "user act events_3 inform event_name 1", "user-events_3-inform-event_type-1": "user act events_3 inform event_type 1", "user-events_3-inform-number_of_tickets-1": "user act events_3 inform number_of_tickets 1", "user-events_3-inform_intent-buyeventtickets-1": "user act events_3 inform_intent buyeventtickets 1", "user-events_3-inform_intent-findevents-1": "user act events_3 inform_intent findevents 1", "user-events_3-negate_intent-none-none": "user act events_3 negate_intent none none", "user-events_3-request-price_per_ticket-?": "user act events_3 request price_per_ticket ?", "user-events_3-request-venue_address-?": "user act events_3 request venue_address ?", "user-events_3-request_alts-none-none": "user act events_3 request_alts none none", "user-events_3-select-none-none": "user act events_3 select none none", "user-flights_1-affirm_intent-none-none": "user act flights_1 affirm_intent none none", "user-flights_1-inform-airlines-1": "user act flights_1 inform airlines 1", "user-flights_1-inform-departure_date-1": "user act flights_1 inform departure_date 1", "user-flights_1-inform-destination_city-1": "user act flights_1 inform destination_city 1", "user-flights_1-inform-inbound_departure_time-1": "user act flights_1 inform inbound_departure_time 1", "user-flights_1-inform-origin_city-1": "user act flights_1 inform origin_city 1", "user-flights_1-inform-outbound_departure_time-1": "user act flights_1 inform outbound_departure_time 1", "user-flights_1-inform-passengers-1": "user act flights_1 inform passengers 1", "user-flights_1-inform-refundable-1": "user act flights_1 inform refundable 1", "user-flights_1-inform-return_date-1": "user act flights_1 inform return_date 1", "user-flights_1-inform-seating_class-1": "user act flights_1 inform seating_class 1", "user-flights_1-inform_intent-reserveonewayflight-1": "user act flights_1 inform_intent reserveonewayflight 1", "user-flights_1-inform_intent-reserveroundtripflights-1": "user act flights_1 inform_intent reserveroundtripflights 1", "user-flights_1-inform_intent-searchonewayflight-1": "user act flights_1 inform_intent searchonewayflight 1", "user-flights_1-inform_intent-searchroundtripflights-1": "user act flights_1 inform_intent searchroundtripflights 1", "user-flights_1-negate_intent-none-none": "user act flights_1 negate_intent none none", "user-flights_1-request-destination_airport-?": "user act flights_1 request destination_airport ?", "user-flights_1-request-inbound_arrival_time-?": "user act flights_1 request inbound_arrival_time ?", "user-flights_1-request-number_stops-?": "user act flights_1 request number_stops ?", "user-flights_1-request-origin_airport-?": "user act flights_1 request origin_airport ?", "user-flights_1-request-outbound_arrival_time-?": "user act flights_1 request outbound_arrival_time ?", "user-flights_1-request-refundable-?": "user act flights_1 request refundable ?", "user-flights_1-request_alts-none-none": "user act flights_1 request_alts none none", "user-flights_1-select-none-none": "user act flights_1 select none none", "user-flights_2-inform-airlines-1": "user act flights_2 inform airlines 1", "user-flights_2-inform-departure_date-1": "user act flights_2 inform departure_date 1", "user-flights_2-inform-destination-1": "user act flights_2 inform destination 1", "user-flights_2-inform-origin-1": "user act flights_2 inform origin 1", "user-flights_2-inform-passengers-1": "user act flights_2 inform passengers 1", "user-flights_2-inform-return_date-1": "user act flights_2 inform return_date 1", "user-flights_2-inform-seating_class-1": "user act flights_2 inform seating_class 1", "user-flights_2-inform_intent-searchonewayflight-1": "user act flights_2 inform_intent searchonewayflight 1", "user-flights_2-inform_intent-searchroundtripflights-1": "user act flights_2 inform_intent searchroundtripflights 1", "user-flights_2-request-destination_airport-?": "user act flights_2 request destination_airport ?", "user-flights_2-request-is_redeye-?": "user act flights_2 request is_redeye ?", "user-flights_2-request-origin_airport-?": "user act flights_2 request origin_airport ?", "user-flights_2-request-outbound_arrival_time-?": "user act flights_2 request outbound_arrival_time ?", "user-flights_2-request_alts-none-none": "user act flights_2 request_alts none none", "user-flights_2-select-none-none": "user act flights_2 select none none", "user-flights_3-inform-airlines-1": "user act flights_3 inform airlines 1", "user-flights_3-inform-departure_date-1": "user act flights_3 inform departure_date 1", "user-flights_3-inform-destination_city-1": "user act flights_3 inform destination_city 1", "user-flights_3-inform-flight_class-1": "user act flights_3 inform flight_class 1", "user-flights_3-inform-number_checked_bags-1": "user act flights_3 inform number_checked_bags 1", "user-flights_3-inform-origin_city-1": "user act flights_3 inform origin_city 1", "user-flights_3-inform-passengers-1": "user act flights_3 inform passengers 1", "user-flights_3-inform-return_date-1": "user act flights_3 inform return_date 1", "user-flights_3-inform_intent-searchonewayflight-1": "user act flights_3 inform_intent searchonewayflight 1", "user-flights_3-inform_intent-searchroundtripflights-1": "user act flights_3 inform_intent searchroundtripflights 1", "user-flights_3-request-arrives_next_day-?": "user act flights_3 request arrives_next_day ?", "user-flights_3-request-destination_airport_name-?": "user act flights_3 request destination_airport_name ?", "user-flights_3-request-origin_airport_name-?": "user act flights_3 request origin_airport_name ?", "user-flights_3-request-outbound_arrival_time-?": "user act flights_3 request outbound_arrival_time ?", "user-flights_3-request_alts-none-none": "user act flights_3 request_alts none none", "user-flights_3-select-none-none": "user act flights_3 select none none", "user-flights_4-inform-airlines-1": "user act flights_4 inform airlines 1", "user-flights_4-inform-departure_date-1": "user act flights_4 inform departure_date 1", "user-flights_4-inform-destination_airport-1": "user act flights_4 inform destination_airport 1", "user-flights_4-inform-number_of_tickets-1": "user act flights_4 inform number_of_tickets 1", "user-flights_4-inform-origin_airport-1": "user act flights_4 inform origin_airport 1", "user-flights_4-inform-return_date-1": "user act flights_4 inform return_date 1", "user-flights_4-inform-seating_class-1": "user act flights_4 inform seating_class 1", "user-flights_4-inform_intent-searchonewayflight-1": "user act flights_4 inform_intent searchonewayflight 1", "user-flights_4-inform_intent-searchroundtripflights-1": "user act flights_4 inform_intent searchroundtripflights 1", "user-flights_4-request-inbound_arrival_time-?": "user act flights_4 request inbound_arrival_time ?", "user-flights_4-request-number_of_tickets-?": "user act flights_4 request number_of_tickets ?", "user-flights_4-request-outbound_arrival_time-?": "user act flights_4 request outbound_arrival_time ?", "user-flights_4-request-seating_class-?": "user act flights_4 request seating_class ?", "user-flights_4-request_alts-none-none": "user act flights_4 request_alts none none", "user-flights_4-select-none-none": "user act flights_4 select none none", "user-homes_1-affirm_intent-none-none": "user act homes_1 affirm_intent none none", "user-homes_1-inform-area-1": "user act homes_1 inform area 1", "user-homes_1-inform-furnished-1": "user act homes_1 inform furnished 1", "user-homes_1-inform-number_of_baths-1": "user act homes_1 inform number_of_baths 1", "user-homes_1-inform-number_of_beds-1": "user act homes_1 inform number_of_beds 1", "user-homes_1-inform-pets_allowed-1": "user act homes_1 inform pets_allowed 1", "user-homes_1-inform-visit_date-1": "user act homes_1 inform visit_date 1", "user-homes_1-inform_intent-findapartment-1": "user act homes_1 inform_intent findapartment 1", "user-homes_1-inform_intent-schedulevisit-1": "user act homes_1 inform_intent schedulevisit 1", "user-homes_1-negate_intent-none-none": "user act homes_1 negate_intent none none", "user-homes_1-request-furnished-?": "user act homes_1 request furnished ?", "user-homes_1-request-pets_allowed-?": "user act homes_1 request pets_allowed ?", "user-homes_1-request-phone_number-?": "user act homes_1 request phone_number ?", "user-homes_1-request_alts-none-none": "user act homes_1 request_alts none none", "user-homes_1-select-none-none": "user act homes_1 select none none", "user-homes_2-affirm_intent-none-none": "user act homes_2 affirm_intent none none", "user-homes_2-inform-area-1": "user act homes_2 inform area 1", "user-homes_2-inform-intent-1": "user act homes_2 inform intent 1", "user-homes_2-inform-number_of_baths-1": "user act homes_2 inform number_of_baths 1", "user-homes_2-inform-number_of_beds-1": "user act homes_2 inform number_of_beds 1", "user-homes_2-inform-property_name-1": "user act homes_2 inform property_name 1", "user-homes_2-inform-visit_date-1": "user act homes_2 inform visit_date 1", "user-homes_2-inform_intent-findhomebyarea-1": "user act homes_2 inform_intent findhomebyarea 1", "user-homes_2-inform_intent-schedulevisit-1": "user act homes_2 inform_intent schedulevisit 1", "user-homes_2-request-has_garage-?": "user act homes_2 request has_garage ?", "user-homes_2-request-in_unit_laundry-?": "user act homes_2 request in_unit_laundry ?", "user-homes_2-request-phone_number-?": "user act homes_2 request phone_number ?", "user-homes_2-request_alts-none-none": "user act homes_2 request_alts none none", "user-homes_2-select-none-none": "user act homes_2 select none none", "user-hotels_1-affirm_intent-none-none": "user act hotels_1 affirm_intent none none", "user-hotels_1-inform-check_in_date-1": "user act hotels_1 inform check_in_date 1", "user-hotels_1-inform-destination-1": "user act hotels_1 inform destination 1", "user-hotels_1-inform-has_wifi-1": "user act hotels_1 inform has_wifi 1", "user-hotels_1-inform-hotel_name-1": "user act hotels_1 inform hotel_name 1", "user-hotels_1-inform-number_of_days-1": "user act hotels_1 inform number_of_days 1", "user-hotels_1-inform-number_of_rooms-1": "user act hotels_1 inform number_of_rooms 1", "user-hotels_1-inform-star_rating-1": "user act hotels_1 inform star_rating 1", "user-hotels_1-inform_intent-reservehotel-1": "user act hotels_1 inform_intent reservehotel 1", "user-hotels_1-inform_intent-searchhotel-1": "user act hotels_1 inform_intent searchhotel 1", "user-hotels_1-negate_intent-none-none": "user act hotels_1 negate_intent none none", "user-hotels_1-request-has_wifi-?": "user act hotels_1 request has_wifi ?", "user-hotels_1-request-phone_number-?": "user act hotels_1 request phone_number ?", "user-hotels_1-request-price_per_night-?": "user act hotels_1 request price_per_night ?", "user-hotels_1-request-street_address-?": "user act hotels_1 request street_address ?", "user-hotels_1-request_alts-none-none": "user act hotels_1 request_alts none none", "user-hotels_1-select-none-none": "user act hotels_1 select none none", "user-hotels_2-affirm_intent-none-none": "user act hotels_2 affirm_intent none none", "user-hotels_2-inform-check_in_date-1": "user act hotels_2 inform check_in_date 1", "user-hotels_2-inform-check_out_date-1": "user act hotels_2 inform check_out_date 1", "user-hotels_2-inform-has_laundry_service-1": "user act hotels_2 inform has_laundry_service 1", "user-hotels_2-inform-number_of_adults-1": "user act hotels_2 inform number_of_adults 1", "user-hotels_2-inform-rating-1": "user act hotels_2 inform rating 1", "user-hotels_2-inform-where_to-1": "user act hotels_2 inform where_to 1", "user-hotels_2-inform_intent-bookhouse-1": "user act hotels_2 inform_intent bookhouse 1", "user-hotels_2-inform_intent-searchhouse-1": "user act hotels_2 inform_intent searchhouse 1", "user-hotels_2-negate_intent-none-none": "user act hotels_2 negate_intent none none", "user-hotels_2-request-has_laundry_service-?": "user act hotels_2 request has_laundry_service ?", "user-hotels_2-request-phone_number-?": "user act hotels_2 request phone_number ?", "user-hotels_2-request-total_price-?": "user act hotels_2 request total_price ?", "user-hotels_2-request_alts-none-none": "user act hotels_2 request_alts none none", "user-hotels_2-select-none-none": "user act hotels_2 select none none", "user-hotels_3-affirm_intent-none-none": "user act hotels_3 affirm_intent none none", "user-hotels_3-inform-check_in_date-1": "user act hotels_3 inform check_in_date 1", "user-hotels_3-inform-check_out_date-1": "user act hotels_3 inform check_out_date 1", "user-hotels_3-inform-hotel_name-1": "user act hotels_3 inform hotel_name 1", "user-hotels_3-inform-location-1": "user act hotels_3 inform location 1", "user-hotels_3-inform-number_of_rooms-1": "user act hotels_3 inform number_of_rooms 1", "user-hotels_3-inform-pets_welcome-1": "user act hotels_3 inform pets_welcome 1", "user-hotels_3-inform_intent-reservehotel-1": "user act hotels_3 inform_intent reservehotel 1", "user-hotels_3-inform_intent-searchhotel-1": "user act hotels_3 inform_intent searchhotel 1", "user-hotels_3-negate_intent-none-none": "user act hotels_3 negate_intent none none", "user-hotels_3-request-pets_welcome-?": "user act hotels_3 request pets_welcome ?", "user-hotels_3-request-phone_number-?": "user act hotels_3 request phone_number ?", "user-hotels_3-request-price-?": "user act hotels_3 request price ?", "user-hotels_3-request-street_address-?": "user act hotels_3 request street_address ?", "user-hotels_3-request_alts-none-none": "user act hotels_3 request_alts none none", "user-hotels_3-select-none-none": "user act hotels_3 select none none", "user-hotels_4-affirm_intent-none-none": "user act hotels_4 affirm_intent none none", "user-hotels_4-inform-check_in_date-1": "user act hotels_4 inform check_in_date 1", "user-hotels_4-inform-location-1": "user act hotels_4 inform location 1", "user-hotels_4-inform-number_of_rooms-1": "user act hotels_4 inform number_of_rooms 1", "user-hotels_4-inform-smoking_allowed-1": "user act hotels_4 inform smoking_allowed 1", "user-hotels_4-inform-star_rating-1": "user act hotels_4 inform star_rating 1", "user-hotels_4-inform-stay_length-1": "user act hotels_4 inform stay_length 1", "user-hotels_4-inform_intent-reservehotel-1": "user act hotels_4 inform_intent reservehotel 1", "user-hotels_4-inform_intent-searchhotel-1": "user act hotels_4 inform_intent searchhotel 1", "user-hotels_4-negate_intent-none-none": "user act hotels_4 negate_intent none none", "user-hotels_4-request-phone_number-?": "user act hotels_4 request phone_number ?", "user-hotels_4-request-price_per_night-?": "user act hotels_4 request price_per_night ?", "user-hotels_4-request-smoking_allowed-?": "user act hotels_4 request smoking_allowed ?", "user-hotels_4-request-street_address-?": "user act hotels_4 request street_address ?", "user-hotels_4-request_alts-none-none": "user act hotels_4 request_alts none none", "user-hotels_4-select-none-none": "user act hotels_4 select none none", "user-media_1-affirm_intent-none-none": "user act media_1 affirm_intent none none", "user-media_1-inform-directed_by-1": "user act media_1 inform directed_by 1", "user-media_1-inform-genre-1": "user act media_1 inform genre 1", "user-media_1-inform-subtitles-1": "user act media_1 inform subtitles 1", "user-media_1-inform-title-1": "user act media_1 inform title 1", "user-media_1-inform_intent-findmovies-1": "user act media_1 inform_intent findmovies 1", "user-media_1-inform_intent-playmovie-1": "user act media_1 inform_intent playmovie 1", "user-media_1-negate_intent-none-none": "user act media_1 negate_intent none none", "user-media_1-request-directed_by-?": "user act media_1 request directed_by ?", "user-media_1-request-genre-?": "user act media_1 request genre ?", "user-media_1-request_alts-none-none": "user act media_1 request_alts none none", "user-media_1-select-title-1": "user act media_1 select title 1", "user-media_2-affirm_intent-none-none": "user act media_2 affirm_intent none none", "user-media_2-inform-actors-1": "user act media_2 inform actors 1", "user-media_2-inform-director-1": "user act media_2 inform director 1", "user-media_2-inform-genre-1": "user act media_2 inform genre 1", "user-media_2-inform-subtitle_language-1": "user act media_2 inform subtitle_language 1", "user-media_2-inform_intent-findmovies-1": "user act media_2 inform_intent findmovies 1", "user-media_2-inform_intent-rentmovie-1": "user act media_2 inform_intent rentmovie 1", "user-media_2-request-price-?": "user act media_2 request price ?", "user-media_2-select-movie_name-1": "user act media_2 select movie_name 1", "user-media_3-affirm_intent-none-none": "user act media_3 affirm_intent none none", "user-media_3-inform-genre-1": "user act media_3 inform genre 1", "user-media_3-inform-starring-1": "user act media_3 inform starring 1", "user-media_3-inform-subtitle_language-1": "user act media_3 inform subtitle_language 1", "user-media_3-inform-title-1": "user act media_3 inform title 1", "user-media_3-inform_intent-findmovies-1": "user act media_3 inform_intent findmovies 1", "user-media_3-inform_intent-playmovie-1": "user act media_3 inform_intent playmovie 1", "user-media_3-negate_intent-none-none": "user act media_3 negate_intent none none", "user-media_3-request-starring-?": "user act media_3 request starring ?", "user-media_3-request_alts-none-none": "user act media_3 request_alts none none", "user-media_3-select-title-1": "user act media_3 select title 1", "user-messaging_1-inform-contact_name-1": "user act messaging_1 inform contact_name 1", "user-messaging_1-inform-location-1": "user act messaging_1 inform location 1", "user-messaging_1-inform_intent-sharelocation-1": "user act messaging_1 inform_intent sharelocation 1", "user-movies_1-inform-genre-1": "user act movies_1 inform genre 1", "user-movies_1-inform-location-1": "user act movies_1 inform location 1", "user-movies_1-inform-movie_name-1": "user act movies_1 inform movie_name 1", "user-movies_1-inform-number_of_tickets-1": "user act movies_1 inform number_of_tickets 1", "user-movies_1-inform-show_date-1": "user act movies_1 inform show_date 1", "user-movies_1-inform-show_time-1": "user act movies_1 inform show_time 1", "user-movies_1-inform-show_type-1": "user act movies_1 inform show_type 1", "user-movies_1-inform-theater_name-1": "user act movies_1 inform theater_name 1", "user-movies_1-inform_intent-buymovietickets-1": "user act movies_1 inform_intent buymovietickets 1", "user-movies_1-inform_intent-findmovies-1": "user act movies_1 inform_intent findmovies 1", "user-movies_1-inform_intent-gettimesformovie-1": "user act movies_1 inform_intent gettimesformovie 1", "user-movies_1-negate_intent-none-none": "user act movies_1 negate_intent none none", "user-movies_1-request-genre-?": "user act movies_1 request genre ?", "user-movies_1-request-price-?": "user act movies_1 request price ?", "user-movies_1-request-street_address-?": "user act movies_1 request street_address ?", "user-movies_1-request_alts-none-none": "user act movies_1 request_alts none none", "user-movies_1-select-movie_name-1": "user act movies_1 select movie_name 1", "user-movies_1-select-none-none": "user act movies_1 select none none", "user-movies_2-inform-director-1": "user act movies_2 inform director 1", "user-movies_2-inform-genre-1": "user act movies_2 inform genre 1", "user-movies_2-inform-starring-1": "user act movies_2 inform starring 1", "user-movies_2-inform_intent-findmovies-1": "user act movies_2 inform_intent findmovies 1", "user-movies_2-request_alts-none-none": "user act movies_2 request_alts none none", "user-movies_2-select-none-none": "user act movies_2 select none none", "user-movies_3-inform-cast-1": "user act movies_3 inform cast 1", "user-movies_3-inform-directed_by-1": "user act movies_3 inform directed_by 1", "user-movies_3-inform-genre-1": "user act movies_3 inform genre 1", "user-movies_3-inform_intent-findmovies-1": "user act movies_3 inform_intent findmovies 1", "user-movies_3-request-cast-?": "user act movies_3 request cast ?", "user-movies_3-request-directed_by-?": "user act movies_3 request directed_by ?", "user-movies_3-request-genre-?": "user act movies_3 request genre ?", "user-movies_3-select-none-none": "user act movies_3 select none none", "user-music_1-affirm_intent-none-none": "user act music_1 affirm_intent none none", "user-music_1-inform-album-1": "user act music_1 inform album 1", "user-music_1-inform-artist-1": "user act music_1 inform artist 1", "user-music_1-inform-genre-1": "user act music_1 inform genre 1", "user-music_1-inform-playback_device-1": "user act music_1 inform playback_device 1", "user-music_1-inform-song_name-1": "user act music_1 inform song_name 1", "user-music_1-inform-year-1": "user act music_1 inform year 1", "user-music_1-inform_intent-lookupsong-1": "user act music_1 inform_intent lookupsong 1", "user-music_1-inform_intent-playsong-1": "user act music_1 inform_intent playsong 1", "user-music_1-request-album-?": "user act music_1 request album ?", "user-music_1-request-genre-?": "user act music_1 request genre ?", "user-music_1-request-year-?": "user act music_1 request year ?", "user-music_1-request_alts-none-none": "user act music_1 request_alts none none", "user-music_1-select-none-none": "user act music_1 select none none", "user-music_2-affirm_intent-none-none": "user act music_2 affirm_intent none none", "user-music_2-inform-album-1": "user act music_2 inform album 1", "user-music_2-inform-artist-1": "user act music_2 inform artist 1", "user-music_2-inform-genre-1": "user act music_2 inform genre 1", "user-music_2-inform-playback_device-1": "user act music_2 inform playback_device 1", "user-music_2-inform-song_name-1": "user act music_2 inform song_name 1", "user-music_2-inform_intent-lookupmusic-1": "user act music_2 inform_intent lookupmusic 1", "user-music_2-inform_intent-playmedia-1": "user act music_2 inform_intent playmedia 1", "user-music_2-request-genre-?": "user act music_2 request genre ?", "user-music_2-request_alts-none-none": "user act music_2 request_alts none none", "user-music_2-select-none-none": "user act music_2 select none none", "user-music_3-affirm_intent-none-none": "user act music_3 affirm_intent none none", "user-music_3-inform-album-1": "user act music_3 inform album 1", "user-music_3-inform-artist-1": "user act music_3 inform artist 1", "user-music_3-inform-device-1": "user act music_3 inform device 1", "user-music_3-inform-genre-1": "user act music_3 inform genre 1", "user-music_3-inform-year-1": "user act music_3 inform year 1", "user-music_3-inform_intent-lookupmusic-1": "user act music_3 inform_intent lookupmusic 1", "user-music_3-inform_intent-playmedia-1": "user act music_3 inform_intent playmedia 1", "user-music_3-negate_intent-none-none": "user act music_3 negate_intent none none", "user-music_3-request-genre-?": "user act music_3 request genre ?", "user-music_3-request-year-?": "user act music_3 request year ?", "user-music_3-request_alts-none-none": "user act music_3 request_alts none none", "user-music_3-select-none-none": "user act music_3 select none none", "user-payment_1-inform-amount-1": "user act payment_1 inform amount 1", "user-payment_1-inform-payment_method-1": "user act payment_1 inform payment_method 1", "user-payment_1-inform-private_visibility-1": "user act payment_1 inform private_visibility 1", "user-payment_1-inform-receiver-1": "user act payment_1 inform receiver 1", "user-payment_1-inform_intent-makepayment-1": "user act payment_1 inform_intent makepayment 1", "user-payment_1-inform_intent-requestpayment-1": "user act payment_1 inform_intent requestpayment 1", "user-rentalcars_1-affirm_intent-none-none": "user act rentalcars_1 affirm_intent none none", "user-rentalcars_1-inform-dropoff_date-1": "user act rentalcars_1 inform dropoff_date 1", "user-rentalcars_1-inform-pickup_city-1": "user act rentalcars_1 inform pickup_city 1", "user-rentalcars_1-inform-pickup_date-1": "user act rentalcars_1 inform pickup_date 1", "user-rentalcars_1-inform-pickup_location-1": "user act rentalcars_1 inform pickup_location 1", "user-rentalcars_1-inform-pickup_time-1": "user act rentalcars_1 inform pickup_time 1", "user-rentalcars_1-inform-type-1": "user act rentalcars_1 inform type 1", "user-rentalcars_1-inform_intent-getcarsavailable-1": "user act rentalcars_1 inform_intent getcarsavailable 1", "user-rentalcars_1-inform_intent-reservecar-1": "user act rentalcars_1 inform_intent reservecar 1", "user-rentalcars_1-negate_intent-none-none": "user act rentalcars_1 negate_intent none none", "user-rentalcars_1-request-car_name-?": "user act rentalcars_1 request car_name ?", "user-rentalcars_1-request-total_price-?": "user act rentalcars_1 request total_price ?", "user-rentalcars_1-request_alts-none-none": "user act rentalcars_1 request_alts none none", "user-rentalcars_1-select-none-none": "user act rentalcars_1 select none none", "user-rentalcars_2-affirm_intent-none-none": "user act rentalcars_2 affirm_intent none none", "user-rentalcars_2-inform-car_type-1": "user act rentalcars_2 inform car_type 1", "user-rentalcars_2-inform-dropoff_date-1": "user act rentalcars_2 inform dropoff_date 1", "user-rentalcars_2-inform-pickup_city-1": "user act rentalcars_2 inform pickup_city 1", "user-rentalcars_2-inform-pickup_date-1": "user act rentalcars_2 inform pickup_date 1", "user-rentalcars_2-inform-pickup_location-1": "user act rentalcars_2 inform pickup_location 1", "user-rentalcars_2-inform-pickup_time-1": "user act rentalcars_2 inform pickup_time 1", "user-rentalcars_2-inform_intent-getcarsavailable-1": "user act rentalcars_2 inform_intent getcarsavailable 1", "user-rentalcars_2-inform_intent-reservecar-1": "user act rentalcars_2 inform_intent reservecar 1", "user-rentalcars_2-negate_intent-none-none": "user act rentalcars_2 negate_intent none none", "user-rentalcars_2-request-car_name-?": "user act rentalcars_2 request car_name ?", "user-rentalcars_2-request-total_price-?": "user act rentalcars_2 request total_price ?", "user-rentalcars_2-request_alts-none-none": "user act rentalcars_2 request_alts none none", "user-rentalcars_2-select-none-none": "user act rentalcars_2 select none none", "user-rentalcars_3-affirm_intent-none-none": "user act rentalcars_3 affirm_intent none none", "user-rentalcars_3-inform-add_insurance-1": "user act rentalcars_3 inform add_insurance 1", "user-rentalcars_3-inform-car_type-1": "user act rentalcars_3 inform car_type 1", "user-rentalcars_3-inform-city-1": "user act rentalcars_3 inform city 1", "user-rentalcars_3-inform-end_date-1": "user act rentalcars_3 inform end_date 1", "user-rentalcars_3-inform-pickup_location-1": "user act rentalcars_3 inform pickup_location 1", "user-rentalcars_3-inform-pickup_time-1": "user act rentalcars_3 inform pickup_time 1", "user-rentalcars_3-inform-start_date-1": "user act rentalcars_3 inform start_date 1", "user-rentalcars_3-inform_intent-getcarsavailable-1": "user act rentalcars_3 inform_intent getcarsavailable 1", "user-rentalcars_3-inform_intent-reservecar-1": "user act rentalcars_3 inform_intent reservecar 1", "user-rentalcars_3-negate_intent-none-none": "user act rentalcars_3 negate_intent none none", "user-rentalcars_3-request-car_name-?": "user act rentalcars_3 request car_name ?", "user-rentalcars_3-request-price_per_day-?": "user act rentalcars_3 request price_per_day ?", "user-rentalcars_3-request_alts-none-none": "user act rentalcars_3 request_alts none none", "user-rentalcars_3-select-none-none": "user act rentalcars_3 select none none", "user-restaurants_1-affirm_intent-none-none": "user act restaurants_1 affirm_intent none none", "user-restaurants_1-inform-city-1": "user act restaurants_1 inform city 1", "user-restaurants_1-inform-cuisine-1": "user act restaurants_1 inform cuisine 1", "user-restaurants_1-inform-date-1": "user act restaurants_1 inform date 1", "user-restaurants_1-inform-has_live_music-1": "user act restaurants_1 inform has_live_music 1", "user-restaurants_1-inform-party_size-1": "user act restaurants_1 inform party_size 1", "user-restaurants_1-inform-price_range-1": "user act restaurants_1 inform price_range 1", "user-restaurants_1-inform-restaurant_name-1": "user act restaurants_1 inform restaurant_name 1", "user-restaurants_1-inform-serves_alcohol-1": "user act restaurants_1 inform serves_alcohol 1", "user-restaurants_1-inform-time-1": "user act restaurants_1 inform time 1", "user-restaurants_1-inform_intent-findrestaurants-1": "user act restaurants_1 inform_intent findrestaurants 1", "user-restaurants_1-inform_intent-reserverestaurant-1": "user act restaurants_1 inform_intent reserverestaurant 1", "user-restaurants_1-negate_intent-none-none": "user act restaurants_1 negate_intent none none", "user-restaurants_1-request-cuisine-?": "user act restaurants_1 request cuisine ?", "user-restaurants_1-request-has_live_music-?": "user act restaurants_1 request has_live_music ?", "user-restaurants_1-request-phone_number-?": "user act restaurants_1 request phone_number ?", "user-restaurants_1-request-price_range-?": "user act restaurants_1 request price_range ?", "user-restaurants_1-request-serves_alcohol-?": "user act restaurants_1 request serves_alcohol ?", "user-restaurants_1-request-street_address-?": "user act restaurants_1 request street_address ?", "user-restaurants_1-request_alts-none-none": "user act restaurants_1 request_alts none none", "user-restaurants_1-select-none-none": "user act restaurants_1 select none none", "user-restaurants_2-affirm_intent-none-none": "user act restaurants_2 affirm_intent none none", "user-restaurants_2-inform-category-1": "user act restaurants_2 inform category 1", "user-restaurants_2-inform-date-1": "user act restaurants_2 inform date 1", "user-restaurants_2-inform-has_vegetarian_options-1": "user act restaurants_2 inform has_vegetarian_options 1", "user-restaurants_2-inform-location-1": "user act restaurants_2 inform location 1", "user-restaurants_2-inform-number_of_seats-1": "user act restaurants_2 inform number_of_seats 1", "user-restaurants_2-inform-price_range-1": "user act restaurants_2 inform price_range 1", "user-restaurants_2-inform-restaurant_name-1": "user act restaurants_2 inform restaurant_name 1", "user-restaurants_2-inform-time-1": "user act restaurants_2 inform time 1", "user-restaurants_2-inform_intent-findrestaurants-1": "user act restaurants_2 inform_intent findrestaurants 1", "user-restaurants_2-inform_intent-reserverestaurant-1": "user act restaurants_2 inform_intent reserverestaurant 1", "user-restaurants_2-request-address-?": "user act restaurants_2 request address ?", "user-restaurants_2-request-has_seating_outdoors-?": "user act restaurants_2 request has_seating_outdoors ?", "user-restaurants_2-request-has_vegetarian_options-?": "user act restaurants_2 request has_vegetarian_options ?", "user-restaurants_2-request-phone_number-?": "user act restaurants_2 request phone_number ?", "user-restaurants_2-request-price_range-?": "user act restaurants_2 request price_range ?", "user-restaurants_2-request-rating-?": "user act restaurants_2 request rating ?", "user-restaurants_2-request_alts-none-none": "user act restaurants_2 request_alts none none", "user-restaurants_2-select-none-none": "user act restaurants_2 select none none", "user-ridesharing_1-inform-destination-1": "user act ridesharing_1 inform destination 1", "user-ridesharing_1-inform-number_of_riders-1": "user act ridesharing_1 inform number_of_riders 1", "user-ridesharing_1-inform-shared_ride-1": "user act ridesharing_1 inform shared_ride 1", "user-ridesharing_1-inform_intent-getride-1": "user act ridesharing_1 inform_intent getride 1", "user-ridesharing_1-request-approximate_ride_duration-?": "user act ridesharing_1 request approximate_ride_duration ?", "user-ridesharing_1-request-ride_fare-?": "user act ridesharing_1 request ride_fare ?", "user-ridesharing_2-inform-destination-1": "user act ridesharing_2 inform destination 1", "user-ridesharing_2-inform-number_of_seats-1": "user act ridesharing_2 inform number_of_seats 1", "user-ridesharing_2-inform-ride_type-1": "user act ridesharing_2 inform ride_type 1", "user-ridesharing_2-inform_intent-getride-1": "user act ridesharing_2 inform_intent getride 1", "user-ridesharing_2-request-ride_fare-?": "user act ridesharing_2 request ride_fare ?", "user-ridesharing_2-request-wait_time-?": "user act ridesharing_2 request wait_time ?", "user-services_1-affirm_intent-none-none": "user act services_1 affirm_intent none none", "user-services_1-inform-appointment_date-1": "user act services_1 inform appointment_date 1", "user-services_1-inform-appointment_time-1": "user act services_1 inform appointment_time 1", "user-services_1-inform-city-1": "user act services_1 inform city 1", "user-services_1-inform-is_unisex-1": "user act services_1 inform is_unisex 1", "user-services_1-inform-stylist_name-1": "user act services_1 inform stylist_name 1", "user-services_1-inform_intent-bookappointment-1": "user act services_1 inform_intent bookappointment 1", "user-services_1-inform_intent-findprovider-1": "user act services_1 inform_intent findprovider 1", "user-services_1-negate_intent-none-none": "user act services_1 negate_intent none none", "user-services_1-request-average_rating-?": "user act services_1 request average_rating ?", "user-services_1-request-is_unisex-?": "user act services_1 request is_unisex ?", "user-services_1-request-phone_number-?": "user act services_1 request phone_number ?", "user-services_1-request-street_address-?": "user act services_1 request street_address ?", "user-services_1-request_alts-none-none": "user act services_1 request_alts none none", "user-services_1-select-none-none": "user act services_1 select none none", "user-services_2-affirm_intent-none-none": "user act services_2 affirm_intent none none", "user-services_2-inform-appointment_date-1": "user act services_2 inform appointment_date 1", "user-services_2-inform-appointment_time-1": "user act services_2 inform appointment_time 1", "user-services_2-inform-city-1": "user act services_2 inform city 1", "user-services_2-inform-dentist_name-1": "user act services_2 inform dentist_name 1", "user-services_2-inform_intent-bookappointment-1": "user act services_2 inform_intent bookappointment 1", "user-services_2-inform_intent-findprovider-1": "user act services_2 inform_intent findprovider 1", "user-services_2-negate_intent-none-none": "user act services_2 negate_intent none none", "user-services_2-request-address-?": "user act services_2 request address ?", "user-services_2-request-offers_cosmetic_services-?": "user act services_2 request offers_cosmetic_services ?", "user-services_2-request-phone_number-?": "user act services_2 request phone_number ?", "user-services_2-request_alts-none-none": "user act services_2 request_alts none none", "user-services_2-select-none-none": "user act services_2 select none none", "user-services_3-affirm_intent-none-none": "user act services_3 affirm_intent none none", "user-services_3-inform-appointment_date-1": "user act services_3 inform appointment_date 1", "user-services_3-inform-appointment_time-1": "user act services_3 inform appointment_time 1", "user-services_3-inform-city-1": "user act services_3 inform city 1", "user-services_3-inform-doctor_name-1": "user act services_3 inform doctor_name 1", "user-services_3-inform-type-1": "user act services_3 inform type 1", "user-services_3-inform_intent-bookappointment-1": "user act services_3 inform_intent bookappointment 1", "user-services_3-inform_intent-findprovider-1": "user act services_3 inform_intent findprovider 1", "user-services_3-negate_intent-none-none": "user act services_3 negate_intent none none", "user-services_3-request-average_rating-?": "user act services_3 request average_rating ?", "user-services_3-request-phone_number-?": "user act services_3 request phone_number ?", "user-services_3-request-street_address-?": "user act services_3 request street_address ?", "user-services_3-request_alts-none-none": "user act services_3 request_alts none none", "user-services_3-select-none-none": "user act services_3 select none none", "user-services_4-affirm_intent-none-none": "user act services_4 affirm_intent none none", "user-services_4-inform-appointment_date-1": "user act services_4 inform appointment_date 1", "user-services_4-inform-appointment_time-1": "user act services_4 inform appointment_time 1", "user-services_4-inform-city-1": "user act services_4 inform city 1", "user-services_4-inform-type-1": "user act services_4 inform type 1", "user-services_4-inform_intent-bookappointment-1": "user act services_4 inform_intent bookappointment 1", "user-services_4-inform_intent-findprovider-1": "user act services_4 inform_intent findprovider 1", "user-services_4-negate_intent-none-none": "user act services_4 negate_intent none none", "user-services_4-request-address-?": "user act services_4 request address ?", "user-services_4-request-phone_number-?": "user act services_4 request phone_number ?", "user-services_4-request_alts-none-none": "user act services_4 request_alts none none", "user-services_4-select-none-none": "user act services_4 select none none", "user-trains_1-affirm_intent-none-none": "user act trains_1 affirm_intent none none", "user-trains_1-inform-class-1": "user act trains_1 inform class 1", "user-trains_1-inform-date_of_journey-1": "user act trains_1 inform date_of_journey 1", "user-trains_1-inform-from-1": "user act trains_1 inform from 1", "user-trains_1-inform-number_of_adults-1": "user act trains_1 inform number_of_adults 1", "user-trains_1-inform-to-1": "user act trains_1 inform to 1", "user-trains_1-inform-trip_protection-1": "user act trains_1 inform trip_protection 1", "user-trains_1-inform_intent-findtrains-1": "user act trains_1 inform_intent findtrains 1", "user-trains_1-inform_intent-gettraintickets-1": "user act trains_1 inform_intent gettraintickets 1", "user-trains_1-negate_intent-none-none": "user act trains_1 negate_intent none none", "user-trains_1-request-from_station-?": "user act trains_1 request from_station ?", "user-trains_1-request-to_station-?": "user act trains_1 request to_station ?", "user-trains_1-request_alts-none-none": "user act trains_1 request_alts none none", "user-trains_1-select-none-none": "user act trains_1 select none none", "user-travel_1-inform-category-1": "user act travel_1 inform category 1", "user-travel_1-inform-free_entry-1": "user act travel_1 inform free_entry 1", "user-travel_1-inform-good_for_kids-1": "user act travel_1 inform good_for_kids 1", "user-travel_1-inform-location-1": "user act travel_1 inform location 1", "user-travel_1-inform_intent-findattractions-1": "user act travel_1 inform_intent findattractions 1", "user-travel_1-request-free_entry-?": "user act travel_1 request free_entry ?", "user-travel_1-request-good_for_kids-?": "user act travel_1 request good_for_kids ?", "user-travel_1-request-phone_number-?": "user act travel_1 request phone_number ?", "user-travel_1-request_alts-none-none": "user act travel_1 request_alts none none", "user-travel_1-select-none-none": "user act travel_1 select none none", "user-weather_1-inform-city-1": "user act weather_1 inform city 1", "user-weather_1-inform-date-1": "user act weather_1 inform date 1", "user-weather_1-inform_intent-getweather-1": "user act weather_1 inform_intent getweather 1", "user-weather_1-request-humidity-?": "user act weather_1 request humidity ?", "user-weather_1-request-wind-?": "user act weather_1 request wind ?", "user-weather_1-request_alts-none-none": "user act weather_1 request_alts none none", "user-weather_1-select-none-none": "user act weather_1 select none none"}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/memory.py b/convlab/policy/vtrace_DPT/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e13eb69bf68fa309ad552d62ff752c90ff230f
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/memory.py
@@ -0,0 +1,192 @@
+# Modified by Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import os
+import json, random
+import torch
+import pickle
+
+import logging
+from queue import PriorityQueue
+
+from convlab.util.custom_util import set_seed
+
+
+class Memory:
+
+    def __init__(self, seed=0):
+
+        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
+            cfg = json.load(f)
+
+        self.batch_size = cfg.get('batchsz', 32)
+        self.max_size = cfg.get('memory_size', 2000)
+        self.reservoir_sampling = cfg.get("use_reservoir_sampling", False)
+        logging.info(f"We use reservoir sampling: {self.reservoir_sampling}")
+        self.second_r = False
+        self.reward_weight = 1.0
+        self.priority_queue = PriorityQueue()
+
+        self.size = 0  # total experiences stored
+        self.number_episodes = 0
+
+        self.data_keys = ['states', 'actions', 'rewards', 'small_actions', 'mu', 'action_masks', 'critic_value',
+                          'description_idx_list', 'value_list', 'current_domain_mask', 'non_current_domain_mask']
+        self.reset()
+        set_seed(seed)
+
+    def set_seed(self, seed):
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+        random.seed(seed)
+        torch.manual_seed(seed)
+        if torch.cuda.is_available():
+            torch.cuda.manual_seed_all(seed)
+
+    def reset(self):
+        for k in self.data_keys:
+            setattr(self, k, [[]])
+
+    def update_episode(self, state_list, action_list, reward_list, small_act_list, mu_list, action_mask_list,
+                       critic_value_list, description_idx_list, value_list, current_domain_mask, non_current_domain_mask):
+
+        if len(self.states) > self.max_size:
+            # delete the oldest episode when max-size is reached
+            #for k in self.data_keys:
+            #    getattr(self, k).pop(0)
+            if not self.reservoir_sampling:
+                # We sample a random experience for deletion
+                remove_index = random.choice(range(len(self.states) - 2))
+            else:
+                item = self.priority_queue.get()
+                remove_index = item[1]
+
+            for k in self.data_keys:
+                getattr(self, k).pop(remove_index)
+
+        self.states[-1] = state_list
+        self.actions[-1] = action_list
+        self.rewards[-1] = [r/40.0 for r in reward_list]
+        self.small_actions[-1] = small_act_list
+        self.mu[-1] = mu_list
+        self.action_masks[-1] = action_mask_list
+        self.critic_value[-1] = critic_value_list
+        self.description_idx_list[-1] = description_idx_list
+        self.value_list[-1] = value_list
+        self.current_domain_mask[-1] = current_domain_mask
+        self.non_current_domain_mask[-1] = non_current_domain_mask
+
+        self.states.append([])
+        self.actions.append([])
+        self.rewards.append([])
+        self.small_actions.append([])
+        self.mu.append([])
+        self.action_masks.append([])
+        self.critic_value.append([])
+        self.description_idx_list.append([])
+        self.value_list.append([])
+        self.current_domain_mask.append([])
+        self.non_current_domain_mask.append([])
+
+        self.number_episodes += 1
+
+        if self.reservoir_sampling:
+            self.priority_queue.put((torch.randn(1), len(self.states) - 2))
+
+    def update(self, state, action, reward, next_state, done):
+
+        self.add_experience(state, action, reward, next_state, done)
+
+    def add_experience(self, state, action, reward, next_state, done, mu=None):
+
+        reward = reward / 40.0
+        if isinstance(action, dict):
+            mu = action.get('mu')
+            action_index = action.get('action_index')
+            mask = action.get('mask')
+        else:
+            action_index = action
+
+        if done:
+            self.states[-1].append(state)
+            self.actions[-1].append(action_index)
+            self.rewards[-1].append(reward)
+            self.next_states[-1].append(next_state)
+            #self.dones[-1].append(done)
+            self.mu[-1].append(mu)
+            self.masks[-1].append(mask)
+
+            self.states.append([])
+            self.actions.append([])
+            self.rewards.append([])
+            self.next_states.append([])
+            #self.dones.append([])
+            self.mu.append([])
+            self.masks.append([])
+
+            if len(self.states) > self.max_size:
+                #self.number_episodes = self.max_size
+                #delete the oldest episode when max-size is reached
+                for k in self.data_keys:
+                    getattr(self, k).pop(0)
+            else:
+                self.number_episodes += 1
+
+        else:
+            self.states[-1].append(state)
+            self.actions[-1].append(action_index)
+            self.rewards[-1].append(reward)
+            self.next_states[-1].append(next_state)
+            #self.dones[-1].append(done)
+            self.mu[-1].append(mu)
+            self.masks[-1].append(mask)
+
+        # Actually occupied size of memory
+        if self.size < self.max_size:
+            self.size += 1
+
+    def sample(self, online_offline_ratio=0.0):
+        '''
+        Returns a batch of batch_size samples. Batch is stored as a dict.
+        Keys are the names of the different elements of an experience. Values are an array of the corresponding sampled elements
+        e.g.
+        batch = {
+            'states'     : states,
+            'actions'    : actions,
+            'rewards'    : rewards,
+            'next_states': next_states,
+            'dones'      : dones}
+        '''
+        number_episodes = len(self.states) - 1
+        num_online = 0
+
+        #Sample batch-size many episodes
+        if number_episodes <= self.batch_size:
+            batch_ids = list(range(number_episodes))
+        elif online_offline_ratio != 0:
+            num_online = int(online_offline_ratio * self.batch_size)
+            batch_ids_online = list(range(number_episodes - num_online, number_episodes - 1))
+            batch_ids_offline = np.random.randint(number_episodes - 1 - num_online, size=self.batch_size - num_online).tolist()
+            batch_ids = batch_ids_online + batch_ids_offline
+        else:
+            batch_ids = np.random.randint(number_episodes - 1, size=self.batch_size).tolist()
+
+        batch = {}
+        for k in self.data_keys:
+            batch[k] = [getattr(self, k)[index] for index in batch_ids]
+
+        return batch, num_online
+
+    def save(self, path):
+
+        # PriorityQueue is not serializable, so only save the list behind it
+        self.priority_queue = self.priority_queue.queue
+        with open(path + f'/vtrace.memory', "wb") as f:
+            pickle.dump(self, f)
+
+    def build_priority_queue(self, queue_list):
+
+        self.priority_queue = PriorityQueue()
+        for element in queue_list:
+            self.priority_queue.put(element)
diff --git a/convlab/policy/vtrace_DPT/multiprocessing_helper.py b/convlab/policy/vtrace_DPT/multiprocessing_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..945cf4e066265543a0fa19aa8cf2efc84000b93b
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/multiprocessing_helper.py
@@ -0,0 +1,144 @@
+from torch import multiprocessing as mp
+from copy import deepcopy
+
+import logging
+import torch
+import time
+from convlab.util.custom_util import set_seed
+
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+try:
+    mp.set_start_method('spawn', force=True)
+    mp = mp.get_context('spawn')
+except RuntimeError:
+    pass
+
+
+# we use a queue for every process to guarantee reproducibility
+# queues are used for job submission, while episode queues are used for pushing dialogues inside
+def get_queues(train_processes):
+    queues = []
+    episode_queues = []
+    for p in range(train_processes):
+        queues.append(mp.SimpleQueue())
+        episode_queues.append(mp.SimpleQueue())
+
+    return queues, episode_queues
+
+
+# this is our target function for the processes
+def create_episodes_process(do_queue, put_queue, environment, policy, seed, metric_queue):
+    traj_len = 40
+    set_seed(seed)
+
+    while True:
+        if not do_queue.empty():
+            item = do_queue.get()
+            if item == 'stop':
+                print("Got stop signal.")
+                break
+            else:
+                s = environment.reset(item)
+                rl_return = 0
+                user_act_list, sys_act_list, s_vec_list, action_list, reward_list, small_act_list, action_mask_list, mu_list, \
+                trajectory_list, vector_mask_list, critic_value_list, description_idx_list, value_list, current_domain_mask, \
+                non_current_domain_mask = \
+                    [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
+
+                for t in range(traj_len):
+
+                    s_vec, mask = policy.vector.state_vectorize(s)
+                    with torch.no_grad():
+                        a = policy.predict(s)
+
+                    # s_vec_list.append(policy.info_dict['kg'])
+                    action_list.append(policy.info_dict['big_act'].detach())
+                    small_act_list.append(policy.info_dict['small_act'])
+                    action_mask_list.append(policy.info_dict['action_mask'])
+                    mu_list.append(policy.info_dict['a_prob'].detach())
+                    critic_value_list.append(policy.info_dict['critic_value'])
+                    vector_mask_list.append(torch.Tensor(mask))
+                    description_idx_list.append(policy.info_dict["description_idx_list"])
+                    value_list.append(policy.info_dict["value_list"])
+                    current_domain_mask.append(policy.info_dict["current_domain_mask"])
+                    non_current_domain_mask.append(policy.info_dict["non_current_domain_mask"])
+
+                    sys_act_list.append(policy.vector.action_vectorize(a))
+                    trajectory_list.extend([s['user_action'], a])
+
+                    # interact with env
+                    next_s, r, done = environment.step(a)
+                    rl_return += r
+                    reward_list.append(torch.Tensor([r]))
+
+                    next_s_vec, next_mask = policy.vector.state_vectorize(next_s)
+
+                    # update per step
+                    s = next_s
+
+                    if done:
+                        metric_queue.put({"success": environment.evaluator.success_strict, "return": rl_return,
+                                          "avg_actions": torch.stack(action_list).sum(dim=-1).mean().item(),
+                                          "turns": t, "goal": item.domain_goals})
+                        put_queue.put((description_idx_list, action_list, reward_list, small_act_list, mu_list,
+                                       action_mask_list, critic_value_list, description_idx_list, value_list,
+                                       current_domain_mask, non_current_domain_mask))
+                        break
+
+
+def start_processes(train_processes, queues, episode_queues, env, policy_sys, seed, metric_queue):
+    logging.info("Spawning processes..")
+    processes = []
+    for i in range(train_processes):
+        process_args = (queues[i], episode_queues[i], env, policy_sys, seed, metric_queue)
+        p = mp.Process(target=create_episodes_process, args=process_args)
+        processes.append(p)
+    for b, p in enumerate(processes):
+        p.daemon = True
+        p.start()
+        logging.info(f"Started process {b}")
+    return processes
+
+
+def terminate_processes(processes, queues):
+    # kill processes properly
+    logging.info("Terminating processes..")
+    for b, p in enumerate(processes):
+        queues[b].put('stop')
+    time.sleep(2)
+    for b, p in enumerate(processes):
+        p.terminate()
+        logging.info(f"Terminated process {b}")
+
+
+def submit_jobs(num_jobs, queues, episode_queues, train_processes, memory, goals, metric_queue):
+    # first create goals with global environment and put them into queue.
+    # If every environment process would do that itself, it could happen that environment 1 creates 24 dialogues in
+    # one run and 25 in another run (for two processes and 50 jobs for instance)
+    metrics = []
+    for job in range(num_jobs):
+        if goals:
+            goal = goals.pop()
+            queues[job % train_processes].put(goal)
+    time_now = time.time()
+    collected_dialogues = 0
+    episode_list = []
+    for b in range(train_processes):
+        episode_list.append([])
+    # we need to have a dialogue list for every process, otherwise it could happen that the order in which dialogues
+    # are pushed into the list is different for different runs
+    # in the end the dialogue lists are just appended basically instead of being possibly mixed
+    while collected_dialogues != num_jobs:
+        for b in range(train_processes):
+            if not episode_queues[b].empty():
+                metrics.append(metric_queue.get())
+                dialogue = episode_queues[b].get()
+                dialogue_ = deepcopy(dialogue)
+                episode_list[b].append(dialogue_)
+                del dialogue
+                collected_dialogues += 1
+    for b in range(train_processes):
+        for dialogue in episode_list[b]:
+            memory.update_episode(*dialogue)
+    return time_now, metrics
diff --git a/convlab/policy/vtrace_DPT/semantic_level_config.json b/convlab/policy/vtrace_DPT/semantic_level_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..7f6d266751a32642c9842af3bca0d80a33a97beb
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/semantic_level_config.json
@@ -0,0 +1,48 @@
+{"goals": {"single_domains": false, "allowed_domains": null},
+
+	"model": {
+		"load_path": "",
+		"use_pretrained_initialisation": false,
+		"pretrained_load_path": "",
+		"seed": 0,
+        "process_num": 4,
+		"eval_frequency": 1000,
+        "num_eval_dialogues": 500,
+        "process_num_train": 1,
+        "total_dialogues": 10000,
+        "update_rounds": 1,
+		"new_dialogues": 2,
+		"sys_semantic_to_usr": false
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_nodes.VectorNodes",
+			"ini_params": {
+				"use_masking": true,
+				"manually_add_entity_names": true,
+				"seed": 0,
+				"dataset_name": "multiwoz21",
+				"filter_state": true
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"RuleDST": {
+			"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
+			"ini_params": {}
+		}
+	},
+	"sys_nlg": {},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"RulePolicy": {
+			"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
+			"ini_params": {
+				"character": "usr"
+			}
+		}
+	},
+	"usr_nlg": {}
+}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/supervised/loader.py b/convlab/policy/vtrace_DPT/supervised/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbbcaafb95672711d17715294c6dcd9ded9326b
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/supervised/loader.py
@@ -0,0 +1,169 @@
+import os
+import pickle
+import torch
+import time
+import torch.utils.data as data
+
+from convlab.policy.vector.vector_binary import VectorBinary
+from convlab.util import load_policy_data, load_dataset
+from convlab.util.custom_util import flatten_acts
+from convlab.util.multiwoz.state import default_state
+from convlab.policy.vector.dataset import ActDatasetKG
+from tqdm import tqdm
+
+mwoz_domains = ['restaurant', 'hotel', 'train', 'taxi', 'attraction']
+
+
+class PolicyDataVectorizer:
+
+    def __init__(self, dataset_name='multiwoz21', vector=None, percentage=1.0, dialogue_order=0):
+        self.dataset_name = dataset_name
+        self.percentage = percentage
+        self.dialogue_order = dialogue_order
+        if vector is None:
+            self.vector = VectorBinary(dataset_name)
+        else:
+            self.vector = vector
+        self.process_data()
+
+    def process_data(self):
+
+        processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                     f'processed_data/{self.dataset_name}_{type(self.vector).__name__}')
+        if self.percentage != 1.0:
+            processed_dir += f"_{self.percentage}_{self.dialogue_order}"
+        if os.path.exists(processed_dir):
+            print('Load processed data file')
+            self._load_data(processed_dir)
+        else:
+            print('Start preprocessing the dataset, this can take a while..')
+            self._build_data(processed_dir)
+
+    def _build_data(self, processed_dir):
+        self.data = {}
+
+        os.makedirs(processed_dir, exist_ok=True)
+        dataset = load_dataset(self.dataset_name, dial_ids_order=self.dialogue_order,
+                               split2ratio={'train': self.percentage, 'validation': self.percentage,
+                                            'test': self.percentage})
+        data_split = load_policy_data(dataset, context_window_size=2)
+
+        for split in data_split:
+            self.data[split] = []
+            raw_data = data_split[split]
+
+            for data_point in tqdm(raw_data):
+                state = default_state()
+
+                state['belief_state'] = data_point['context'][-1]['state']
+                state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
+                last_system_act = data_point['context'][-2]['dialogue_acts'] \
+                    if len(data_point['context']) > 1 else {}
+                state['system_action'] = flatten_acts(last_system_act)
+                state['terminated'] = data_point['terminated']
+                if 'booked' in data_point:
+                    state['booked'] = data_point['booked']
+                dialogue_act = flatten_acts(data_point['dialogue_acts'])
+
+                vectorized_state, mask = self.vector.state_vectorize(state)
+                vectorized_action = self.vector.action_vectorize(dialogue_act)
+                self.data[split].append({"state": self.vector.kg_info, "action": vectorized_action, "mask": mask,
+                                         "terminated": state['terminated']})
+
+            with open(os.path.join(processed_dir, '{}.pkl'.format(split)), 'wb') as f:
+                pickle.dump(self.data[split], f)
+
+        print("Data processing done.")
+
+    def _load_data(self, processed_dir):
+        self.data = {}
+        for part in ['train', 'validation', 'test']:
+            with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'rb') as f:
+                self.data[part] = pickle.load(f)
+
+    def is_multiwoz_like(self, item):
+
+        state = item['state']
+        is_like = False
+        for node in state:
+            domain = node['domain'].lower()
+            for mw_domain in mwoz_domains:
+                # we check if the mw_domain as a string is contained in the domain of the node
+                if mw_domain in domain:
+                    is_like = True
+                    break
+            if is_like:
+                break
+        return is_like
+
+    def create_dataset(self, part, batchsz, policy, multiwoz_like=False):
+        print('Start creating {} dataset'.format(part))
+        time_now = time.time()
+
+        root_dir = os.path.dirname(os.path.abspath(__file__))
+        data_dir = os.path.join(root_dir, "data", self.dataset_name)
+        os.makedirs(data_dir, exist_ok=True)
+        file_path = os.path.join(data_dir, part)
+        if multiwoz_like:
+            file_path += "mw"
+
+        if self.percentage != 1.0:
+            file_path += f"_{self.percentage}_{self.dialogue_order}"
+
+        if os.path.exists(file_path):
+            action_batch, a_masks, max_length, small_act_batch, \
+            current_domain_mask_batch, non_current_domain_mask_batch, \
+            description_batch, value_batch, kg_list = torch.load(file_path)
+            print(f"Loaded data from {file_path}")
+        else:
+            print("Creating data from scratch.")
+
+            action_batch, small_act_batch, \
+            current_domain_mask_batch, non_current_domain_mask_batch, \
+            description_batch, value_batch = [], [], [], [], [], []
+            kg_list = []
+
+            for num, item in tqdm(enumerate(self.data[part])):
+
+                if item['action'].sum() == 0 or len(item['state']) == 0:
+                    continue
+                if multiwoz_like:
+                    if not self.is_multiwoz_like(item):
+                        continue
+                action_batch.append(torch.Tensor(item['action']))
+
+                kg = [item['state']]
+                kg_list.append(item['state'])
+
+                description_idx_list, value_list = policy.get_descriptions_and_values(kg)
+                description_batch.append(description_idx_list)
+                value_batch.append(value_list)
+
+                current_domains = policy.get_current_domains(kg)
+                current_domain_mask = policy.action_embedder.get_current_domain_mask(current_domains[0], current=True)
+                non_current_domain_mask = policy.action_embedder.get_current_domain_mask(current_domains[0], current=False)
+                current_domain_mask_batch.append(current_domain_mask)
+                non_current_domain_mask_batch.append(non_current_domain_mask)
+
+                small_act_batch.append(torch.Tensor(policy.action_embedder.real_action_to_small_action_list(torch.Tensor(item['action']))))
+
+            print("Creating action masks..")
+            a_masks, max_length = policy.get_action_masks(action_batch)
+            action_batch = torch.stack(action_batch)
+            current_domain_mask_batch = torch.stack(current_domain_mask_batch)
+            non_current_domain_mask_batch = torch.stack(non_current_domain_mask_batch)
+
+            print(f"Finished data set, time spent: {time.time() - time_now}")
+
+            torch.save([action_batch, a_masks, max_length, small_act_batch,
+                        current_domain_mask_batch, non_current_domain_mask_batch,
+                        description_batch, value_batch, kg_list], file_path)
+
+        dataset = ActDatasetKG(action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch)
+        dataloader = data.DataLoader(dataset, batchsz, True)
+        print("NUMBER OF EXAMPLES:", len(current_domain_mask_batch))
+        return dataloader, max_length, small_act_batch, description_batch, value_batch, kg_list
+
+
+if __name__ == '__main__':
+    data_loader = PolicyDataVectorizer()
diff --git a/convlab/policy/vtrace_DPT/supervised/train_supervised.py b/convlab/policy/vtrace_DPT/supervised/train_supervised.py
new file mode 100644
index 0000000000000000000000000000000000000000..1807a671da7e2938173a18277cd21980ee577a11
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/supervised/train_supervised.py
@@ -0,0 +1,245 @@
+import argparse
+import os
+import torch
+import logging
+import json
+import sys
+
+from torch import optim
+from copy import deepcopy
+from convlab.policy.vtrace_DPT.supervised.loader import PolicyDataVectorizer
+from convlab.util.custom_util import set_seed, init_logging, save_config
+from convlab.util.train_util import to_device
+from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder
+from convlab.policy.vector.vector_nodes import VectorNodes
+
+root_dir = os.path.dirname(
+    os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+sys.path.append(root_dir)
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class MLE_Trainer:
+    def __init__(self, manager, cfg, policy):
+        self.start_policy = deepcopy(policy)
+        self.policy = policy
+        self.policy_optim = optim.Adam(list(self.policy.parameters()), lr=cfg['supervised_lr'])
+        self.entropy_weight = cfg['entropy_weight']
+        self.regularization_weight = cfg['regularization_weight']
+        self._init_data(manager, cfg)
+
+    def _init_data(self, manager, cfg):
+        multiwoz_like = cfg['multiwoz_like']
+        self.data_train, self.max_length_train, self.small_act_train, self.descriptions_train, self.values_train, \
+            self.kg_train = manager.create_dataset('train', cfg['batchsz'], self.policy, multiwoz_like)
+        self.data_valid, self.max_length_valid, self.small_act_valid, self.descriptions_valid, self.values_valid, \
+            self.kg_valid = manager.create_dataset('validation', cfg['batchsz'], self.policy, multiwoz_like)
+        self.data_test, self.max_length_test, self.small_act_test, self.descriptions_test, self.values_test, \
+            self.kg_test = manager.create_dataset('test', cfg['batchsz'], self.policy, multiwoz_like)
+        self.save_dir = cfg['save_dir']
+
+    def policy_loop(self, data):
+
+        actions, action_masks, current_domain_mask, non_current_domain_mask, indices = to_device(data)
+
+        small_act_batch = [self.small_act_train[i].to(DEVICE) for i in indices]
+        description_batch = [self.descriptions_train[i].to(DEVICE) for i in indices]
+        value_batch = [self.values_train[i].to(DEVICE) for i in indices]
+
+        log_prob, entropy = self.policy.get_log_prob(actions, action_masks, self.max_length_train, small_act_batch,
+                                 current_domain_mask, non_current_domain_mask,
+                                 description_batch, value_batch)
+        loss_a = -1 * log_prob.mean()
+
+        weight_loss = self.weight_loss()
+
+        return loss_a, -entropy, weight_loss
+
+    def weight_loss(self):
+
+        loss = 0
+        num_params = sum(p.numel() for p in self.policy.parameters() if p.requires_grad)
+        for paramA, paramB in zip(self.policy.parameters(), self.start_policy.parameters()):
+            loss += torch.sum(torch.abs(paramA - paramB.detach()))
+        return loss / num_params
+
+    def imitating(self):
+        """
+        pretrain the policy by simple imitation learning (behavioral cloning)
+        """
+        self.policy.train()
+        a_loss = 0.
+        for i, data in enumerate(self.data_train):
+            self.policy_optim.zero_grad()
+            loss_a, entropy_loss, weight_loss = self.policy_loop(data)
+            a_loss += loss_a.item()
+            loss_a = loss_a + self.entropy_weight * entropy_loss + self.regularization_weight * weight_loss
+
+            if i % 20 == 0 and i != 0:
+                print("LOSS:", a_loss / 20.0)
+                a_loss = 0
+            loss_a.backward()
+            for p in self.policy.parameters():
+                if p.grad is not None:
+                    p.grad[p.grad != p.grad] = 0.0
+            self.policy_optim.step()
+
+        self.policy.eval()
+
+    def validate(self):
+        def f1(a, target):
+            TP, FP, FN = 0, 0, 0
+            real = target.nonzero().tolist()
+            predict = a.nonzero().tolist()
+            for item in real:
+                if item in predict:
+                    TP += 1
+                else:
+                    FN += 1
+            for item in predict:
+                if item not in real:
+                    FP += 1
+            return TP, FP, FN
+
+        average_actions, average_target_actions, counter = 0, 0, 0
+        a_TP, a_FP, a_FN = 0, 0, 0
+        for i, data in enumerate(self.data_valid):
+            counter += 1
+            target_a, action_masks, current_domain_mask, non_current_domain_mask, indices = to_device(data)
+
+            kg_batch = [self.kg_valid[i] for i in indices]
+            a = torch.stack([self.policy.select_action([kg]) for kg in kg_batch])
+
+            TP, FP, FN = f1(a, target_a)
+            a_TP += TP
+            a_FP += FP
+            a_FN += FN
+
+            average_actions += a.float().sum(dim=-1).mean()
+            average_target_actions += target_a.float().sum(dim=-1).mean()
+
+        logging.info(f"Average actions: {average_actions / counter}")
+        logging.info(f"Average target actions: {average_target_actions / counter}")
+        prec = a_TP / (a_TP + a_FP)
+        rec = a_TP / (a_TP + a_FN)
+        F1 = 2 * prec * rec / (prec + rec)
+        return prec, rec, F1
+
+    def test(self):
+        def f1(a, target):
+            TP, FP, FN = 0, 0, 0
+            real = target.nonzero().tolist()
+            predict = a.nonzero().tolist()
+            for item in real:
+                if item in predict:
+                    TP += 1
+                else:
+                    FN += 1
+            for item in predict:
+                if item not in real:
+                    FP += 1
+            return TP, FP, FN
+
+        a_TP, a_FP, a_FN = 0, 0, 0
+        for i, data in enumerate(self.data_test):
+            s, target_a = to_device(data)
+            a_weights = self.policy(s)
+            a = a_weights.ge(0)
+            TP, FP, FN = f1(a, target_a)
+            a_TP += TP
+            a_FP += FP
+            a_FN += FN
+
+        prec = a_TP / (a_TP + a_FP)
+        rec = a_TP / (a_TP + a_FN)
+        F1 = 2 * prec * rec / (prec + rec)
+        print(a_TP, a_FP, a_FN, F1)
+
+    def save(self, directory, epoch):
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+
+        torch.save(self.policy.state_dict(), directory + '/supervised.pol.mdl')
+
+        logging.info('<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))
+
+
+def arg_parser():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--seed", type=int, default=0)
+    parser.add_argument("--eval_freq", type=int, default=1)
+    parser.add_argument("--dataset_name", type=str, default="multiwoz21")
+    parser.add_argument("--model_path", type=str, default="")
+
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+
+    args = arg_parser()
+
+    root_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+    with open(os.path.join(root_directory, 'config.json'), 'r') as f:
+        cfg = json.load(f)
+
+    cfg['dataset_name'] = args.dataset_name
+
+    logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
+        init_logging(os.path.dirname(os.path.abspath(__file__)), "info")
+    save_config(vars(args), cfg, config_save_path)
+
+    set_seed(args.seed)
+    logging.info(f"Seed used: {args.seed}")
+    logging.info(f"Batch size: {cfg['batchsz']}")
+    logging.info(f"Epochs: {cfg['epoch']}")
+    logging.info(f"Learning rate: {cfg['supervised_lr']}")
+    logging.info(f"Entropy weight: {cfg['entropy_weight']}")
+    logging.info(f"Regularization weight: {cfg['regularization_weight']}")
+    logging.info(f"Only use multiwoz like domains: {cfg['multiwoz_like']}")
+    logging.info(f"We use: {cfg['data_percentage']*100}% of the data")
+    logging.info(f"Dialogue order used: {cfg['dialogue_order']}")
+
+    vector = VectorNodes(dataset_name=args.dataset_name, use_masking=False, filter_state=True)
+    manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector,
+                                   percentage=cfg['data_percentage'], dialogue_order=cfg["dialogue_order"])
+    policy = EncoderDecoder(**cfg, action_dict=vector.act2vec).to(device=DEVICE)
+    try:
+        policy.load_state_dict(torch.load(args.model_path, map_location=DEVICE))
+        logging.info(f"Loaded model from {args.model_path}")
+    except:
+        logging.info("Didnt load a model")
+    agent = MLE_Trainer(manager, cfg, policy)
+
+    logging.info('Start training')
+
+    best_recall = 0.0
+    best_precision = 0.0
+    best_f1 = 0.0
+    precision = 0
+    recall = 0
+    f1 = 0
+
+    for e in range(cfg['epoch']):
+        agent.imitating()
+        logging.info(f"Epoch: {e}")
+
+        if e % args.eval_freq == 0:
+            precision, recall, f1 = agent.validate()
+
+        logging.info(f"Precision: {precision}")
+        logging.info(f"Recall: {recall}")
+        logging.info(f"F1: {f1}")
+
+        if precision > best_precision:
+            best_precision = precision
+        if recall > best_recall:
+            best_recall = recall
+        if f1 > best_f1:
+            best_f1 = f1
+            agent.save(save_path, e)
+        logging.info(f"Best Precision: {best_precision}")
+        logging.info(f"Best Recall: {best_recall}")
+        logging.info(f"Best F1: {best_f1}")
diff --git a/convlab/policy/vtrace_DPT/train.py b/convlab/policy/vtrace_DPT/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f441da29cf44fb366daf9f5afd3f2f21d1bd420d
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/train.py
@@ -0,0 +1,224 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Sun Jul 14 16:14:07 2019
+@author: Chris Geishauser
+"""
+
+import sys
+import os
+import logging
+import time
+import torch
+
+from torch import multiprocessing as mp
+from argparse import ArgumentParser
+from convlab.policy.vtrace_DPT import VTRACE
+from convlab.policy.vtrace_DPT.memory import Memory
+from convlab.policy.vtrace_DPT.multiprocessing_helper import get_queues, start_processes, submit_jobs, \
+    terminate_processes
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+from convlab.util.custom_util import set_seed, init_logging, save_config, move_finished_training, env_config, \
+    eval_policy, log_start_args, save_best, load_config_file, create_goals, get_config
+from datetime import datetime
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = DEVICE
+
+try:
+    mp.set_start_method('spawn', force=True)
+    mp = mp.get_context('spawn')
+except RuntimeError:
+    pass
+
+
+def create_episodes(environment, policy, num_episodes, memory, goals):
+    sampled_num = 0
+    traj_len = 40
+
+    while sampled_num < num_episodes:
+        goal = goals.pop()
+        s = environment.reset(goal)
+
+        user_act_list, sys_act_list, s_vec_list, action_list, reward_list, small_act_list, action_mask_list, mu_list, \
+        trajectory_list, vector_mask_list, critic_value_list, description_idx_list, value_list, current_domain_mask, \
+        non_current_domain_mask = \
+            [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
+
+        for t in range(traj_len):
+
+            s_vec, mask = policy.vector.state_vectorize(s)
+            with torch.no_grad():
+                a = policy.predict(s)
+
+            # s_vec_list.append(policy.info_dict['kg'])
+            action_list.append(policy.info_dict['big_act'].detach())
+            small_act_list.append(policy.info_dict['small_act'])
+            action_mask_list.append(policy.info_dict['action_mask'])
+            mu_list.append(policy.info_dict['a_prob'].detach())
+            critic_value_list.append(policy.info_dict['critic_value'])
+            vector_mask_list.append(torch.Tensor(mask))
+            description_idx_list.append(policy.info_dict["description_idx_list"])
+            value_list.append(policy.info_dict["value_list"])
+            current_domain_mask.append(policy.info_dict["current_domain_mask"])
+            non_current_domain_mask.append(policy.info_dict["non_current_domain_mask"])
+
+            sys_act_list.append(policy.vector.action_vectorize(a))
+            trajectory_list.extend([s['user_action'], a])
+
+            # interact with env
+            next_s, r, done = environment.step(a)
+            reward_list.append(torch.Tensor([r]))
+
+            next_s_vec, next_mask = policy.vector.state_vectorize(next_s)
+
+            # update per step
+            s = next_s
+
+            if done:
+                memory.update_episode(description_idx_list, action_list, reward_list, small_act_list, mu_list,
+                                      action_mask_list, critic_value_list, description_idx_list, value_list,
+                                      current_domain_mask, non_current_domain_mask)
+                break
+
+        sampled_num += 1
+
+
+def log_train_configs():
+    logging.info('Train seed is ' + str(seed))
+    logging.info("Start of Training: " +
+                 time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
+    logging.info(f"Number of processes for training: {train_processes}")
+    logging.info(f"Number of new dialogues per update: {new_dialogues}")
+    logging.info(f"Number of total dialogues: {total_dialogues}")
+
+
+if __name__ == '__main__':
+
+    time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
+
+    begin_time = datetime.now()
+    parser = ArgumentParser()
+    parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json',
+                        help="Load path for config file")
+    parser.add_argument("--seed", type=int, default=None,
+                        help="Seed for the policy parameter initialization")
+    parser.add_argument("--mode", type=str, default='info',
+                        help="Set level for logger")
+    parser.add_argument("--save_eval_dials", type=bool, default=False,
+                        help="Flag for saving dialogue_info during evaluation")
+
+    path = parser.parse_args().path
+    seed = parser.parse_args().seed
+    mode = parser.parse_args().mode
+    save_eval = parser.parse_args().save_eval_dials
+
+    logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
+        init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
+
+    args = [('model', 'seed', seed)] if seed is not None else list()
+
+    environment_config = load_config_file(path)
+
+    conf = get_config(path, args)
+    seed = conf['model']['seed']
+    set_seed(seed)
+
+    policy_sys = VTRACE(is_train=True, seed=seed, vectorizer=conf['vectorizer_sys_activated'],
+                        load_path=conf['model']['load_path'])
+    policy_sys.share_memory()
+    memory = Memory(seed=seed)
+    policy_sys.current_time = current_time
+    policy_sys.log_dir = config_save_path.replace('configs', 'logs')
+    policy_sys.save_dir = save_path
+
+    save_config(vars(parser.parse_args()), environment_config, config_save_path, policy_config=policy_sys.cfg)
+
+    env, sess = env_config(conf, policy_sys)
+
+    # Setup uncertainty thresholding
+    if env.sys_dst:
+        try:
+            if env.sys_dst.use_confidence_scores:
+                policy_sys.vector.setup_uncertain_query(env.sys_dst.thresholds)
+        except:
+            logging.info('Uncertainty threshold not set.')
+
+    single_domains = conf['goals']['single_domains']
+    allowed_domains = conf['goals']['allowed_domains']
+    logging.info(f"Single domains only: {single_domains}")
+    logging.info(f"Allowed domains {allowed_domains}")
+
+    logging.info(f"Evaluating at start - {time_now}" + '-'*60)
+    time_now = time.time()
+    eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path,
+                            single_domain_goals=single_domains, allowed_domains=allowed_domains)
+    logging.info(f"Finished evaluating, time spent: {time.time() - time_now}")
+
+    for key in eval_dict:
+        tb_writer.add_scalar(key, eval_dict[key], 0)
+    best_complete_rate = eval_dict['complete_rate']
+    best_success_rate = eval_dict['success_rate_strict']
+    best_return = eval_dict['avg_return']
+
+    train_processes = conf['model']["process_num_train"]
+
+    if train_processes > 1:
+        # We use multiprocessing
+        queues, episode_queues = get_queues(train_processes)
+        online_metric_queue = mp.SimpleQueue()
+        processes = start_processes(train_processes, queues, episode_queues, env, policy_sys, seed,
+                                    online_metric_queue)
+    goal_generator = GoalGenerator()
+
+    num_dialogues = 0
+    new_dialogues = conf['model']["new_dialogues"]
+    total_dialogues = conf['model']["total_dialogues"]
+
+    log_train_configs()
+
+    while num_dialogues < total_dialogues:
+
+        goals = create_goals(goal_generator, new_dialogues, single_domains=single_domains,
+                             allowed_domains=allowed_domains)
+        if train_processes > 1:
+            time_now, metrics = submit_jobs(new_dialogues, queues, episode_queues, train_processes, memory, goals,
+                                            online_metric_queue)
+        else:
+            create_episodes(env, policy_sys, new_dialogues, memory, goals)
+        num_dialogues += new_dialogues
+
+        for r in range(conf['model']['update_rounds']):
+            if num_dialogues > 50:
+                policy_sys.update(memory)
+                torch.cuda.empty_cache()
+
+        if num_dialogues % conf['model']['eval_frequency'] == 0:
+            time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
+            logging.info(f"Evaluating after Dialogues: {num_dialogues} - {time_now}" + '-' * 60)
+
+            eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path,
+                                    single_domain_goals=single_domains, allowed_domains=allowed_domains)
+
+            best_complete_rate, best_success_rate, best_return = \
+                save_best(policy_sys, best_complete_rate, best_success_rate, best_return,
+                          eval_dict["complete_rate"], eval_dict["success_rate_strict"],
+                          eval_dict["avg_return"], save_path)
+            policy_sys.save(save_path, "last")
+            for key in eval_dict:
+                tb_writer.add_scalar(key, eval_dict[key], num_dialogues)
+
+    logging.info("End of Training: " +
+                 time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
+
+    if train_processes > 1:
+        terminate_processes(processes, queues)
+
+    f = open(os.path.join(dir_path, "time.txt"), "a")
+    f.write(str(datetime.now() - begin_time))
+    f.close()
+
+    move_finished_training(dir_path, os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "finished_experiments"))
diff --git a/convlab/policy/vtrace_DPT/transformer_model/EncoderCritic.py b/convlab/policy/vtrace_DPT/transformer_model/EncoderCritic.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d64deffe2a8b4c74072a3f6f625bfb47f359685
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/EncoderCritic.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import logging
+
+from torch.nn.utils.rnn import pad_sequence
+from .noisy_linear import NoisyLinear
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class EncoderCritic(nn.Module):
+
+    def __init__(self, node_embedder, encoder, cls_dim=128, independent=True, enc_nhead=2, noisy_linear=False,
+                 **kwargs):
+        super(EncoderCritic, self).__init__()
+
+        self.node_embedder = node_embedder
+        self.encoder = encoder
+        #self.cls = torch.nn.Parameter(torch.randn(cls_dim), requires_grad=True).to(DEVICE)
+        self.cls = torch.randn(cls_dim)
+        self.cls.requires_grad = True
+        self.cls = torch.nn.Parameter(self.cls)
+        if noisy_linear:
+            logging.info("EncoderCritic: We use noisy linear layers.")
+            self.linear = NoisyLinear(cls_dim, 1).to(DEVICE)
+        else:
+            self.linear = torch.nn.Linear(cls_dim, 1).to(DEVICE)
+        self.num_heads = enc_nhead
+
+        logging.info(f"Initialised critic. Independent: {independent}")
+
+    def forward(self, descriptions_list, value_list):
+        # return output of cls token
+        return self.linear(self.encode_kg(descriptions_list, value_list)[:, 0, :])
+
+    def encode_kg(self, descriptions_list, value_list):
+        #encoder_mask = self.compute_mask_extended(kg_list)
+        encoder_mask = self.compute_mask(descriptions_list)
+        embedded_nodes = self.embedd_nodes(descriptions_list, value_list)
+        padded_graphs = pad_sequence(embedded_nodes, batch_first=False).to(DEVICE)
+        encoded_nodes, att_weights = self.encoder(padded_graphs, src_key_padding_mask=encoder_mask)
+        # size [num_graphs, max_num_nodes, enc_input_dim]
+        return encoded_nodes.permute(1, 0, 2)
+
+    def embedd_nodes(self, descriptions_list, value_list):
+        kg_sizes = [len(kg) for kg in descriptions_list]
+
+        # we view the kg_list as one huge knowledge graph to embed all nodes simultaneously
+        flattened_descriptions = torch.stack(
+            [descr_idx for descr_list in descriptions_list for descr_idx in descr_list]).to(DEVICE)
+        flattened_values = torch.stack(
+            [value for values in value_list for value in values])
+        flat_embedded_nodes = self.node_embedder(flattened_descriptions, flattened_values).to(DEVICE)
+
+        #now get back the individual knowledge graphs
+        embedded_nodes = []
+        counter = 0
+        for size in kg_sizes:
+            embedded_nodes.append(
+                torch.cat([self.cls.unsqueeze(0), flat_embedded_nodes[counter:counter + size, :]], dim=0))
+            counter += size
+        return embedded_nodes
+
+    def compute_mask(self, kg_list, all=True):
+        # we add 1 for the cls_node in every graph
+        kg_sizes = [len(kg) + 1 for kg in kg_list]
+        max_size = max(kg_sizes)
+
+        attention_mask = torch.ones((len(kg_list), max_size))
+
+        for idx, size in enumerate(kg_sizes):
+            if not all:
+                attention_mask[idx, idx] = 0
+            else:
+                attention_mask[idx, :size] = torch.zeros(size)
+
+        return attention_mask.bool().to(DEVICE)
+
+    def compute_mask_extended(self, kg_list):
+
+        kg_sizes = [len(kg) + 1 for kg in kg_list]
+        max_size = max(kg_sizes)
+        attention_mask = torch.ones((len(kg_list), max_size, max_size))
+
+        domain_list = []
+        for kg in kg_list:
+            node_dict = {}
+            for idx, node in enumerate(kg):
+                domain = node['domain']
+                if domain not in node_dict:
+                    node_dict[domain] = torch.ones(max_size)
+                    node_dict[domain][idx + 1] = 0
+                else:
+                    node_dict[domain][idx + 1] = 0
+
+            domain_list.append(node_dict)
+
+        for idx, kg in enumerate(kg_list):
+            for idx_n, node in enumerate(kg):
+                domain = node['domain']
+                attention_mask[idx, idx_n + 1] = domain_list[idx][domain]
+
+            attention_mask[idx, 0, :len(kg) + 1] = torch.zeros(len(kg) + 1)
+            pad_size = max_size - (len(kg) + 1)
+            attention_mask[idx, len(kg) + 1:, :] = torch.zeros(pad_size, max_size)
+
+        attention_mask = attention_mask.repeat(self.num_heads, 1, 1)
+
+        return attention_mask.bool().to(DEVICE)
diff --git a/convlab/policy/vtrace_DPT/transformer_model/EncoderDecoder.py b/convlab/policy/vtrace_DPT/transformer_model/EncoderDecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc2065f57dac4f17108f4db219368b4bc592dbee
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/EncoderDecoder.py
@@ -0,0 +1,487 @@
+from torch.nn.utils.rnn import pad_sequence
+from .node_embedder import NodeEmbedderRoberta
+from .transformer import TransformerModelEncoder, TransformerModelDecoder
+from .action_embedder import ActionEmbedder
+from torch.distributions.categorical import Categorical
+from .noisy_linear import NoisyLinear
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import sys
+import logging
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class EncoderDecoder(nn.Module):
+    '''
+    Documentation
+    '''
+    def __init__(self, enc_input_dim, enc_nhead, enc_d_hid, enc_nlayers, enc_dropout,
+                 dec_input_dim, dec_nhead, dec_d_hid, dec_nlayers, dec_dropout,
+                 action_embedding_dim, action_dict, domain_embedding_dim, value_embedding_dim,
+                 node_embedding_dim, roberta_path="", node_attention=True, max_length=25, semantic_descriptions=True,
+                 freeze_roberta=True, use_pooled=False, verbose=False, mean=False, ignore_features=None,
+                 only_active_values=False, roberta_actions=False, independent_descriptions=False, need_weights=True,
+                 random_matrix=False, distance_metric=False, noisy_linear=False, dataset_name='multiwoz21', **kwargs):
+        super(EncoderDecoder, self).__init__()
+        self.node_embedder = NodeEmbedderRoberta(node_embedding_dim, freeze_roberta=freeze_roberta,
+                                                 use_pooled=use_pooled, roberta_path=roberta_path,
+                                                 semantic_descriptions=semantic_descriptions, mean=mean,
+                                                 dataset_name=dataset_name).to(DEVICE)
+        #TODO: Encoder input dim should be same as projection dim or use another linear layer?
+        self.encoder = TransformerModelEncoder(enc_input_dim, enc_nhead, enc_d_hid, enc_nlayers, enc_dropout, need_weights).to(DEVICE)
+        self.decoder = TransformerModelDecoder(action_embedding_dim, dec_nhead, dec_d_hid, dec_nlayers, dec_dropout, need_weights).to(DEVICE)
+        if not roberta_actions:
+            self.action_embedder = ActionEmbedder(action_dict, domain_embedding_dim, value_embedding_dim,
+                                                  action_embedding_dim,
+                                                  random_matrix=random_matrix,
+                                                  distance_metric=distance_metric).to(DEVICE)
+        else:
+            self.action_embedder = ActionEmbedder(action_dict, domain_embedding_dim, value_embedding_dim,
+                                                  action_embedding_dim, node_embedder=self.node_embedder,
+                                                  random_matrix=random_matrix,
+                                                  distance_metric=distance_metric).to(DEVICE)
+        #TODO: Ignore features for better robustness and simulating absence of certain information
+        self.ignore_features = ignore_features
+        self.node_attention = node_attention
+        self.freeze_roberta = freeze_roberta
+        self.max_length = max_length
+        self.verbose = verbose
+        self.only_active_values = only_active_values
+        self.num_heads = enc_nhead
+        self.action_embedding_dim = action_embedding_dim
+        # embeddings for "domain", intent", "slot" and "start"
+        self.embedding = nn.Embedding(4, action_embedding_dim).to(DEVICE)
+        self.info_dict = {}
+
+        if noisy_linear:
+            logging.info("EncoderDecoder: We use noisy linear layers.")
+            self.action_projector = NoisyLinear(dec_input_dim, action_embedding_dim).to(DEVICE)
+            self.current_domain_predictor = NoisyLinear(dec_input_dim, 1).to(DEVICE)
+        else:
+            self.action_projector = torch.nn.Linear(dec_input_dim, action_embedding_dim).to(DEVICE)
+            self.current_domain_predictor = torch.nn.Linear(dec_input_dim, 1).to(DEVICE)
+        self.softmax = torch.nn.Softmax(dim=-1)
+        self.sigmoid = torch.nn.Sigmoid()
+
+        self.num_book = 0
+        self.num_nobook = 0
+        self.num_selected = 0
+
+    def get_current_domain_mask(self, kg_list, current=True):
+
+        current_domains = self.get_current_domains(kg_list)
+        current_domain_mask = self.action_embedder.get_current_domain_mask(current_domains[0], current=current).to(DEVICE)
+        return current_domain_mask
+
+    def get_descriptions_and_values(self, kg_list):
+
+        description_idx_list = self.node_embedder.description_2_idx(kg_list[0]).to(DEVICE)
+        value_list = torch.Tensor([node['value'] for node in kg_list[0]]).unsqueeze(1).to(DEVICE)
+        return description_idx_list, value_list
+
+    def select_action(self, kg_list, mask=None, eval=False):
+        '''
+        :param kg_list: A single knowledge graph consisting of a list of nodes
+        :return: multi-action
+        Will also return tensors that are used for calculating log-probs, i.e. for doing RL training
+        '''
+
+        kg_list = [[node for node in kg if node['node_type'] not in self.ignore_features] for kg in kg_list]
+        # this is a bug during supervised training that they use ticket instead of people in book information
+        kg_list = [[node for node in kg if node['description'] != "user goal-train-ticket"] for kg in kg_list]
+
+        current_domains = self.get_current_domains(kg_list)
+        legal_mask = self.action_embedder.get_legal_mask(mask)
+
+        if self.only_active_values:
+            kg_list = [[node for node in kg if node['value'] != 0.0] for kg in kg_list]
+
+        description_idx_list, value_list = self.get_descriptions_and_values(kg_list)
+        encoded_nodes, att_weights_encoder = self.encode_kg([description_idx_list], [value_list])
+        encoded_nodes = encoded_nodes.to(DEVICE)
+
+        active_domains = set([node['domain'].lower() for node in kg_list[0]] + ['general', 'booking'])
+
+        decoder_input = self.embedding(torch.Tensor([3]).long().to(DEVICE)) + self.embedding(torch.Tensor([0]).to(DEVICE).long())
+        decoder_input = decoder_input.view(1, 1, -1).to(DEVICE)
+        start = True
+        action_mask = self.action_embedder.get_action_mask(start=start)
+        action_mask = action_mask + legal_mask
+        action_mask = action_mask.bool().float()
+        action_mask_list = [action_mask]
+        action_list = []
+        action_list_num = []
+        distribution_list = []
+        attention_weights_list = []
+
+        current_domain_mask = self.action_embedder.get_current_domain_mask(current_domains[0], current=True).to(DEVICE)
+        non_current_domain_mask = self.action_embedder.get_current_domain_mask(current_domains[0], current=False).to(DEVICE)
+
+        domains = [d for d, i in sorted(self.action_embedder.domain_dict.items(), key=lambda item: item[1])]
+        domain_indices = [self.action_embedder.small_action_dict[d] for d in domains]
+
+        intents = [d for d, i in sorted(self.action_embedder.intent_dict.items(), key=lambda item: item[1])]
+        intent_indices = [self.action_embedder.small_action_dict[d] for d in intents]
+
+        slot_values = [d for d, i in sorted(self.action_embedder.slot_value_dict.items(), key=lambda item: item[1])]
+        s_v_indices = [self.action_embedder.small_action_dict[d] for d in slot_values]
+
+        for t in range(self.max_length):
+            decoder_output, att_weights_decoder = self.decoder(decoder_input, encoded_nodes.permute(1, 0, 2))
+            attention_weights_list.append(att_weights_decoder)
+            action_logits = self.action_embedder(self.action_projector(decoder_output))
+
+            if t % 3 == 0:
+                # We need to choose a domain
+                current_domain_empty = float((len(current_domains[0]) == 0))
+                # we mask taking a current domain if there is none
+                pick_current_domain_prob = self.sigmoid(
+                    self.current_domain_predictor(decoder_output) - current_domain_empty * sys.maxsize)
+
+                # only pick from current domains
+                action_logits_current_domain = action_logits - (
+                            action_mask + current_domain_mask).bool().float() * sys.maxsize
+                action_distribution_current_domain = self.softmax(action_logits_current_domain)
+                action_distribution_current_domain = (action_distribution_current_domain * pick_current_domain_prob)
+
+                # only pick from non-current domains
+                action_logits_non_current_domain = action_logits - (
+                            action_mask + non_current_domain_mask).bool().float() * sys.maxsize
+                action_distribution_non_current_domain = self.softmax(action_logits_non_current_domain)
+                action_distribution_non_current_domain = (
+                            action_distribution_non_current_domain * (1.0 - pick_current_domain_prob))
+
+                action_distribution = action_distribution_non_current_domain + action_distribution_current_domain
+                action_distribution = (action_distribution / action_distribution.sum(dim=-1, keepdim=True)).squeeze(-1)
+
+            else:
+                action_logits = action_logits - action_mask * sys.maxsize
+                action_distribution = self.softmax(action_logits).squeeze(-1)
+
+            if not eval or t % 3 != 0:
+                dist = Categorical(action_distribution)
+                rand_state = torch.random.get_rng_state()
+                action = dist.sample().tolist()[-1]
+                torch.random.set_rng_state(rand_state)
+                semantic_action = self.action_embedder.small_action_dict_reversed[action[-1]]
+                action_list.append(semantic_action)
+                action_list_num.append(action[-1])
+            else:
+                action = action_distribution[-1, -1, :]
+                action = torch.argmax(action).item()
+                semantic_action = self.action_embedder.small_action_dict_reversed[action]
+                action_list.append(semantic_action)
+                action_list_num.append(action)
+
+            #prepare for next step
+            next_input = self.action_embedder.action_projector(self.action_embedder.action_embeddings[action]).view(1, 1, -1) + \
+                         self.embedding(torch.Tensor([(t + 1) % 3]).to(DEVICE).long())
+            decoder_input = torch.cat([decoder_input, next_input], dim=0)
+
+            if t % 3 == 0:
+                # We chose a domain
+                action_mask_restricted = action_mask[domain_indices]
+                domain_dist = action_distribution[0, -1, :][domain_indices]
+                distribution_list.append(
+                    [semantic_action, dict((domain, (distri, m)) for domain, distri, m in
+                                           zip(domains, domain_dist, action_mask_restricted))])
+
+                if semantic_action == 'eos':
+                    break
+                chosen_domain = semantic_action
+                # focus only on the chosen domain information
+
+                action_mask = self.action_embedder.get_action_mask(domain=semantic_action, start=False)
+                action_mask = action_mask + self.action_embedder.get_legal_mask(mask, domain=semantic_action)
+                action_mask = action_mask.bool().float()
+            elif t % 3 == 1:
+                # We chose an intent
+                if semantic_action == "book":
+                    self.num_book += 1
+                if semantic_action == "nobook":
+                    self.num_nobook += 1
+
+                action_mask = self.action_embedder.get_action_mask(domain=chosen_domain,
+                                                                   intent=semantic_action, start=False)
+                action_mask = action_mask + self.action_embedder.get_legal_mask(mask, domain=chosen_domain,
+                                                                                intent=semantic_action)
+                action_mask = action_mask.bool().float()
+                #intent_dist = action_distribution[0, -1, :][intent_indices]
+                #distribution_list.append(
+                #    [semantic_action, dict((intent, (distri, m)) for intent, distri, m in
+                 #                          zip(intents, intent_dist, action_mask_restricted))])
+            else:
+                # We chose a slot-value pair
+                action_mask = self.action_embedder.get_action_mask(start=False)
+                action_mask = action_mask + self.action_embedder.get_legal_mask(mask)
+                action_mask = action_mask.bool().float()
+
+            action_mask_list.append(action_mask)
+
+        self.num_selected += 1
+
+        if action_list[-1] != 'eos':
+            action_mask_list = action_mask_list[:-1]
+
+        self.info_dict["kg"] = kg_list[0]
+        self.info_dict["small_act"] = torch.Tensor(action_list_num)
+        self.info_dict["action_mask"] = torch.stack(action_mask_list)
+        self.info_dict["description_idx_list"] = description_idx_list
+        self.info_dict["value_list"] = value_list
+        self.info_dict["semantic_action"] = action_list
+        self.info_dict["current_domain_mask"] = current_domain_mask
+        self.info_dict["non_current_domain_mask"] = non_current_domain_mask
+        self.info_dict["active_domains"] = active_domains
+        self.info_dict["attention_weights"] = attention_weights_list
+
+        if self.verbose:
+            print("NEW SELECTION **************************")
+            print(f"KG: {kg_list}")
+            print(f"Active Domains: {active_domains}")
+            print(f"Semantic Act: {action_list}")
+            #print("DISTRIBUTION LIST", distribution_list)
+            print("Attention:", attention_weights_list[1][1][1])
+
+        return self.action_embedder.small_action_list_to_real_actions(action_list)
+
+    def get_log_prob(self, actions, action_mask_list, max_length, action_targets,
+                 current_domain_mask, non_current_domain_mask, descriptions_list, value_list, no_slots=False):
+
+        action_probs, entropy_probs = self.get_prob(actions, action_mask_list, max_length, action_targets,
+                 current_domain_mask, non_current_domain_mask, descriptions_list, value_list)
+        log_probs = torch.log(action_probs)
+
+        entropy_probs = torch.where(entropy_probs < 0.00001, torch.ones(entropy_probs.size()).to(DEVICE), entropy_probs)
+        entropy_probs = torch.where(entropy_probs > 1.0, torch.ones(entropy_probs.size()).to(DEVICE), entropy_probs)
+        entropy = -(entropy_probs * torch.log(entropy_probs)).sum(-1).sum(-1).mean()
+
+        # sometimes a domain will be masked because it is inactive due to labelling error. Will ignore these cases.
+        log_probs[log_probs == -float("Inf")] = 0
+
+        if no_slots:
+            time_steps = torch.arange(0, max_length)
+            slot_steps = torch.where(time_steps % 3 == 2, torch.zeros(max_length), torch.ones(max_length))\
+                .view(1, -1).to(DEVICE)
+            log_probs *= slot_steps
+
+        return log_probs.sum(-1), entropy
+
+    def get_prob(self, actions, action_mask_list, max_length, action_targets,
+                 current_domain_mask, non_current_domain_mask, descriptions_list, value_list):
+        if not self.freeze_roberta:
+            self.node_embedder.form_embedded_descriptions()
+
+        current_domain_mask = current_domain_mask.unsqueeze(1).to(DEVICE)
+        non_current_domain_mask = non_current_domain_mask.unsqueeze(1).to(DEVICE)
+
+        encoded_nodes, att_weights_encoder = self.encode_kg(descriptions_list, value_list)
+        encoder_mask = self.compute_mask(descriptions_list)
+        padded_decoder_input, padded_action_targets = self.get_decoder_tensors(actions, max_length, action_targets)
+        # produde decoder mask to not attend to future time-steps
+        decoder_mask = torch.triu(torch.ones(max_length, max_length) * float('-inf'), diagonal=1)
+
+        decoder_output, att_weights_decoder = self.decoder(padded_decoder_input.permute(1, 0, 2).to(DEVICE),
+                                                           encoded_nodes.permute(1, 0, 2).to(DEVICE),
+                                                           decoder_mask.to(DEVICE), encoder_mask.to(DEVICE))
+
+        pick_current_domain_prob = self.sigmoid(self.current_domain_predictor(decoder_output.permute(1, 0, 2)).clone())
+
+        action_logits = self.action_embedder(self.action_projector(decoder_output.permute(1, 0, 2)))
+
+        # do the general mask for intent and slots, domain must be treated separately
+        action_logits_general = action_logits - action_mask_list * sys.maxsize
+        action_distribution_general = self.softmax(action_logits_general)
+
+        # only pick from current domains
+        action_logits_current_domain = action_logits - (action_mask_list + current_domain_mask).bool().float() * sys.maxsize
+        action_distribution_current_domain = self.softmax(action_logits_current_domain)
+        action_distribution_current_domain = (action_distribution_current_domain * pick_current_domain_prob)
+
+        # only pick from non-current domains
+        action_logits_non_current_domain = action_logits - (action_mask_list + non_current_domain_mask).bool().float() * sys.maxsize
+        action_distribution_non_current_domain = self.softmax(action_logits_non_current_domain)
+        action_distribution_non_current_domain = (action_distribution_non_current_domain * (1.0 - pick_current_domain_prob))
+
+        action_distribution_domain = action_distribution_non_current_domain + action_distribution_current_domain
+
+        time_steps = torch.arange(0, max_length)
+        non_domain_steps = torch.where(time_steps % 3 == 0, torch.zeros(max_length), torch.ones(max_length))\
+            .view(1, -1, 1).to(DEVICE)
+        domain_steps = torch.where(time_steps % 3 == 0, torch.ones(max_length), torch.zeros(max_length))\
+            .view(1, -1, 1).to(DEVICE)
+
+        action_distribution_domain = (action_distribution_domain * domain_steps)
+        action_distribution_general = (action_distribution_general * non_domain_steps)
+        action_distribution = action_distribution_domain + action_distribution_general
+        # make sure it sums up to 1 in every time-step
+        action_distribution = (action_distribution / action_distribution.sum(dim=-1, keepdim=True))
+
+        action_probs = action_distribution.gather(-1, padded_action_targets.long().unsqueeze(-1)).squeeze()
+        # padded time-steps can have very low probability, taking log can be unstable. This prevents it.
+        action_prob_helper = torch.Tensor(
+            [[1] * len(actions) + [0] * (max_length - len(actions)) for actions in action_targets]).to(DEVICE)
+        action_prob_helper_rev = torch.Tensor(
+            [[0] * len(actions) + [1] * (max_length - len(actions)) for actions in action_targets]).to(DEVICE)
+        # set all padded time-steps to 0 probability
+        action_probs = action_probs * action_prob_helper
+        # set padded time-steps to probability 1, so that log will be 0
+        action_probs = action_probs + action_prob_helper_rev
+
+        entropy_probs = action_distribution_general * action_prob_helper.unsqueeze(-1) + action_prob_helper_rev.unsqueeze(-1)
+        #entropy_probs = entropy_probs + domain_steps
+        return action_probs, entropy_probs
+
+    def get_current_domains(self, kg_list):
+        current_domains = []
+        for kg in kg_list:
+            curr_list = []
+            for node in kg:
+                if node['node_type'] == 'user act':
+                    if node['domain'].lower() not in current_domains:
+                        curr_list.append(node['domain'].lower())
+            current_domains.append(curr_list)
+        return current_domains
+
+    def get_decoder_tensors(self, actions, max_length, action_targets):
+
+        # Map the actions to action embeddings that are fed as input to decoder model
+        # pad input and remove "eos" token
+        padded_decoder_input = torch.stack(
+            [torch.cat([act[:-1].to(DEVICE), torch.zeros(max_length - len(act)).to(DEVICE)], dim=-1) for act in action_targets], dim=0) \
+            .to(DEVICE).long()
+
+        padded_action_targets = torch.stack(
+            [torch.cat([act.to(DEVICE), torch.zeros(max_length - len(act)).to(DEVICE)], dim=-1) for act in action_targets], dim=0) \
+            .to(DEVICE)
+
+        decoder_input = self.action_embedder.action_embeddings[padded_decoder_input]
+        decoder_input = self.action_embedder.action_projector(decoder_input)
+        # Add "start" token
+        start_input = self.embedding(torch.Tensor([3]).to(DEVICE).long()).to(DEVICE).repeat(len(actions), 1, 1)
+        decoder_input = torch.cat([start_input, decoder_input], dim=1)
+        # Add "domain", "intent" or "slot" token to input so model knows what to predict
+        type_tokens = self.embedding(torch.remainder(torch.Tensor(range(max_length)).to(DEVICE), 3).long())
+        decoder_input += type_tokens
+
+        return decoder_input, padded_action_targets
+
+    def encode_kg(self, descriptions_list, value_list):
+        #encoder_mask = self.compute_mask_extended(kg_list)
+        encoder_mask = self.compute_mask(descriptions_list)
+        embedded_nodes = self.embedd_nodes(descriptions_list, value_list)
+        padded_graphs = pad_sequence(embedded_nodes, batch_first=False).to(DEVICE)
+        encoded_nodes, att_weights = self.encoder(padded_graphs, src_key_padding_mask=encoder_mask)
+        # size [num_graphs, max_num_nodes, enc_input_dim]
+
+        return encoded_nodes.permute(1, 0, 2), att_weights
+
+    def embedd_nodes(self, descriptions_list, value_list):
+        kg_sizes = [len(descr_list) for descr_list in descriptions_list]
+
+        # we view the kg_list as one huge knowledge graph to embed all nodes simultaneously
+        flattened_descriptions = torch.stack(
+            [descr_idx for descr_list in descriptions_list for descr_idx in descr_list]).to(DEVICE)
+        flattened_values = torch.stack(
+            [value for values in value_list for value in values])
+        flat_embedded_nodes = self.node_embedder(flattened_descriptions, flattened_values).to(DEVICE)
+
+        #now get back the individual knowledge graphs
+        embedded_nodes = []
+        counter = 0
+        for size in kg_sizes:
+            embedded_nodes.append(flat_embedded_nodes[counter:counter + size, :])
+            counter += size
+        return embedded_nodes
+
+    def compute_mask(self, descriptions_list):
+        kg_sizes = [len(descr_list) for descr_list in descriptions_list]
+        max_size = max(kg_sizes)
+        attention_mask = torch.ones((len(descriptions_list), max_size))
+
+        for idx, size in enumerate(kg_sizes):
+            attention_mask[idx, :size] = torch.zeros(size)
+
+        return attention_mask.bool().to(DEVICE)
+
+    def compute_mask_extended(self, kg_list):
+
+        kg_sizes = [len(kg) for kg in kg_list]
+        max_size = max(kg_sizes)
+        attention_mask = torch.ones((len(kg_list), max_size, max_size))
+
+        domain_list = []
+        for kg in kg_list:
+            node_dict = {}
+            for idx, node in enumerate(kg):
+                domain = node['domain']
+                if domain not in node_dict:
+                    node_dict[domain] = torch.ones(max_size)
+                    node_dict[domain][idx] = 0
+                else:
+                    node_dict[domain][idx] = 0
+
+            domain_list.append(node_dict)
+
+        for idx, kg in enumerate(kg_list):
+            for idx_n, node in enumerate(kg):
+                domain = node['domain']
+                attention_mask[idx, idx_n] = domain_list[idx][domain]
+            pad_size = max_size - len(kg)
+            attention_mask[idx, len(kg):, :] = torch.zeros(pad_size, max_size)
+
+        attention_mask = attention_mask.repeat(self.num_heads, 1, 1)
+
+        return attention_mask.bool().to(DEVICE)
+
+    def get_action_masks(self, actions):
+        # active domains
+        # active_domain_list = [set([node['domain'].lower() for node in kg] + ['general', 'booking']) for kg in kg_list]
+        # print("active domain list", active_domain_list)
+
+        action_targets = [self.action_embedder.real_action_to_small_action_list(act) for act in actions]
+        action_lengths = [len(actions) for actions in action_targets]
+        max_length = max(action_lengths)
+
+        semantic_acts = [self.action_embedder.real_action_to_small_action_list(act, semantic=True) for act in actions]
+        action_mask_list = []
+        decoder_encoder_mask_list = []
+        for i, act_sequence in tqdm(enumerate(semantic_acts)):
+            action_mask = [self.action_embedder.get_action_mask(start=True)]
+
+            for t, act in enumerate(act_sequence):
+
+                if t % 3 == 0:
+                    # We chose a domain
+                    if act == 'eos':
+                        break
+                    chosen_domain = act
+                    # focus only on the chosen domain information
+                    # TODO: Decoder encoder mask is unfinished, but I need to modify the self-attention mask for that
+                    # decoder_encoder_mask = [0 if node['domain'] in ['booking', 'general', chosen_domain] else 1
+                    #                        for node in kg_list[i]]
+                    # decoder_encoder_mask.append(decoder_encoder_mask)
+
+                    action_mask.append(self.action_embedder.get_action_mask(domain=act, start=False))
+                elif t % 3 == 1:
+                    # We chose an intent
+                    action_mask.append(self.action_embedder.get_action_mask(domain=chosen_domain,
+                                                                            intent=act, start=(t == 0)))
+                else:
+                    # We chose a slot-value pair
+                    action_mask.append(self.action_embedder.get_action_mask(start=False))
+
+            # pad action mask to get list of max_length
+            action_mask = torch.cat([
+                torch.stack(action_mask).to(DEVICE),
+                torch.zeros(max_length - len(action_mask), len(self.action_embedder.small_action_dict)).to(DEVICE)],
+                dim=0)
+            action_mask_list.append(action_mask)
+
+        action_mask_list = torch.stack(action_mask_list).to(DEVICE)
+        # print("semantic acts:", semantic_acts)
+        return action_mask_list, max_length
+
+
diff --git a/convlab/policy/vtrace_DPT/transformer_model/TransformerLayerCustom.py b/convlab/policy/vtrace_DPT/transformer_model/TransformerLayerCustom.py
new file mode 100644
index 0000000000000000000000000000000000000000..0414114acc58860d1391572f3d587247b4675c64
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/TransformerLayerCustom.py
@@ -0,0 +1,160 @@
+from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer, TransformerDecoder
+from torch import Tensor
+from typing import Optional
+from torch.nn import ModuleList
+from torch.nn import functional as F
+
+import copy
+
+
+class TransformerDecoderLayerCustom(TransformerDecoderLayer):
+
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
+                 layer_norm_eps=1e-5, batch_first=False, device=None, dtype=None, need_weights=False):
+        super().__init__(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, device, dtype)
+        self.need_weights = need_weights
+
+    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
+                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None):
+        r"""Pass the inputs (and mask) through the decoder layer.
+
+        Args:
+            tgt: the sequence to the decoder layer (required).
+            memory: the sequence from the last layer of the encoder (required).
+            tgt_mask: the mask for the tgt sequence (optional).
+            memory_mask: the mask for the memory sequence (optional).
+            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+            memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+        Shape:
+            see the docs in Transformer class.
+        """
+        tgt2, self_att_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask, need_weights=self.need_weights)
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+        tgt2, att_weights = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask, need_weights=self.need_weights)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+        return tgt, (self_att_weights, att_weights)
+
+
+class TransformerEncoderLayerCustom(TransformerEncoderLayer):
+
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
+                 layer_norm_eps=1e-5, batch_first=False, device=None, dtype=None, need_weights=False):
+        super().__init__(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, device, dtype)
+
+        self.need_weights = need_weights
+
+    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
+                src_key_padding_mask: Optional[Tensor] = None):
+        r"""Pass the input through the encoder layer.
+
+        Args:
+            src: the sequence to the encoder layer (required).
+            src_mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+
+        Shape:
+            see the docs in Transformer class.
+        """
+        src2, attention_weights = self.self_attn(src, src, src, attn_mask=src_mask,
+                              key_padding_mask=src_key_padding_mask, need_weights=self.need_weights)
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        return src, attention_weights
+
+
+class TransformerEncoderCustom(TransformerEncoder):
+
+    def __init__(self, encoder_layer, num_layers, norm=None, need_weights=False):
+        super(TransformerEncoder, self).__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.need_weights = need_weights
+
+    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
+        r"""Pass the input through the encoder layers in turn.
+
+        Args:
+            src: the sequence to the encoder (required).
+            mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+
+        Shape:
+            see the docs in Transformer class.
+        """
+        output = src
+        attention_weights_list = []
+
+        for mod in self.layers:
+            output, attention_weights = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
+            attention_weights_list.append(attention_weights)
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output, attention_weights_list
+
+
+class TransformerDecoderCustom(TransformerDecoder):
+
+    def __init__(self, decoder_layer, num_layers, norm=None, need_weights=False):
+        super(TransformerDecoder, self).__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.need_weights = need_weights
+
+    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
+                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
+                memory_key_padding_mask: Optional[Tensor] = None):
+        r"""Pass the inputs (and mask) through the decoder layer in turn.
+
+        Args:
+            tgt: the sequence to the decoder (required).
+            memory: the sequence from the last layer of the encoder (required).
+            tgt_mask: the mask for the tgt sequence (optional).
+            memory_mask: the mask for the memory sequence (optional).
+            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+            memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+        Shape:
+            see the docs in Transformer class.
+        """
+        output = tgt
+        att_weights_list = []
+
+        for mod in self.layers:
+            output, att_weights_tuple = mod(output, memory, tgt_mask=tgt_mask,
+                         memory_mask=memory_mask,
+                         tgt_key_padding_mask=tgt_key_padding_mask,
+                         memory_key_padding_mask=memory_key_padding_mask)
+            att_weights_list.append(att_weights_tuple)
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output, att_weights_list
+
+
+def _get_clones(module, N):
+    return ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+    if activation == "relu":
+        return F.relu
+    elif activation == "gelu":
+        return F.gelu
+
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/transformer_model/__init__.py b/convlab/policy/vtrace_DPT/transformer_model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..283443eb2daf55c56427e6de01f9cf6a2b4049cf
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py
@@ -0,0 +1,292 @@
+import os
+import torch
+import torch.nn as nn
+import logging
+import json
+
+from copy import deepcopy
+from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear
+
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class ActionEmbedder(nn.Module):
+    '''
+    Obtains the action-dictionary with all actions and creates embeddings for domain, intent and slot-value pairs
+    The embeddings are used for creating the domain, intent and slot-value actions in the EncoderDecoder
+    '''
+    def __init__(self, action_dict, embedding_dim, value_embedding_dim, action_embedding_dim, node_embedder=None,
+                 random_matrix=False, distance_metric=False):
+        super(ActionEmbedder, self).__init__()
+
+        self.domain_dict, self.intent_dict, self.slot_dict, self.value_dict, self.slot_value_dict \
+            = self.create_dicts(action_dict)
+
+        #EOS token is considered a "domain"
+        self.action_dict = dict((key.lower(), value) for key, value in action_dict.items())
+        self.action_dict_reversed = dict((value, key) for key, value in self.action_dict.items())
+        self.embed_domain = torch.randn(len(self.domain_dict), embedding_dim)
+        self.embed_intent = torch.randn(len(self.intent_dict), embedding_dim)
+        self.embed_slot = torch.randn(len(self.slot_dict), embedding_dim - value_embedding_dim)
+        self.embed_value = torch.randn(len(self.value_dict), value_embedding_dim)
+        self.embed_rest = torch.randn(1, embedding_dim)     #Pad token
+        self.use_random_matrix = random_matrix
+        self.distance_metric = distance_metric
+        self.forbidden_domains = []
+
+        if not node_embedder:
+            logging.info("We train action embeddings from scratch.")
+            self.action_embeddings, self.small_action_dict = self.create_action_embeddings(embedding_dim)
+            self.action_embeddings.requires_grad = True
+            self.action_embeddings = nn.Parameter(self.action_embeddings)
+        else:
+            logging.info("We use Roberta to embed actions.")
+            self.dataset_name = node_embedder.dataset_name
+            self.create_action_embeddings_roberta(node_embedder)
+            self.action_embeddings.requires_grad = False
+            embedding_dim = 768
+
+        #logging.info(f"Small Action Dict: {self.small_action_dict}")
+
+        self.small_action_dict_reversed = dict((value, key) for key, value in self.small_action_dict.items())
+
+        self.linear = torch.nn.Linear(embedding_dim, action_embedding_dim).to(DEVICE)
+        #self.linear = NoisyLinear(embedding_dim, action_embedding_dim).to(DEVICE)
+        self.random_matrix = torch.randn(embedding_dim, action_embedding_dim).to(DEVICE) / \
+                             torch.sqrt(torch.Tensor([768])).to(DEVICE)
+
+    def action_projector(self, actions):
+        if self.use_random_matrix:
+            return torch.matmul(actions, self.random_matrix).to(DEVICE)
+        else:
+            return self.linear(actions)
+
+    def forward(self, state):
+        # state [batch-size, action_dim], self.action_embeddings [num_actions, embedding_dim]
+        action_embeddings = self.action_projector(self.action_embeddings)
+
+        if not self.distance_metric:
+            # We use scalar product for similarity
+            output = torch.matmul(state, action_embeddings.permute(1, 0))
+        else:
+            # We use distance metric for similarity as in SUMBT
+            output = -torch.cdist(state, action_embeddings, p=2)
+
+        return output
+
+    def get_legal_mask(self, legal_mask, domain="", intent=""):
+
+        if legal_mask is None:
+            return torch.zeros(len(self.small_action_dict)).to(DEVICE)
+
+        action_mask = torch.ones(len(self.small_action_dict))
+        if not domain:
+            for domain in self.domain_dict:
+                # check whether we can use that domain, at the moment we want to allow all domains
+                action_mask[self.small_action_dict[domain]] = 0
+        elif not intent:
+            # Domain was selected, check intents that are allowed
+            for intent in self.intent_dict:
+                domain_intent = f"{domain}-{intent}"
+                for idx, not_allow in enumerate(legal_mask):
+                    semantic_act = self.action_dict_reversed[idx]
+                    if domain_intent in semantic_act and not_allow == 0:
+                        action_mask[self.small_action_dict[intent]] = 0
+                        break
+        else:
+            # Selected domain and intent, need slot-value
+            for slot_value in self.slot_value_dict:
+                domain_intent_slot = f"{domain}-{intent}-{slot_value}"
+                for idx, not_allow in enumerate(legal_mask):
+                    semantic_act = self.action_dict_reversed[idx]
+                    if domain_intent_slot in semantic_act and not_allow == 0:
+                        action_mask[self.small_action_dict[slot_value]] = 0
+                        break
+
+        return action_mask.to(DEVICE)
+
+    def get_action_mask(self, domain=None, intent="", start=False):
+
+        action_mask = torch.ones(len(self.small_action_dict))
+
+        # This is for predicting end of sequence token <eos>
+        if not start and domain is None:
+            action_mask[self.small_action_dict['eos']] = 0
+
+        if domain is None:
+            #TODO: I allow all domains now for checking supervised training
+            for domain in self.domain_dict:
+                if domain not in self.forbidden_domains:
+                    action_mask[self.small_action_dict[domain]] = 0
+            if start:
+                action_mask[self.small_action_dict['eos']] = 1
+            # Only active domains can be selected
+            #for domain in active_domains:
+            #    action_mask[self.small_action_dict[domain]] = 0
+
+        elif not intent:
+            # Domain was selected, need intent now
+            for intent in self.intent_dict:
+                domain_intent = f"{domain}-{intent}"
+                valid = self.is_valid(domain_intent + "-")
+                if valid:
+                    action_mask[self.small_action_dict[intent]] = 0
+        else:
+            # Selected domain and intent, need slot-value
+            for slot_value in self.slot_value_dict:
+                domain_intent_slot = f"{domain}-{intent}-{slot_value}"
+                valid = self.is_valid(domain_intent_slot)
+                if valid:
+                    action_mask[self.small_action_dict[slot_value]] = 0
+
+        assert not torch.equal(action_mask, torch.ones(len(self.small_action_dict)))
+
+        return action_mask.to(DEVICE)
+
+    def get_current_domain_mask(self, current_domains, current=True):
+
+        action_mask = torch.ones(len(self.small_action_dict))
+        if current:
+            for domain in current_domains:
+                action_mask[self.small_action_dict[domain]] = 0
+        else:
+            for domain in self.domain_dict:
+                if domain not in current_domains:
+                    action_mask[self.small_action_dict[domain]] = 0
+
+        return action_mask.to(DEVICE)
+
+    def is_valid(self, part_action):
+
+        for act in self.action_dict:
+            if act.startswith(part_action):
+                return True
+
+        return False
+
+    def create_action_embeddings(self, embedding_dim):
+
+        action_embeddings = torch.zeros((len(self.domain_dict) + len(self.intent_dict) + len(self.slot_value_dict) + 1,
+                                         embedding_dim))
+
+        small_action_dict = {}
+        for domain, idx in self.domain_dict.items():
+            action_embeddings[len(small_action_dict)] = self.embed_domain[idx]
+            small_action_dict[domain] = len(small_action_dict)
+        for intent, idx in self.intent_dict.items():
+            action_embeddings[len(small_action_dict)] = self.embed_intent[idx]
+            small_action_dict[intent] = len(small_action_dict)
+        for slot_value in self.slot_value_dict:
+            slot, value = slot_value.split("-")
+            slot_idx = self.slot_dict[slot]
+            value_idx = self.value_dict[value]
+            action_embeddings[len(small_action_dict)] = torch.cat(
+                (self.embed_slot[slot_idx], self.embed_value[value_idx]))
+            small_action_dict[slot_value] = len(small_action_dict)
+
+        action_embeddings[len(small_action_dict)] = self.embed_rest[0]      #add the PAD token
+        small_action_dict['pad'] = len(small_action_dict)
+        return action_embeddings.to(DEVICE), small_action_dict
+
+    def create_action_embeddings_roberta(self, node_embedder):
+
+        action_embeddings = []
+
+        small_action_dict = {}
+        for domain, idx in self.domain_dict.items():
+            action_embeddings.append(domain)
+            small_action_dict[domain] = len(small_action_dict)
+        for intent, idx in self.intent_dict.items():
+            action_embeddings.append(intent)
+            small_action_dict[intent] = len(small_action_dict)
+        for slot_value in self.slot_value_dict:
+            slot, value = slot_value.split("-")
+            action_embeddings.append(f"{slot} {value}")
+            small_action_dict[slot_value] = len(small_action_dict)
+
+        action_embeddings.append("pad")     #add the PAD token
+        small_action_dict['pad'] = len(small_action_dict)
+
+        action_embeddings_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                              f'action_embeddings_{self.dataset_name}.pt')
+        small_action_dict_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                              f'small_action_dict_{self.dataset_name}.json')
+
+        if os.path.exists(action_embeddings_path):
+            self.action_embeddings = torch.load(action_embeddings_path).to(DEVICE)
+        else:
+            self.action_embeddings = node_embedder.embed_sentences(action_embeddings).to(DEVICE)
+            torch.save(self.action_embeddings, action_embeddings_path)
+
+        if os.path.exists(small_action_dict_path):
+            self.small_action_dict = json.load(open(small_action_dict_path, 'r'))
+        else:
+            self.small_action_dict = small_action_dict
+            with open(small_action_dict_path, 'w') as f:
+                json.dump(self.small_action_dict, f)
+
+        self.small_action_dict = small_action_dict
+
+    def create_dicts(self, action_dict):
+        domain_dict = {}
+        intent_dict = {}
+        slot_dict = {}
+        value_dict = {}
+        slot_value_dict = {}
+        for action in action_dict:
+            domain, intent, slot, value = [act.lower() for act in action.split('-')]
+            if domain not in domain_dict:
+                domain_dict[domain] = len(domain_dict)
+            if intent not in intent_dict:
+                intent_dict[intent] = len(intent_dict)
+            if slot not in slot_dict:
+                slot_dict[slot] = len(slot_dict)
+            if value not in value_dict:
+                value_dict[value] = len(value_dict)
+            if slot + "-" + value not in slot_value_dict:
+                slot_value_dict[slot + "-" + value] = len(slot_value_dict)
+
+        domain_dict['eos'] = len(domain_dict)
+
+        return domain_dict, intent_dict, slot_dict, value_dict, slot_value_dict
+
+    def small_action_list_to_real_actions(self, small_action_list):
+
+        #print("SMALL ACTION LIST:", small_action_list)
+        action_vector = torch.zeros(len(self.action_dict))
+        act_string = ""
+        for idx, act in enumerate(small_action_list):
+            if act == 'eos':
+                break
+
+            if idx % 3 != 2:
+                act_string += f"{act}-"
+            else:
+                act_string += act
+                action_vector[self.action_dict[act_string]] = 1
+                act_string = ""
+
+        return action_vector
+
+    def real_action_to_small_action_list(self, action, semantic=False, permute=False):
+        '''
+        :param action: [hotel-req-address, taxi-inform-phone]
+        :return: [hotel, req, address, taxi, inform, phone, eos]
+        '''
+
+        action_list = []
+        for idx, i in enumerate(action):
+            if i == 1:
+                action_list += self.action_dict_reversed[idx].split("-", 2)
+
+        if permute and len(action_list) > 3:
+            action_list_new = deepcopy(action_list[-3:]) + deepcopy(action_list[:-3])
+            action_list = action_list_new
+        action_list.append("eos")
+
+        if semantic:
+            return action_list
+
+        action_list = [self.small_action_dict[act] for act in action_list]
+        return action_list
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt
new file mode 100644
index 0000000000000000000000000000000000000000..2f7e5e9cdf949b082b9ae752c09d2eafeb322bfb
Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt differ
diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt
new file mode 100644
index 0000000000000000000000000000000000000000..67e557ce5ef7a0bfd40c60ae5a03937b47b92de2
Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt differ
diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5c6193163b6a1fc788b959eff2d0153b3a4bf7cb
Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt differ
diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt
new file mode 100644
index 0000000000000000000000000000000000000000..619824588654a36cab3bf795e6fe94527b04ba68
Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt differ
diff --git a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab57df122369e9457c78e04fe2e64eab389d5abd
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py
@@ -0,0 +1,123 @@
+import os, json, logging
+import torch
+import torch.nn as nn
+
+from transformers import RobertaTokenizer, RobertaModel
+from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear
+from convlab.policy.vtrace_DPT.create_descriptions import create_description_dicts
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class NodeEmbedderRoberta(nn.Module):
+    '''
+    Class to build node embeddings
+    Nodes have attributes: Description, value and node type that are used for building embedding
+    '''
+
+    def __init__(self, projection_dim, freeze_roberta=True, use_pooled=False, max_length=25, roberta_path="",
+                 description_dict=None, semantic_descriptions=True, mean=False, dataset_name="multiwoz21"):
+        super(NodeEmbedderRoberta, self).__init__()
+
+        self.dataset_name = dataset_name
+        self.max_length = max_length
+        self.description_size = 768
+        self.projection_dim = projection_dim
+        self.feature_projection = torch.nn.Linear(2 * self.description_size, projection_dim).to(DEVICE)
+        #self.feature_projection = NoisyLinear(2 * self.description_size, projection_dim).to(DEVICE)
+        self.value_embedding = torch.nn.Linear(1, self.description_size).to(DEVICE)
+
+        self.semantic_descriptions = semantic_descriptions
+        self.init_description_dict()
+
+        self.description2idx = dict((descr, i) for i, descr in enumerate(self.description_dict))
+        self.idx2description = dict((i, descr) for descr, i in self.description2idx.items())
+        self.use_pooled = use_pooled
+        self.mean = mean
+        self.embedded_descriptions = None
+
+        if roberta_path:
+            embedded_descriptions_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                                      f'embedded_descriptions_{self.dataset_name}.pt')
+            if os.path.exists(embedded_descriptions_path):
+                self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE)
+            else:
+                logging.info(f"Loading Roberta from path {roberta_path}")
+                self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+                self.roberta_model = RobertaModel.from_pretrained(roberta_path).to(DEVICE)
+
+        else:
+            embedded_descriptions_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                                      f'embedded_descriptions_base_{self.dataset_name}.pt')
+            if os.path.exists(embedded_descriptions_path):
+                self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE)
+            else:
+                self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+                self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE)
+
+        if self.embedded_descriptions is None:
+            if freeze_roberta:
+                for param in self.roberta_model.parameters():
+                    param.requires_grad = False
+            #We embed descriptions beforehand and only make a lookup for better efficiency
+            self.form_embedded_descriptions()
+            torch.save(self.embedded_descriptions, embedded_descriptions_path)
+
+        logging.info(f"Embedding semantic descriptions: {semantic_descriptions}")
+        logging.info(f"Embedded descriptions successfully. Size: {self.embedded_descriptions.size()}")
+        logging.info(f"Data set used for descriptions: {dataset_name}")
+
+    def form_embedded_descriptions(self):
+
+        self.embedded_descriptions = self.embed_sentences(
+            [self.description_dict[self.idx2description[i]] for i in range(len(self.description_dict))])
+
+    def description_2_idx(self, kg_info):
+        embedded_descriptions_idx = torch.Tensor([self.description2idx[node["description"]] for node in kg_info])\
+            .long()
+        return embedded_descriptions_idx
+
+    def forward(self, description_idx, values):
+
+        #embedded_descriptions = torch.stack(
+        #    [self.embedded_descriptions[idx] for idx in description_idx]).to(DEVICE)
+        embedded_descriptions = self.embedded_descriptions[description_idx]
+        description_value_tensor = torch.cat((self.value_embedding(values),
+                                                  embedded_descriptions), dim=-1).to(DEVICE)
+
+        node_embedding = self.feature_projection(description_value_tensor).to(DEVICE)
+
+        return node_embedding
+
+    def embed_sentences(self, sentences):
+
+        tokenized = [self.tokenizer.encode_plus(sen, add_special_tokens=True, max_length=self.max_length,
+                                                padding='max_length') for sen in sentences]
+
+        input_ids = torch.Tensor([feat['input_ids'] for feat in tokenized]).long().to(DEVICE)
+        attention_mask = torch.Tensor([feat['attention_mask'] for feat in tokenized]).long().to(DEVICE)
+
+        roberta_output = self.roberta_model(input_ids, attention_mask)
+        output_states = roberta_output.last_hidden_state
+        pooled = roberta_output.pooler_output
+
+        if self.mean:
+            length_mask = torch.Tensor([[1 if id_ != 1 else 0 for id_ in ids] for ids in input_ids]).unsqueeze(-1)\
+                .to(DEVICE)
+            output = (output_states * length_mask).sum(dim=1) / length_mask.sum(dim=1)
+        else:
+            output = pooled if self.use_pooled else output_states[:, 0, :]
+
+        return output
+
+    def init_description_dict(self):
+
+        create_description_dicts(self.dataset_name)
+        root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        if self.semantic_descriptions:
+            path = os.path.join(root_dir, f'descriptions/semantic_information_descriptions_{self.dataset_name}.json')
+        else:
+            path = os.path.join(root_dir, 'information_descriptions.json')
+        with open(path, "r") as f:
+            self.description_dict = json.load(f)
+
diff --git a/convlab/policy/vtrace_DPT/transformer_model/noisy_linear.py b/convlab/policy/vtrace_DPT/transformer_model/noisy_linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..90fcb9fd38b07f4dc2a0ee60833c4d1803812e54
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/noisy_linear.py
@@ -0,0 +1,34 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class NoisyLinear(nn.Linear):
+
+    def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
+        super(NoisyLinear, self).__init__(in_features, out_features, bias=bias)
+        w = torch.full((out_features, in_features), sigma_init)
+        self.sigma_weight = nn.Parameter(w)
+        z = torch.zeros(out_features, in_features)
+        self.register_buffer("epsilon_weight", z)
+        if bias:
+            w = torch.full((out_features,), sigma_init)
+            self.sigma_bias = nn.Parameter(w)
+            z = torch.zeros(out_features)
+            self.register_buffer("epsilon_bias", z)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        std = math.sqrt(3 / self.in_features)
+        self.weight.data.uniform_(-std, std)
+        self.bias.data.uniform_(-std, std)
+
+    def forward(self, input):
+        self.epsilon_weight.normal_()
+        bias = self.bias
+        if bias is not None:
+            self.epsilon_bias.normal_()
+            bias = bias + self.sigma_bias * self.epsilon_bias.data
+        v = self.sigma_weight * self.epsilon_weight.data + self.weight
+        return F.linear(input, v, bias)
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json
new file mode 100644
index 0000000000000000000000000000000000000000..0d5bd2002fa7d0082e7589b80ae3664781732ece
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json
@@ -0,0 +1 @@
+{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address-1": 21, "address-2": 22, "address-3": 23, "area-1": 24, "area-2": 25, "area-3": 26, "choice-1": 27, "choice-2": 28, "choice-3": 29, "entrance fee-1": 30, "entrance fee-2": 31, "name-1": 32, "name-2": 33, "name-3": 34, "name-4": 35, "phone-1": 36, "postcode-1": 37, "type-1": 38, "type-2": 39, "type-3": 40, "type-4": 41, "type-5": 42, "none-none": 43, "area-?": 44, "entrance fee-?": 45, "name-?": 46, "type-?": 47, "department-1": 48, "department-?": 49, "book day-1": 50, "book people-1": 51, "book stay-1": 52, "internet-1": 53, "parking-1": 54, "price range-1": 55, "price range-2": 56, "ref-1": 57, "stars-1": 58, "stars-2": 59, "book day-?": 60, "book people-?": 61, "book stay-?": 62, "internet-?": 63, "parking-?": 64, "price range-?": 65, "stars-?": 66, "book time-1": 67, "food-1": 68, "food-2": 69, "food-3": 70, "food-4": 71, "postcode-2": 72, "book time-?": 73, "food-?": 74, "arrive by-1": 75, "departure-1": 76, "destination-1": 77, "leave at-1": 78, "arrive by-?": 79, "departure-?": 80, "destination-?": 81, "leave at-?": 82, "arrive by-2": 83, "day-1": 84, "duration-1": 85, "leave at-2": 86, "leave at-3": 87, "price-1": 88, "train id-1": 89, "day-?": 90, "pad": 91}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json
new file mode 100644
index 0000000000000000000000000000000000000000..0d5bd2002fa7d0082e7589b80ae3664781732ece
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json
@@ -0,0 +1 @@
+{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address-1": 21, "address-2": 22, "address-3": 23, "area-1": 24, "area-2": 25, "area-3": 26, "choice-1": 27, "choice-2": 28, "choice-3": 29, "entrance fee-1": 30, "entrance fee-2": 31, "name-1": 32, "name-2": 33, "name-3": 34, "name-4": 35, "phone-1": 36, "postcode-1": 37, "type-1": 38, "type-2": 39, "type-3": 40, "type-4": 41, "type-5": 42, "none-none": 43, "area-?": 44, "entrance fee-?": 45, "name-?": 46, "type-?": 47, "department-1": 48, "department-?": 49, "book day-1": 50, "book people-1": 51, "book stay-1": 52, "internet-1": 53, "parking-1": 54, "price range-1": 55, "price range-2": 56, "ref-1": 57, "stars-1": 58, "stars-2": 59, "book day-?": 60, "book people-?": 61, "book stay-?": 62, "internet-?": 63, "parking-?": 64, "price range-?": 65, "stars-?": 66, "book time-1": 67, "food-1": 68, "food-2": 69, "food-3": 70, "food-4": 71, "postcode-2": 72, "book time-?": 73, "food-?": 74, "arrive by-1": 75, "departure-1": 76, "destination-1": 77, "leave at-1": 78, "arrive by-?": 79, "departure-?": 80, "destination-?": 81, "leave at-?": 82, "arrive by-2": 83, "day-1": 84, "duration-1": 85, "leave at-2": 86, "leave at-3": 87, "price-1": 88, "train id-1": 89, "day-?": 90, "pad": 91}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json
new file mode 100644
index 0000000000000000000000000000000000000000..c48573262f50cabe1d4cd3811638ebf4c046a7e0
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json
@@ -0,0 +1 @@
+{"": 0, "alarm_1": 1, "banks_1": 2, "banks_2": 3, "buses_1": 4, "buses_2": 5, "buses_3": 6, "calendar_1": 7, "events_1": 8, "events_2": 9, "events_3": 10, "flights_1": 11, "flights_2": 12, "flights_3": 13, "flights_4": 14, "homes_1": 15, "homes_2": 16, "hotels_1": 17, "hotels_2": 18, "hotels_3": 19, "hotels_4": 20, "media_1": 21, "media_2": 22, "media_3": 23, "messaging_1": 24, "movies_1": 25, "movies_2": 26, "movies_3": 27, "music_1": 28, "music_2": 29, "music_3": 30, "payment_1": 31, "rentalcars_1": 32, "rentalcars_2": 33, "rentalcars_3": 34, "restaurants_1": 35, "restaurants_2": 36, "ridesharing_1": 37, "ridesharing_2": 38, "services_1": 39, "services_2": 40, "services_3": 41, "services_4": 42, "trains_1": 43, "travel_1": 44, "weather_1": 45, "eos": 46, "goodbye": 47, "req_more": 48, "confirm": 49, "inform_count": 50, "notify_success": 51, "offer": 52, "offer_intent": 53, "request": 54, "inform": 55, "notify_failure": 56, "none-none": 57, "new_alarm_name-1": 58, "new_alarm_time-1": 59, "count-1": 60, "alarm_name-1": 61, "alarm_time-1": 62, "addalarm-1": 63, "new_alarm_time-?": 64, "account_type-1": 65, "amount-1": 66, "recipient_account_name-1": 67, "recipient_account_type-1": 68, "balance-1": 69, "transfermoney-1": 70, "account_type-?": 71, "amount-?": 72, "recipient_account_name-?": 73, "recipient_name-1": 74, "transfer_amount-1": 75, "transfer_time-1": 76, "account_balance-1": 77, "recipient_name-?": 78, "transfer_amount-?": 79, "from_location-1": 80, "leaving_date-1": 81, "leaving_time-1": 82, "to_location-1": 83, "travelers-1": 84, "from_station-1": 85, "to_station-1": 86, "transfers-1": 87, "fare-1": 88, "buybusticket-1": 89, "from_location-?": 90, "leaving_date-?": 91, "leaving_time-?": 92, "to_location-?": 93, "travelers-?": 94, "departure_date-1": 95, "departure_time-1": 96, "destination-1": 97, "fare_type-1": 98, "group_size-1": 99, "origin-1": 100, "destination_station_name-1": 101, "origin_station_name-1": 102, "price-1": 103, "departure_date-?": 104, "departure_time-?": 105, "destination-?": 106, "group_size-?": 107, "origin-?": 108, "additional_luggage-1": 109, "from_city-1": 110, "num_passengers-1": 111, "to_city-1": 112, "category-1": 113, "from_city-?": 114, "num_passengers-?": 115, "to_city-?": 116, "event_date-1": 117, "event_location-1": 118, "event_name-1": 119, "event_time-1": 120, "available_end_time-1": 121, "available_start_time-1": 122, "addevent-1": 123, "event_date-?": 124, "event_location-?": 125, "event_name-?": 126, "event_time-?": 127, "city_of_event-1": 128, "date-1": 129, "number_of_seats-1": 130, "address_of_location-1": 131, "subcategory-1": 132, "time-1": 133, "buyeventtickets-1": 134, "category-?": 135, "city_of_event-?": 136, "date-?": 137, "number_of_seats-?": 138, "city-1": 139, "number_of_tickets-1": 140, "venue-1": 141, "venue_address-1": 142, "city-?": 143, "event_type-?": 144, "number_of_tickets-?": 145, "price_per_ticket-1": 146, "airlines-1": 147, "destination_city-1": 148, "inbound_departure_time-1": 149, "origin_city-1": 150, "outbound_departure_time-1": 151, "passengers-1": 152, "return_date-1": 153, "seating_class-1": 154, "destination_airport-1": 155, "inbound_arrival_time-1": 156, "number_stops-1": 157, "origin_airport-1": 158, "outbound_arrival_time-1": 159, "refundable-1": 160, "reserveonewayflight-1": 161, "reserveroundtripflights-1": 162, "airlines-?": 163, "destination_city-?": 164, "inbound_departure_time-?": 165, "origin_city-?": 166, "outbound_departure_time-?": 167, "return_date-?": 168, "is_redeye-1": 169, "arrives_next_day-1": 170, "destination_airport_name-1": 171, "origin_airport_name-1": 172, "is_nonstop-1": 173, "destination_airport-?": 174, "origin_airport-?": 175, "property_name-1": 176, "visit_date-1": 177, "furnished-1": 178, "pets_allowed-1": 179, "phone_number-1": 180, "address-1": 181, "number_of_baths-1": 182, "number_of_beds-1": 183, "rent-1": 184, "schedulevisit-1": 185, "area-?": 186, "number_of_beds-?": 187, "visit_date-?": 188, "has_garage-1": 189, "in_unit_laundry-1": 190, "intent-?": 191, "number_of_baths-?": 192, "check_in_date-1": 193, "hotel_name-1": 194, "number_of_days-1": 195, "number_of_rooms-1": 196, "has_wifi-1": 197, "price_per_night-1": 198, "street_address-1": 199, "star_rating-1": 200, "reservehotel-1": 201, "check_in_date-?": 202, "hotel_name-?": 203, "number_of_days-?": 204, "check_out_date-1": 205, "number_of_adults-1": 206, "where_to-1": 207, "has_laundry_service-1": 208, "total_price-1": 209, "rating-1": 210, "bookhouse-1": 211, "check_out_date-?": 212, "number_of_adults-?": 213, "where_to-?": 214, "location-1": 215, "pets_welcome-1": 216, "average_rating-1": 217, "location-?": 218, "place_name-1": 219, "stay_length-1": 220, "smoking_allowed-1": 221, "stay_length-?": 222, "subtitles-1": 223, "title-1": 224, "directed_by-1": 225, "genre-1": 226, "title-2": 227, "title-3": 228, "playmovie-1": 229, "genre-?": 230, "title-?": 231, "movie_name-1": 232, "subtitle_language-1": 233, "movie_name-2": 234, "movie_name-3": 235, "rentmovie-1": 236, "starring-1": 237, "contact_name-1": 238, "contact_name-?": 239, "show_date-1": 240, "show_time-1": 241, "show_type-1": 242, "theater_name-1": 243, "buymovietickets-1": 244, "movie_name-?": 245, "show_date-?": 246, "show_time-?": 247, "show_type-?": 248, "aggregate_rating-1": 249, "cast-1": 250, "movie_title-1": 251, "percent_rating-1": 252, "playback_device-1": 253, "song_name-1": 254, "album-1": 255, "year-1": 256, "artist-1": 257, "playsong-1": 258, "song_name-?": 259, "playmedia-1": 260, "device-1": 261, "track-1": 262, "payment_method-1": 263, "private_visibility-1": 264, "receiver-1": 265, "payment_method-?": 266, "receiver-?": 267, "dropoff_date-1": 268, "pickup_date-1": 269, "pickup_location-1": 270, "pickup_time-1": 271, "type-1": 272, "car_name-1": 273, "reservecar-1": 274, "dropoff_date-?": 275, "pickup_city-?": 276, "pickup_date-?": 277, "pickup_location-?": 278, "pickup_time-?": 279, "type-?": 280, "car_type-1": 281, "car_type-?": 282, "add_insurance-1": 283, "end_date-1": 284, "start_date-1": 285, "price_per_day-1": 286, "add_insurance-?": 287, "end_date-?": 288, "start_date-?": 289, "party_size-1": 290, "restaurant_name-1": 291, "cuisine-1": 292, "has_live_music-1": 293, "price_range-1": 294, "serves_alcohol-1": 295, "reserverestaurant-1": 296, "cuisine-?": 297, "restaurant_name-?": 298, "time-?": 299, "has_seating_outdoors-1": 300, "has_vegetarian_options-1": 301, "number_of_riders-1": 302, "shared_ride-1": 303, "approximate_ride_duration-1": 304, "ride_fare-1": 305, "number_of_riders-?": 306, "shared_ride-?": 307, "ride_type-1": 308, "wait_time-1": 309, "ride_type-?": 310, "appointment_date-1": 311, "appointment_time-1": 312, "stylist_name-1": 313, "is_unisex-1": 314, "bookappointment-1": 315, "appointment_date-?": 316, "appointment_time-?": 317, "dentist_name-1": 318, "offers_cosmetic_services-1": 319, "doctor_name-1": 320, "therapist_name-1": 321, "class-1": 322, "date_of_journey-1": 323, "from-1": 324, "journey_start_time-1": 325, "to-1": 326, "trip_protection-1": 327, "total-1": 328, "gettraintickets-1": 329, "date_of_journey-?": 330, "from-?": 331, "to-?": 332, "trip_protection-?": 333, "free_entry-1": 334, "good_for_kids-1": 335, "attraction_name-1": 336, "humidity-1": 337, "wind-1": 338, "precipitation-1": 339, "temperature-1": 340, "pad": 341}
\ No newline at end of file
diff --git a/convlab/policy/vtrace_DPT/transformer_model/transformer.py b/convlab/policy/vtrace_DPT/transformer_model/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb26a5319f1ddc6feeccf35b4096ec6f934efde5
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/transformer_model/transformer.py
@@ -0,0 +1,99 @@
+import math
+import torch
+
+from torch import nn, Tensor
+from .TransformerLayerCustom import TransformerEncoderCustom, TransformerEncoderLayerCustom, \
+    TransformerDecoderLayerCustom, TransformerDecoderCustom
+
+
+class TransformerModelEncoder(nn.Module):
+
+    def __init__(self, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5, need_weights=False):
+        super().__init__()
+        self.model_type = 'TransformerEncoder'
+        self.pos_encoder = PositionalEncoding(d_model, dropout)
+        encoder_layers = TransformerEncoderLayerCustom(d_model, nhead, d_hid, dropout, need_weights=need_weights)
+        self.encoder = TransformerEncoderCustom(encoder_layers, nlayers, need_weights=need_weights)
+        self.d_model = d_model
+
+        #self.init_weights()
+
+    def init_weights(self) -> None:
+        initrange = 0.1
+        self.encoder.weight.data.uniform_(-initrange, initrange)
+        self.decoder.bias.data.zero_()
+        self.decoder.weight.data.uniform_(-initrange, initrange)
+
+    def forward(self, src, mask=None, src_key_padding_mask=None):
+        """
+        Args:
+            src: Tensor, shape [seq_len, batch_size]
+            src_mask: Tensor, shape [seq_len, seq_len]
+
+        Returns:
+            output Tensor of shape [seq_len, batch_size, ntoken]
+        """
+        src = self.pos_encoder(src)
+        output, attention_weights = self.encoder(src, mask=mask, src_key_padding_mask=src_key_padding_mask)
+        return output, attention_weights
+
+
+class TransformerModelDecoder(nn.Module):
+
+    def __init__(self, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5, need_weights=False):
+        super().__init__()
+        self.model_type = 'TransformerDecoder'
+        self.pos_encoder = PositionalEncoding(d_model, dropout)
+        decoder_layers = TransformerDecoderLayerCustom(d_model, nhead, d_hid, dropout, need_weights=need_weights)
+        self.decoder = TransformerDecoderCustom(decoder_layers, nlayers, need_weights=need_weights)
+        self.d_model = d_model
+
+        #self.init_weights()
+
+    def init_weights(self) -> None:
+        initrange = 0.1
+        self.encoder.weight.data.uniform_(-initrange, initrange)
+        #self.decoder.bias.data.zero_()
+        #self.decoder.weight.data.uniform_(-initrange, initrange)
+
+    def forward(self, decoder_input, encoder_output, tgt_mask=None, memory_key_padding_mask=None):
+        """
+        Args:
+            src: Tensor, shape [seq_len, batch_size]
+            src_mask: Tensor, shape [seq_len, seq_len]
+
+        Returns:
+            output Tensor of shape [seq_len, batch_size, ntoken]
+        """
+        decoder_input = self.pos_encoder(decoder_input)
+        output, att_weights = self.decoder(tgt=decoder_input, memory=encoder_output, tgt_mask=tgt_mask,
+                                          tgt_key_padding_mask=None,
+                                          memory_key_padding_mask=memory_key_padding_mask)
+        return output, att_weights
+
+
+def generate_square_subsequent_mask(sz: int) -> Tensor:
+    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
+    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
+
+
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        position = torch.arange(max_len).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+        pe = torch.zeros(max_len, 1, d_model)
+        pe[:, 0, 0::2] = torch.sin(position * div_term)
+        pe[:, 0, 1::2] = torch.cos(position * div_term)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x: Tensor, shape [seq_len, batch_size, embedding_dim]
+        """
+        x = x + self.pe[:x.size(0)]
+        return self.dropout(x)
diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..b03662c60539a2e3aa80a5132618a3f7563a0f09
--- /dev/null
+++ b/convlab/policy/vtrace_DPT/vtrace.py
@@ -0,0 +1,368 @@
+import numpy as np
+import logging
+import json
+import os
+import sys
+import torch
+import torch.nn as nn
+
+from torch import optim
+from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder
+from convlab.policy.vtrace_DPT.transformer_model.EncoderCritic import EncoderCritic
+from ... import Policy
+from ...util.custom_util import set_seed
+
+root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+sys.path.append(root_dir)
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class VTRACE(nn.Module, Policy):
+
+    def __init__(self, is_train=True, seed=0, vectorizer=None, load_path=""):
+
+        super(VTRACE, self).__init__()
+
+        dir_name = os.path.dirname(os.path.abspath(__file__))
+        self.config_path = os.path.join(dir_name, 'config.json')
+
+        with open(self.config_path, 'r') as f:
+            cfg = json.load(f)
+
+        self.cfg = cfg
+        self.save_dir = os.path.join(dir_name, cfg['save_dir'])
+        self.save_per_epoch = cfg['save_per_epoch']
+        self.gamma = cfg['gamma']
+        self.tau = cfg['tau']
+        self.is_train = is_train
+        self.entropy_weight = cfg.get('entropy_weight', 0.0)
+        self.behaviour_cloning_weight = cfg.get('behaviour_cloning_weight', 0.0)
+        self.online_offline_ratio = cfg.get('online_offline_ratio', 0.0)
+        self.hidden_size = cfg['hidden_size']
+        self.policy_freq = cfg['policy_freq']
+        self.seed = seed
+        self.total_it = 0
+        self.rho_bar = cfg.get('rho_bar', 10)
+        self.c = cfg['c']
+        self.info_dict = {}
+        self.use_regularization = False
+        self.supervised_weight = cfg.get('supervised_weight', 0.0)
+
+        logging.info(f"Entropy weight: {self.entropy_weight}")
+        logging.info(f"Online-Offline-ratio: {self.online_offline_ratio}")
+        logging.info(f"Behaviour cloning weight: {self.behaviour_cloning_weight}")
+        logging.info(f"Supervised weight: {self.supervised_weight}")
+
+        set_seed(seed)
+
+        self.last_action = None
+
+        self.vector = vectorizer
+        self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
+        self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
+
+        try:
+            self.load_policy(load_path)
+        except Exception as e:
+            print(f"Could not load the policy, Exception: {e}")
+
+        if self.cfg['independent']:
+            self.value = EncoderCritic(self.value_helper.node_embedder, self.value_helper.encoder, **self.cfg).to(
+                device=DEVICE)
+        else:
+            self.value = EncoderCritic(self.policy.node_embedder, self.policy.encoder, **self.cfg).to(device=DEVICE)
+
+        try:
+            self.load_value(load_path)
+        except Exception as e:
+            print(f"Could not load the critic, Exception: {e}")
+
+        self.optimizer = optim.Adam([
+            {'params': self.policy.parameters(), 'lr': cfg['policy_lr'], 'betas': (0.0, 0.999)},
+            {'params': self.value.parameters(), 'lr': cfg['value_lr']}
+        ])
+
+        try:
+            self.load_optimizer_dicts(load_path)
+        except Exception as e:
+            print(f"Could not load optimiser dicts, Exception: {e}")
+
+    def predict(self, state):
+        """
+        Predict an system action given state.
+        Args:
+            state (dict): Dialog state. Please refer to util/state.py
+        Returns:
+            action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
+        """
+
+        if not self.is_train:
+            for param in self.policy.parameters():
+                param.requires_grad = False
+            for param in self.value.parameters():
+                param.requires_grad = False
+
+        s, action_mask = self.vector.state_vectorize(state)
+
+        kg_states = [self.vector.kg_info]
+        a = self.policy.select_action(kg_states, mask=action_mask, eval=not self.is_train).detach().cpu()
+        self.info_dict = self.policy.info_dict
+
+        descr_list = self.info_dict["description_idx_list"]
+        value_list = self.info_dict["value_list"]
+        current_domain_mask = self.info_dict["current_domain_mask"].unsqueeze(0)
+        non_current_domain_mask = self.info_dict["non_current_domain_mask"].unsqueeze(0)
+
+        a_prob, _ = self.policy.get_prob(a.unsqueeze(0), self.info_dict['action_mask'].unsqueeze(0),
+                                         len(self.info_dict['small_act']), [self.info_dict['small_act']],
+                                         current_domain_mask, non_current_domain_mask, [descr_list], [value_list])
+
+        self.info_dict['big_act'] = a
+        self.info_dict['a_prob'] = a_prob.prod()
+        self.info_dict['critic_value'] = self.value([descr_list], [value_list]).squeeze()
+
+        action = self.vector.action_devectorize(a.detach().numpy())
+
+        return action
+
+    def update(self, memory):
+        p_loss, v_loss = self.get_loss(memory)
+        loss = v_loss
+        if p_loss is not None:
+            loss += p_loss
+
+        self.optimizer.zero_grad()
+        loss.backward()
+
+        torch.nn.utils.clip_grad_norm_(self.value.parameters(), 40)
+        for p in self.policy.parameters():
+            if p.grad is not None:
+                p.grad[p.grad != p.grad] = 0.0
+        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 10)
+
+        self.optimizer.step()
+
+    def get_loss(self, memory):
+
+        self.is_train = True
+
+        if self.is_train:
+            self.total_it += 1
+
+            for param in self.policy.parameters():
+                param.requires_grad = True
+            for param in self.value.parameters():
+                param.requires_grad = True
+
+            batch, num_online = self.get_batch(memory)
+
+            action_masks, actions, critic_v, current_domain_mask, description_batch, max_length, mu, \
+            non_current_domain_mask, rewards, small_actions, unflattened_states, value_batch \
+                = self.prepare_batch(batch)
+
+            with torch.no_grad():
+                values = self.value(description_batch, value_batch).squeeze(-1)
+
+                pi_prob, _ = self.policy.get_prob(actions, action_masks, max_length, small_actions,
+                                                  current_domain_mask, non_current_domain_mask,
+                                                  description_batch, value_batch)
+                pi_prob = pi_prob.prod(dim=-1)
+
+                rho = torch.min(torch.Tensor([self.rho_bar]).to(DEVICE), pi_prob / mu)
+                cs = torch.min(torch.Tensor([self.c]).to(DEVICE), pi_prob / mu)
+
+                vtrace_target, advantages = self.compute_vtrace_advantage(unflattened_states, rewards, rho, cs, values)
+
+            # Compute critic loss
+            current_v = self.value(description_batch, value_batch).to(DEVICE)
+            critic_loss = torch.square(vtrace_target.unsqueeze(-1).to(DEVICE) - current_v).mean()
+
+            if self.use_regularization:
+                # do behaviour cloning on the buffer data
+                num_online = sum([len(reward_list) for reward_list in batch['rewards'][:num_online]])
+
+                behaviour_loss_critic = torch.square(
+                    critic_v[num_online:].unsqueeze(-1).to(DEVICE) - current_v[num_online:]).mean()
+                critic_loss += self.behaviour_cloning_weight * behaviour_loss_critic
+
+            actor_loss = None
+
+            # Delayed policy updates
+            if self.total_it % self.policy_freq == 0:
+
+                actor_loss, entropy = self.policy.get_log_prob(actions, action_masks, max_length, small_actions,
+                                                               current_domain_mask, non_current_domain_mask,
+                                                               description_batch, value_batch)
+                actor_loss = -1 * actor_loss
+                actor_loss = actor_loss * (advantages.to(DEVICE) * rho)
+                actor_loss = actor_loss.mean() - entropy * self.entropy_weight
+
+                if self.use_regularization:
+                    log_prob, entropy = self.policy.get_log_prob(actions[num_online:], action_masks[num_online:],
+                                                                 max_length, small_actions[num_online:],
+                                                                 current_domain_mask[num_online:],
+                                                                 non_current_domain_mask[num_online:],
+                                                                 description_batch[num_online:],
+                                                                 value_batch[num_online:])
+                    actor_loss = actor_loss - log_prob.mean() * self.behaviour_cloning_weight
+
+            return actor_loss, critic_loss
+
+        else:
+            return np.nan
+
+    def get_batch(self, memory):
+
+        if self.use_regularization or self.online_offline_ratio == 1.0:
+            batch, num_online = memory.sample(self.online_offline_ratio)
+        else:
+            batch, num_online = memory.sample(0.0)
+        return batch, num_online
+
+    def prepare_batch(self, batch):
+        unflattened_states = batch['states']
+        states = [kg for kg_list in unflattened_states for kg in kg_list]
+        description_batch = batch['description_idx_list']
+        description_batch = [descr_ for descr_episode in description_batch for descr_ in descr_episode]
+        value_batch = batch['value_list']
+        value_batch = [value_ for value_episode in value_batch for value_ in value_episode]
+
+        current_domain_mask = batch['current_domain_mask']
+        current_domain_mask = torch.stack([curr_mask for curr_mask_episode in current_domain_mask
+                                           for curr_mask in curr_mask_episode]).to(DEVICE)
+        non_current_domain_mask = batch['non_current_domain_mask']
+        non_current_domain_mask = torch.stack([non_curr_mask for non_curr_mask_episode in non_current_domain_mask
+                                               for non_curr_mask in non_curr_mask_episode]).to(DEVICE)
+        actions = batch['actions']
+        actions = torch.stack([act for act_list in actions for act in act_list], dim=0).to(DEVICE)
+        small_actions = batch['small_actions']
+        small_actions = [act for act_list in small_actions for act in act_list]
+        rewards = batch['rewards']
+        rewards = torch.stack([r for r_episode in rewards for r in r_episode]).to(DEVICE)
+        # rewards = torch.from_numpy(np.concatenate(np.array(rewards), axis=0)).to(DEVICE)
+        mu = batch['mu']
+        mu = torch.stack([mu_ for mu_list in mu for mu_ in mu_list], dim=0).to(DEVICE)
+        critic_v = batch['critic_value']
+        critic_v = torch.stack([v for v_list in critic_v for v in v_list]).to(DEVICE)
+        max_length = max(len(act) for act in small_actions)
+        action_masks = batch['action_masks']
+        action_mask_list = [mask for mask_list in action_masks for mask in mask_list]
+        action_masks = torch.stack([torch.cat([
+            action_mask.to(DEVICE),
+            torch.zeros(max_length - len(action_mask), len(self.policy.action_embedder.small_action_dict)).to(
+                DEVICE)],
+            dim=0) for action_mask in action_mask_list]).to(DEVICE)
+        return action_masks, actions, critic_v, current_domain_mask, description_batch, max_length, mu, \
+               non_current_domain_mask, rewards, small_actions, unflattened_states, value_batch
+
+    def compute_vtrace_advantage(self, states, rewards, rho, cs, values):
+
+        vtraces, advantages, offset = [], [], 0
+        #len(states) is number of episodes sampled, so we iterate over episodes
+        for j in range(0, len(states)):
+            vtrace_list, advantage_list, new_vtrace, v_next = [], [], 0, 0
+            for i in range(len(states[j]) - 1, -1, -1):
+                v_now = values[offset + i]
+                delta = rewards[offset + i] + self.gamma * v_next - v_now
+                delta = rho[offset + i] * delta
+                advantage = rewards[offset + i] + self.gamma * new_vtrace - v_now
+                new_vtrace = v_now + delta + self.gamma * cs[offset + i] * (new_vtrace - v_next)
+                v_next = v_now
+                vtrace_list.append(new_vtrace)
+                advantage_list.append(advantage)
+            vtrace_list = list(reversed(vtrace_list))
+            advantange_list = list(reversed(advantage_list))
+            vtraces.append(vtrace_list)
+            advantages.append(advantange_list)
+            offset += len(states[j])
+
+        vtraces_flat = torch.Tensor([v for v_episode in vtraces for v in v_episode])
+        advantages_flat = torch.Tensor([a for a_episode in advantages for a in a_episode])
+        return vtraces_flat, advantages_flat
+
+    def save(self, directory, addition=""):
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+
+        torch.save(self.value.state_dict(), directory + f'/{addition}_vtrace.val.mdl')
+        torch.save(self.policy.state_dict(), directory + f'/{addition}_vtrace.pol.mdl')
+        torch.save(self.optimizer.state_dict(), directory + f'/{addition}_vtrace.optimizer')
+
+        logging.info(f"Saved policy, critic and optimizer.")
+
+    def load(self, filename):
+
+        value_mdl_candidates = [
+            filename + '.val.mdl',
+            filename + '_vtrace.val.mdl',
+            os.path.join(os.path.dirname(
+                os.path.abspath(__file__)), filename + '.val.mdl'),
+            os.path.join(os.path.dirname(os.path.abspath(
+                __file__)), filename + '_vtrace.val.mdl')
+        ]
+        for value_mdl in value_mdl_candidates:
+            if os.path.exists(value_mdl):
+                self.value.load_state_dict(torch.load(value_mdl, map_location=DEVICE))
+                print('<<dialog policy>> loaded checkpoint from file: {}'.format(value_mdl))
+                break
+
+        policy_mdl_candidates = [
+            filename + '.pol.mdl',
+            filename + '_vtrace.pol.mdl',
+            os.path.join(os.path.dirname(
+                os.path.abspath(__file__)), filename + '.pol.mdl'),
+            os.path.join(os.path.dirname(os.path.abspath(
+                __file__)), filename + '_vtrace.pol.mdl')
+        ]
+
+        for policy_mdl in policy_mdl_candidates:
+            if os.path.exists(policy_mdl):
+                self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
+                self.value_helper.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
+                print('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
+                break
+
+    def load_policy(self, filename):
+
+        policy_mdl_candidates = [
+            filename + '.pol.mdl',
+            filename + '_vtrace.pol.mdl',
+            os.path.join(os.path.dirname(
+                os.path.abspath(__file__)), filename + '.pol.mdl'),
+            os.path.join(os.path.dirname(os.path.abspath(
+                __file__)), filename + '_vtrace.pol.mdl')
+        ]
+
+        for policy_mdl in policy_mdl_candidates:
+            if os.path.exists(policy_mdl):
+                self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
+                self.value_helper.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
+                logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
+                break
+
+    def load_value(self, filename):
+
+        value_mdl_candidates = [
+            filename + '.val.mdl',
+            filename + '_vtrace.val.mdl',
+            os.path.join(os.path.dirname(
+                os.path.abspath(__file__)), filename + '.val.mdl'),
+            os.path.join(os.path.dirname(os.path.abspath(
+                __file__)), filename + '_vtrace.val.mdl')
+        ]
+        for value_mdl in value_mdl_candidates:
+            if os.path.exists(value_mdl):
+                self.value.load_state_dict(torch.load(value_mdl, map_location=DEVICE))
+                logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(value_mdl))
+                break
+
+    def load_optimizer_dicts(self, filename):
+        self.optimizer.load_state_dict(torch.load(filename + f".optimizer", map_location=DEVICE))
+        logging.info('<<dialog policy>> loaded optimisers from file: {}'.format(filename))
+
+    def from_pretrained(self):
+        raise NotImplementedError
+
+    def init_session(self):
+        pass
diff --git a/convlab/task/multiwoz/goal_generator.py b/convlab/task/multiwoz/goal_generator.py
index 3a709c08e061a57d0885063bda5a5c97a501097d..b1f8443b720eb803cdb9852d64f2e5c34bc798aa 100755
--- a/convlab/task/multiwoz/goal_generator.py
+++ b/convlab/task/multiwoz/goal_generator.py
@@ -151,7 +151,7 @@ class GoalGenerator:
                  boldify=False,
                  sample_info_from_trainset=True,
                  sample_reqt_from_trainset=False,
-                 seed = 0):
+                 seed = 0, domain_ordering_dist=None):
         """
         Args:
             goal_model_path: path to a goal model
@@ -174,6 +174,8 @@ class GoalGenerator:
         else:
             self._build_goal_model()
             print('Building goal model is done')
+        if domain_ordering_dist is not None:
+            self.domain_ordering_dist = domain_ordering_dist
         np.random.seed(seed)
         random.seed(seed)
         # remove some slot
diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py
index f3abff7009a21adacd33aef50bd497634d391ae1..38d8b92a36efd67bdf9166c1c5f9f20734d1ecb5 100644
--- a/convlab/util/custom_util.py
+++ b/convlab/util/custom_util.py
@@ -8,6 +8,7 @@ import zipfile
 import numpy as np
 import torch
 from tensorboardX import SummaryWriter
+from convlab.task.multiwoz.goal_generator import GoalGenerator
 from convlab.util.file_util import cached_path
 from convlab.policy.evaluate_distributed import evaluate_distributed
 from convlab.util.train_util_neo import init_logging_nunu
@@ -18,16 +19,15 @@ from convlab.dst.rule.multiwoz import RuleDST
 from convlab.policy.rule.multiwoz import RulePolicy
 from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
 from convlab.util import load_dataset
-from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal
 
 import shutil
+import signal
 
 
 slot_mapping = {"pricerange": "price range", "post": "postcode", "arriveBy": "arrive by", "leaveAt": "leave at",
                 "Id": "trainid", "ref": "reference"}
 
 
-
 sys.path.append(os.path.dirname(os.path.dirname(
     os.path.dirname(os.path.abspath(__file__)))))
 
@@ -35,6 +35,22 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 device = DEVICE
 
 
+class timeout:
+    def __init__(self, seconds=10, error_message='Timeout'):
+        self.seconds = seconds
+        self.error_message = error_message
+
+    def handle_timeout(self, signum, frame):
+        raise TimeoutError(self.error_message)
+
+    def __enter__(self):
+        signal.signal(signal.SIGALRM, self.handle_timeout)
+        signal.alarm(self.seconds)
+
+    def __exit__(self, type, value, traceback):
+        signal.alarm(0)
+
+
 class NumpyEncoder(json.JSONEncoder):
     """ Special json encoder for numpy types """
 
@@ -58,7 +74,8 @@ def flatten_acts(dialogue_acts):
     act_list = []
     for act_type in dialogue_acts:
         for act in dialogue_acts[act_type]:
-            act_list.append([act['intent'], act['domain'], act['slot'], act.get('value', "")])
+            act_list.append([act['intent'], act['domain'],
+                            act['slot'], act.get('value', "")])
     return act_list
 
 
@@ -84,17 +101,20 @@ def load_config_file(filepath: str = None) -> dict:
     return conf
 
 
-def save_config(terminal_args, config_file_args, config_save_path):
+def save_config(terminal_args, config_file_args, config_save_path, policy_config=None):
     config_save_path = os.path.join(config_save_path, f'config_saved.json')
-    args_dict = {"args": terminal_args, "config": config_file_args}
+    args_dict = {"args": terminal_args, "config": config_file_args, "policy_config": policy_config}
     json.dump(args_dict, open(config_save_path, 'w'))
 
 
 def set_seed(seed):
     random.seed(seed)
     np.random.seed(seed)
+    torch.manual_seed(seed)
     if torch.cuda.is_available():
         torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
 
 
 def init_logging(root_dir, mode):
@@ -138,24 +158,33 @@ def save_best(policy_sys, best_complete_rate, best_success_rate, best_return, co
     return best_complete_rate, best_success_rate, best_return
 
 
-def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path):
+def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_domain_goals=False, allowed_domains=None):
     policy_sys.is_train = False
+
+    goal_generator = GoalGenerator()
+    goals = []
+    for seed in range(1000, 1000 + conf['model']['num_eval_dialogues']):
+        set_seed(seed)
+        goal = create_goals(goal_generator, 1, single_domain_goals, allowed_domains)
+        goals.append(goal[0])
+
     if conf['model']['process_num'] == 1:
         complete_rate, success_rate, success_rate_strict, avg_return, turns, \
             avg_actions, task_success, book_acts, inform_acts, request_acts, \
-                select_acts, offer_acts = evaluate(sess,
+                select_acts, offer_acts, recommend_acts = evaluate(sess,
                                                 num_dialogues=conf['model']['num_eval_dialogues'],
                                                 sys_semantic_to_usr=conf['model'][
                                                     'sys_semantic_to_usr'],
-                                                save_flag=save_eval, save_path=log_save_path)
-        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts
+                                                save_flag=save_eval, save_path=log_save_path, goals=goals)
+
+        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
     else:
         complete_rate, success_rate, success_rate_strict, avg_return, turns, \
-        avg_actions, task_success, book_acts, inform_acts, request_acts, \
-        select_acts, offer_acts = \
+            avg_actions, task_success, book_acts, inform_acts, request_acts, \
+            select_acts, offer_acts, recommend_acts = \
             evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])),
-                                 conf['model']['process_num'])
-        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts
+                                 conf['model']['process_num'], goals)
+        total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
 
         task_success_gathered = {}
         for task_dict in task_success:
@@ -164,24 +193,42 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path):
                     task_success_gathered[key] = []
                 task_success_gathered[key].append(value)
         task_success = task_success_gathered
-    
+
     policy_sys.is_train = True
-    logging.info(f"Complete: {complete_rate}, Success: {success_rate}, Success strict: {success_rate_strict}, "
-                 f"Average Return: {avg_return}, Turns: {turns}, Average Actions: {avg_actions}, "
+
+    mean_complete, err_complete = np.average(complete_rate), np.std(complete_rate) / np.sqrt(len(complete_rate))
+    mean_success, err_success = np.average(success_rate), np.std(success_rate) / np.sqrt(len(success_rate))
+    mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std(success_rate_strict) / np.sqrt(len(success_rate_strict))
+    mean_return, err_return = np.average(avg_return), np.std(avg_return) / np.sqrt(len(avg_return))
+    mean_turns, err_turns = np.average(turns), np.std(turns) / np.sqrt(len(turns))
+    mean_actions, err_actions = np.average(avg_actions), np.std(avg_actions) / np.sqrt(len(avg_actions))
+
+    logging.info(f"Complete: {mean_complete}+-{round(err_complete, 2)}, "
+                 f"Success: {mean_success}+-{round(err_success, 2)}, "
+                 f"Success strict: {mean_success_strict}+-{round(err_success_strict, 2)}, "
+                 f"Average Return: {mean_return}+-{round(err_return, 2)}, "
+                 f"Turns: {mean_turns}+-{round(err_turns, 2)}, "
+                 f"Average Actions: {mean_actions}+-{round(err_actions, 2)}, "
                  f"Book Actions: {book_acts/total_acts}, Inform Actions: {inform_acts/total_acts}, "
                  f"Request Actions: {request_acts/total_acts}, Select Actions: {select_acts/total_acts}, "
-                 f"Offer Actions: {offer_acts/total_acts}")
+                 f"Offer Actions: {offer_acts/total_acts}, Recommend Actions: {recommend_acts/total_acts}")
 
     for key in task_success:
         logging.info(
             f"{key}: Num: {len(task_success[key])} Success: {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}")
 
-    return {"complete_rate": complete_rate,
-            "success_rate": success_rate,
-            "success_rate_strict": success_rate_strict,
-            "avg_return": avg_return,
-            "turns": turns,
-            "avg_actions": avg_actions}
+    return {"complete_rate": mean_complete,
+            "success_rate": mean_success,
+            "success_rate_strict": mean_success_strict,
+            "avg_return": mean_return,
+            "turns": mean_turns,
+            "avg_actions": mean_actions,
+            "book_acts": book_acts/total_acts,
+            "inform_acts": inform_acts/total_acts,
+            "request_acts": request_acts/total_acts,
+            "select_acts": select_acts/total_acts,
+            "offer_acts": offer_acts/total_acts,
+            "recommend_acts": recommend_acts/total_acts}
 
 
 def env_config(conf, policy_sys, check_book_constraints=True):
@@ -193,12 +240,21 @@ def env_config(conf, policy_sys, check_book_constraints=True):
     policy_usr = conf['policy_usr_activated']
     usr_nlg = conf['usr_nlg_activated']
 
+    # Setup uncertainty thresholding
+    if dst_sys:
+        try:
+            if dst_sys.return_confidence_scores:
+                policy_sys.vector.setup_uncertain_query(dst_sys.confidence_thresholds)
+        except:
+            logging.info('Uncertainty threshold not set.')
+
     simulator = PipelineAgent(nlu_usr, dst_usr, policy_usr, usr_nlg, 'user')
     system_pipeline = PipelineAgent(nlu_sys, dst_sys, policy_sys, sys_nlg,
                                     'sys', return_semantic_acts=conf['model']['sys_semantic_to_usr'])
 
     # assemble
-    evaluator = MultiWozEvaluator(check_book_constraints=check_book_constraints)
+    evaluator = MultiWozEvaluator(
+        check_book_constraints=check_book_constraints)
     env = Environment(sys_nlg, simulator, nlu_sys, dst_sys, evaluator=evaluator,
                       use_semantic_acts=conf['model']['sys_semantic_to_usr'])
     sess = BiSession(system_pipeline, simulator, None, evaluator)
@@ -264,12 +320,7 @@ def create_env(args, policy_sys):
     return env, sess
 
 
-def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None):
-    seed = 0
-    random.seed(seed)
-    np.random.seed(seed)
-
-    # sess = BiSession(agent_sys, simulator, None, evaluator)
+def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None, goals=None):
 
     eval_save = {}
     turn_counter_dict = {}
@@ -278,11 +329,12 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
     task_success = {'All_user_sim': [], 'All_evaluator': [], "All_evaluator_strict": [],
                     'total_return': [], 'turns': [], 'avg_actions': [],
                     'total_booking_acts': [], 'total_inform_acts': [], 'total_request_acts': [],
-                    'total_select_acts': [], 'total_offer_acts': []}
+                    'total_select_acts': [], 'total_offer_acts': [], 'total_recommend_acts': []}
     dial_count = 0
     for seed in range(1000, 1000 + num_dialogues):
         set_seed(seed)
-        sess.init_session()
+        goal = goals.pop()
+        sess.init_session(goal=goal)
         sys_response = [] if sess.sys_agent.nlg is None else ''
         sys_response = [] if sys_semantic_to_usr else sys_response
         avg_actions = 0
@@ -293,6 +345,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
         request = 0
         select = 0
         offer = 0
+        recommend = 0
         # this 40 represents the max turn of dialogue
         for i in range(40):
             sys_response, user_response, session_over, reward = sess.next_turn(
@@ -315,6 +368,8 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
                     select += 1
                 if intent.lower() == 'offerbook':
                     offer += 1
+                if intent.lower() == 'recommend':
+                    recommend += 1
             avg_actions += len(acts)
             turn_counter += 1
             turns += 1
@@ -345,12 +400,14 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
         task_success['total_return'].append(total_return)
         task_success['turns'].append(turns)
         task_success['avg_actions'].append(avg_actions / turns)
-        
+
         task_success['total_booking_acts'].append(book)
         task_success['total_inform_acts'].append(inform)
         task_success['total_request_acts'].append(request)
         task_success['total_select_acts'].append(select)
         task_success['total_offer_acts'].append(offer)
+        task_success['total_offer_acts'].append(offer)
+        task_success['total_recommend_acts'].append(recommend)
 
         # print(agent_sys.agent_saves)
         eval_save['Conversation {}'.format(str(dial_count))] = [
@@ -367,11 +424,11 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
     # save dialogue_info and clear mem
 
     return np.average(task_success['All_user_sim']), np.average(task_success['All_evaluator']), \
-           np.average(task_success['All_evaluator_strict']), np.average(task_success['total_return']), \
-           np.average(task_success['turns']), np.average(task_success['avg_actions']), task_success, \
-           np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \
-           np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \
-           np.average(task_success['total_offer_acts'])
+        np.average(task_success['All_evaluator_strict']), np.average(task_success['total_return']), \
+        np.average(task_success['turns']), np.average(task_success['avg_actions']), task_success, \
+        np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \
+        np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \
+        np.average(task_success['total_offer_acts']), np.average(task_success['total_recommend_acts'])
 
 
 def model_downloader(download_dir, model_path):
@@ -395,7 +452,8 @@ def get_goal_distribution(dataset_name='multiwoz21'):
         data = data_split[key]
         for dialogue in data:
             goal = dialogue['goal']
-            domains = list(set(goal['inform'].keys()) | set(goal['request'].keys()))
+            domains = list(set(goal['inform'].keys()) |
+                           set(goal['request'].keys()))
             domains.sort()
             domains = "-".join(domains)
 
@@ -417,14 +475,16 @@ def get_goal_distribution(dataset_name='multiwoz21'):
     domain_combinations.sort(key=lambda x: x[1], reverse=True)
     print(domain_combinations)
     print(single_domain_counter)
-    print("Number of combinations:", sum([value for _, value in domain_combinations]))
+    print("Number of combinations:", sum(
+        [value for _, value in domain_combinations]))
 
 
 def unified_format(acts):
     new_acts = {'categorical': []}
     for act in acts:
         intent, domain, slot, value = act
-        new_acts['categorical'].append({"intent": intent, "domain": domain, "slot": slot, "value": value})
+        new_acts['categorical'].append(
+            {"intent": intent, "domain": domain, "slot": slot, "value": value})
 
     return new_acts
 
@@ -438,6 +498,7 @@ def act_dict_to_flat_tuple(acts):
 
 
 def create_goals(goal_generator, num_goals, single_domains=False, allowed_domains=None):
+    from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal
 
     collected_goals = []
     while len(collected_goals) != num_goals:
@@ -450,6 +511,32 @@ def create_goals(goal_generator, num_goals, single_domains=False, allowed_domain
     return collected_goals
 
 
+def build_domains_goal(goal_generator, domains=None):
+    from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal
+    found = False
+    while not found:
+        goal = Goal(goal_generator)
+        if domains is None:
+            found = True
+        if set(goal.domain_goals) == domains:
+            found = True
+    return goal
+
+
+def data_goals(num_goals, dataset="multiwoz21", dial_ids_order=0):
+    from convlab.policy.tus.unify.Goal import Goal
+    from convlab.policy.tus.unify.util import create_goal
+    data = load_dataset(dataset, dial_ids_order)
+    collected_goals = []
+    for dialog in data["test"]:
+        goal = Goal(create_goal(dialog))
+        collected_goals.append(goal)
+    if len(collected_goals) < num_goals:
+        print(f"# of data goals ({data['test']}) < num_goals {num_goals}")
+    # reorder goals?
+    return collected_goals
+
+
 def map_class(cls_path: str):
     """
     Map to class via package text path
@@ -492,18 +579,21 @@ def get_config(filepath, args) -> dict:
     vec_name = [model for model in conf['vectorizer_sys']]
     vec_name = vec_name[0] if vec_name else None
     if dst_name and 'setsumbt' in dst_name.lower():
-        if 'get_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = conf['dst_sys'][dst_name]['ini_params']['get_confidence_scores']
+        if 'return_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']:
+            param = conf['dst_sys'][dst_name]['ini_params']['return_confidence_scores']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = param
         else:
             conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = False
-        if 'return_mutual_info' in conf['dst_sys'][dst_name]['ini_params']:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = conf['dst_sys'][dst_name]['ini_params']['return_mutual_info']
+        if 'return_belief_state_mutual_info' in conf['dst_sys'][dst_name]['ini_params']:
+            param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_mutual_info']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = param
         else:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = False
-        if 'return_entropy' in conf['dst_sys'][dst_name]['ini_params']:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = conf['dst_sys'][dst_name]['ini_params']['return_entropy']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = False
+        if 'return_belief_state_entropy' in conf['dst_sys'][dst_name]['ini_params']:
+            param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_entropy']
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = param
         else:
-            conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = False
+            conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = False
 
     from convlab.nlu import NLU
     from convlab.dst import DST
@@ -532,13 +622,10 @@ def get_config(filepath, args) -> dict:
                 cls_path = infos.get('class_path', '')
                 cls = map_class(cls_path)
                 conf[unit + '_class'] = cls
-                conf[unit + '_activated'] = conf[unit +
-                                                 '_class'](**conf[unit][model]['ini_params'])
+                conf[unit + '_activated'] = conf[unit + '_class'](**conf[unit][model]['ini_params'])
                 print("Loaded " + model + " for " + unit)
     return conf
 
 
 if __name__ == '__main__':
     get_goal_distribution()
-
-
diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py
index 9eab25e26d82d7f20d92343a2ac64bb807239074..4fd9f262ece3fe38935c26bc449b1e93121db498 100755
--- a/convlab/util/multiwoz/lexicalize.py
+++ b/convlab/util/multiwoz/lexicalize.py
@@ -86,7 +86,8 @@ def lexicalize_da(meta, entities, state, requestable):
                             pair[1] = entities[domain][n][slot_old]
                         elif slot in state[domain]:
                             pair[1] = state[domain][slot]
-                        pair[1] = pair[1] if pair[1] else 'not available'
+                        else:
+                            pair[1] = 'not available'
                     elif slot in state[domain]:
                         pair[1] = state[domain][slot] if state[domain][slot] else 'none'
                     else:
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index aff31be6742be3908d5ff4ab65e141b3427471d9..726079d1c2b04c304bfd7055d48b1b4ae4905856 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -16,17 +16,19 @@ from tqdm import tqdm
 
 class BaseDatabase(ABC):
     """Base class of unified database. Should override the query function."""
+
     def __init__(self):
         """extract data.zip and load the database."""
 
     @abstractmethod
-    def query(self, domain:str, state:dict, topk:int, **kwargs)->list:
+    def query(self, domain: str, state: dict, topk: int, **kwargs) -> list:
         """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
 
+
 def download_unified_datasets(dataset_name, filename, data_dir):
     """
     It downloads the file of unified datasets from HuggingFace's datasets if it doesn't exist in the data directory
-    
+
     :param dataset_name: The name of the dataset
     :param filename: the name of the file you want to download
     :param data_dir: the directory where the file will be downloaded to
@@ -41,22 +43,26 @@ def download_unified_datasets(dataset_name, filename, data_dir):
         shutil.move(cache_path, data_path)
     return data_path
 
+
 def relative_import_module_from_unified_datasets(dataset_name, filename, names2import):
     """
     It downloads a file from the unified datasets repository, imports it as a module, and returns the
     variable(s) you want from that module
-    
+
     :param dataset_name: the name of the dataset, e.g. 'multiwoz21'
     :param filename: the name of the file to download, e.g. 'preprocess.py'
     :param names2import: a string or a list of strings. If it's a string, it's the name of the variable
     to import. If it's a list of strings, it's the names of the variables to import
     :return: the variable(s) that are being imported from the module.
     """
-    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(
+        __file__), f'../../../data/unified_datasets/{dataset_name}'))
     assert filename.endswith('.py')
-    assert isinstance(names2import, str) or (isinstance(names2import, list) and len(names2import) > 0)
+    assert isinstance(names2import, str) or (
+        isinstance(names2import, list) and len(names2import) > 0)
     data_path = download_unified_datasets(dataset_name, filename, data_dir)
-    module_spec = importlib.util.spec_from_file_location(filename[:-3], data_path)
+    module_spec = importlib.util.spec_from_file_location(
+        filename[:-3], data_path)
     module = importlib.util.module_from_spec(module_spec)
     module_spec.loader.exec_module(module)
     if isinstance(names2import, str):
@@ -67,7 +73,8 @@ def relative_import_module_from_unified_datasets(dataset_name, filename, names2i
             variables.append(eval(f'module.{name}'))
         return variables
 
-def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
+
+def load_dataset(dataset_name: str, dial_ids_order=None, split2ratio={}) -> Dict:
     """load unified dataset from `data/unified_datasets/$dataset_name`
 
     Args:
@@ -79,7 +86,8 @@ def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
     Returns:
         dataset (dict): keys are data splits and the values are lists of dialogues
     """
-    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(
+        __file__), f'../../../data/unified_datasets/{dataset_name}'))
     data_path = download_unified_datasets(dataset_name, 'data.zip', data_dir)
 
     archive = ZipFile(data_path)
@@ -87,11 +95,13 @@ def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
         dialogues = json.loads(f.read())
     dataset = {}
     if dial_ids_order is not None:
-        data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir)
+        data_path = download_unified_datasets(
+            dataset_name, 'shuffled_dial_ids.json', data_dir)
         dial_ids = json.load(open(data_path))[dial_ids_order]
         for data_split in dial_ids:
             ratio = split2ratio.get(data_split, 1)
-            dataset[data_split] = [dialogues[i] for i in dial_ids[data_split][:round(len(dial_ids[data_split])*ratio)]]
+            dataset[data_split] = [dialogues[i]
+                                   for i in dial_ids[data_split][:round(len(dial_ids[data_split])*ratio)]]
     else:
         for dialogue in dialogues:
             if dialogue['data_split'] not in dataset:
@@ -100,10 +110,12 @@ def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
                 dataset[dialogue['data_split']].append(dialogue)
         for data_split in dataset:
             if data_split in split2ratio:
-                dataset[data_split] = dataset[data_split][:round(len(dataset[data_split])*split2ratio[data_split])]
+                dataset[data_split] = dataset[data_split][:round(
+                    len(dataset[data_split])*split2ratio[data_split])]
     return dataset
 
-def load_ontology(dataset_name:str) -> Dict:
+
+def load_ontology(dataset_name: str) -> Dict:
     """load unified ontology from `data/unified_datasets/$dataset_name`
 
     Args:
@@ -112,7 +124,8 @@ def load_ontology(dataset_name:str) -> Dict:
     Returns:
         ontology (dict): dataset ontology
     """
-    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(
+        __file__), f'../../../data/unified_datasets/{dataset_name}'))
     data_path = download_unified_datasets(dataset_name, 'data.zip', data_dir)
 
     archive = ZipFile(data_path)
@@ -120,7 +133,8 @@ def load_ontology(dataset_name:str) -> Dict:
         ontology = json.loads(f.read())
     return ontology
 
-def load_database(dataset_name:str):
+
+def load_database(dataset_name: str):
     """load database from `data/unified_datasets/$dataset_name`
 
     Args:
@@ -129,36 +143,40 @@ def load_database(dataset_name:str):
     Returns:
         database: an instance of BaseDatabase
     """
-    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
-    data_path = download_unified_datasets(dataset_name, 'database.py', data_dir)
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(
+        __file__), f'../../../data/unified_datasets/{dataset_name}'))
+    data_path = download_unified_datasets(
+        dataset_name, 'database.py', data_dir)
     module_spec = importlib.util.spec_from_file_location('database', data_path)
     module = importlib.util.module_from_spec(module_spec)
     module_spec.loader.exec_module(module)
-    Database = relative_import_module_from_unified_datasets(dataset_name, 'database.py', 'Database')
+    Database = relative_import_module_from_unified_datasets(
+        dataset_name, 'database.py', 'Database')
     assert issubclass(Database, BaseDatabase)
     database = Database()
     assert isinstance(database, BaseDatabase)
     return database
 
+
 def load_unified_data(
-        dataset, 
-        data_split='all', 
-        speaker='all', 
-        utterance=False, 
-        dialogue_acts=False, 
-        state=False, 
-        db_results=False,
-        use_context=False, 
-        context_window_size=0, 
-        terminated=False, 
-        goal=False, 
-        active_domains=False,
-        split_to_turn=True
-    ):
+    dataset,
+    data_split='all',
+    speaker='all',
+    utterance=False,
+    dialogue_acts=False,
+    state=False,
+    db_results=False,
+    use_context=False,
+    context_window_size=0,
+    terminated=False,
+    goal=False,
+    active_domains=False,
+    split_to_turn=True
+):
     """
     > This function takes in a dataset, and returns a dictionary of data splits, where each data split
     is a list of samples
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: which split of the data to load. Can be 'train', 'validation', 'test', or 'all',
     defaults to all (optional)
@@ -182,7 +200,8 @@ def load_unified_data(
     data_splits = dataset.keys() if data_split == 'all' else [data_split]
     assert speaker in ['user', 'system', 'all']
     assert not use_context or context_window_size > 0
-    info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
+    info_list = list(
+        filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
     info_list += ['utt_idx']
     data_by_split = {}
     for data_split in data_splits:
@@ -194,7 +213,7 @@ def load_unified_data(
                 for ele in info_list:
                     if ele in turn:
                         sample[ele] = turn[ele]
-                
+
                 if use_context or not split_to_turn:
                     sample_copy = deepcopy(sample)
                     context.append(sample_copy)
@@ -207,7 +226,8 @@ def load_unified_data(
                     if active_domains:
                         sample['domains'] = dialogue['domains']
                     if terminated:
-                        sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1
+                        sample['terminated'] = turn['utt_idx'] == len(
+                            dialogue['turns']) - 1
                     if speaker == 'system' and 'booked' in turn:
                         sample['booked'] = turn['booked']
                     data_by_split[data_split].append(sample)
@@ -221,7 +241,7 @@ def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False,
     """
     It loads the data from the specified dataset, and returns it in a format that is suitable for
     training a NLU model
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: 'train', 'validation', 'test', or 'all', defaults to all (optional)
     :param speaker: 'user' or 'system', defaults to user (optional)
@@ -243,7 +263,7 @@ def load_dst_data(dataset, data_split='all', speaker='user', context_window_size
     """
     It loads the data from the specified dataset, with the specified data split, speaker, context window
     size, suitable for training a DST model
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: 'train', 'validation', 'test', or 'all', defaults to all (optional)
     :param speaker: 'user' or 'system', defaults to user (optional)
@@ -259,11 +279,12 @@ def load_dst_data(dataset, data_split='all', speaker='user', context_window_size
     kwargs.setdefault('state', True)
     return load_unified_data(dataset, **kwargs)
 
+
 def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs):
     """
     It loads the data from the specified dataset, and returns it in a format that is suitable for
     training a policy
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: 'train', 'validation', 'test', or 'all', defaults to all (optional)
     :param speaker: 'system' or 'user', defaults to system (optional)
@@ -283,11 +304,12 @@ def load_policy_data(dataset, data_split='all', speaker='system', context_window
     kwargs.setdefault('terminated', True)
     return load_unified_data(dataset, **kwargs)
 
+
 def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs):
     """
     It loads the data from the specified dataset, and returns it in a format that is suitable for
     training a NLG model
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: 'train', 'validation', 'test', or 'all', defaults to all (optional)
     :param speaker: 'system' or 'user', defaults to system (optional)
@@ -304,11 +326,12 @@ def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False
     kwargs.setdefault('dialogue_acts', True)
     return load_unified_data(dataset, **kwargs)
 
+
 def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
     """
     It loads the data from the specified dataset, and returns it in a format that is suitable for
     training an End2End model
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: 'train', 'validation', 'test', or 'all', defaults to all (optional)
     :param speaker: 'system' or 'user', defaults to system (optional)
@@ -327,11 +350,12 @@ def load_e2e_data(dataset, data_split='all', speaker='system', context_window_si
     kwargs.setdefault('dialogue_acts', True)
     return load_unified_data(dataset, **kwargs)
 
+
 def load_rg_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
     """
     It loads the data from the dataset, and returns it in a format that is suitable for training a 
     response generation model
-    
+
     :param dataset: dataset object from `load_dataset`
     :param data_split: 'train', 'validation', 'test', or 'all', defaults to all (optional)
     :param speaker: 'system' or 'user', defaults to system (optional)
@@ -347,7 +371,7 @@ def load_rg_data(dataset, data_split='all', speaker='system', context_window_siz
     return load_unified_data(dataset, **kwargs)
 
 
-def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore_values=['yes', 'no']):
+def create_delex_data(dataset, delex_func=lambda d, s, v: f'[({d})-({s})]', ignore_values=['yes', 'no']):
     """add delex_utterance to the dataset according to dialogue acts and belief_state
     delex_func: function that return the placeholder (e.g. "[(domain_name)-(slot_name)]") given (domain, slot, value)
     ignore_values: ignored values when delexicalizing using the categorical acts and states
@@ -357,7 +381,7 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
         It takes a list of strings and placeholders, and a regex pattern. If the pattern matches exactly
         one string, it replaces that string with a placeholder and returns True. Otherwise, it returns
         False
-        
+
         :param texts_placeholders: a list of tuples, each tuple is a string and a boolean. The boolean
         indicates whether the string is a placeholder or not
         :param value_pattern: a regular expression that matches the value to be delexicalized
@@ -379,7 +403,8 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
             searchObj = re.search(value_pattern, substring)
             assert searchObj
             start, end = searchObj.span(1)
-            texts_placeholders[idx:idx+1] = [(substring[0:start], False), (placeholder, True), (substring[end:], False)]
+            texts_placeholders[idx:idx+1] = [
+                (substring[0:start], False), (placeholder, True), (substring[end:], False)]
             return True
         return False
 
@@ -392,7 +417,8 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
                 delex_utt = []
                 last_end = 0
                 # ignore the non-categorical das that do not have span annotation
-                spans = [x for x in turn['dialogue_acts']['non-categorical'] if 'start' in x]
+                spans = [x for x in turn['dialogue_acts']
+                         ['non-categorical'] if 'start' in x]
                 for da in sorted(spans, key=lambda x: x['start']):
                     # from left to right
                     start, end = da['start'], da['end']
@@ -412,7 +438,8 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
                     domain, slot, value = da['domain'], da['slot'], da['value']
                     if value.lower() not in ignore_values:
                         placeholder = delex_func(domain, slot, value)
-                        pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
+                        pattern = re.compile(
+                            r'\b({})\b'.format(value), flags=re.I)
                         if delex_inplace(delex_utt, pattern):
                             delex_vocab.add(placeholder)
 
@@ -425,13 +452,15 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
                             # has value
                             for value in values.split('|'):
                                 if value.lower() not in ignore_values:
-                                    placeholder = delex_func(domain, slot, value)
-                                    pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
+                                    placeholder = delex_func(
+                                        domain, slot, value)
+                                    pattern = re.compile(
+                                        r'\b({})\b'.format(value), flags=re.I)
                                     if delex_inplace(delex_utt, pattern):
                                         delex_vocab.add(placeholder)
 
                 turn['delex_utterance'] = ''.join([x[0] for x in delex_utt])
-    
+
     return dataset, sorted(list(delex_vocab))
 
 
@@ -468,7 +497,8 @@ def retrieve_utterances(query_turns, turn_pool, top_k, model_name):
 if __name__ == "__main__":
     dataset = load_dataset('multiwoz21', dial_ids_order=0)
     train_ratio = 0.1
-    dataset['train'] = dataset['train'][:round(len(dataset['train'])*train_ratio)]
+    dataset['train'] = dataset['train'][:round(
+        len(dataset['train'])*train_ratio)]
     print(len(dataset['train']))
     print(dataset.keys())
     print(len(dataset['test']))
@@ -477,7 +507,7 @@ if __name__ == "__main__":
     database = load_database('multiwoz21')
     res = database.query("train", {'train':{'departure':'cambridge', 'destination':'peterborough', 'day':'tuesday', 'arrive by':'11:15'}}, topk=3)
     print(res[0], len(res))
-    
+
     data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
     query_turns = data_by_split['test'][:10]
     pool_dataset = load_dataset('camrest')
@@ -490,8 +520,10 @@ if __name__ == "__main__":
         return f'[{slot}]'
 
     dataset, delex_vocab = create_delex_data(dataset, delex_slot)
-    json.dump(dataset['test'], open('new_delex_multiwoz21_test.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
-    json.dump(delex_vocab, open('new_delex_vocab.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(dataset['test'], open('new_delex_multiwoz21_test.json',
+              'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(delex_vocab, open('new_delex_vocab.json', 'w',
+              encoding='utf-8'), indent=2, ensure_ascii=False)
     with open('new_delex_cmp.txt', 'w') as f:
         for dialog in dataset['test']:
             for turn in dialog['turns']: