diff --git a/DO.example.advanced b/DO.example.advanced
new file mode 100755
index 0000000000000000000000000000000000000000..d705f5088ce85227ffb0eb85e769ea473ace3139
--- /dev/null
+++ b/DO.example.advanced
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+# Parameters ------------------------------------------------------
+
+#TASK="sim-m"
+#DATA_DIR="data/simulated-dialogue/sim-M"
+#TASK="sim-r"
+#DATA_DIR="data/simulated-dialogue/sim-R"
+TASK="woz2"
+DATA_DIR="data/woz2"
+#TASK="multiwoz21"
+#DATA_DIR="data/MULTIWOZ2.1"
+
+# Project paths etc. ----------------------------------------------
+
+OUT_DIR=results
+mkdir -p ${OUT_DIR}
+
+# Main ------------------------------------------------------------
+
+for step in train dev test; do
+    args_add=""
+    if [ "$step" = "train" ]; then
+	args_add="--do_train --predict_type=dummy"
+    elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+	args_add="--do_eval --predict_type=${step}"
+    fi
+
+    python3 run_dst.py \
+	    --task_name=${TASK} \
+	    --data_dir=${DATA_DIR} \
+	    --dataset_config=dataset_config/${TASK}.json \
+	    --model_type="bert" \
+	    --model_name_or_path="bert-base-uncased" \
+	    --do_lower_case \
+	    --learning_rate=1e-4 \
+	    --num_train_epochs=10 \
+	    --max_seq_length=180 \
+	    --per_gpu_train_batch_size=48 \
+	    --per_gpu_eval_batch_size=1 \
+	    --output_dir=${OUT_DIR} \
+	    --save_epochs=2 \
+	    --logging_steps=10 \
+	    --warmup_proportion=0.1 \
+	    --eval_all_checkpoints \
+	    --adam_epsilon=1e-6 \
+	    --label_value_repetitions \
+            --swap_utterances \
+	    --append_history \
+	    --use_history_labels \
+	    --delexicalize_sys_utts \
+	    --class_aux_feats_inform \
+	    --class_aux_feats_ds \
+	    ${args_add} \
+	    2>&1 | tee ${OUT_DIR}/${step}.log
+    fi
+
+    if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+    	python3 ~/tools/trippy_clean/metric_bert_dst.py \
+    		${TASK} \
+		dataset_config/${TASK}.json \
+    		"${OUT_DIR}/pred_res.${step}*json" \
+    		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
+    fi
+done
diff --git a/DO.example.simple b/DO.example.simple
new file mode 100755
index 0000000000000000000000000000000000000000..43bb891897e2fec9198f6fed631c50b9a1b10f72
--- /dev/null
+++ b/DO.example.simple
@@ -0,0 +1,59 @@
+#!/bin/bash
+
+# Parameters ------------------------------------------------------
+
+#TASK="sim-m"
+#DATA_DIR="data/simulated-dialogue/sim-M"
+#TASK="sim-r"
+#DATA_DIR="data/simulated-dialogue/sim-R"
+TASK="woz2"
+DATA_DIR="data/woz2"
+#TASK="multiwoz21"
+#DATA_DIR="data/MULTIWOZ2.1"
+
+# Project paths etc. ----------------------------------------------
+
+OUT_DIR=results
+mkdir -p ${OUT_DIR}
+
+# Main ------------------------------------------------------------
+
+for step in train dev test; do
+    args_add=""
+    if [ "$step" = "train" ]; then
+	args_add="--do_train --predict_type=dummy"
+    elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+	args_add="--do_eval --predict_type=${step}"
+    fi
+
+    python3 run_dst.py \
+	    --task_name=${TASK} \
+	    --data_dir=${DATA_DIR} \
+	    --dataset_config=dataset_config/${TASK}.json \
+	    --model_type="bert" \
+	    --model_name_or_path="bert-base-uncased" \
+	    --do_lower_case \
+	    --learning_rate=1e-4 \
+	    --num_train_epochs=10 \
+	    --max_seq_length=180 \
+	    --per_gpu_train_batch_size=48 \
+	    --per_gpu_eval_batch_size=1 \
+	    --output_dir=${OUT_DIR} \
+	    --save_epochs=2 \
+	    --logging_steps=10 \
+	    --warmup_proportion=0.1 \
+	    --eval_all_checkpoints \
+	    --adam_epsilon=1e-6 \
+	    --label_value_repetitions \
+	    ${args_add} \
+	    2>&1 | tee ${OUT_DIR}/${step}.log
+    fi
+
+    if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+    	python3 ~/tools/trippy_clean/metric_bert_dst.py \
+    		${TASK} \
+		dataset_config/${TASK}.json \
+    		"${OUT_DIR}/pred_res.${step}*json" \
+    		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
+    fi
+done
diff --git a/data_processors.py b/data_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7670df0b5a0734f05499fab24bf4a137edb652
--- /dev/null
+++ b/data_processors.py
@@ -0,0 +1,94 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+
+import dataset_woz2
+import dataset_sim
+import dataset_multiwoz21
+
+
+class DataProcessor(object):
+    def __init__(self, dataset_config):
+        with open(dataset_config, "r", encoding='utf-8') as f:
+            raw_config = json.load(f)
+        self.class_types = raw_config['class_types']
+        self.slot_list = raw_config['slots']
+        self.label_maps = raw_config['label_maps']
+
+    def get_train_examples(self, data_dir, **args):
+        raise NotImplementedError()
+
+    def get_dev_examples(self, data_dir, **args):
+        raise NotImplementedError()
+
+    def get_test_examples(self, data_dir, **args):
+        raise NotImplementedError()
+
+
+class Woz2Processor(DataProcessor):
+    def get_train_examples(self, data_dir, args):
+        return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_train_en.json'),
+                                            'train', self.slot_list, self.label_maps, **args)
+
+    def get_dev_examples(self, data_dir, args):
+        return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_validate_en.json'),
+                                            'dev', self.slot_list, self.label_maps, **args)
+
+    def get_test_examples(self, data_dir, args):
+        return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_test_en.json'),
+                                            'test', self.slot_list, self.label_maps, **args)
+
+
+class Multiwoz21Processor(DataProcessor):
+    def get_train_examples(self, data_dir, args):
+        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'train_dials.json'),
+                                                  os.path.join(data_dir, 'dialogue_acts.json'),
+                                                  'train', self.slot_list, self.label_maps, **args)
+
+    def get_dev_examples(self, data_dir, args):
+        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'),
+                                                  os.path.join(data_dir, 'dialogue_acts.json'),
+                                                  'dev', self.slot_list, self.label_maps, **args)
+
+    def get_test_examples(self, data_dir, args):
+        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'),
+                                                  os.path.join(data_dir, 'dialogue_acts.json'),
+                                                  'test', self.slot_list, self.label_maps, **args)
+
+
+class SimProcessor(DataProcessor):
+    def get_train_examples(self, data_dir, args):
+        return dataset_sim.create_examples(os.path.join(data_dir, 'train.json'),
+                                           'train', self.slot_list, **args)
+
+    def get_dev_examples(self, data_dir, args):
+        return dataset_sim.create_examples(os.path.join(data_dir, 'dev.json'),
+                                           'dev', self.slot_list, **args)
+
+    def get_test_examples(self, data_dir, args):
+        return dataset_sim.create_examples(os.path.join(data_dir, 'test.json'),
+                                           'test', self.slot_list, **args)
+
+
+PROCESSORS = {"woz2": Woz2Processor,
+              "sim-m": SimProcessor,
+              "sim-r": SimProcessor,
+              "multiwoz21": Multiwoz21Processor}
diff --git a/dataset_config/multiwoz21.json b/dataset_config/multiwoz21.json
new file mode 100644
index 0000000000000000000000000000000000000000..2fff3bb30b60d48d623c9792ebf7a0caf70e2aaf
--- /dev/null
+++ b/dataset_config/multiwoz21.json
@@ -0,0 +1,1297 @@
+{
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "true",
+    "false",
+    "refer",
+    "inform"
+  ]
+  "slots": [
+    "taxi-leaveAt",
+    "taxi-destination",
+    "taxi-departure",
+    "taxi-arriveBy",
+    "restaurant-book_people",
+    "restaurant-book_day",
+    "restaurant-book_time",
+    "restaurant-food",
+    "restaurant-pricerange",
+    "restaurant-name",
+    "restaurant-area",
+    "hotel-book_people",
+    "hotel-book_day",
+    "hotel-book_stay",
+    "hotel-name",
+    "hotel-area",
+    "hotel-parking",
+    "hotel-pricerange",
+    "hotel-stars",
+    "hotel-internet",
+    "hotel-type",
+    "attraction-type",
+    "attraction-name",
+    "attraction-area",
+    "train-book_people",
+    "train-leaveAt",
+    "train-destination",
+    "train-day",
+    "train-arriveBy",
+    "train-departure"
+  ]
+  "label_maps": {
+    "guest house": [
+      "guest houses"
+    ],
+    "hotel": [
+      "hotels"
+    ],
+    "centre": [
+      "center",
+      "downtown"
+    ],
+    "north": [
+      "northern",
+      "northside",
+      "northend"
+    ],
+    "east": [
+      "eastern",
+      "eastside",
+      "eastend"
+    ],
+    "west": [
+      "western",
+      "westside",
+      "westend"
+    ],
+    "south": [
+      "southern",
+      "southside",
+      "southend"
+    ],
+    "cheap": [
+      "inexpensive",
+      "lower price",
+      "lower range",
+      "cheaply",
+      "cheaper",
+      "cheapest",
+      "very affordable"
+    ],
+    "moderate": [
+      "moderately",
+      "reasonable",
+      "reasonably",
+      "affordable",
+      "mid range",
+      "mid-range",
+      "priced moderately",
+      "decently priced",
+      "mid price",
+      "mid-price",
+      "mid priced",
+      "mid-priced",
+      "middle price",
+      "medium price",
+      "medium priced",
+      "not too expensive",
+      "not too cheap"
+    ],
+    "expensive": [
+      "high end",
+      "high-end",
+      "high class",
+      "high-class",
+      "high scale",
+      "high-scale",
+      "high price",
+      "high priced",
+      "higher price",
+      "fancy",
+      "upscale",
+      "nice",
+      "expensively",
+      "luxury"
+    ],
+    "0": [
+      "zero"
+    ],
+    "1": [
+      "one",
+      "just me",
+      "for me",
+      "myself",
+      "alone",
+      "me"
+    ],
+    "2": [
+      "two"
+    ],
+    "3": [
+      "three"
+    ],
+    "4": [
+      "four"
+    ],
+    "5": [
+      "five"
+    ],
+    "6": [
+      "six"
+    ],
+    "7": [
+      "seven"
+    ],
+    "8": [
+      "eight"
+    ],
+    "9": [
+      "nine"
+    ],
+    "10": [
+      "ten"
+    ],
+    "11": [
+      "eleven"
+    ],
+    "12": [
+      "twelve"
+    ],
+    "architecture": [
+      "architectural",
+      "architecturally",
+      "architect"
+    ],
+    "boat": [
+      "boating",
+      "boats",
+      "camboats"
+    ],
+    "boating": [
+      "boat",
+      "boats",
+      "camboats"
+    ],
+    "camboats": [
+      "boating",
+      "boat",
+      "boats"
+    ],
+    "cinema": [
+      "cinemas",
+      "movie",
+      "films",
+      "film"
+    ],
+    "college": [
+      "colleges"
+    ],
+    "concert": [
+      "concert hall",
+      "concert halls",
+      "concerthall",
+      "concerthalls",
+      "concerts"
+    ],
+    "concerthall": [
+      "concert hall",
+      "concert halls",
+      "concerthalls",
+      "concerts",
+      "concert"
+    ],
+    "entertainment": [
+      "entertaining"
+    ],
+    "gallery": [
+      "museum"
+    ],
+    "gastropubs": [
+      "gastropub"
+    ],
+    "multiple sports": [
+      "multiple sport",
+      "multi sport",
+      "multi sports",
+      "sports",
+      "sporting"
+    ],
+    "museum": [
+      "museums",
+      "gallery",
+      "galleries"
+    ],
+    "night club": [
+      "night clubs",
+      "nightclub",
+      "nightclubs",
+      "club",
+      "clubs"
+    ],
+    "park": [
+      "parks"
+    ],
+    "pool": [
+      "swimming pool",
+      "swimming",
+      "pools",
+      "swimmingpool",
+      "water",
+      "swim"
+    ],
+    "sports": [
+      "multiple sport",
+      "multi sport",
+      "multi sports",
+      "multiple sports",
+      "sporting"
+    ],
+    "swimming pool": [
+      "swimming",
+      "pool",
+      "pools",
+      "swimmingpool",
+      "water",
+      "swim"
+    ],
+    "theater": [
+      "theatre",
+      "theatres",
+      "theaters"
+    ],
+    "theatre": [
+      "theater",
+      "theatres",
+      "theaters"
+    ],
+    "abbey pool and astroturf pitch": [
+      "abbey pool and astroturf",
+      "abbey pool"
+    ],
+    "abbey pool and astroturf": [
+      "abbey pool and astroturf pitch",
+      "abbey pool"
+    ],
+    "abbey pool": [
+      "abbey pool and astroturf pitch",
+      "abbey pool and astroturf"
+    ],
+    "adc theatre": [
+      "adc theater",
+      "adc"
+    ],
+    "adc": [
+      "adc theatre",
+      "adc theater"
+    ],
+    "addenbrookes hospital": [
+      "addenbrooke's hospital"
+    ],
+    "cafe jello gallery": [
+      "cafe jello"
+    ],
+    "cambridge and county folk museum": [
+      "cambridge and country folk museum",
+      "county folk museum"
+    ],
+    "cambridge and country folk museum": [
+      "cambridge and county folk museum",
+      "county folk museum"
+    ],
+    "county folk museum": [
+      "cambridge and county folk museum",
+      "cambridge and country folk museum"
+    ],
+    "cambridge arts theatre": [
+      "cambridge arts theater"
+    ],
+    "cambridge book and print gallery": [
+      "book and print gallery"
+    ],
+    "cambridge contemporary art": [
+      "cambridge contemporary art museum",
+      "contemporary art museum"
+    ],
+    "cambridge contemporary art museum": [
+      "cambridge contemporary art",
+      "contemporary art museum"
+    ],
+    "cambridge corn exchange": [
+      "the cambridge corn exchange"
+    ],
+    "the cambridge corn exchange": [
+      "cambridge corn exchange"
+    ],
+    "cambridge museum of technology": [
+      "museum of technology"
+    ],
+    "cambridge punter": [
+      "the cambridge punter",
+      "cambridge punters"
+    ],
+    "cambridge punters": [
+      "the cambridge punter",
+      "cambridge punter"
+    ],
+    "the cambridge punter": [
+      "cambridge punter",
+      "cambridge punters"
+    ],
+    "cambridge university botanic gardens": [
+      "cambridge university botanical gardens",
+      "cambridge university botanical garden",
+      "cambridge university botanic garden",
+      "cambridge botanic gardens",
+      "cambridge botanical gardens",
+      "cambridge botanic garden",
+      "cambridge botanical garden",
+      "botanic gardens",
+      "botanical gardens",
+      "botanic garden",
+      "botanical garden"
+    ],
+    "cambridge botanic gardens": [
+      "cambridge university botanic gardens",
+      "cambridge university botanical gardens",
+      "cambridge university botanical garden",
+      "cambridge university botanic garden",
+      "cambridge botanical gardens",
+      "cambridge botanic garden",
+      "cambridge botanical garden",
+      "botanic gardens",
+      "botanical gardens",
+      "botanic garden",
+      "botanical garden"
+    ],
+    "botanic gardens": [
+      "cambridge university botanic gardens",
+      "cambridge university botanical gardens",
+      "cambridge university botanical garden",
+      "cambridge university botanic garden",
+      "cambridge botanic gardens",
+      "cambridge botanical gardens",
+      "cambridge botanic garden",
+      "cambridge botanical garden",
+      "botanical gardens",
+      "botanic garden",
+      "botanical garden"
+    ],
+    "cherry hinton village centre": [
+      "cherry hinton village center"
+    ],
+    "cherry hinton village center": [
+      "cherry hinton village centre"
+    ],
+    "cherry hinton hall and grounds": [
+      "cherry hinton hall"
+    ],
+    "cherry hinton hall": [
+      "cherry hinton hall and grounds"
+    ],
+    "cherry hinton water play": [
+      "cherry hinton water play park"
+    ],
+    "cherry hinton water play park": [
+      "cherry hinton water play"
+    ],
+    "christ college": [
+      "christ's college",
+      "christs college"
+    ],
+    "christs college": [
+      "christ college",
+      "christ's college"
+    ],
+    "churchills college": [
+      "churchill's college",
+      "churchill college"
+    ],
+    "cineworld cinema": [
+      "cineworld"
+    ],
+    "clair hall": [
+      "clare hall"
+    ],
+    "clare hall": [
+      "clair hall"
+    ],
+    "the fez club": [
+      "fez club"
+    ],
+    "great saint marys church": [
+      "great saint mary's church",
+      "great saint mary's",
+      "great saint marys"
+    ],
+    "jesus green outdoor pool": [
+      "jesus green"
+    ],
+    "jesus green": [
+      "jesus green outdoor pool"
+    ],
+    "kettles yard": [
+      "kettle's yard"
+    ],
+    "kings college": [
+      "king's college"
+    ],
+    "kings hedges learner pool": [
+      "king's hedges learner pool",
+      "king hedges learner pool"
+    ],
+    "king hedges learner pool": [
+      "king's hedges learner pool",
+      "kings hedges learner pool"
+    ],
+    "little saint marys church": [
+      "little saint mary's church",
+      "little saint mary's",
+      "little saint marys"
+    ],
+    "mumford theatre": [
+      "mumford theater"
+    ],
+    "museum of archaelogy": [
+      "museum of archaeology"
+    ],
+    "museum of archaelogy and anthropology": [
+      "museum of archaeology and anthropology"
+    ],
+    "peoples portraits exhibition": [
+      "people's portraits exhibition at girton college",
+      "peoples portraits exhibition at girton college",
+      "people's portraits exhibition"
+    ],
+    "peoples portraits exhibition at girton college": [
+      "people's portraits exhibition at girton college",
+      "people's portraits exhibition",
+      "peoples portraits exhibition"
+    ],
+    "queens college": [
+      "queens' college",
+      "queen's college"
+    ],
+    "riverboat georgina": [
+      "riverboat"
+    ],
+    "saint barnabas": [
+      "saint barbabas press gallery"
+    ],
+    "saint barnabas press gallery": [
+      "saint barbabas"
+    ],
+    "saint catharines college": [
+      "saint catharine's college",
+      "saint catharine's"
+    ],
+    "saint johns college": [
+      "saint john's college",
+      "st john's college",
+      "st johns college"
+    ],
+    "scott polar": [
+      "scott polar museum"
+    ],
+    "scott polar museum": [
+      "scott polar"
+    ],
+    "scudamores punting co": [
+      "scudamore's punting co",
+      "scudamores punting",
+      "scudamore's punting",
+      "scudamores",
+      "scudamore's",
+      "scudamore"
+    ],
+    "scudamore": [
+      "scudamore's punting co",
+      "scudamores punting co",
+      "scudamores punting",
+      "scudamore's punting",
+      "scudamores",
+      "scudamore's"
+    ],
+    "sheeps green and lammas land park fen causeway": [
+      "sheep's green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "sheeps green and lammas land park",
+      "lammas land park",
+      "sheep's green",
+      "sheeps green"
+    ],
+    "sheeps green and lammas land park": [
+      "sheep's green and lammas land park fen causeway",
+      "sheeps green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "lammas land park",
+      "sheep's green",
+      "sheeps green"
+    ],
+    "lammas land park": [
+      "sheep's green and lammas land park fen causeway",
+      "sheeps green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "sheeps green and lammas land park",
+      "sheep's green",
+      "sheeps green"
+    ],
+    "sheeps green": [
+      "sheep's green and lammas land park fen causeway",
+      "sheeps green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "sheeps green and lammas land park",
+      "lammas land park",
+      "sheep's green"
+    ],
+    "soul tree nightclub": [
+      "soul tree night club",
+      "soul tree",
+      "soultree"
+    ],
+    "soultree": [
+      "soul tree nightclub",
+      "soul tree night club",
+      "soul tree"
+    ],
+    "the man on the moon": [
+      "man on the moon"
+    ],
+    "man on the moon": [
+      "the man on the moon"
+    ],
+    "the junction": [
+      "junction theatre",
+      "junction theater"
+    ],
+    "junction theatre": [
+      "the junction",
+      "junction theater"
+    ],
+    "old schools": [
+      "old school"
+    ],
+    "vue cinema": [
+      "vue"
+    ],
+    "wandlebury country park": [
+      "the wandlebury"
+    ],
+    "the wandlebury": [
+      "wandlebury country park"
+    ],
+    "whipple museum of the history of science": [
+      "whipple museum",
+      "history of science museum"
+    ],
+    "history of science museum": [
+      "whipple museum of the history of science",
+      "whipple museum"
+    ],
+    "williams art and antique": [
+      "william's art and antique"
+    ],
+    "alimentum": [
+      "restaurant alimentum"
+    ],
+    "restaurant alimentum": [
+      "alimentum"
+    ],
+    "bedouin": [
+      "the bedouin"
+    ],
+    "the bedouin": [
+      "bedouin"
+    ],
+    "bloomsbury restaurant": [
+      "bloomsbury"
+    ],
+    "cafe uno": [
+      "caffe uno",
+      "caffee uno"
+    ],
+    "caffe uno": [
+      "cafe uno",
+      "caffee uno"
+    ],
+    "caffee uno": [
+      "cafe uno",
+      "caffe uno"
+    ],
+    "cambridge lodge restaurant": [
+      "cambridge lodge"
+    ],
+    "chiquito": [
+      "chiquito restaurant bar",
+      "chiquito restaurant"
+    ],
+    "chiquito restaurant bar": [
+      "chiquito restaurant",
+      "chiquito"
+    ],
+    "city stop restaurant": [
+      "city stop"
+    ],
+    "cityr": [
+      "cityroomz"
+    ],
+    "citiroomz": [
+      "cityroomz"
+    ],
+    "clowns cafe": [
+      "clown's cafe"
+    ],
+    "cow pizza kitchen and bar": [
+      "the cow pizza kitchen and bar",
+      "cow pizza"
+    ],
+    "the cow pizza kitchen and bar": [
+      "cow pizza kitchen and bar",
+      "cow pizza"
+    ],
+    "darrys cookhouse and wine shop": [
+      "darry's cookhouse and wine shop",
+      "darry's cookhouse",
+      "darrys cookhouse"
+    ],
+    "de luca cucina and bar": [
+      "de luca cucina and bar riverside brasserie",
+      "luca cucina and bar",
+      "de luca cucina",
+      "luca cucina"
+    ],
+    "de luca cucina and bar riverside brasserie": [
+      "de luca cucina and bar",
+      "luca cucina and bar",
+      "de luca cucina",
+      "luca cucina"
+    ],
+    "da vinci pizzeria": [
+      "da vinci pizza",
+      "da vinci"
+    ],
+    "don pasquale pizzeria": [
+      "don pasquale pizza",
+      "don pasquale",
+      "pasquale pizzeria",
+      "pasquale pizza"
+    ],
+    "efes": [
+      "efes restaurant"
+    ],
+    "efes restaurant": [
+      "efes"
+    ],
+    "fitzbillies restaurant": [
+      "fitzbillies"
+    ],
+    "frankie and bennys": [
+      "frankie and benny's"
+    ],
+    "funky": [
+      "funky fun house"
+    ],
+    "funky fun house": [
+      "funky"
+    ],
+    "gardenia": [
+      "the gardenia"
+    ],
+    "the gardenia": [
+      "gardenia"
+    ],
+    "grafton hotel restaurant": [
+      "the grafton hotel",
+      "grafton hotel"
+    ],
+    "the grafton hotel": [
+      "grafton hotel restaurant",
+      "grafton hotel"
+    ],
+    "grafton hotel": [
+      "grafton hotel restaurant",
+      "the grafton hotel"
+    ],
+    "hotel du vin and bistro": [
+      "hotel du vin",
+      "du vin"
+    ],
+    "Kohinoor": [
+      "kohinoor",
+      "the kohinoor"
+    ],
+    "kohinoor": [
+      "the kohinoor"
+    ],
+    "the kohinoor": [
+      "kohinoor"
+    ],
+    "lan hong house": [
+      "lan hong",
+      "ian hong house",
+      "ian hong"
+    ],
+    "ian hong": [
+      "lan hong house",
+      "lan hong",
+      "ian hong house"
+    ],
+    "lovel": [
+      "the lovell lodge",
+      "lovell lodge"
+    ],
+    "lovell lodge": [
+      "lovell"
+    ],
+    "the lovell lodge": [
+      "lovell lodge",
+      "lovell"
+    ],
+    "mahal of cambridge": [
+      "mahal"
+    ],
+    "mahal": [
+      "mahal of cambridge"
+    ],
+    "maharajah tandoori restaurant": [
+      "maharajah tandoori"
+    ],
+    "the maharajah tandoor": [
+      "maharajah tandoori restaurant",
+      "maharajah tandoori"
+    ],
+    "meze bar": [
+      "meze bar restaurant",
+      "the meze bar"
+    ],
+    "meze bar restaurant": [
+      "the meze bar",
+      "meze bar"
+    ],
+    "the meze bar": [
+      "meze bar restaurant",
+      "meze bar"
+    ],
+    "michaelhouse cafe": [
+      "michael house cafe"
+    ],
+    "midsummer house restaurant": [
+      "midsummer house"
+    ],
+    "missing sock": [
+      "the missing sock"
+    ],
+    "the missing sock": [
+      "missing sock"
+    ],
+    "nandos": [
+      "nando's city centre",
+      "nando's city center",
+      "nandos city centre",
+      "nandos city center",
+      "nando's"
+    ],
+    "nandos city centre": [
+      "nando's city centre",
+      "nando's city center",
+      "nandos city center",
+      "nando's",
+      "nandos"
+    ],
+    "oak bistro": [
+      "the oak bistro"
+    ],
+    "the oak bistro": [
+      "oak bistro"
+    ],
+    "one seven": [
+      "restaurant one seven"
+    ],
+    "restaurant one seven": [
+      "one seven"
+    ],
+    "river bar steakhouse and grill": [
+      "the river bar steakhouse and grill",
+      "the river bar steakhouse",
+      "river bar steakhouse"
+    ],
+    "the river bar steakhouse and grill": [
+      "river bar steakhouse and grill",
+      "the river bar steakhouse",
+      "river bar steakhouse"
+    ],
+    "pipasha restaurant": [
+      "pipasha"
+    ],
+    "pizza hut city centre": [
+      "pizza hut city center"
+    ],
+    "pizza hut fenditton": [
+      "pizza hut fen ditton",
+      "pizza express fen ditton"
+    ],
+    "restaurant two two": [
+      "two two",
+      "restaurant 22"
+    ],
+    "saffron brasserie": [
+      "saffron"
+    ],
+    "saint johns chop house": [
+      "saint john's chop house",
+      "st john's chop house",
+      "st johns chop house"
+    ],
+    "sesame restaurant and bar": [
+      "sesame restaurant",
+      "sesame"
+    ],
+    "shanghai": [
+      "shanghai family restaurant"
+    ],
+    "shanghai family restaurant": [
+      "shanghai"
+    ],
+    "sitar": [
+      "sitar tandoori"
+    ],
+    "sitar tandoori": [
+      "sitar"
+    ],
+    "slug and lettuce": [
+      "the slug and lettuce"
+    ],
+    "the slug and lettuce": [
+      "slug and lettuce"
+    ],
+    "st johns chop house": [
+      "saint john's chop house",
+      "st john's chop house",
+      "saint johns chop house"
+    ],
+    "stazione restaurant and coffee bar": [
+      "stazione restaurant",
+      "stazione"
+    ],
+    "thanh binh": [
+      "thanh",
+      "binh"
+    ],
+    "thanh": [
+      "thanh binh",
+      "binh"
+    ],
+    "binh": [
+      "thanh binh",
+      "thanh"
+    ],
+    "the hotpot": [
+      "the hotspot",
+      "hotpot",
+      "hotspot"
+    ],
+    "hotpot": [
+      "the hotpot",
+      "the hotpot",
+      "hotspot"
+    ],
+    "the lucky star": [
+      "lucky star"
+    ],
+    "lucky star": [
+      "the lucky star"
+    ],
+    "the peking restaurant: ": [
+      "peking restaurant"
+    ],
+    "the varsity restaurant": [
+      "varsity restaurant",
+      "the varsity",
+      "varsity"
+    ],
+    "two two": [
+      "restaurant two two",
+      "restaurant 22"
+    ],
+    "restaurant 22": [
+      "restaurant two two",
+      "two two"
+    ],
+    "zizzi cambridge": [
+      "zizzi"
+    ],
+    "american": [
+      "americas"
+    ],
+    "asian oriental": [
+      "asian",
+      "oriental"
+    ],
+    "australian": [
+      "australasian"
+    ],
+    "barbeque": [
+      "barbecue",
+      "bbq"
+    ],
+    "corsica": [
+      "corsican"
+    ],
+    "indian": [
+      "tandoori"
+    ],
+    "italian": [
+      "pizza",
+      "pizzeria"
+    ],
+    "japanese": [
+      "sushi"
+    ],
+    "latin american": [
+      "latin-american",
+      "latin"
+    ],
+    "malaysian": [
+      "malay"
+    ],
+    "middle eastern": [
+      "middle-eastern"
+    ],
+    "traditional american": [
+      "american"
+    ],
+    "modern american": [
+      "american modern",
+      "american"
+    ],
+    "modern european": [
+      "european modern",
+      "european"
+    ],
+    "north american": [
+      "north-american",
+      "american"
+    ],
+    "portuguese": [
+      "portugese"
+    ],
+    "portugese": [
+      "portuguese"
+    ],
+    "seafood": [
+      "sea food"
+    ],
+    "singaporean": [
+      "singapore"
+    ],
+    "steakhouse": [
+      "steak house",
+      "steak"
+    ],
+    "the americas": [
+      "american",
+      "americas"
+    ],
+    "a and b guest house": [
+      "a & b guest house",
+      "a and b",
+      "a & b"
+    ],
+    "the acorn guest house": [
+      "acorn guest house",
+      "acorn"
+    ],
+    "acorn guest house": [
+      "the acorn guest house",
+      "acorn"
+    ],
+    "alexander bed and breakfast": [
+      "alexander"
+    ],
+    "allenbell": [
+      "the allenbell"
+    ],
+    "the allenbell": [
+      "allenbell"
+    ],
+    "alpha-milton guest house": [
+      "the alpha-milton",
+      "alpha-milton"
+    ],
+    "the alpha-milton": [
+      "alpha-milton guest house",
+      "alpha-milton"
+    ],
+    "arbury lodge guest house": [
+      "arbury lodge",
+      "arbury"
+    ],
+    "archway house": [
+      "archway"
+    ],
+    "ashley hotel": [
+      "the ashley hotel",
+      "ashley"
+    ],
+    "the ashley hotel": [
+      "ashley hotel",
+      "ashley"
+    ],
+    "aylesbray lodge guest house": [
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "aylesbray lodge guest": [
+      "aylesbray lodge guest house",
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "alesbray lodge guest house": [
+      "aylesbray lodge guest house",
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "alyesbray lodge hotel": [
+      "aylesbray lodge guest house",
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "bridge guest house": [
+      "bridge house"
+    ],
+    "cambridge belfry": [
+      "the cambridge belfry",
+      "belfry hotel",
+      "belfry"
+    ],
+    "the cambridge belfry": [
+      "cambridge belfry",
+      "belfry hotel",
+      "belfry"
+    ],
+    "belfry hotel": [
+      "the cambridge belfry",
+      "cambridge belfry",
+      "belfry"
+    ],
+    "carolina bed and breakfast": [
+      "carolina"
+    ],
+    "city centre north": [
+      "city centre north bed and breakfast"
+    ],
+    "north b and b": [
+      "city centre north bed and breakfast"
+    ],
+    "city centre north b and b": [
+      "city centre north bed and breakfast"
+    ],
+    "el shaddia guest house": [
+      "el shaddai guest house",
+      "el shaddai",
+      "el shaddia"
+    ],
+    "el shaddai guest house": [
+      "el shaddia guest house",
+      "el shaddai",
+      "el shaddia"
+    ],
+    "express by holiday inn cambridge": [
+      "express by holiday inn",
+      "holiday inn cambridge",
+      "holiday inn"
+    ],
+    "holiday inn": [
+      "express by holiday inn cambridge",
+      "express by holiday inn",
+      "holiday inn cambridge"
+    ],
+    "finches bed and breakfast": [
+      "finches"
+    ],
+    "gonville hotel": [
+      "gonville"
+    ],
+    "hamilton lodge": [
+      "the hamilton lodge",
+      "hamilton"
+    ],
+    "the hamilton lodge": [
+      "hamilton lodge",
+      "hamilton"
+    ],
+    "hobsons house": [
+      "hobson's house",
+      "hobson's"
+    ],
+    "huntingdon marriott hotel": [
+      "huntington marriott hotel",
+      "huntington marriot hotel",
+      "huntingdon marriot hotel",
+      "huntington marriott",
+      "huntingdon marriott",
+      "huntington marriot",
+      "huntingdon marriot",
+      "huntington",
+      "huntingdon"
+    ],
+    "kirkwood": [
+      "kirkwood house"
+    ],
+    "kirkwood house": [
+      "kirkwood"
+    ],
+    "lensfield hotel": [
+      "the lensfield hotel",
+      "lensfield"
+    ],
+    "the lensfield hotel": [
+      "lensfield hotel",
+      "lensfield"
+    ],
+    "leverton house": [
+      "leverton"
+    ],
+    "marriot hotel": [
+      "marriott hotel",
+      "marriott"
+    ],
+    "rosas bed and breakfast": [
+      "rosa's bed and breakfast",
+      "rosa's",
+      "rosas"
+    ],
+    "university arms hotel": [
+      "university arms"
+    ],
+    "warkworth house": [
+      "warkworth hotel",
+      "warkworth"
+    ],
+    "warkworth hotel": [
+      "warkworth house",
+      "warkworth"
+    ],
+    "wartworth": [
+      "warkworth house",
+      "warkworth hotel",
+      "warkworth"
+    ],
+    "worth house": [
+      "the worth house"
+    ],
+    "the worth house": [
+      "worth house"
+    ],
+    "birmingham new street": [
+      "birmingham new street train station"
+    ],
+    "birmingham new street train station": [
+      "birmingham new street"
+    ],
+    "bishops stortford": [
+      "bishops stortford train station"
+    ],
+    "bishops stortford train station": [
+      "bishops stortford"
+    ],
+    "broxbourne": [
+      "broxbourne train station"
+    ],
+    "broxbourne train station": [
+      "broxbourne"
+    ],
+    "cambridge": [
+      "cambridge train station"
+    ],
+    "cambridge train station": [
+      "cambridge"
+    ],
+    "ely": [
+      "ely train station"
+    ],
+    "ely train station": [
+      "ely"
+    ],
+    "kings lynn": [
+      "king's lynn",
+      "king's lynn train station",
+      "kings lynn train station"
+    ],
+    "kings lynn train station": [
+      "kings lynn",
+      "king's lynn",
+      "king's lynn train station"
+    ],
+    "leicester": [
+      "leicester train station"
+    ],
+    "leicester train station": [
+      "leicester"
+    ],
+    "london kings cross": [
+      "kings cross",
+      "king's cross",
+      "london king's cross",
+      "kings cross train station",
+      "king's cross train station",
+      "london king's cross train station",
+      "london kings cross train station"
+    ],
+    "london kings cross train station": [
+      "kings cross",
+      "king's cross",
+      "london king's cross",
+      "london kings cross",
+      "kings cross train station",
+      "king's cross train station",
+      "london king's cross train station"
+    ],
+    "london liverpool": [
+      "liverpool street",
+      "london liverpool street",
+      "london liverpool train station",
+      "liverpool street train station",
+      "london liverpool street train station"
+    ],
+    "london liverpool street": [
+      "london liverpool",
+      "liverpool street",
+      "london liverpool train station",
+      "liverpool street train station",
+      "london liverpool street train station"
+    ],
+    "london liverpool street train station": [
+      "london liverpool",
+      "liverpool street",
+      "london liverpool street",
+      "london liverpool train station",
+      "liverpool street train station"
+    ],
+    "norwich": [
+      "norwich train station"
+    ],
+    "norwich train station": [
+      "norwich"
+    ],
+    "peterborough": [
+      "peterborough train station"
+    ],
+    "peterborough train station": [
+      "peterborough"
+    ],
+    "stansted airport": [
+      "stansted airport train station"
+    ],
+    "stansted airport train station": [
+      "stansted airport"
+    ],
+    "stevenage": [
+      "stevenage train station"
+    ],
+    "stevenage train station": [
+      "stevenage"
+    ]
+  }
+}
diff --git a/dataset_config/sim-m.json b/dataset_config/sim-m.json
new file mode 100644
index 0000000000000000000000000000000000000000..34793d75da2a58876eb98ebf0aa8441069d6a300
--- /dev/null
+++ b/dataset_config/sim-m.json
@@ -0,0 +1,16 @@
+{
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "inform"
+  ],
+  "slots": [
+    "date",
+    "movie",
+    "time",
+    "num_tickets",
+    "theatre_name"
+  ],
+  "label_maps": {}
+}
diff --git a/dataset_config/sim-r.json b/dataset_config/sim-r.json
new file mode 100644
index 0000000000000000000000000000000000000000..cc635e174e8ef50d33abe5b6ee12f868aca1155a
--- /dev/null
+++ b/dataset_config/sim-r.json
@@ -0,0 +1,20 @@
+{
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "inform"
+  ],
+  "slots": [
+    "category",
+    "rating",
+    "num_people",
+    "location",
+    "restaurant_name",
+    "time",
+    "date",
+    "price_range",
+    "meal"
+  ],
+  "label_maps": {}
+}
diff --git a/dataset_config/woz2.json b/dataset_config/woz2.json
new file mode 100644
index 0000000000000000000000000000000000000000..361446e7c9936f99643e8e91ba449578fa26c868
--- /dev/null
+++ b/dataset_config/woz2.json
@@ -0,0 +1,212 @@
+{
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "inform"
+  ],
+  "slots": [
+    "area",
+    "food",
+    "price_range"
+  ],
+  "label_maps": {
+    "center": [
+      "centre",
+      "downtown",
+      "central",
+      "down town",
+      "middle"
+    ],
+    "south": [
+      "southern",
+      "southside"
+    ],
+    "north": [
+      "northern",
+      "uptown",
+      "northside"
+    ],
+    "west": [
+      "western",
+      "westside"
+    ],
+    "east": [
+      "eastern",
+      "eastside"
+    ],
+    "east side": [
+      "eastern",
+      "eastside"
+    ],
+    "cheap": [
+      "low price",
+      "inexpensive",
+      "cheaper",
+      "low priced",
+      "affordable",
+      "nothing too expensive",
+      "without costing a fortune",
+      "cheapest",
+      "good deals",
+      "low prices",
+      "afford",
+      "on a budget",
+      "fair prices",
+      "less expensive",
+      "cheapeast",
+      "not cost an arm and a leg"
+    ],
+    "moderate": [
+      "moderately",
+      "medium priced",
+      "medium price",
+      "fair price",
+      "fair prices",
+      "reasonable",
+      "reasonably priced",
+      "mid price",
+      "fairly priced",
+      "not outrageous",
+      "not too expensive",
+      "on a budget",
+      "mid range",
+      "reasonable priced",
+      "less expensive",
+      "not too pricey",
+      "nothing too expensive",
+      "nothing cheap",
+      "not overpriced",
+      "medium",
+      "inexpensive"
+    ],
+    "expensive": [
+      "high priced",
+      "high end",
+      "high class",
+      "high quality",
+      "fancy",
+      "upscale",
+      "nice",
+      "fine dining",
+      "expensively priced"
+    ],
+    "afghan": [
+      "afghanistan"
+    ],
+    "african": [
+      "africa"
+    ],
+    "asian oriental": [
+      "asian",
+      "oriental"
+    ],
+    "australasian": [
+      "australian asian",
+      "austral asian"
+    ],
+    "australian": [
+      "aussie"
+    ],
+    "barbeque": [
+      "barbecue",
+      "bbq"
+    ],
+    "basque": [
+      "bask"
+    ],
+    "belgian": [
+      "belgium"
+    ],
+    "british": [
+      "cotto"
+    ],
+    "canapes": [
+      "canopy",
+      "canape",
+      "canap"
+    ],
+    "catalan": [
+      "catalonian"
+    ],
+    "corsican": [
+      "corsica"
+    ],
+    "crossover": [
+      "cross over",
+      "over"
+    ],
+    "gastropub": [
+      "gastro pub",
+      "gastro",
+      "gastropubs"
+    ],
+    "hungarian": [
+      "goulash"
+    ],
+    "indian": [
+      "india",
+      "indians",
+      "nirala"
+    ],
+    "international": [
+      "all types of food"
+    ],
+    "italian": [
+      "prezzo"
+    ],
+    "jamaican": [
+      "jamaica"
+    ],
+    "japanese": [
+      "sushi",
+      "beni hana"
+    ],
+    "korean": [
+      "korea"
+    ],
+    "lebanese": [
+      "lebanse"
+    ],
+    "north american": [
+      "american",
+      "hamburger"
+    ],
+    "portuguese": [
+      "portugese"
+    ],
+    "seafood": [
+      "sea food",
+      "shellfish",
+      "fish"
+    ],
+    "singaporean": [
+      "singapore"
+    ],
+    "steakhouse": [
+      "steak house",
+      "steak"
+    ],
+    "thai": [
+      "thailand",
+      "bangkok"
+    ],
+    "traditional": [
+      "old fashioned",
+      "plain"
+    ],
+    "turkish": [
+      "turkey"
+    ],
+    "unusual": [
+      "unique and strange"
+    ],
+    "venetian": [
+      "vanessa"
+    ],
+    "vietnamese": [
+      "vietnam",
+      "thanh binh"
+    ]
+  }
+}
diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py
new file mode 100644
index 0000000000000000000000000000000000000000..9751f5c8a945c8f84004ac2a9eea3edab9d44852
--- /dev/null
+++ b/dataset_multiwoz21.py
@@ -0,0 +1,541 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+
+from utils_dst import (DSTExample, convert_to_unicode)
+
+
+# Required for mapping slot names in dialogue_acts.json file
+# to proper designations.
+ACTS_DICT = {'taxi-depart': 'taxi-departure',
+             'taxi-dest': 'taxi-destination',
+             'taxi-leave': 'taxi-leaveAt',
+             'taxi-arrive': 'taxi-arriveBy',
+             'train-depart': 'train-departure',
+             'train-dest': 'train-destination',
+             'train-leave': 'train-leaveAt',
+             'train-arrive': 'train-arriveBy',
+             'train-people': 'train-book_people',
+             'restaurant-price': 'restaurant-pricerange',
+             'restaurant-people': 'restaurant-book_people',
+             'restaurant-day': 'restaurant-book_day',
+             'restaurant-time': 'restaurant-book_time',
+             'hotel-price': 'hotel-pricerange',
+             'hotel-people': 'hotel-book_people',
+             'hotel-day': 'hotel-book_day',
+             'hotel-stay': 'hotel-book_stay',
+             'booking-people': 'booking-book_people',
+             'booking-day': 'booking-book_day',
+             'booking-stay': 'booking-book_stay',
+             'booking-time': 'booking-book_time',
+}
+
+
+LABEL_MAPS = {} # Loaded from file
+
+
+# Loads the dialogue_acts.json and returns a list
+# of slot-value pairs.
+def load_acts(input_file):
+    with open(input_file) as f:
+        acts = json.load(f)
+    s_dict = {}
+    for d in acts:
+        for t in acts[d]:
+            # Only process, if turn has annotation
+            if isinstance(acts[d][t], dict):
+                for a in acts[d][t]:
+                    aa = a.lower().split('-')
+                    if aa[1] == 'inform' or aa[1] == 'recommend' or aa[1] == 'select' or aa[1] == 'book':
+                        for i in acts[d][t][a]:
+                            s = i[0].lower()
+                            v = i[1].lower().strip()
+                            if s == 'none' or v == '?' or v == 'none':
+                                continue
+                            slot = aa[0] + '-' + s
+                            if slot in ACTS_DICT:
+                                slot = ACTS_DICT[slot]
+                            key = d + '.json', t, slot
+                            # In case of multiple mentioned values...
+                            # ... Option 1: Keep first informed value
+                            if key not in s_dict:
+                                s_dict[key] = list([v])
+                            # ... Option 2: Keep last informed value
+                            #s_dict[key] = list([v])
+    return s_dict
+
+
+def normalize_time(text):
+    text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space
+    text = re.sub("(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text) # am/pm short to long form
+    text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)", r"\1\2 \3:\4\5", text) # Missing separator
+    text = re.sub("(^| )(\d{2})[;.,](\d{2})", r"\1\2:\3", text) # Wrong separator
+    text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)", r"\1\2 \3:00\4", text) # normalize simple full hour time
+    text = re.sub("(^| )(\d{1}:\d{2})", r"\g<1>0\2", text) # Add missing leading 0
+    # Map 12 hour times to 24 hour times
+    text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?", lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text)
+    text = re.sub("(^| )24:(\d{2})", r"\g<1>00:\2", text) # Correct times that use 24 as hour
+    return text
+
+
+def normalize_text(text):
+    text = normalize_time(text)
+    text = re.sub("n't", " not", text)
+    text = re.sub("(^| )zero(-| )star([s.,? ]|$)", r"\g<1>0 star\3", text)
+    text = re.sub("(^| )one(-| )star([s.,? ]|$)", r"\g<1>1 star\3", text)
+    text = re.sub("(^| )two(-| )star([s.,? ]|$)", r"\g<1>2 star\3", text)
+    text = re.sub("(^| )three(-| )star([s.,? ]|$)", r"\g<1>3 star\3", text)
+    text = re.sub("(^| )four(-| )star([s.,? ]|$)", r"\g<1>4 star\3", text)
+    text = re.sub("(^| )five(-| )star([s.,? ]|$)", r"\g<1>5 star\3", text)
+    text = re.sub("archaelogy", "archaeology", text) # Systematic typo
+    text = re.sub("guesthouse", "guest house", text) # Normalization
+    text = re.sub("(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text) # Normalization
+    text = re.sub("bed & breakfast", "bed and breakfast", text) # Normalization
+    return text
+
+
+# This should only contain label normalizations. All other mappings should
+# be defined in LABEL_MAPS.
+def normalize_label(slot, value_label):
+    # Normalization of empty slots
+    if value_label == '' or value_label == "not mentioned":
+        return "none"
+
+    # Normalization of time slots
+    if "leaveAt" in slot or "arriveBy" in slot or slot == 'restaurant-book_time':
+        return normalize_time(value_label)
+
+    # Normalization
+    if "type" in slot or "name" in slot or "destination" in slot or "departure" in slot:
+        value_label = re.sub("guesthouse", "guest house", value_label)
+
+    # Map to boolean slots
+    if slot == 'hotel-parking' or slot == 'hotel-internet':
+        if value_label == 'yes' or value_label == 'free':
+            return "true"
+        if value_label == "no":
+            return "false"
+    if slot == 'hotel-type':
+        if value_label == "hotel":
+            return "true"
+        if value_label == "guest house":
+            return "false"
+
+    return value_label
+
+
+def get_token_pos(tok_list, value_label):
+    find_pos = []
+    found = False
+    label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0]
+    len_label = len(label_list)
+    for i in range(len(tok_list) + 1 - len_label):
+        if tok_list[i:i + len_label] == label_list:
+            find_pos.append((i, i + len_label)) # start, exclusive_end
+            found = True
+    return found, find_pos
+
+
+def check_label_existence(value_label, usr_utt_tok):
+    in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label)
+    # If no hit even though there should be one, check for value label variants
+    if not in_usr and value_label in LABEL_MAPS:
+        for value_label_variant in LABEL_MAPS[value_label]:
+            in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant)
+            if in_usr:
+                break
+    return in_usr, usr_pos
+
+
+def check_slot_referral(value_label, slot, seen_slots):
+    referred_slot = 'none'
+    if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking':
+        return referred_slot
+    for s in seen_slots:
+        # Avoid matches for slots that share values with different meaning.
+        # hotel-internet and -parking are handled separately as Boolean slots.
+        if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking':
+            continue
+        if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay':
+            continue
+        if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay':
+            continue
+        if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label):
+            if seen_slots[s] == value_label:
+                referred_slot = s
+                break
+            elif value_label in LABEL_MAPS:
+                for value_label_variant in LABEL_MAPS[value_label]:
+                    if seen_slots[s] == value_label_variant:
+                        referred_slot = s
+                        break
+    return referred_slot
+
+
+def is_in_list(tok, value):
+    found = False
+    tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
+    value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
+    tok_len = len(tok_list)
+    value_len = len(value_list)
+    for i in range(tok_len + 1 - value_len):
+        if tok_list[i:i + value_len] == value_list:
+            found = True
+            break
+    return found
+ 
+
+def delex_utt(utt, values):
+    utt_norm = tokenize(utt)
+    for s, vals in values.items():
+        for v in vals:
+            if v != 'none':
+                v_norm = tokenize(v)
+                v_len = len(v_norm)
+                for i in range(len(utt_norm) + 1 - v_len):
+                    if utt_norm[i:i + v_len] == v_norm:
+                        utt_norm[i:i + v_len] = ['[UNK]'] * v_len
+    return utt_norm
+
+
+# Fuzzy matching to label informed slot values
+def check_slot_inform(value_label, inform_label):
+    result = False
+    informed_value = 'none'
+    vl = ' '.join(tokenize(value_label))
+    for il in inform_label:
+        if vl == il:
+            result = True
+        elif is_in_list(il, vl):
+            result = True
+        elif is_in_list(vl, il):
+            result = True
+        elif il in LABEL_MAPS:
+            for il_variant in LABEL_MAPS[il]:
+                if vl == il_variant:
+                    result = True
+                    break
+                elif is_in_list(il_variant, vl):
+                    result = True
+                    break
+                elif is_in_list(vl, il_variant):
+                    result = True
+                    break
+        elif vl in LABEL_MAPS:
+            for value_label_variant in LABEL_MAPS[vl]:
+                if value_label_variant == il:
+                    result = True
+                    break
+                elif is_in_list(il, value_label_variant):
+                    result = True
+                    break
+                elif is_in_list(value_label_variant, il):
+                    result = True
+                    break
+        if result:
+            informed_value = il
+            break
+    return result, informed_value
+
+
+def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence):
+    usr_utt_tok_label = [0 for _ in usr_utt_tok]
+    informed_value = 'none'
+    referred_slot = 'none'
+    if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false':
+        class_type = value_label
+    else:
+        in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok)
+        if in_usr:
+            class_type = 'copy_value'
+            if slot_last_occurrence:
+                (s, e) = usr_pos[-1]
+                for i in range(s, e):
+                    usr_utt_tok_label[i] = 1
+            else:
+                for (s, e) in usr_pos:
+                    for i in range(s, e):
+                        usr_utt_tok_label[i] = 1
+        else:
+            is_informed, informed_value = check_slot_inform(value_label, inform_label)
+            if is_informed:
+                class_type = 'inform'
+            else:
+                referred_slot = check_slot_referral(value_label, slot, seen_slots)
+                if referred_slot != 'none':
+                    class_type = 'refer'
+                else:
+                    class_type = 'unpointable'
+    return informed_value, referred_slot, usr_utt_tok_label, class_type
+
+
+def tokenize(utt):
+    utt_lower = convert_to_unicode(utt).lower()
+    utt_lower = normalize_text(utt_lower)
+    utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
+    return utt_tok
+
+
+def create_examples(input_file, acts_file, set_type, slot_list,
+                    label_maps={},
+                    append_history=False,
+                    use_history_labels=False,
+                    swap_utterances=False,
+                    label_value_repetitions=False,
+                    delexicalize_sys_utts=False,
+                    analyze=False):
+    """Read a DST json file into a list of DSTExample."""
+
+    sys_inform_dict = load_acts(acts_file)
+
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+
+    global LABEL_MAPS
+    LABEL_MAPS = label_maps
+
+    examples = []
+    for dialog_id in input_data:
+        entry = input_data[dialog_id]
+        utterances = entry['log']
+
+        # Collects all slot changes throughout the dialog
+        cumulative_labels = {slot: 'none' for slot in slot_list}
+
+        # First system utterance is empty, since multiwoz starts with user input
+        utt_tok_list = [[]]
+        mod_slots_list = [{}]
+
+        # Collect all utterances and their metadata
+        usr_sys_switch = True
+        turn_itr = 0
+        for utt in utterances:
+            # Assert that system and user utterances alternate
+            is_sys_utt = utt['metadata'] != {}
+            if usr_sys_switch == is_sys_utt:
+                print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
+                break
+            usr_sys_switch = is_sys_utt
+
+            if is_sys_utt:
+                turn_itr += 1
+
+            # Delexicalize sys utterance
+            if delexicalize_sys_utts and is_sys_utt:
+                inform_dict = {slot: 'none' for slot in slot_list}
+                for slot in slot_list:
+                    if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
+                        inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]
+                utt_tok_list.append(delex_utt(utt['text'], inform_dict)) # normalize utterances
+            else:
+                utt_tok_list.append(tokenize(utt['text'])) # normalize utterances
+
+            modified_slots = {}
+
+            # If sys utt, extract metadata (identify and collect modified slots)
+            if is_sys_utt:
+                for d in utt['metadata']:
+                    booked = utt['metadata'][d]['book']['booked']
+                    booked_slots = {}
+                    # Check the booked section
+                    if booked != []:
+                        for s in booked[0]:
+                            booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s]) # normalize labels
+                    # Check the semi and the inform slots
+                    for category in ['book', 'semi']:
+                        for s in utt['metadata'][d][category]:
+                            cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s)
+                            value_label = normalize_label(cs, utt['metadata'][d][category][s]) # normalize labels
+                            # Prefer the slot value as stored in the booked section
+                            if s in booked_slots:
+                                value_label = booked_slots[s]
+                            # Remember modified slots and entire dialog state
+                            if cs in slot_list and cumulative_labels[cs] != value_label:
+                                modified_slots[cs] = value_label
+                                cumulative_labels[cs] = value_label
+
+            mod_slots_list.append(modified_slots.copy())
+
+        # Form proper (usr, sys) turns
+        turn_itr = 0
+        diag_seen_slots_dict = {}
+        diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
+        diag_state = {slot: 'none' for slot in slot_list}
+        sys_utt_tok = []
+        usr_utt_tok = []
+        hst_utt_tok = []
+        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
+        for i in range(1, len(utt_tok_list) - 1, 2):
+            sys_utt_tok_label_dict = {}
+            usr_utt_tok_label_dict = {}
+            value_dict = {}
+            inform_dict = {}
+            inform_slot_dict = {}
+            referral_dict = {}
+            class_type_dict = {}
+
+            # Collect turn data
+            if append_history:
+                if swap_utterances:
+                    hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
+                else:
+                    hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
+            sys_utt_tok = utt_tok_list[i - 1]
+            usr_utt_tok = utt_tok_list[i]
+            turn_slots = mod_slots_list[i + 1]
+
+            guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
+
+            if analyze:
+                print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
+                print("%15s %2s [" % (dialog_id, turn_itr), end='')
+
+            new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
+            new_diag_state = diag_state.copy()
+            for slot in slot_list:
+                value_label = 'none'
+                if slot in turn_slots:
+                    value_label = turn_slots[slot]
+                    # We keep the original labels so as to not
+                    # overlook unpointable values, as well as to not
+                    # modify any of the original labels for test sets,
+                    # since this would make comparison difficult.
+                    value_dict[slot] = value_label
+                elif label_value_repetitions and slot in diag_seen_slots_dict:
+                    value_label = diag_seen_slots_value_dict[slot]
+
+                # Get dialog act annotations
+                inform_label = list(['none'])
+                if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
+                    inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
+                elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
+                    inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
+
+                (informed_value,
+                 referred_slot,
+                 usr_utt_tok_label,
+                 class_type) = get_turn_label(value_label,
+                                              inform_label,
+                                              sys_utt_tok,
+                                              usr_utt_tok,
+                                              slot,
+                                              diag_seen_slots_value_dict,
+                                              slot_last_occurrence=True)
+
+                inform_dict[slot] = informed_value
+                if informed_value != 'none':
+                    inform_slot_dict[slot] = 1
+                else:
+                    inform_slot_dict[slot] = 0
+
+                # Generally don't use span prediction on sys utterance (but inform prediction instead).
+                sys_utt_tok_label = [0 for _ in sys_utt_tok]
+
+                # Determine what to do with value repetitions.
+                # If value is unique in seen slots, then tag it, otherwise not,
+                # since correct slot assignment can not be guaranteed anymore.
+                if label_value_repetitions and slot in diag_seen_slots_dict:
+                    if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
+                        class_type = 'none'
+                        usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
+
+                sys_utt_tok_label_dict[slot] = sys_utt_tok_label
+                usr_utt_tok_label_dict[slot] = usr_utt_tok_label
+
+                if append_history:
+                    if use_history_labels:
+                        if swap_utterances:
+                            new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                        else:
+                            new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                    else:
+                        new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
+                    
+                # For now, we map all occurences of unpointable slot values
+                # to none. However, since the labels will still suggest
+                # a presence of unpointable slot values, the task of the
+                # DST is still to find those values. It is just not
+                # possible to do that via span prediction on the current input.
+                if class_type == 'unpointable':
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                    if analyze:
+                        if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
+                            print("(%s): %s, " % (slot, value_label), end='')
+                elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
+                    # If slot has seen before and its class type did not change, label this slot a not present,
+                    # assuming that the slot has not actually been mentioned in this turn.
+                    # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
+                    # this must mean there is evidence in the original labels, therefore consider
+                    # them as mentioned again.
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                else:
+                    class_type_dict[slot] = class_type
+                    referral_dict[slot] = referred_slot
+                # Remember that this slot was mentioned during this dialog already.
+                if class_type != 'none':
+                    diag_seen_slots_dict[slot] = class_type
+                    diag_seen_slots_value_dict[slot] = value_label
+                    new_diag_state[slot] = class_type
+                    # Unpointable is not a valid class, therefore replace with
+                    # some valid class for now...
+                    if class_type == 'unpointable':
+                        new_diag_state[slot] = 'copy_value'
+
+            if analyze:
+                print("]")
+
+            if swap_utterances:
+                txt_a = usr_utt_tok
+                txt_b = sys_utt_tok
+                txt_a_lbl = usr_utt_tok_label_dict
+                txt_b_lbl = sys_utt_tok_label_dict
+            else:
+                txt_a = sys_utt_tok
+                txt_b = usr_utt_tok
+                txt_a_lbl = sys_utt_tok_label_dict
+                txt_b_lbl = usr_utt_tok_label_dict
+            examples.append(DSTExample(
+                guid=guid,
+                text_a=txt_a,
+                text_b=txt_b,
+                history=hst_utt_tok,
+                text_a_label=txt_a_lbl,
+                text_b_label=txt_b_lbl,
+                history_label=hst_utt_tok_label_dict,
+                values=diag_seen_slots_value_dict.copy(),
+                inform_label=inform_dict,
+                inform_slot_label=inform_slot_dict,
+                refer_label=referral_dict,
+                diag_state=diag_state,
+                class_label=class_type_dict))
+
+            # Update some variables.
+            hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
+            diag_state = new_diag_state.copy()
+
+            turn_itr += 1
+
+        if analyze:
+            print("----------------------------------------------------------------------")
+
+    return examples
diff --git a/dataset_sim.py b/dataset_sim.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30ed09cd9c289778dcd0cb78546780c7df28493
--- /dev/null
+++ b/dataset_sim.py
@@ -0,0 +1,230 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from utils_dst import (DSTExample)
+
+
+def dialogue_state_to_sv_dict(sv_list):
+    sv_dict = {}
+    for d in sv_list:
+        sv_dict[d['slot']] = d['value']
+    return sv_dict
+
+
+def get_token_and_slot_label(turn):
+    if 'system_utterance' in turn:
+        sys_utt_tok = turn['system_utterance']['tokens']
+        sys_slot_label = turn['system_utterance']['slots']
+    else:
+        sys_utt_tok = []
+        sys_slot_label = []
+
+    usr_utt_tok = turn['user_utterance']['tokens']
+    usr_slot_label = turn['user_utterance']['slots']
+    return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label
+
+
+def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok,
+                  sys_slot_label, usr_utt_tok, usr_slot_label, dial_id,
+                  turn_id, slot_last_occurrence=True):
+    """The position of the last occurrence of the slot value will be used."""
+    sys_utt_tok_label = [0 for _ in sys_utt_tok]
+    usr_utt_tok_label = [0 for _ in usr_utt_tok]
+    if slot_type not in cur_ds_dict:
+        class_type = 'none'
+    else:
+        value = cur_ds_dict[slot_type]
+        if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'):
+            # Only label dontcare at its first occurrence in the dialog
+            class_type = 'dontcare'
+        else: # If not none or dontcare, we have to identify whether
+            # there is a span, or if the value is informed
+            in_usr = False
+            in_sys = False
+            for label_d in usr_slot_label:
+                if label_d['slot'] == slot_type and value == ' '.join(
+                        usr_utt_tok[label_d['start']:label_d['exclusive_end']]):
+
+                    for idx in range(label_d['start'], label_d['exclusive_end']):
+                        usr_utt_tok_label[idx] = 1
+                    in_usr = True
+                    class_type = 'copy_value'
+                    if slot_last_occurrence:
+                        break
+            if not in_usr or not slot_last_occurrence:
+                for label_d in sys_slot_label:
+                    if label_d['slot'] == slot_type and value == ' '.join(
+                            sys_utt_tok[label_d['start']:label_d['exclusive_end']]):
+                        for idx in range(label_d['start'], label_d['exclusive_end']):
+                            sys_utt_tok_label[idx] = 1
+                        in_sys = True
+                        class_type = 'inform'
+                        if slot_last_occurrence:
+                            break
+            if not in_usr and not in_sys:
+                assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0
+                if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]):
+                    raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id)))
+                else:
+                    class_type = 'none'
+            else:
+                assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0
+    return sys_utt_tok_label, usr_utt_tok_label, class_type
+
+
+def delex_utt(utt, values):
+    utt_delex = utt.copy()
+    for v in values:
+        utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start'])
+    return utt_delex
+    
+
+def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
+                   delexicalize_sys_utts=False, slot_last_occurrence=True):
+    """Make turn_label a dictionary of slot with value positions or being dontcare / none:
+    Turn label contains:
+      (1) the updates from previous to current dialogue state,
+      (2) values in current dialogue state explicitly mentioned in system or user utterance."""
+    prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state)
+    cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state'])
+
+    (sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn)
+
+    sys_utt_tok_label_dict = {}
+    usr_utt_tok_label_dict = {}
+    inform_label_dict = {}
+    inform_slot_label_dict = {}
+    referral_label_dict = {}
+    class_type_dict = {}
+
+    for slot_type in slot_list:
+        inform_label_dict[slot_type] = 'none'
+        inform_slot_label_dict[slot_type] = 0
+        referral_label_dict[slot_type] = 'none' # Referral is not present in sim data
+        sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label(
+            prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label,
+            usr_utt_tok, usr_slot_label, dial_id, turn_id,
+            slot_last_occurrence=slot_last_occurrence)
+        if sum(sys_utt_tok_label) > 0:
+            inform_label_dict[slot_type] = cur_ds_dict[slot_type]
+            inform_slot_label_dict[slot_type] = 1
+        sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt
+        sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label
+        usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label
+        class_type_dict[slot_type] = class_type
+
+    if delexicalize_sys_utts:
+        sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label)
+
+    return (sys_utt_tok, sys_utt_tok_label_dict,
+            usr_utt_tok, usr_utt_tok_label_dict,
+            inform_label_dict, inform_slot_label_dict,
+            referral_label_dict, cur_ds_dict, class_type_dict)
+
+
+def create_examples(input_file, set_type, slot_list,
+                    label_maps={},
+                    append_history=False,
+                    use_history_labels=False,
+                    swap_utterances=False,
+                    label_value_repetitions=False,
+                    delexicalize_sys_utts=False,
+                    analyze=False):
+    """Read a DST json file into a list of DSTExample."""
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+
+    examples = []
+    for entry in input_data:
+        dial_id = entry['dialogue_id']
+        prev_ds = []
+        hst = []
+        prev_hst_lbl_dict = {slot: [] for slot in slot_list}
+        prev_ds_lbl_dict = {slot: 'none' for slot in slot_list}
+
+        for turn_id, turn in enumerate(entry['turns']):
+            guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id))
+            ds_lbl_dict = prev_ds_lbl_dict.copy()
+            hst_lbl_dict = prev_hst_lbl_dict.copy()
+            (text_a,
+             text_a_label,
+             text_b,
+             text_b_label,
+             inform_label,
+             inform_slot_label,
+             referral_label,
+             cur_ds_dict,
+             class_label) = get_turn_label(turn,
+                                           prev_ds,
+                                           slot_list,
+                                           dial_id,
+                                           turn_id,
+                                           delexicalize_sys_utts=delexicalize_sys_utts,
+                                           slot_last_occurrence=True)
+
+            if swap_utterances:
+                txt_a = text_b
+                txt_b = text_a
+                txt_a_lbl = text_b_label
+                txt_b_lbl = text_a_label
+            else:
+                txt_a = text_a
+                txt_b = text_b
+                txt_a_lbl = text_a_label
+                txt_b_lbl = text_b_label
+
+            value_dict = {}
+            for slot in slot_list:
+                if slot in cur_ds_dict:
+                    value_dict[slot] = cur_ds_dict[slot]
+                else:
+                    value_dict[slot] = 'none'
+                if class_label[slot] != 'none':
+                    ds_lbl_dict[slot] = class_label[slot]
+                if append_history:
+                    if use_history_labels:
+                        hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]
+                    else:
+                        hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]]
+
+            examples.append(DSTExample(
+                guid=guid,
+                text_a=txt_a,
+                text_b=txt_b,
+                history=hst,
+                text_a_label=txt_a_lbl,
+                text_b_label=txt_b_lbl,
+                history_label=prev_hst_lbl_dict,
+                values=value_dict,
+                inform_label=inform_label,
+                inform_slot_label=inform_slot_label,
+                refer_label=referral_label,
+                diag_state=prev_ds_lbl_dict,
+                class_label=class_label))
+
+            prev_ds = turn['dialogue_state']
+            prev_ds_lbl_dict = ds_lbl_dict.copy()
+            prev_hst_lbl_dict = hst_lbl_dict.copy()
+
+            if append_history:
+                hst = txt_a + txt_b + hst
+
+    return examples
diff --git a/dataset_woz2.py b/dataset_woz2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32af05afd55bc99c4f4c84adea711991c12c10e
--- /dev/null
+++ b/dataset_woz2.py
@@ -0,0 +1,278 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+
+from utils_dst import (DSTExample, convert_to_unicode)
+
+
+LABEL_MAPS = {} # Loaded from file
+LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'}
+
+
+def delex_utt(utt, values):
+    utt_norm = utt.copy()
+    for s, v in values.items():
+        if v != 'none':
+            v_norm = tokenize(v)
+            v_len = len(v_norm)
+            for i in range(len(utt_norm) + 1 - v_len):
+                if utt_norm[i:i + v_len] == v_norm:
+                    utt_norm[i:i + v_len] = ['[UNK]'] * v_len
+    return utt_norm
+
+
+def get_token_pos(tok_list, label):
+    find_pos = []
+    found = False
+    label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0]
+    len_label = len(label_list)
+    for i in range(len(tok_list) + 1 - len_label):
+        if tok_list[i:i + len_label] == label_list:
+            find_pos.append((i, i + len_label))  # start, exclusive_end
+            found = True
+    return found, find_pos
+
+
+def check_label_existence(label, usr_utt_tok, sys_utt_tok):
+    in_usr, usr_pos = get_token_pos(usr_utt_tok, label)
+    if not in_usr and label in LABEL_MAPS:
+        for tmp_label in LABEL_MAPS[label]:
+            in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label)
+            if in_usr:
+                break
+    in_sys, sys_pos = get_token_pos(sys_utt_tok, label)
+    if not in_sys and label in LABEL_MAPS:
+        for tmp_label in LABEL_MAPS[label]:
+            in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label)
+            if in_sys:
+                break
+    return in_usr, usr_pos, in_sys, sys_pos
+
+
+def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence):
+    usr_utt_tok_label = [0 for _ in usr_utt_tok]
+    if label == 'none' or label == 'dontcare':
+        class_type = label
+    else:
+        in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
+        if in_usr:
+            class_type = 'copy_value'
+            if slot_last_occurrence:
+                (s, e) = usr_pos[-1]
+                for i in range(s, e):
+                    usr_utt_tok_label[i] = 1
+            else:
+                for (s, e) in usr_pos:
+                    for i in range(s, e):
+                        usr_utt_tok_label[i] = 1
+        elif in_sys:
+            class_type = 'inform'
+        else:
+            class_type = 'unpointable'
+    return usr_utt_tok_label, class_type
+
+
+def tokenize(utt):
+    utt_lower = convert_to_unicode(utt).lower()
+    utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
+    return utt_tok
+
+
+def create_examples(input_file, set_type, slot_list,
+                    label_maps={},
+                    append_history=False,
+                    use_history_labels=False,
+                    swap_utterances=False,
+                    label_value_repetitions=False,
+                    delexicalize_sys_utts=False,
+                    analyze=False):
+    """Read a DST json file into a list of DSTExample."""
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+
+    global LABEL_MAPS
+    LABEL_MAPS = label_maps
+
+    examples = []
+    for entry in input_data:
+        diag_seen_slots_dict = {}
+        diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
+        diag_state = {slot: 'none' for slot in slot_list}
+        sys_utt_tok = []
+        sys_utt_tok_delex = []
+        usr_utt_tok = []
+        hst_utt_tok = []
+        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
+        for turn in entry['dialogue']:
+            sys_utt_tok_label_dict = {}
+            usr_utt_tok_label_dict = {}
+            inform_dict = {slot: 'none' for slot in slot_list}
+            inform_slot_dict = {slot: 0 for slot in slot_list}
+            referral_dict = {}
+            class_type_dict = {}
+
+            # Collect turn data
+            if append_history:
+                if swap_utterances:
+                    if delexicalize_sys_utts:
+                        hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok
+                    else:
+                        hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
+                else:
+                    if delexicalize_sys_utts:
+                        hst_utt_tok = sys_utt_tok_delex + usr_utt_tok + hst_utt_tok
+                    else:
+                        hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
+
+            sys_utt_tok = tokenize(turn['system_transcript'])
+            usr_utt_tok = tokenize(turn['transcript'])
+            turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']}
+
+            guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx']))
+
+            # Create delexicalized sys utterances.
+            if delexicalize_sys_utts:
+                delex_dict = {}
+                for slot in slot_list:
+                    delex_dict[slot] = 'none'
+                    label = 'none'
+                    if slot in turn_label:
+                        label = turn_label[slot]
+                    elif label_value_repetitions and slot in diag_seen_slots_dict:
+                        label = diag_seen_slots_value_dict[slot]
+                    if label != 'none' and label != 'dontcare':
+                        _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
+                        if in_sys:
+                            delex_dict[slot] = label
+                sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict)
+
+            new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
+            new_diag_state = diag_state.copy()
+            for slot in slot_list:
+                label = 'none'
+                if slot in turn_label:
+                    label = turn_label[slot]
+                elif label_value_repetitions and slot in diag_seen_slots_dict:
+                    label = diag_seen_slots_value_dict[slot]
+
+                (usr_utt_tok_label,
+                 class_type) = get_turn_label(label,
+                                              sys_utt_tok,
+                                              usr_utt_tok,
+                                              slot_last_occurrence=True)
+
+                if class_type == 'inform':
+                    inform_dict[slot] = label
+                    if label != 'none':
+                        inform_slot_dict[slot] = 1
+
+                referral_dict[slot] = 'none' # Referral is not present in woz2 data
+
+                # Generally don't use span prediction on sys utterance (but inform prediction instead).
+                if delexicalize_sys_utts:
+                    sys_utt_tok_label = [0 for _ in sys_utt_tok_delex]
+                else:
+                    sys_utt_tok_label = [0 for _ in sys_utt_tok]
+
+                # Determine what to do with value repetitions.
+                # If value is unique in seen slots, then tag it, otherwise not,
+                # since correct slot assignment can not be guaranteed anymore.
+                if label_value_repetitions and slot in diag_seen_slots_dict:
+                    if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1:
+                        class_type = 'none'
+                        usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
+
+                sys_utt_tok_label_dict[slot] = sys_utt_tok_label
+                usr_utt_tok_label_dict[slot] = usr_utt_tok_label
+
+                if append_history:
+                    if use_history_labels:
+                        if swap_utterances:
+                            new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                        else:
+                            new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                    else:
+                        new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
+
+                # For now, we map all occurences of unpointable slot values
+                # to none. However, since the labels will still suggest
+                # a presence of unpointable slot values, the task of the
+                # DST is still to find those values. It is just not
+                # possible to do that via span prediction on the current input.
+                if class_type == 'unpointable':
+                    class_type_dict[slot] = 'none'
+                elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
+                    # If slot has seen before and its class type did not change, label this slot a not present,
+                    # assuming that the slot has not actually been mentioned in this turn.
+                    # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
+                    # this must mean there is evidence in the original labels, therefore consider
+                    # them as mentioned again.
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                else:
+                    class_type_dict[slot] = class_type
+                # Remember that this slot was mentioned during this dialog already.
+                if class_type != 'none':
+                    diag_seen_slots_dict[slot] = class_type
+                    diag_seen_slots_value_dict[slot] = label
+                    new_diag_state[slot] = class_type
+                    # Unpointable is not a valid class, therefore replace with
+                    # some valid class for now...
+                    if class_type == 'unpointable':
+                        new_diag_state[slot] = 'copy_value'
+
+            if swap_utterances:
+                txt_a = usr_utt_tok
+                if delexicalize_sys_utts:
+                    txt_b = sys_utt_tok_delex
+                else:
+                    txt_b = sys_utt_tok
+                txt_a_lbl = usr_utt_tok_label_dict
+                txt_b_lbl = sys_utt_tok_label_dict
+            else:
+                if delexicalize_sys_utts:
+                    txt_a = sys_utt_tok_delex
+                else:
+                    txt_a = sys_utt_tok
+                txt_b = usr_utt_tok
+                txt_a_lbl = sys_utt_tok_label_dict
+                txt_b_lbl = usr_utt_tok_label_dict
+            examples.append(DSTExample(
+                guid=guid,
+                text_a=txt_a,
+                text_b=txt_b,
+                history=hst_utt_tok,
+                text_a_label=txt_a_lbl,
+                text_b_label=txt_b_lbl,
+                history_label=hst_utt_tok_label_dict,
+                values=diag_seen_slots_value_dict.copy(),
+                inform_label=inform_dict,
+                inform_slot_label=inform_slot_dict,
+                refer_label=referral_dict,
+                diag_state=diag_state,
+                class_label=class_type_dict))
+
+            # Update some variables.
+            hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
+            diag_state = new_diag_state.copy()
+            
+    return examples
+
diff --git a/metric_bert_dst.py b/metric_bert_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2df63af7146e0feef28b90ee9bc11000c8fa337
--- /dev/null
+++ b/metric_bert_dst.py
@@ -0,0 +1,370 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import glob
+import json
+import sys
+import numpy as np
+import re
+
+
+def load_dataset_config(dataset_config):
+    with open(dataset_config, "r", encoding='utf-8') as f:
+        raw_config = json.load(f)
+    return raw_config['class_types'], raw_config['slots'], raw_config['label_maps']
+
+
+def tokenize(text):
+    if "\u0120" in text:
+        text = re.sub(" ", "", text)
+        text = re.sub("\u0120", " ", text)
+        text = text.strip()
+    return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
+
+
+def is_in_list(tok, value):
+    found = False
+    tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
+    value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
+    tok_len = len(tok_list)
+    value_len = len(value_list)
+    for i in range(tok_len + 1 - value_len):
+        if tok_list[i:i + value_len] == value_list:
+            found = True
+            break
+    return found
+
+
+def check_slot_inform(value_label, inform_label, label_maps):
+    value = inform_label
+    if value_label == inform_label:
+        value = value_label
+    elif is_in_list(inform_label, value_label):
+        value = value_label
+    elif is_in_list(value_label, inform_label):
+        value = value_label
+    elif inform_label in label_maps:
+        for inform_label_variant in label_maps[inform_label]:
+            if value_label == inform_label_variant:
+                value = value_label
+                break
+            elif is_in_list(inform_label_variant, value_label):
+                value = value_label
+                break
+            elif is_in_list(value_label, inform_label_variant):
+                value = value_label
+                break
+    elif value_label in label_maps:
+        for value_label_variant in label_maps[value_label]:
+            if value_label_variant == inform_label:
+                value = value_label
+                break
+            elif is_in_list(inform_label, value_label_variant):
+                value = value_label
+                break
+            elif is_in_list(value_label_variant, inform_label):
+                value = value_label
+                break
+    return value
+
+
+def get_joint_slot_correctness(fp, class_types, label_maps,
+                               key_class_label_id='class_label_id',
+                               key_class_prediction='class_prediction',
+                               key_start_pos='start_pos',
+                               key_start_prediction='start_prediction',
+                               key_end_pos='end_pos',
+                               key_end_prediction='end_prediction',
+                               key_refer_id='refer_id',
+                               key_refer_prediction='refer_prediction',
+                               key_slot_groundtruth='slot_groundtruth',
+                               key_slot_prediction='slot_prediction'):
+    with open(fp) as f:
+        preds = json.load(f)
+        class_correctness = [[] for cl in range(len(class_types) + 1)]
+        confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
+        pos_correctness = []
+        refer_correctness = []
+        val_correctness = []
+        total_correctness = []
+        c_tp = {ct: 0 for ct in range(len(class_types))}
+        c_tn = {ct: 0 for ct in range(len(class_types))}
+        c_fp = {ct: 0 for ct in range(len(class_types))}
+        c_fn = {ct: 0 for ct in range(len(class_types))}
+
+        for pred in preds:
+            guid = pred['guid']  # List: set_type, dialogue_idx, turn_idx
+            turn_gt_class = pred[key_class_label_id]
+            turn_pd_class = pred[key_class_prediction]
+            gt_start_pos = pred[key_start_pos]
+            pd_start_pos = pred[key_start_prediction]
+            gt_end_pos = pred[key_end_pos]
+            pd_end_pos = pred[key_end_prediction]
+            gt_refer = pred[key_refer_id]
+            pd_refer = pred[key_refer_prediction]
+            gt_slot = pred[key_slot_groundtruth]
+            pd_slot = pred[key_slot_prediction]
+
+            gt_slot = tokenize(gt_slot)
+            pd_slot = tokenize(pd_slot)
+
+            # Make sure the true turn labels are contained in the prediction json file!
+            joint_gt_slot = gt_slot
+        
+            if guid[-1] == '0': # First turn, reset the slots
+                joint_pd_slot = 'none'
+
+            # If turn_pd_class or a value to be copied is "none", do not update the dialog state.
+            if turn_pd_class == class_types.index('none'):
+                pass
+            elif turn_pd_class == class_types.index('dontcare'):
+                joint_pd_slot = 'dontcare'
+            elif turn_pd_class == class_types.index('copy_value'):
+                joint_pd_slot = pd_slot
+            elif 'true' in class_types and turn_pd_class == class_types.index('true'):
+                joint_pd_slot = 'true'
+            elif 'false' in class_types and turn_pd_class == class_types.index('false'):
+                joint_pd_slot = 'false'
+            elif 'refer' in class_types and turn_pd_class == class_types.index('refer'):
+                if pd_slot[0:3] == "ยงยง ":
+                    if pd_slot[3:] != 'none':
+                        joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
+                elif pd_slot[0:2] == "ยงยง":
+                    if pd_slot[2:] != 'none':
+                        joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
+                elif pd_slot != 'none':
+                    joint_pd_slot = pd_slot
+            elif 'inform' in class_types and turn_pd_class == class_types.index('inform'):
+                if pd_slot[0:3] == "ยงยง ":
+                    if pd_slot[3:] != 'none':
+                        joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
+                elif pd_slot[0:2] == "ยงยง":
+                    if pd_slot[2:] != 'none':
+                        joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
+                else:
+                    print("ERROR: Unexpected slot value format. Aborting.")
+                    exit()
+            else:
+                print("ERROR: Unexpected class_type. Aborting.")
+                exit()
+
+            total_correct = True
+
+            # Check the per turn correctness of the class_type prediction
+            if turn_gt_class == turn_pd_class:
+                class_correctness[turn_gt_class].append(1.0)
+                class_correctness[-1].append(1.0)
+                c_tp[turn_gt_class] += 1
+                for cc in range(len(class_types)):
+                    if cc != turn_gt_class:
+                        c_tn[cc] += 1
+                # Only where there is a span, we check its per turn correctness
+                if turn_gt_class == class_types.index('copy_value'):
+                    if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos:
+                        pos_correctness.append(1.0)
+                    else:
+                        pos_correctness.append(0.0)
+                # Only where there is a referral, we check its per turn correctness
+                if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
+                    if gt_refer == pd_refer:
+                        refer_correctness.append(1.0)
+                        print("  [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer))
+                    else:
+                        refer_correctness.append(0.0)
+                        print("  [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer))
+            else:
+                if turn_gt_class == class_types.index('copy_value'):
+                    pos_correctness.append(0.0)
+                if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
+                    refer_correctness.append(0.0)
+                class_correctness[turn_gt_class].append(0.0)
+                class_correctness[-1].append(0.0)
+                confusion_matrix[turn_gt_class][turn_pd_class].append(1.0)
+                c_fn[turn_gt_class] += 1
+                c_fp[turn_pd_class] += 1
+
+            # Check the joint slot correctness.
+            # If the value label is not none, then we need to have a value prediction.
+            # Even if the class_type is 'none', there can still be a value label,
+            # it might just not be pointable in the current turn. It might however
+            # be referrable and thus predicted correctly.
+            if joint_gt_slot == joint_pd_slot:
+                val_correctness.append(1.0)
+            elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false' and joint_gt_slot in label_maps:
+                no_match = True
+                for variant in label_maps[joint_gt_slot]:
+                    if variant == joint_pd_slot:
+                        no_match = False
+                        break
+                if no_match:
+                    val_correctness.append(0.0)
+                    total_correct = False
+                    print("  [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
+                else:
+                    val_correctness.append(1.0)
+            else:
+                val_correctness.append(0.0)
+                total_correct = False
+                print("  [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
+
+            total_correctness.append(1.0 if total_correct else 0.0)
+
+        # Account for empty lists (due to no instances of spans or referrals being seen)
+        if pos_correctness == []:
+            pos_correctness.append(1.0)
+        if refer_correctness == []:
+            refer_correctness.append(1.0)
+
+        for ct in range(len(class_types)):
+            if c_tp[ct] + c_fp[ct] > 0:
+                precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
+            else:
+                precision = 1.0
+            if c_tp[ct] + c_fn[ct] > 0:
+                recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
+            else:
+                recall = 1.0
+            if precision + recall > 0:
+                f1 = 2 * ((precision * recall) / (precision + recall))
+            else:
+                f1 = 1.0
+            if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
+                acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
+            else:
+                acc = 1.0
+            print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
+                  (class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
+        
+        print("Confusion matrix:")
+        for cl in range(len(class_types)):
+            print("    %s" % (cl), end="")
+        print("")
+        for cl_a in range(len(class_types)):
+            print("%s " % (cl_a), end="")
+            for cl_b in range(len(class_types)):
+                if len(class_correctness[cl_a]) > 0:
+                    print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="")
+                else:
+                    print("---- ", end="")
+            print("")
+
+        return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn
+
+
+if __name__ == "__main__":
+    acc_list = []
+    acc_list_v = []
+    key_class_label_id = 'class_label_id_%s'
+    key_class_prediction = 'class_prediction_%s'
+    key_start_pos = 'start_pos_%s'
+    key_start_prediction = 'start_prediction_%s'
+    key_end_pos = 'end_pos_%s'
+    key_end_prediction = 'end_prediction_%s'
+    key_refer_id = 'refer_id_%s'
+    key_refer_prediction = 'refer_prediction_%s'
+    key_slot_groundtruth = 'slot_groundtruth_%s'
+    key_slot_prediction = 'slot_prediction_%s'
+
+    dataset = sys.argv[1].lower()
+    dataset_config = sys.argv[2].lower()
+
+    if dataset not in ['woz2', 'sim-m', 'sim-r', 'multiwoz21']:
+        raise ValueError("Task not found: %s" % (dataset))
+
+    class_types, slots, label_maps = load_dataset_config(dataset_config)
+
+    # Prepare label_maps
+    label_maps_tmp = {}
+    for v in label_maps:
+        label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]]
+    label_maps = label_maps_tmp
+
+    for fp in sorted(glob.glob(sys.argv[3])):
+        print(fp)
+        goal_correctness = 1.0
+        cls_acc = [[] for cl in range(len(class_types))]
+        cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
+        c_tp = {ct: 0 for ct in range(len(class_types))}
+        c_tn = {ct: 0 for ct in range(len(class_types))}
+        c_fp = {ct: 0 for ct in range(len(class_types))}
+        c_fn = {ct: 0 for ct in range(len(class_types))}
+        for slot in slots:
+            tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, conf_mat, ctp, ctn, cfp, cfn = get_joint_slot_correctness(fp, class_types, label_maps,
+                                                             key_class_label_id=(key_class_label_id % slot),
+                                                             key_class_prediction=(key_class_prediction % slot),
+                                                             key_start_pos=(key_start_pos % slot),
+                                                             key_start_prediction=(key_start_prediction % slot),
+                                                             key_end_pos=(key_end_pos % slot),
+                                                             key_end_prediction=(key_end_prediction % slot),
+                                                             key_refer_id=(key_refer_id % slot),
+                                                             key_refer_prediction=(key_refer_prediction % slot),
+                                                             key_slot_groundtruth=(key_slot_groundtruth % slot),
+                                                             key_slot_prediction=(key_slot_prediction % slot)
+                                                             )
+            print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' %
+                  (slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor)))
+            goal_correctness *= tot_cor
+            for cl_a in range(len(class_types)):
+                cls_acc[cl_a] += cls_cor[cl_a]
+                for cl_b in range(len(class_types)):
+                    cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b])
+                c_tp[cl_a] += ctp[cl_a]
+                c_tn[cl_a] += ctn[cl_a]
+                c_fp[cl_a] += cfp[cl_a]
+                c_fn[cl_a] += cfn[cl_a]
+
+        for ct in range(len(class_types)):
+            if c_tp[ct] + c_fp[ct] > 0:
+                precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
+            else:
+                precision = 1.0
+            if c_tp[ct] + c_fn[ct] > 0:
+                recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
+            else:
+                recall = 1.0
+            if precision + recall > 0:
+                f1 = 2 * ((precision * recall) / (precision + recall))
+            else:
+                f1 = 1.0
+            if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
+                acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
+            else:
+                acc = 1.0
+            print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
+                  (class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
+
+        print("Confusion matrix:")
+        for cl in range(len(class_types)):
+            print("    %s" % (cl), end="")
+        print("")
+        for cl_a in range(len(class_types)):
+            print("%s " % (cl_a), end="")
+            for cl_b in range(len(class_types)):
+                if len(cls_acc[cl_a]) > 0:
+                    print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="")
+                else:
+                    print("---- ", end="")
+            print("")
+
+        acc = np.mean(goal_correctness)
+        acc_list.append((fp, acc))
+
+    acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True)
+    for (fp, acc) in acc_list_s:
+        print('Joint goal acc: %g, %s' % (acc, fp))
diff --git a/modeling_bert_dst.py b/modeling_bert_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc7206c3f0380417292289f2fe1909777c6cb418
--- /dev/null
+++ b/modeling_bert_dst.py
@@ -0,0 +1,174 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
+from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
+
+
+@add_start_docstrings(
+    """BERT Model with a classification heads for the DST task. """,
+    BERT_START_DOCSTRING,
+)
+class BertForDST(BertPreTrainedModel):
+    def __init__(self, config):
+        super(BertForDST, self).__init__(config)
+        self.slot_list = config.dst_slot_list
+        self.class_types = config.dst_class_types
+        self.class_labels = config.dst_class_labels
+        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
+        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
+        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
+        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
+        self.class_loss_ratio = config.dst_class_loss_ratio
+
+        # Only use refer loss if refer class is present in dataset.
+        if 'refer' in self.class_types:
+            self.refer_index = self.class_types.index('refer')
+        else:
+            self.refer_index = -1
+
+        self.bert = BertModel(config)
+        self.dropout = nn.Dropout(config.dst_dropout_rate)
+        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
+
+        if self.class_aux_feats_inform:
+            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+        if self.class_aux_feats_ds:
+            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+
+        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
+
+        for slot in self.slot_list:
+            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
+            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
+            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
+
+        self.init_weights()
+
+    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
+    def forward(self,
+                input_ids,
+                input_mask=None,
+                segment_ids=None,
+                position_ids=None,
+                head_mask=None,
+                start_pos=None,
+                end_pos=None,
+                inform_slot_id=None,
+                refer_id=None,
+                class_label_id=None,
+                diag_state=None):
+        outputs = self.bert(
+            input_ids,
+            attention_mask=input_mask,
+            token_type_ids=segment_ids,
+            position_ids=position_ids,
+            head_mask=head_mask
+        )
+
+        sequence_output = outputs[0]
+        pooled_output = outputs[1]
+
+        sequence_output = self.dropout(sequence_output)
+        pooled_output = self.dropout(pooled_output)
+
+        # TODO: establish proper format in labels already?
+        if inform_slot_id is not None:
+            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
+        if diag_state is not None:
+            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
+        
+        total_loss = 0
+        per_slot_per_example_loss = {}
+        per_slot_class_logits = {}
+        per_slot_start_logits = {}
+        per_slot_end_logits = {}
+        per_slot_refer_logits = {}
+        for slot in self.slot_list:
+            if self.class_aux_feats_inform and self.class_aux_feats_ds:
+                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
+            elif self.class_aux_feats_inform:
+                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
+            elif self.class_aux_feats_ds:
+                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
+            else:
+                pooled_output_aux = pooled_output
+            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
+
+            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
+            start_logits, end_logits = token_logits.split(1, dim=-1)
+            start_logits = start_logits.squeeze(-1)
+            end_logits = end_logits.squeeze(-1)
+
+            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
+
+            per_slot_class_logits[slot] = class_logits
+            per_slot_start_logits[slot] = start_logits
+            per_slot_end_logits[slot] = end_logits
+            per_slot_refer_logits[slot] = refer_logits
+            
+            # If there are no labels, don't compute loss
+            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
+                # If we are on multi-GPU, split add a dimension
+                if len(start_pos[slot].size()) > 1:
+                    start_pos[slot] = start_pos[slot].squeeze(-1)
+                if len(end_pos[slot].size()) > 1:
+                    end_pos[slot] = end_pos[slot].squeeze(-1)
+                # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                ignored_index = start_logits.size(1) # This is a single index
+                start_pos[slot].clamp_(0, ignored_index)
+                end_pos[slot].clamp_(0, ignored_index)
+
+                class_loss_fct = CrossEntropyLoss(reduction='none')
+                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
+                refer_loss_fct = CrossEntropyLoss(reduction='none')
+
+                start_loss = token_loss_fct(start_logits, start_pos[slot])
+                end_loss = token_loss_fct(end_logits, end_pos[slot])
+                token_loss = (start_loss + end_loss) / 2.0
+
+                token_is_pointable = (start_pos[slot] > 0).float()
+                if not self.token_loss_for_nonpointable:
+                    token_loss *= token_is_pointable
+
+                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
+                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
+                if not self.refer_loss_for_nonpointable:
+                    refer_loss *= token_is_referrable
+
+                class_loss = class_loss_fct(class_logits, class_label_id[slot])
+
+                if self.refer_index > -1:
+                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
+                else:
+                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
+
+                total_loss += per_example_loss.sum()
+                per_slot_per_example_loss[slot] = per_example_loss
+
+        # add hidden states and attention if they are here
+        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
+
+        return outputs
diff --git a/run_dst.py b/run_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde263007dfaa6bf071a3c9547f7732f5c38183b
--- /dev/null
+++ b/run_dst.py
@@ -0,0 +1,764 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+import random
+import glob
+import json
+import math
+import re
+
+import numpy as np
+import torch
+from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
+from torch.utils.data.distributed import DistributedSampler
+from tqdm import tqdm, trange
+
+from tensorboardX import SummaryWriter
+
+from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer)
+from transformers import (AdamW, get_linear_schedule_with_warmup)
+
+from modeling_bert_dst import (BertForDST)
+from data_processors import PROCESSORS
+from utils_dst import (convert_examples_to_features)
+from tensorlistdataset import (TensorListDataset)
+
+logger = logging.getLogger(__name__)
+
+ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys())
+
+MODEL_CLASSES = {
+    'bert': (BertConfig, BertForDST, BertTokenizer),
+}
+
+
+def set_seed(args):
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+        
+def to_list(tensor):
+    return tensor.detach().cpu().tolist()
+
+
+def batch_to_device(batch, device):
+    batch_on_device = []
+    for element in batch:
+        if isinstance(element, dict):
+            batch_on_device.append({k: v.to(device) for k, v in element.items()})
+        else:
+            batch_on_device.append(element.to(device))
+    return tuple(batch_on_device)
+
+
+def train(args, train_dataset, features, model, tokenizer, processor, continue_from_global_step=0):
+    """ Train the model """
+    if args.local_rank in [-1, 0]:
+        tb_writer = SummaryWriter()
+
+    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
+    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
+    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
+
+    if args.max_steps > 0:
+        t_total = args.max_steps
+        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
+    else:
+        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
+
+    if args.save_epochs > 0:
+        args.save_steps = t_total // args.num_train_epochs * args.save_epochs
+
+    num_warmup_steps = int(t_total * args.warmup_proportion)
+
+    # Prepare optimizer and schedule (linear warmup and decay)
+    no_decay = ['bias', 'LayerNorm.weight']
+    optimizer_grouped_parameters = [
+        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
+        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
+    ]
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
+    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
+    if args.fp16:
+        try:
+            from apex import amp
+        except ImportError:
+            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
+        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
+
+    # multi-gpu training (should be after apex fp16 initialization)
+    model_single_gpu = model
+    if args.n_gpu > 1:
+        model = torch.nn.DataParallel(model_single_gpu)
+
+    # Distributed training (should be after apex fp16 initialization)
+    if args.local_rank != -1:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
+                                                          output_device=args.local_rank,
+                                                          find_unused_parameters=True)
+
+    # Train!
+    logger.info("***** Running training *****")
+    logger.info("  Num examples = %d", len(train_dataset))
+    logger.info("  Num Epochs = %d", args.num_train_epochs)
+    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
+    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
+                args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
+    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
+    logger.info("  Total optimization steps = %d", t_total)
+    logger.info("  Warmup steps = %d", num_warmup_steps)
+
+    if continue_from_global_step > 0:
+        logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step)
+    
+    global_step = 0
+    tr_loss, logging_loss = 0.0, 0.0
+    model.zero_grad()
+    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
+    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
+
+    for _ in train_iterator:
+        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
+
+        for step, batch in enumerate(epoch_iterator):
+            # If training is continued from a checkpoint, fast forward
+            # to the state of that checkpoint.
+            if global_step < continue_from_global_step:
+                if (step + 1) % args.gradient_accumulation_steps == 0:
+                    scheduler.step()  # Update learning rate schedule
+                    global_step += 1
+                continue
+
+            model.train()
+            batch = batch_to_device(batch, args.device)
+
+            # This is what is forwarded to the "forward" def.
+            inputs = {'input_ids':       batch[0],
+                      'input_mask':      batch[1], 
+                      'segment_ids':     batch[2],
+                      'start_pos':       batch[3],
+                      'end_pos':         batch[4],
+                      'inform_slot_id':  batch[5],
+                      'refer_id':        batch[6],
+                      'diag_state':      batch[7],
+                      'class_label_id':  batch[8]}
+            outputs = model(**inputs)
+            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
+
+            if args.n_gpu > 1:
+                loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
+            if args.gradient_accumulation_steps > 1:
+                loss = loss / args.gradient_accumulation_steps
+
+            if args.fp16:
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
+            else:
+                loss.backward()
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+            tr_loss += loss.item()
+            if (step + 1) % args.gradient_accumulation_steps == 0:
+                optimizer.step()
+                scheduler.step()  # Update learning rate schedule
+                model.zero_grad()
+                global_step += 1
+
+                # Log metrics
+                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
+                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
+                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
+                    logging_loss = tr_loss
+
+                # Save model checkpoint
+                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
+                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
+                    if not os.path.exists(output_dir):
+                        os.makedirs(output_dir)
+                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
+                    model_to_save.save_pretrained(output_dir)
+                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
+                    logger.info("Saving model checkpoint to %s", output_dir)
+
+            if args.max_steps > 0 and global_step > args.max_steps:
+                epoch_iterator.close()
+                break
+
+        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
+            results = evaluate(args, model_single_gpu, tokenizer, processor, prefix=global_step)
+            for key, value in results.items():
+                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
+
+        if args.max_steps > 0 and global_step > args.max_steps:
+            train_iterator.close()
+            break
+
+    if args.local_rank in [-1, 0]:
+        tb_writer.close()
+
+    return global_step, tr_loss / global_step
+
+
+def evaluate(args, model, tokenizer, processor, prefix=""):
+    dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True)
+
+    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
+        os.makedirs(args.output_dir)
+
+    args.eval_batch_size = args.per_gpu_eval_batch_size
+    eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly
+    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
+
+    # Eval!
+    logger.info("***** Running evaluation {} *****".format(prefix))
+    logger.info("  Num examples = %d", len(dataset))
+    logger.info("  Batch size = %d", args.eval_batch_size)
+    all_results = []
+    all_preds = []
+    ds = {slot: 'none' for slot in model.slot_list}
+    with torch.no_grad():
+        diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list}
+    for batch in tqdm(eval_dataloader, desc="Evaluating"):
+        model.eval()
+        batch = batch_to_device(batch, args.device)
+
+        # Reset dialog state if turn is first in the dialog.
+        turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
+        reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
+        for slot in model.slot_list:
+            for i in reset_diag_state:
+                diag_state[slot][i] = 0
+
+        with torch.no_grad():
+            inputs = {'input_ids':       batch[0],
+                      'input_mask':      batch[1],
+                      'segment_ids':     batch[2],
+                      'start_pos':       batch[3],
+                      'end_pos':         batch[4],
+                      'inform_slot_id':  batch[5],
+                      'refer_id':        batch[6],
+                      'diag_state':      diag_state,
+                      'class_label_id':  batch[8]}
+            unique_ids = [features[i.item()].guid for i in batch[9]]
+            values = [features[i.item()].values for i in batch[9]]
+            input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]]
+            inform = [features[i.item()].inform for i in batch[9]]
+            outputs = model(**inputs)
+
+            # Update dialog state for next turn.
+            for slot in model.slot_list:
+                updates = outputs[2][slot].max(1)[1]
+                for i, u in enumerate(updates):
+                    if u != 0:
+                        diag_state[slot][i] = u
+
+        results = eval_metric(model, inputs, outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5])
+        preds, ds = predict_and_format(model, tokenizer, inputs, outputs[2], outputs[3], outputs[4], outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds)
+        all_results.append(results)
+        all_preds.append(preds)
+
+    all_preds = [item for sublist in all_preds for item in sublist] # Flatten list
+
+    # Generate final results
+    final_results = {}
+    for k in all_results[0].keys():
+        final_results[k] = torch.stack([r[k] for r in all_results]).mean()
+
+    # Write final predictions (for evaluation with external tool)
+    output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix))
+    with open(output_prediction_file, "w") as f:
+        json.dump(all_preds, f, indent=2)
+
+    return final_results
+
+
+def eval_metric(model, features, total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits):
+    metric_dict = {}
+    per_slot_correctness = {}
+    for slot in model.slot_list:
+        per_example_loss = per_slot_per_example_loss[slot]
+        class_logits = per_slot_class_logits[slot]
+        start_logits = per_slot_start_logits[slot]
+        end_logits = per_slot_end_logits[slot]
+        refer_logits = per_slot_refer_logits[slot]
+
+        class_label_id = features['class_label_id'][slot]
+        start_pos = features['start_pos'][slot]
+        end_pos = features['end_pos'][slot]
+        refer_id = features['refer_id'][slot]
+
+        _, class_prediction = class_logits.max(1)
+        class_correctness = torch.eq(class_prediction, class_label_id).float()
+        class_accuracy = class_correctness.mean()
+
+        # "is pointable" means whether class label is "copy_value",
+        # i.e., that there is a span to be detected.
+        token_is_pointable = torch.eq(class_label_id, model.class_types.index('copy_value')).float()
+        _, start_prediction = start_logits.max(1)
+        start_correctness = torch.eq(start_prediction, start_pos).float()
+        _, end_prediction = end_logits.max(1)
+        end_correctness = torch.eq(end_prediction, end_pos).float()
+        token_correctness = start_correctness * end_correctness
+        token_accuracy = (token_correctness * token_is_pointable).sum() / token_is_pointable.sum()
+        # NaNs mean that none of the examples in this batch contain spans. -> division by 0
+        # The accuracy therefore is 1 by default. -> replace NaNs
+        if math.isnan(token_accuracy):
+            token_accuracy = torch.tensor(1.0, device=token_accuracy.device)
+
+        token_is_referrable = torch.eq(class_label_id, model.class_types.index('refer') if 'refer' in model.class_types else -1).float()
+        _, refer_prediction = refer_logits.max(1)
+        refer_correctness = torch.eq(refer_prediction, refer_id).float()
+        refer_accuracy = refer_correctness.sum() / token_is_referrable.sum()
+        # NaNs mean that none of the examples in this batch contain referrals. -> division by 0
+        # The accuracy therefore is 1 by default. -> replace NaNs
+        if math.isnan(refer_accuracy) or math.isinf(refer_accuracy):
+            refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device)
+            
+        total_correctness = class_correctness * (token_is_pointable * token_correctness + (1 - token_is_pointable)) * (token_is_referrable * refer_correctness + (1 - token_is_referrable))
+        total_accuracy = total_correctness.mean()
+
+        loss = per_example_loss.mean()
+        metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy
+        metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy
+        metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy
+        metric_dict['eval_accuracy_%s' % slot] = total_accuracy
+        metric_dict['eval_loss_%s' % slot] = loss
+        per_slot_correctness[slot] = total_correctness
+
+    goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1)
+    goal_accuracy = goal_correctness.mean()
+    metric_dict['eval_accuracy_goal'] = goal_accuracy
+    metric_dict['loss'] = total_loss
+    return metric_dict
+
+
+def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, ids, input_ids_unmasked, values, inform, prefix, ds):
+    prediction_list = []
+    dialog_state = ds
+    for i in range(len(ids)):
+        if int(ids[i].split("-")[2]) == 0:
+            dialog_state = {slot: 'none' for slot in model.slot_list}
+
+        prediction = {}
+        prediction_addendum = {}
+        for slot in model.slot_list:
+            class_logits = per_slot_class_logits[slot][i]
+            start_logits = per_slot_start_logits[slot][i]
+            end_logits = per_slot_end_logits[slot][i]
+            refer_logits = per_slot_refer_logits[slot][i]
+
+            input_ids = features['input_ids'][i].tolist()
+            class_label_id = int(features['class_label_id'][slot][i])
+            start_pos = int(features['start_pos'][slot][i])
+            end_pos = int(features['end_pos'][slot][i])
+            refer_id = int(features['refer_id'][slot][i])
+            
+            class_prediction = int(class_logits.argmax())
+            start_prediction = int(start_logits.argmax())
+            end_prediction = int(end_logits.argmax())
+            refer_prediction = int(refer_logits.argmax())
+
+            prediction['guid'] = ids[i].split("-")
+            prediction['class_prediction_%s' % slot] = class_prediction
+            prediction['class_label_id_%s' % slot] = class_label_id
+            prediction['start_prediction_%s' % slot] = start_prediction
+            prediction['start_pos_%s' % slot] = start_pos
+            prediction['end_prediction_%s' % slot] = end_prediction
+            prediction['end_pos_%s' % slot] = end_pos
+            prediction['refer_prediction_%s' % slot] = refer_prediction
+            prediction['refer_id_%s' % slot] = refer_id
+            prediction['input_ids_%s' % slot] = input_ids
+
+            if class_prediction == model.class_types.index('dontcare'):
+                dialog_state[slot] = 'dontcare'
+            elif class_prediction == model.class_types.index('copy_value'):
+                input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i])
+                dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
+                dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot])
+            elif 'true' in model.class_types and class_prediction == model.class_types.index('true'):
+                dialog_state[slot] = 'true'
+            elif 'false' in model.class_types and class_prediction == model.class_types.index('false'):
+                dialog_state[slot] = 'false'
+            elif class_prediction == model.class_types.index('inform'):
+                dialog_state[slot] = 'ยงยง' + inform[i][slot]
+            # Referral case is handled below
+
+            prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot]
+            prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot]
+
+        # Referral case. All other slot values need to be seen first in order
+        # to be able to do this correctly.
+        for slot in model.slot_list:
+            class_logits = per_slot_class_logits[slot][i]
+            refer_logits = per_slot_refer_logits[slot][i]
+            
+            class_prediction = int(class_logits.argmax())
+            refer_prediction = int(refer_logits.argmax())
+
+            if 'refer' in model.class_types and class_prediction == model.class_types.index('refer'):
+                # Only slots that have been mentioned before can be referred to.
+                # One can think of a situation where one slot is referred to in the same utterance.
+                # This phenomenon is however currently not properly covered in the training data
+                # label generation process.
+                dialog_state[slot] = dialog_state[model.slot_list[refer_prediction - 1]]
+                prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update
+
+        prediction.update(prediction_addendum)
+        prediction_list.append(prediction)
+        
+    return prediction_list, dialog_state
+
+
+def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
+    if args.local_rank not in [-1, 0] and not evaluate:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+    # Load data features from cache or dataset file
+    cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format(
+        args.predict_type if evaluate else 'train'))
+    if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples:
+        logger.info("Loading features from cached file %s", cached_file)
+        features = torch.load(cached_file)
+    else:
+        logger.info("Creating features from dataset file at %s", args.data_dir)
+        processor_args = {'append_history': args.append_history,
+                          'use_history_labels': args.use_history_labels,
+                          'swap_utterances': args.swap_utterances,
+                          'label_value_repetitions': args.label_value_repetitions,
+                          'delexicalize_sys_utts': args.delexicalize_sys_utts}
+        if evaluate and args.predict_type == "dev":
+            examples = processor.get_dev_examples(args.data_dir, processor_args)
+        elif evaluate and args.predict_type == "test":
+            examples = processor.get_test_examples(args.data_dir, processor_args)
+        else:
+            examples = processor.get_train_examples(args.data_dir, processor_args)
+        features = convert_examples_to_features(examples=examples,
+                                                slot_list=model.slot_list,
+                                                class_types=model.class_types,
+                                                model_type=args.model_type,
+                                                tokenizer=tokenizer,
+                                                max_seq_length=args.max_seq_length,
+                                                slot_value_dropout=(0.0 if evaluate else args.svd))
+        if args.local_rank in [-1, 0]:
+            logger.info("Saving features into cached file %s", cached_file)
+            torch.save(features, cached_file)
+
+    if args.local_rank == 0 and not evaluate:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+    # Convert to Tensors and build dataset
+    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
+    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
+    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
+    f_start_pos = [f.start_pos for f in features]
+    f_end_pos = [f.end_pos for f in features]
+    f_inform_slot_ids = [f.inform_slot for f in features]
+    f_refer_ids = [f.refer_id for f in features]
+    f_diag_state = [f.diag_state for f in features]
+    f_class_label_ids = [f.class_label_id for f in features]
+    all_start_positions = {}
+    all_end_positions = {}
+    all_inform_slot_ids = {}
+    all_refer_ids = {}
+    all_diag_state = {}
+    all_class_label_ids = {}
+    for s in model.slot_list:
+        all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long)
+        all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long)
+        all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long)
+        all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long)
+        all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long)
+        all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long)
+    dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids,
+                                all_start_positions, all_end_positions,
+                                all_inform_slot_ids,
+                                all_refer_ids,
+                                all_diag_state,
+                                all_class_label_ids, all_example_index)
+
+    return dataset, features
+
+
+def main():
+    parser = argparse.ArgumentParser()
+
+    # Required parameters
+    parser.add_argument("--task_name", default=None, type=str, required=True,
+                        help="Name of the task (e.g., multiwoz21).")
+    parser.add_argument("--data_dir", default=None, type=str, required=True,
+                        help="Task database.")
+    parser.add_argument("--dataset_config", default=None, type=str, required=True,
+                        help="Dataset configuration file.")
+    parser.add_argument("--predict_type", default=None, type=str, required=True,
+                        help="Portion of the data to perform prediction on (e.g., dev, test).")
+    parser.add_argument("--model_type", default=None, type=str, required=True,
+                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
+    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
+                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
+    parser.add_argument("--output_dir", default=None, type=str, required=True,
+                        help="The output directory where the model checkpoints and predictions will be written.")
+
+    # Other parameters
+    parser.add_argument("--config_name", default="", type=str,
+                        help="Pretrained config name or path if not the same as model_name")
+    parser.add_argument("--tokenizer_name", default="", type=str,
+                        help="Pretrained tokenizer name or path if not the same as model_name")
+
+    parser.add_argument("--max_seq_length", default=384, type=int,
+                        help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.")
+    parser.add_argument("--do_train", action='store_true',
+                        help="Whether to run training.")
+    parser.add_argument("--do_eval", action='store_true',
+                        help="Whether to run eval on the <predict_type> set.")
+    parser.add_argument("--evaluate_during_training", action='store_true',
+                        help="Rul evaluation during training at each logging step.")
+    parser.add_argument("--do_lower_case", action='store_true',
+                        help="Set this flag if you are using an uncased model.")
+
+    parser.add_argument("--dropout_rate", default=0.3, type=float,
+                        help="Dropout rate for BERT representations.")
+    parser.add_argument("--heads_dropout", default=0.0, type=float,
+                        help="Dropout rate for classification heads.")
+    parser.add_argument("--class_loss_ratio", default=0.8, type=float,
+                        help="The ratio applied on class loss in total loss calculation. "
+                             "Should be a value in [0.0, 1.0]. "
+                             "The ratio applied on token loss is (1-class_loss_ratio)/2. "
+                             "The ratio applied on refer loss is (1-class_loss_ratio)/2.")
+    parser.add_argument("--token_loss_for_nonpointable", action='store_true',
+                        help="Whether the token loss for classes other than copy_value contribute towards total loss.")
+    parser.add_argument("--refer_loss_for_nonpointable", action='store_true',
+                        help="Whether the refer loss for classes other than refer contribute towards total loss.")
+
+    parser.add_argument("--append_history", action='store_true',
+                        help="Whether or not to append the dialog history to each turn.")
+    parser.add_argument("--use_history_labels", action='store_true',
+                        help="Whether or not to label the history as well.")
+    parser.add_argument("--swap_utterances", action='store_true',
+                        help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).")
+    parser.add_argument("--label_value_repetitions", action='store_true',
+                        help="Whether or not to label values that have been mentioned before.")
+    parser.add_argument("--delexicalize_sys_utts", action='store_true',
+                        help="Whether or not to delexicalize the system utterances.")
+    parser.add_argument("--class_aux_feats_inform", action='store_true',
+                        help="Whether or not to use the identity of informed slots as auxiliary featurs for class prediction.")
+    parser.add_argument("--class_aux_feats_ds", action='store_true',
+                        help="Whether or not to use the identity of slots in the current dialog state as auxiliary featurs for class prediction.")
+
+    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
+                        help="Batch size per GPU/CPU for training.")
+    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
+                        help="Batch size per GPU/CPU for evaluation.")
+    parser.add_argument("--learning_rate", default=5e-5, type=float,
+                        help="The initial learning rate for Adam.")
+    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
+                        help="Number of updates steps to accumulate before performing a backward/update pass.")
+    parser.add_argument("--weight_decay", default=0.0, type=float,
+                        help="Weight deay if we apply some.")
+    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
+                        help="Epsilon for Adam optimizer.")
+    parser.add_argument("--max_grad_norm", default=1.0, type=float,
+                        help="Max gradient norm.")
+    parser.add_argument("--num_train_epochs", default=3.0, type=float,
+                        help="Total number of training epochs to perform.")
+    parser.add_argument("--max_steps", default=-1, type=int,
+                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
+    parser.add_argument("--warmup_proportion", default=0.0, type=float,
+                        help="Linear warmup over warmup_proportion * steps.")
+    parser.add_argument("--svd", default=0.0, type=float,
+                        help="Slot value dropout ratio (default: 0.0)")
+
+    parser.add_argument('--logging_steps', type=int, default=50,
+                        help="Log every X updates steps.")
+    parser.add_argument('--save_steps', type=int, default=0,
+                        help="Save checkpoint every X updates steps. Overwritten by --save_epochs.")
+    parser.add_argument('--save_epochs', type=int, default=0,
+                        help="Save checkpoint every X epochs. Overrides --save_steps.")
+    parser.add_argument("--eval_all_checkpoints", action='store_true',
+                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
+    parser.add_argument("--no_cuda", action='store_true',
+                        help="Whether not to use CUDA when available")
+    parser.add_argument('--overwrite_output_dir', action='store_true',
+                        help="Overwrite the content of the output directory")
+    parser.add_argument('--overwrite_cache', action='store_true',
+                        help="Overwrite the cached training and evaluation sets")
+    parser.add_argument('--seed', type=int, default=42,
+                        help="random seed for initialization")
+
+    parser.add_argument("--local_rank", type=int, default=-1,
+                        help="local_rank for distributed training on gpus")
+    parser.add_argument('--fp16', action='store_true',
+                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
+    parser.add_argument('--fp16_opt_level', type=str, default='O1',
+                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+                             "See details at https://nvidia.github.io/apex/amp.html")
+
+    args = parser.parse_args()
+
+    assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0)
+    assert(args.svd >= 0.0 and args.svd <= 1.0)
+    assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1)
+    assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1)
+    assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1)
+
+    task_name = args.task_name.lower()
+    if task_name not in PROCESSORS:
+        raise ValueError("Task not found: %s" % (task_name))
+
+    processor = PROCESSORS[task_name](args.dataset_config)
+    dst_slot_list = processor.slot_list
+    dst_class_types = processor.class_types
+    dst_class_labels = len(dst_class_types)
+
+    # Setup CUDA, GPU & distributed training
+    if args.local_rank == -1 or args.no_cuda:
+        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
+        args.n_gpu = torch.cuda.device_count()
+    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+        torch.cuda.set_device(args.local_rank)
+        device = torch.device("cuda", args.local_rank)
+        torch.distributed.init_process_group(backend='nccl')
+        args.n_gpu = 1
+    args.device = device
+
+    # Setup logging
+    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
+                        datefmt = '%m/%d/%Y %H:%M:%S',
+                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
+    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+                   args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
+
+    # Set seed
+    set_seed(args)
+
+    # Load pretrained model and tokenizer
+    if args.local_rank not in [-1, 0]:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
+
+    args.model_type = args.model_type.lower()
+    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
+
+    # Add DST specific parameters to config
+    config.dst_dropout_rate = args.dropout_rate
+    config.dst_heads_dropout_rate = args.heads_dropout
+    config.dst_class_loss_ratio = args.class_loss_ratio
+    config.dst_token_loss_for_nonpointable = args.token_loss_for_nonpointable
+    config.dst_refer_loss_for_nonpointable = args.refer_loss_for_nonpointable
+    config.dst_class_aux_feats_inform = args.class_aux_feats_inform
+    config.dst_class_aux_feats_ds = args.class_aux_feats_ds
+    config.dst_slot_list = dst_slot_list
+    config.dst_class_types = dst_class_types
+    config.dst_class_labels = dst_class_labels
+
+    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
+    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
+
+    logger.info("Updated model config: %s" % config)
+
+    if args.local_rank == 0:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
+
+    model.to(args.device)
+
+    logger.info("Training/evaluation parameters %s", args)
+
+    # Training
+    if args.do_train:
+        # If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set)
+        continue_from_global_step = 0 # If set to 0, start training from the beginning
+        if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
+            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/*/' + WEIGHTS_NAME, recursive=True)))
+            if len(checkpoints) > 0:
+                checkpoint = checkpoints[-1]
+                logger.info("Resuming training from the latest checkpoint: %s", checkpoint)
+                continue_from_global_step = int(checkpoint.split('-')[-1])
+                model = model_class.from_pretrained(checkpoint)
+                model.to(args.device)
+        
+        train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False)
+        global_step, tr_loss = train(args, train_dataset, features, model, tokenizer, processor, continue_from_global_step)
+        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
+
+    # Save the trained model and the tokenizer
+    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
+        # Create output directory if needed
+        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
+            os.makedirs(args.output_dir)
+
+        logger.info("Saving model checkpoint to %s", args.output_dir)
+        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
+        # They can then be reloaded using `from_pretrained()`
+        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
+        model_to_save.save_pretrained(args.output_dir)
+        tokenizer.save_pretrained(args.output_dir)
+
+        # Good practice: save your training arguments together with the trained model
+        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
+
+        # Load a trained model and vocabulary that you have fine-tuned
+        model = model_class.from_pretrained(args.output_dir)
+        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
+        model.to(args.device)
+
+    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
+    results = []
+    if args.do_eval and args.local_rank in [-1, 0]:
+        output_eval_file = os.path.join(args.output_dir, "eval_res.%s.json" % (args.predict_type))
+        checkpoints = [args.output_dir]
+        if args.eval_all_checkpoints:
+            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
+            logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
+
+        logger.info("Evaluate the following checkpoints: %s", checkpoints)
+
+        for cItr, checkpoint in enumerate(checkpoints):
+            # Reload the model
+            global_step = checkpoint.split('-')[-1]
+            if cItr == len(checkpoints) - 1:
+                global_step = "final"
+            model = model_class.from_pretrained(checkpoint)
+            model.to(args.device)
+
+            # Evaluate
+            result = evaluate(args, model, tokenizer, processor, prefix=global_step)
+            result_dict = {k: float(v) for k, v in result.items()}
+            result_dict["global_step"] = global_step
+            results.append(result_dict)
+
+            for key in sorted(result_dict.keys()):
+                logger.info("%s = %s", key, str(result_dict[key]))
+
+        with open(output_eval_file, "w") as f:
+            json.dump(results, f, indent=2)
+
+    return results
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tensorlistdataset.py b/tensorlistdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b27b3cd494ed316647ef727d4b5ee957feccc18
--- /dev/null
+++ b/tensorlistdataset.py
@@ -0,0 +1,57 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from torch.utils.data import Dataset
+
+
+class TensorListDataset(Dataset):
+    r"""Dataset wrapping tensors, tensor dicts and tensor lists.
+
+    Arguments:
+        *data (Tensor or dict or list of Tensors): tensors that have the same size
+        of the first dimension.
+    """
+
+    def __init__(self, *data):
+        if isinstance(data[0], dict):
+            size = list(data[0].values())[0].size(0)
+        elif isinstance(data[0], list):
+            size = data[0][0].size(0)
+        else:
+            size = data[0].size(0)
+        for element in data:
+            if isinstance(element, dict):
+                assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors
+            elif isinstance(element, list):
+                assert all(size == tensor.size(0) for tensor in element) # list of tensors
+            else:
+                assert size == element.size(0) # tensor
+        self.size = size
+        self.data = data
+
+    def __getitem__(self, index):
+        result = []
+        for element in self.data:
+            if isinstance(element, dict):
+                result.append({k: v[index] for k, v in element.items()})
+            elif isinstance(element, list):
+                result.append(v[index] for v in element)
+            else:
+                result.append(element[index])
+        return tuple(result)
+
+    def __len__(self):
+        return self.size
diff --git a/utils_dst.py b/utils_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91971da86f4f9185f3428e52a09f48b5247343d
--- /dev/null
+++ b/utils_dst.py
@@ -0,0 +1,429 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import six
+import numpy as np
+import json
+
+logger = logging.getLogger(__name__)
+
+
+class DSTExample(object):
+    """
+    A single training/test example for the DST dataset.
+    """
+
+    def __init__(self,
+                 guid,
+                 text_a,
+                 text_b,
+                 history,
+                 text_a_label=None,
+                 text_b_label=None,
+                 history_label=None,
+                 values=None,
+                 inform_label=None,
+                 inform_slot_label=None,
+                 refer_label=None,
+                 diag_state=None,
+                 class_label=None):
+        self.guid = guid
+        self.text_a = text_a
+        self.text_b = text_b
+        self.history = history
+        self.text_a_label = text_a_label
+        self.text_b_label = text_b_label
+        self.history_label = history_label
+        self.values = values
+        self.inform_label = inform_label
+        self.inform_slot_label = inform_slot_label
+        self.refer_label = refer_label
+        self.diag_state = diag_state
+        self.class_label = class_label
+
+    def __str__(self):
+        return self.__repr__()
+
+    def __repr__(self):
+        s = ""
+        s += "guid: %s" % (self.guid)
+        s += ", text_a: %s" % (self.text_a)
+        s += ", text_b: %s" % (self.text_b)
+        s += ", history: %s" % (self.history)
+        if self.text_a_label:
+            s += ", text_a_label: %d" % (self.text_a_label)
+        if self.text_b_label:
+            s += ", text_b_label: %d" % (self.text_b_label)
+        if self.history_label:
+            s += ", history_label: %d" % (self.history_label)
+        if self.values:
+            s += ", values: %d" % (self.values)
+        if self.inform_label:
+            s += ", inform_label: %d" % (self.inform_label)
+        if self.inform_slot_label:
+            s += ", inform_slot_label: %d" % (self.inform_slot_label)
+        if self.refer_label:
+            s += ", refer_label: %d" % (self.refer_label)
+        if self.diag_state:
+            s += ", diag_state: %d" % (self.diag_state)
+        if self.class_label:
+            s += ", class_label: %d" % (self.class_label)
+        return s
+
+
+class InputFeatures(object):
+    """A single set of features of data."""
+
+    def __init__(self,
+                 input_ids,
+                 input_ids_unmasked,
+                 input_mask,
+                 segment_ids,
+                 start_pos=None,
+                 end_pos=None,
+                 values=None,
+                 inform=None,
+                 inform_slot=None,
+                 refer_id=None,
+                 diag_state=None,
+                 class_label_id=None,
+                 guid="NONE"):
+        self.guid = guid
+        self.input_ids = input_ids
+        self.input_ids_unmasked = input_ids_unmasked
+        self.input_mask = input_mask
+        self.segment_ids = segment_ids
+        self.start_pos = start_pos
+        self.end_pos = end_pos
+        self.values = values
+        self.inform = inform
+        self.inform_slot = inform_slot
+        self.refer_id = refer_id
+        self.diag_state = diag_state
+        self.class_label_id = class_label_id
+
+
+def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0):
+    """Loads a data file into a list of `InputBatch`s."""
+
+    if model_type == 'bert':
+        model_specs = {'MODEL_TYPE': 'bert',
+                       'CLS_TOKEN': '[CLS]',
+                       'UNK_TOKEN': '[UNK]',
+                       'SEP_TOKEN': '[SEP]',
+                       'TOKEN_CORRECTION': 4}
+    else:
+        logger.error("Unknown model type (%s). Aborting." % (model_type))
+        exit(1)
+
+    def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, model_specs, slot_value_dropout):
+        joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label
+        for slot_text_label in text_label_dict.values():
+            for idx, label in enumerate(slot_text_label):
+                if label == 1:
+                    joint_text_label[idx] = 1
+
+        text_label = text_label_dict[slot]
+        tokens = []
+        tokens_unmasked = []
+        token_labels = []
+        for token, token_label, joint_label in zip(text, text_label, joint_text_label):
+            token = convert_to_unicode(token)
+            sub_tokens = tokenizer.tokenize(token) # Most time intensive step
+            tokens_unmasked.extend(sub_tokens)
+            if slot_value_dropout == 0.0 or joint_label == 0:
+                tokens.extend(sub_tokens)
+            else:
+                rn_list = np.random.random_sample((len(sub_tokens),))
+                for rn, sub_token in zip(rn_list, sub_tokens):
+                    if rn > slot_value_dropout:
+                        tokens.append(sub_token)
+                    else:
+                        tokens.append(model_specs['UNK_TOKEN'])
+            token_labels.extend([token_label for _ in sub_tokens])
+        assert len(tokens) == len(token_labels)
+        assert len(tokens_unmasked) == len(token_labels)
+        return tokens, tokens_unmasked, token_labels
+
+    def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
+        """Truncates a sequence pair in place to the maximum length.
+        Copied from bert/run_classifier.py
+        """
+        # This is a simple heuristic which will always truncate the longer sequence
+        # one token at a time. This makes more sense than truncating an equal percent
+        # of tokens from each, since if one sequence is very short then each token
+        # that's truncated likely contains more information than a longer sequence.
+        while True:
+            total_length = len(tokens_a) + len(tokens_b) + len(history)
+            if total_length <= max_length:
+                break
+            if len(history) > 0:
+                history.pop()
+            elif len(tokens_a) > len(tokens_b):
+                tokens_a.pop()
+            else:
+                tokens_b.pop()
+    
+    def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
+        # Modifies `tokens_a` and `tokens_b` in place so that the total
+        # length is less than the specified length.
+        # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT)
+        if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
+            logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
+            input_text_too_long = True
+        else:
+            input_text_too_long = False
+        _truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
+        return input_text_too_long
+
+    def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
+        token_label_ids = []
+        token_label_ids.append(0) # [CLS]
+        for token_label in token_labels_a:
+            token_label_ids.append(token_label)
+        token_label_ids.append(0) # [SEP]
+        for token_label in token_labels_b:
+            token_label_ids.append(token_label)
+        token_label_ids.append(0) # [SEP]
+        for token_label in token_labels_history:
+            token_label_ids.append(token_label)
+        token_label_ids.append(0) # [SEP]
+        while len(token_label_ids) < max_seq_length:
+            token_label_ids.append(0) # padding
+        assert len(token_label_ids) == max_seq_length
+        return token_label_ids
+
+    def _get_start_end_pos(class_type, token_label_ids, max_seq_length):
+        if class_type == 'copy_value' and 1 not in token_label_ids:
+            #logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.")
+            class_type = 'none'
+        start_pos = 0
+        end_pos = 0
+        if 1 in token_label_ids:
+            start_pos = token_label_ids.index(1)
+            # Parsing is supposed to find only first location of wanted value
+            if 0 not in token_label_ids[start_pos:]:
+                end_pos = len(token_label_ids[start_pos:]) + start_pos - 1
+            else:
+                end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1
+            for i in range(max_seq_length):
+                if i >= start_pos and i <= end_pos:
+                    assert token_label_ids[i] == 1
+        return class_type, start_pos, end_pos
+
+    def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
+        # The convention in BERT is:
+        # (a) For sequence pairs:
+        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+        #  type_ids: 0     0  0    0    0     0       0 0     1  1  1  1   1 1
+        # (b) For single sequences:
+        #  tokens:   [CLS] the dog is hairy . [SEP]
+        #  type_ids: 0     0   0   0  0     0 0
+        #
+        # Where "type_ids" are used to indicate whether this is the first
+        # sequence or the second sequence. The embedding vectors for `type=0` and
+        # `type=1` were learned during pre-training and are added to the wordpiece
+        # embedding vector (and position vector). This is not *strictly* necessary
+        # since the [SEP] token unambiguously separates the sequences, but it makes
+        # it easier for the model to learn the concept of sequences.
+        #
+        # For classification tasks, the first vector (corresponding to [CLS]) is
+        # used as the "sentence vector". Note that this only makes sense because
+        # the entire model is fine-tuned.
+        tokens = []
+        segment_ids = []
+        tokens.append(model_specs['CLS_TOKEN'])
+        segment_ids.append(0)
+        for token in tokens_a:
+            tokens.append(token)
+            segment_ids.append(0)
+        tokens.append(model_specs['SEP_TOKEN'])
+        segment_ids.append(0)
+        for token in tokens_b:
+            tokens.append(token)
+            segment_ids.append(1)
+        tokens.append(model_specs['SEP_TOKEN'])
+        segment_ids.append(1)
+        for token in history:
+            tokens.append(token)
+            segment_ids.append(1)
+        tokens.append(model_specs['SEP_TOKEN'])
+        segment_ids.append(1)
+        input_ids = tokenizer.convert_tokens_to_ids(tokens)
+        # The mask has 1 for real tokens and 0 for padding tokens. Only real
+        # tokens are attended to.
+        input_mask = [1] * len(input_ids)
+        # Zero-pad up to the sequence length.
+        while len(input_ids) < max_seq_length:
+            input_ids.append(0)
+            input_mask.append(0)
+            segment_ids.append(0)
+        assert len(input_ids) == max_seq_length
+        assert len(input_mask) == max_seq_length
+        assert len(segment_ids) == max_seq_length
+        return tokens, input_ids, input_mask, segment_ids
+    
+    total_cnt = 0
+    too_long_cnt = 0
+
+    refer_list = ['none'] + slot_list
+
+    features = []
+    # Convert single example
+    for (example_index, example) in enumerate(examples):
+        if example_index % 1000 == 0:
+            logger.info("Writing example %d of %d" % (example_index, len(examples)))
+
+        total_cnt += 1
+
+        value_dict = {}
+        inform_dict = {}
+        inform_slot_dict = {}
+        refer_id_dict = {}
+        diag_state_dict = {}
+        class_label_id_dict = {}
+        start_pos_dict = {}
+        end_pos_dict = {}
+        for slot in slot_list:
+            tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label(
+                example.text_a, example.text_a_label, slot, tokenizer, model_specs, slot_value_dropout)
+            tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label(
+                example.text_b, example.text_b_label, slot, tokenizer, model_specs, slot_value_dropout)
+            tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label(
+                example.history, example.history_label, slot, tokenizer, model_specs, slot_value_dropout)
+
+            input_text_too_long = _truncate_length_and_warn(
+                tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
+
+            if input_text_too_long:
+                if example_index < 10:
+                    if len(token_labels_a) > len(tokens_a):
+                        logger.info('    tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
+                    if len(token_labels_b) > len(tokens_b):
+                        logger.info('    tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
+                    if len(token_labels_history) > len(tokens_history):
+                        logger.info('    tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):]))
+
+                token_labels_a = token_labels_a[:len(tokens_a)]
+                token_labels_b = token_labels_b[:len(tokens_b)]
+                token_labels_history = token_labels_history[:len(tokens_history)]
+                tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)]
+                tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)]
+                tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)]
+
+            assert len(token_labels_a) == len(tokens_a)
+            assert len(token_labels_b) == len(tokens_b)
+            assert len(token_labels_history) == len(tokens_history)
+            assert len(token_labels_a) == len(tokens_a_unmasked)
+            assert len(token_labels_b) == len(tokens_b_unmasked)
+            assert len(token_labels_history) == len(tokens_history_unmasked)
+            token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs)
+
+            value_dict[slot] = example.values[slot]
+            inform_dict[slot] = example.inform_label[slot]
+
+            class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos(
+                example.class_label[slot], token_label_ids, max_seq_length)
+            if class_label_mod != example.class_label[slot]:
+                example.class_label[slot] = class_label_mod
+            inform_slot_dict[slot] = example.inform_slot_label[slot]
+            refer_id_dict[slot] = refer_list.index(example.refer_label[slot])
+            diag_state_dict[slot] = class_types.index(example.diag_state[slot])
+            class_label_id_dict[slot] = class_types.index(example.class_label[slot])
+
+        if input_text_too_long:
+            too_long_cnt += 1
+            
+        tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
+                                                                            tokens_b,
+                                                                            tokens_history,
+                                                                            max_seq_length,
+                                                                            tokenizer,
+                                                                            model_specs)
+        if slot_value_dropout > 0.0:
+            _, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked,
+                                                                 tokens_b_unmasked,
+                                                                 tokens_history_unmasked,
+                                                                 max_seq_length,
+                                                                 tokenizer,
+                                                                 model_specs)
+        else:
+            input_ids_unmasked = input_ids
+
+        assert(len(input_ids) == len(input_ids_unmasked))
+
+        if example_index < 10:
+            logger.info("*** Example ***")
+            logger.info("guid: %s" % (example.guid))
+            logger.info("tokens: %s" % " ".join(tokens))
+            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
+            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
+            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
+            logger.info("start_pos: %s" % str(start_pos_dict))
+            logger.info("end_pos: %s" % str(end_pos_dict))
+            logger.info("values: %s" % str(value_dict))
+            logger.info("inform: %s" % str(inform_dict))
+            logger.info("inform_slot: %s" % str(inform_slot_dict))
+            logger.info("refer_id: %s" % str(refer_id_dict))
+            logger.info("diag_state: %s" % str(diag_state_dict))
+            logger.info("class_label_id: %s" % str(class_label_id_dict))
+
+        features.append(
+            InputFeatures(
+                guid=example.guid,
+                input_ids=input_ids,
+                input_ids_unmasked=input_ids_unmasked,
+                input_mask=input_mask,
+                segment_ids=segment_ids,
+                start_pos=start_pos_dict,
+                end_pos=end_pos_dict,
+                values=value_dict,
+                inform=inform_dict,
+                inform_slot=inform_slot_dict,
+                refer_id=refer_id_dict,
+                diag_state=diag_state_dict,
+                class_label_id=class_label_id_dict))
+
+    logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))
+
+    return features
+
+
+# From bert.tokenization (TF code)
+def convert_to_unicode(text):
+    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
+    if six.PY3:
+        if isinstance(text, str):
+            return text
+        elif isinstance(text, bytes):
+            return text.decode("utf-8", "ignore")
+        else:
+            raise ValueError("Unsupported string type: %s" % (type(text)))
+    elif six.PY2:
+        if isinstance(text, str):
+            return text.decode("utf-8", "ignore")
+        elif isinstance(text, unicode):
+            return text
+        else:
+            raise ValueError("Unsupported string type: %s" % (type(text)))
+    else:
+        raise ValueError("Not running on Python2 or Python 3?")