From 45089e7a93a416a67ae4654e2ef096096412b116 Mon Sep 17 00:00:00 2001
From: Michael Heck <michael.heck@hhu.de>
Date: Mon, 19 Dec 2022 16:05:44 +0100
Subject: [PATCH] Added support for ConvLab-3's unified data format. Added
 faster caching. Added transformers 4 support.

---
 .gitattributes                                |  117 ++
 DO.example.recommended => DO.example          |   36 +-
 DO.example.mtl                                |   34 +-
 DO.example.advanced => DO.example.paper       |   40 +-
 DO.example.simple                             |   58 -
 README.md                                     |   38 +-
 .../README.md                                 |    0
 .../dialogue_acts.json.gz                     |  Bin
 .../test_dials.json.gz                        |  Bin
 .../train_dials.json.gz                       |  Bin
 .../val_dials.json.gz                         |  Bin
 data/README.md                                |   50 +
 data/split_multiwoz_data.py                   |   65 +
 data_processors.py                            |   66 +-
 dataset_aux_task.py                           |    2 +-
 dataset_config/unified_multiwoz21.json        | 1282 +++++++++++++++++
 dataset_multiwoz21.py                         |  218 +--
 dataset_multiwoz21_legacy.py                  |  321 +++++
 dataset_sim.py                                |   17 +-
 dataset_unified.py                            |  347 +++++
 dataset_woz2.py                               |   32 +-
 metric_bert_dst.py => metric_dst.py           |   30 +-
 modeling_bert_dst.py                          |  213 ---
 modeling_dst.py                               |  251 ++++
 modeling_roberta_dst.py                       |  236 ---
 run_dst.py                                    |   99 +-
 run_dst_mtl.py                                |  380 +----
 tensorlistdataset.py                          |    2 +-
 utils_dst.py                                  |  200 +--
 29 files changed, 2970 insertions(+), 1164 deletions(-)
 create mode 100644 .gitattributes
 rename DO.example.recommended => DO.example (61%)
 rename DO.example.advanced => DO.example.paper (60%)
 delete mode 100755 DO.example.simple
 rename data/{MULTIWOZ2.1 => MULTIWOZ2.1_legacy}/README.md (100%)
 rename data/{MULTIWOZ2.1 => MULTIWOZ2.1_legacy}/dialogue_acts.json.gz (100%)
 rename data/{MULTIWOZ2.1 => MULTIWOZ2.1_legacy}/test_dials.json.gz (100%)
 rename data/{MULTIWOZ2.1 => MULTIWOZ2.1_legacy}/train_dials.json.gz (100%)
 rename data/{MULTIWOZ2.1 => MULTIWOZ2.1_legacy}/val_dials.json.gz (100%)
 create mode 100644 data/README.md
 create mode 100644 data/split_multiwoz_data.py
 create mode 100644 dataset_config/unified_multiwoz21.json
 create mode 100644 dataset_multiwoz21_legacy.py
 create mode 100644 dataset_unified.py
 rename metric_bert_dst.py => metric_dst.py (94%)
 delete mode 100644 modeling_bert_dst.py
 create mode 100644 modeling_dst.py
 delete mode 100644 modeling_roberta_dst.py

diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..9fd0f15
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,117 @@
+# Store binaries in LFS
+## Custom paths
+data/ filter=lfs diff=lfs merge=lfs -text
+
+## Archive/Compressed
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.cpio filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.iso filter=lfs diff=lfs merge=lfs -text
+*.bz filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.bzip filter=lfs diff=lfs merge=lfs -text
+*.bzip2 filter=lfs diff=lfs merge=lfs -text
+*.cab filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.gzip filter=lfs diff=lfs merge=lfs -text
+*.lz filter=lfs diff=lfs merge=lfs -text
+*.lzma filter=lfs diff=lfs merge=lfs -text
+*.lzo filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.z filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.ace filter=lfs diff=lfs merge=lfs -text
+*.dmg filter=lfs diff=lfs merge=lfs -text
+*.dd filter=lfs diff=lfs merge=lfs -text
+*.apk filter=lfs diff=lfs merge=lfs -text
+*.ear filter=lfs diff=lfs merge=lfs -text
+*.jar filter=lfs diff=lfs merge=lfs -text
+*.deb filter=lfs diff=lfs merge=lfs -text
+*.cue filter=lfs diff=lfs merge=lfs -text
+*.dump filter=lfs diff=lfs merge=lfs -text
+
+## Image
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.psd filter=lfs diff=lfs merge=lfs -text
+*.bmp filter=lfs diff=lfs merge=lfs -text
+*.dng filter=lfs diff=lfs merge=lfs -text
+*.cdr filter=lfs diff=lfs merge=lfs -text
+*.indd filter=lfs diff=lfs merge=lfs -text
+*.tiff filter=lfs diff=lfs merge=lfs -text
+*.tif filter=lfs diff=lfs merge=lfs -text
+*.psp filter=lfs diff=lfs merge=lfs -text
+*.tga filter=lfs diff=lfs merge=lfs -text
+*.eps filter=lfs diff=lfs merge=lfs -text
+*.svg filter=lfs diff=lfs merge=lfs -text
+
+## Documents
+*.pdf filter=lfs diff=lfs merge=lfs -text
+*.doc filter=lfs diff=lfs merge=lfs -text
+*.docx filter=lfs diff=lfs merge=lfs -text
+*.xls filter=lfs diff=lfs merge=lfs -text
+*.xlsx filter=lfs diff=lfs merge=lfs -text
+*.ppt filter=lfs diff=lfs merge=lfs -text
+*.pptx filter=lfs diff=lfs merge=lfs -text
+*.ppz filter=lfs diff=lfs merge=lfs -text
+*.dot filter=lfs diff=lfs merge=lfs -text
+*.dotx filter=lfs diff=lfs merge=lfs -text
+*.lwp filter=lfs diff=lfs merge=lfs -text
+*.odm filter=lfs diff=lfs merge=lfs -text
+*.odt filter=lfs diff=lfs merge=lfs -text
+*.ott filter=lfs diff=lfs merge=lfs -text
+*.ods filter=lfs diff=lfs merge=lfs -text
+*.ots filter=lfs diff=lfs merge=lfs -text
+*.odp filter=lfs diff=lfs merge=lfs -text
+*.otp filter=lfs diff=lfs merge=lfs -text
+*.odg filter=lfs diff=lfs merge=lfs -text
+*.otg filter=lfs diff=lfs merge=lfs -text
+*.wps filter=lfs diff=lfs merge=lfs -text
+*.wpd filter=lfs diff=lfs merge=lfs -text
+*.wpt filter=lfs diff=lfs merge=lfs -text
+*.xps filter=lfs diff=lfs merge=lfs -text
+*.ttf filter=lfs diff=lfs merge=lfs -text
+*.otf filter=lfs diff=lfs merge=lfs -text
+*.dvi filter=lfs diff=lfs merge=lfs -text
+*.pages filter=lfs diff=lfs merge=lfs -text
+*.key filter=lfs diff=lfs merge=lfs -text
+
+## Audio/Video
+*.mpg filter=lfs diff=lfs merge=lfs -text
+*.mpeg filter=lfs diff=lfs merge=lfs -text
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
+*.avi filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
+*.mkv filter=lfs diff=lfs merge=lfs -text
+*.3gp filter=lfs diff=lfs merge=lfs -text
+*.flv filter=lfs diff=lfs merge=lfs -text
+*.m4v filter=lfs diff=lfs merge=lfs -text
+*.ogg filter=lfs diff=lfs merge=lfs -text
+*.mov filter=lfs diff=lfs merge=lfs -text
+*.wmv filter=lfs diff=lfs merge=lfs -text
+*.webm filter=lfs diff=lfs merge=lfs -text
+
+## VM
+*.vfd filter=lfs diff=lfs merge=lfs -text
+*.vhd filter=lfs diff=lfs merge=lfs -text
+*.vmdk filter=lfs diff=lfs merge=lfs -text
+*.vmsd filter=lfs diff=lfs merge=lfs -text
+*.vmsn filter=lfs diff=lfs merge=lfs -text
+*.vmss filter=lfs diff=lfs merge=lfs -text
+*.dsk filter=lfs diff=lfs merge=lfs -text
+*.vdi filter=lfs diff=lfs merge=lfs -text
+*.cow filter=lfs diff=lfs merge=lfs -text
+*.qcow filter=lfs diff=lfs merge=lfs -text
+*.qcow2 filter=lfs diff=lfs merge=lfs -text
+*.qed filter=lfs diff=lfs merge=lfs -text
+
+## Other
+*.exe filter=lfs diff=lfs merge=lfs -text
+*.sxi filter=lfs diff=lfs merge=lfs -text
+*.dat filter=lfs diff=lfs merge=lfs -text
+*.data filter=lfs diff=lfs merge=lfs -text
diff --git a/DO.example.recommended b/DO.example
similarity index 61%
rename from DO.example.recommended
rename to DO.example
index 12491b3..5b09a29 100755
--- a/DO.example.recommended
+++ b/DO.example
@@ -2,14 +2,30 @@
 
 # Parameters ------------------------------------------------------
 
+# --- Sim-M dataset
 #TASK="sim-m"
 #DATA_DIR="data/simulated-dialogue/sim-M"
+#DATASET_CONFIG="dataset_config/sim-m.json"
+# --- Sim-R dataset
 #TASK="sim-r"
 #DATA_DIR="data/simulated-dialogue/sim-R"
+#DATASET_CONFIG="dataset_config/sim-r.json"
+# --- WOZ 2.0 dataset
 #TASK="woz2"
 #DATA_DIR="data/woz2"
+#DATASET_CONFIG="dataset_config/woz2.json"
+# --- MultiWOZ 2.1 legacy version dataset
+#TASK="multiwoz21_legacy"
+#DATA_DIR="data/MULTIWOZ2.1"
+#DATASET_CONFIG="dataset_config/multiwoz21.json"
+# --- MultiWOZ 2.1 dataset
 TASK="multiwoz21"
-DATA_DIR="data/MULTIWOZ2.1"
+DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
+DATASET_CONFIG="dataset_config/multiwoz21.json"
+# --- MultiWOZ 2.1 in ConvLab3's unified data format
+#TASK="unified"
+#DATA_DIR=""
+#DATASET_CONFIG="dataset_config/unified_multiwoz21.json"
 
 # Project paths etc. ----------------------------------------------
 
@@ -29,34 +45,28 @@ for step in train dev test; do
     python3 run_dst.py \
 	    --task_name=${TASK} \
 	    --data_dir=${DATA_DIR} \
-	    --dataset_config=dataset_config/${TASK}.json \
+	    --dataset_config=${DATASET_CONFIG} \
 	    --model_type="roberta" \
 	    --model_name_or_path="roberta-base" \
 	    --do_lower_case \
 	    --learning_rate=1e-4 \
 	    --num_train_epochs=10 \
 	    --max_seq_length=180 \
-	    --per_gpu_train_batch_size=16 \
-	    --per_gpu_eval_batch_size=16 \
+	    --per_gpu_train_batch_size=32 \
+	    --per_gpu_eval_batch_size=32 \
 	    --output_dir=${OUT_DIR} \
 	    --save_epochs=2 \
-	    --logging_steps=10 \
 	    --warmup_proportion=0.1 \
 	    --eval_all_checkpoints \
 	    --adam_epsilon=1e-6 \
             --weight_decay=0.01 \
-	    --label_value_repetitions \
-            --swap_utterances \
-	    --append_history \
-	    --use_history_labels \
 	    ${args_add} \
 	    2>&1 | tee ${OUT_DIR}/${step}.log
     
     if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
-    	python3 metric_bert_dst.py \
-    		${TASK} \
-		dataset_config/${TASK}.json \
-    		"${OUT_DIR}/pred_res.${step}*json" \
+    	python3 metric_dst.py \
+		--dataset_config=${DATASET_CONFIG} \
+    		--file_list="${OUT_DIR}/pred_res.${step}*json" \
     		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
     fi
 done
diff --git a/DO.example.mtl b/DO.example.mtl
index 7963eb6..54e436b 100755
--- a/DO.example.mtl
+++ b/DO.example.mtl
@@ -2,14 +2,30 @@
 
 # Parameters ------------------------------------------------------
 
+# --- Sim-M dataset
 #TASK="sim-m"
 #DATA_DIR="data/simulated-dialogue/sim-M"
+#DATASET_CONFIG="dataset_config/sim-m.json"
+# --- Sim-R dataset
 #TASK="sim-r"
 #DATA_DIR="data/simulated-dialogue/sim-R"
+#DATASET_CONFIG="dataset_config/sim-r.json"
+# --- WOZ 2.0 dataset
 #TASK="woz2"
 #DATA_DIR="data/woz2"
+#DATASET_CONFIG="dataset_config/woz2.json"
+# --- MultiWOZ 2.1 legacy version dataset
+#TASK="multiwoz21_legacy"
+#DATA_DIR="data/MULTIWOZ2.1"
+#DATASET_CONFIG="dataset_config/multiwoz21.json"
+# --- MultiWOZ 2.1 dataset
 TASK="multiwoz21"
-DATA_DIR="data/MULTIWOZ2.1"
+DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
+DATASET_CONFIG="dataset_config/multiwoz21.json"
+# --- MultiWOZ 2.1 in ConvLab3's unified data format
+#TASK="unified"
+#DATA_DIR=""
+#DATASET_CONFIG="dataset_config/unified_multiwoz21.json"
 
 AUX_TASK="cola"
 AUX_DATA_DIR="data/aux/roberta_base_cased_lower"
@@ -24,7 +40,7 @@ mkdir -p ${OUT_DIR}
 for step in train dev test; do
     args_add=""
     if [ "$step" = "train" ]; then
-	args_add="--do_train --predict_type=dummy"
+	args_add="--do_train --predict_type=dummy" # INFO: For sim-M, we recommend to add "--svd=0.3"
     elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
 	args_add="--do_eval --predict_type=${step}"
     fi
@@ -32,7 +48,7 @@ for step in train dev test; do
     python3 run_dst_mtl.py \
 	    --task_name=${TASK} \
 	    --data_dir=${DATA_DIR} \
-	    --dataset_config=dataset_config/${TASK}.json \
+	    --dataset_config=${DATASET_CONFIG} \
 	    --model_type="roberta" \
 	    --model_name_or_path="roberta-base" \
 	    --do_lower_case \
@@ -43,16 +59,11 @@ for step in train dev test; do
 	    --per_gpu_eval_batch_size=1 \
 	    --output_dir=${OUT_DIR} \
 	    --save_epochs=2 \
-	    --logging_steps=10 \
 	    --warmup_proportion=0.1 \
 	    --eval_all_checkpoints \
 	    --adam_epsilon=1e-6 \
             --weight_decay=0.01 \
 	    --heads_dropout=0.1 \
-	    --label_value_repetitions \
-            --swap_utterances \
-	    --append_history \
-	    --use_history_labels \
 	    --delexicalize_sys_utts \
 	    --class_aux_feats_inform \
 	    --class_aux_feats_ds \
@@ -65,10 +76,9 @@ for step in train dev test; do
 	    2>&1 | tee ${OUT_DIR}/${step}.log
     
     if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
-    	python3 metric_bert_dst.py \
-    		${TASK} \
-		dataset_config/${TASK}.json \
-    		"${OUT_DIR}/pred_res.${step}*json" \
+    	python3 metric_dst.py \
+		--dataset_config=${DATASET_CONFIG} \
+    		--file_list="${OUT_DIR}/pred_res.${step}*json" \
     		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
     fi
 done
diff --git a/DO.example.advanced b/DO.example.paper
similarity index 60%
rename from DO.example.advanced
rename to DO.example.paper
index 5c68bbd..a153a9c 100755
--- a/DO.example.advanced
+++ b/DO.example.paper
@@ -2,14 +2,30 @@
 
 # Parameters ------------------------------------------------------
 
+# --- Sim-M dataset
 #TASK="sim-m"
 #DATA_DIR="data/simulated-dialogue/sim-M"
+#DATASET_CONFIG="dataset_config/sim-m.json"
+# --- Sim-R dataset
 #TASK="sim-r"
 #DATA_DIR="data/simulated-dialogue/sim-R"
-TASK="woz2"
-DATA_DIR="data/woz2"
-#TASK="multiwoz21"
+#DATASET_CONFIG="dataset_config/sim-r.json"
+# --- WOZ 2.0 dataset
+#TASK="woz2"
+#DATA_DIR="data/woz2"
+#DATASET_CONFIG="dataset_config/woz2.json"
+# --- MultiWOZ 2.1 legacy version dataset
+#TASK="multiwoz21_legacy"
 #DATA_DIR="data/MULTIWOZ2.1"
+#DATASET_CONFIG="dataset_config/multiwoz21.json"
+# --- MultiWOZ 2.1 dataset
+TASK="multiwoz21"
+DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
+DATASET_CONFIG="dataset_config/multiwoz21.json"
+# --- MultiWOZ 2.1 in ConvLab3's unified data format
+#TASK="unified"
+#DATA_DIR=""
+#DATASET_CONFIG="dataset_config/unified_multiwoz21.json"
 
 # Project paths etc. ----------------------------------------------
 
@@ -21,7 +37,7 @@ mkdir -p ${OUT_DIR}
 for step in train dev test; do
     args_add=""
     if [ "$step" = "train" ]; then
-	args_add="--do_train --predict_type=dummy"
+	args_add="--do_train --predict_type=dummy" # INFO: For sim-M, we recommend to add "--svd=0.3"
     elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
 	args_add="--do_eval --predict_type=${step}"
     fi
@@ -29,7 +45,7 @@ for step in train dev test; do
     python3 run_dst.py \
 	    --task_name=${TASK} \
 	    --data_dir=${DATA_DIR} \
-	    --dataset_config=dataset_config/${TASK}.json \
+	    --dataset_config=${DATASET_CONFIG} \
 	    --model_type="bert" \
 	    --model_name_or_path="bert-base-uncased" \
 	    --do_lower_case \
@@ -40,25 +56,19 @@ for step in train dev test; do
 	    --per_gpu_eval_batch_size=1 \
 	    --output_dir=${OUT_DIR} \
 	    --save_epochs=2 \
-	    --logging_steps=10 \
 	    --warmup_proportion=0.1 \
 	    --eval_all_checkpoints \
 	    --adam_epsilon=1e-6 \
-	    --label_value_repetitions \
-            --swap_utterances \
-	    --append_history \
-	    --use_history_labels \
 	    --delexicalize_sys_utts \
 	    --class_aux_feats_inform \
 	    --class_aux_feats_ds \
 	    ${args_add} \
 	    2>&1 | tee ${OUT_DIR}/${step}.log
-    
+
     if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
-    	python3 metric_bert_dst.py \
-    		${TASK} \
-		dataset_config/${TASK}.json \
-    		"${OUT_DIR}/pred_res.${step}*json" \
+    	python3 metric_dst.py \
+		--dataset_config=${DATASET_CONFIG} \
+    		--file_list="${OUT_DIR}/pred_res.${step}*json" \
     		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
     fi
 done
diff --git a/DO.example.simple b/DO.example.simple
deleted file mode 100755
index 46585fd..0000000
--- a/DO.example.simple
+++ /dev/null
@@ -1,58 +0,0 @@
-#!/bin/bash
-
-# Parameters ------------------------------------------------------
-
-#TASK="sim-m"
-#DATA_DIR="data/simulated-dialogue/sim-M"
-#TASK="sim-r"
-#DATA_DIR="data/simulated-dialogue/sim-R"
-TASK="woz2"
-DATA_DIR="data/woz2"
-#TASK="multiwoz21"
-#DATA_DIR="data/MULTIWOZ2.1"
-
-# Project paths etc. ----------------------------------------------
-
-OUT_DIR=results
-mkdir -p ${OUT_DIR}
-
-# Main ------------------------------------------------------------
-
-for step in train dev test; do
-    args_add=""
-    if [ "$step" = "train" ]; then
-	args_add="--do_train --predict_type=dummy"
-    elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
-	args_add="--do_eval --predict_type=${step}"
-    fi
-
-    python3 run_dst.py \
-	    --task_name=${TASK} \
-	    --data_dir=${DATA_DIR} \
-	    --dataset_config=dataset_config/${TASK}.json \
-	    --model_type="bert" \
-	    --model_name_or_path="bert-base-uncased" \
-	    --do_lower_case \
-	    --learning_rate=1e-4 \
-	    --num_train_epochs=10 \
-	    --max_seq_length=180 \
-	    --per_gpu_train_batch_size=48 \
-	    --per_gpu_eval_batch_size=1 \
-	    --output_dir=${OUT_DIR} \
-	    --save_epochs=2 \
-	    --logging_steps=10 \
-	    --warmup_proportion=0.1 \
-	    --eval_all_checkpoints \
-	    --adam_epsilon=1e-6 \
-	    --label_value_repetitions \
-	    ${args_add} \
-	    2>&1 | tee ${OUT_DIR}/${step}.log
-
-    if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
-    	python3 metric_bert_dst.py \
-    		${TASK} \
-		dataset_config/${TASK}.json \
-    		"${OUT_DIR}/pred_res.${step}*json" \
-    		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
-    fi
-done
diff --git a/README.md b/README.md
index b9ebed5..83e3c1d 100644
--- a/README.md
+++ b/README.md
@@ -1,17 +1,23 @@
 ## Introduction
 
-TripPy is a new approach to dialogue state tracking (DST) which makes use of various copy mechanisms to fill slots with values. Our model has no need to maintain a list of candidate values. Instead, all values are extracted from the dialog context on-the-fly.
+TripPy is an approach to dialogue state tracking (DST) that makes use of various copy mechanisms to fill slots with values. Our model has no need to maintain a list of candidate values. Instead, all values are extracted from the dialog context on-the-fly.
 A slot is filled by one of three copy mechanisms:
 1. Span prediction may extract values directly from the user input;
 2. a value may be copied from a system inform memory that keeps track of the system’s inform operations;
 3. a value may be copied over from a different slot that is already contained in the dialog state to resolve coreferences within and across domains.
 Our approach combines the advantages of span-based slot filling methods with memory methods to avoid the use of value picklists altogether. We argue that our strategy simplifies the DST task while at the same time achieving state of the art performance on various popular evaluation sets including MultiWOZ 2.1.
 
+## Recent updates
+
+- 2022.12.19: Added support for ConvLab-3's unified data format. Added faster caching. Added transformers 4 support.
+
+- 2022.02.15: Added support for MultiWOZ versions 2.2, 2.3, 2.4
+
 ## How to run
 
-Two example scripts are provided for how to use TripPy. `DO.example.simple` will train and evaluate a simpler model, whereas `DO.example.advanced` uses the parameters that will result in performance similar to the reported ones. `DO.example.recommended` uses RoBERTa as encoder and the currently recommended set of hyperparameters. For more challenging datasets with longer dialogues, better performance may be achieved by using the maximum sequence length of 512.
+Two example scripts are provided for how to use TripPy. `DO.example` will train and evaluate a model with recommended settings. See below list for expected performance per dataset. `DO.example.paper` uses the parameters that were used for experiments in our paper "TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking". Thus, performance will be similar to the reported ones. For more challenging datasets with longer dialogues, better performance may be achieved by using the maximum sequence length of 512.
 
-`DO.example.mtl` will train a model with multi-task learning (MTL) using an auxiliary task (See our paper "Out-of-Task Training for Dialog State Tracking Models" for details).
+`DO.example.mtl` will train a model with multi-task learning (MTL) using an auxiliary task, using the parameters that we used in our paper "Out-of-Task Training for Dialog State Tracking Models".
 
 ## Datasets
 
@@ -20,13 +26,25 @@ Supported datasets are:
 - sim-R (https://github.com/google-research-datasets/simulated-dialogue.git)
 - WOZ 2.0 (see data/)
 - MultiWOZ 2.0 (https://github.com/budzianowski/multiwoz.git)
-- MultiWOZ 2.1 (see data/, https://github.com/budzianowski/multiwoz.git)
+- MultiWOZ 2.1 (https://github.com/budzianowski/multiwoz.git)
+- MultiWOZ 2.1 legacy version (see data/)
 - MultiWOZ 2.2 (https://github.com/budzianowski/multiwoz.git)
 - MultiWOZ 2.3 (https://github.com/lexmen318/MultiWOZ-coref.git)
 - MultiWOZ 2.4 (https://github.com/smartyfh/MultiWOZ2.4.git)
+- Unified data format (currently supported: MultiWOZ 2.1) (see https://github.com/ConvLab/ConvLab-3)
+
+See the README file in `data/` for more details how to obtain and prepare the datasets for use in TripPy.
+
+The ```--task_name``` is
+- 'sim-m', for sim-M
+- 'sim-r', for sim-R
+- 'woz2', for WOZ 2.0
+- 'multiwoz21', for MultiWOZ 2.0-2.4
+- 'multiwoz21_legacy', for MultiWOZ 2.1 legacy version
 
 With a sequence length of 180, you should expect the following average JGA:
 - 53% for MultiWOZ 2.0
+- 56% for MultiWOZ 2.1 legacy version
 - 56% for MultiWOZ 2.1
 - 56% for MultiWOZ 2.2
 - 63% for MultiWOZ 2.3
@@ -35,11 +53,17 @@ With a sequence length of 180, you should expect the following average JGA:
 - 90% for sim-R
 - 92% for WOZ 2.0
 
+## ConvLab-3
+
+TripPy is integrated in ConvLab-3 as ready-to-use dialogue state tracker. A checkpoint is available at HuggingFace (see the ConvLab-3 repo for more details).
+
+If you want to train your own TripPy model for ConvLab-3 from scratch, you can do so by using this code, setting ```--task_name='unified'```. The ```--data_dir``` parameter will be ignored in that case. Pick the file for ```--dataset_config``` according to the dataset you want to train for. For MultiWOZ, this would 'data/unified_multiwoz21'.
+
 ## Requirements
 
-- torch (tested: 1.4.0)
-- transformers (tested: 2.9.1)
-- tensorboardX (tested: 2.0)
+- torch (tested: 1.8.0)
+- transformers (tested: 4.18.0)
+- tensorboardX (tested: 2.1)
 
 ## Citation
 
diff --git a/data/MULTIWOZ2.1/README.md b/data/MULTIWOZ2.1_legacy/README.md
similarity index 100%
rename from data/MULTIWOZ2.1/README.md
rename to data/MULTIWOZ2.1_legacy/README.md
diff --git a/data/MULTIWOZ2.1/dialogue_acts.json.gz b/data/MULTIWOZ2.1_legacy/dialogue_acts.json.gz
similarity index 100%
rename from data/MULTIWOZ2.1/dialogue_acts.json.gz
rename to data/MULTIWOZ2.1_legacy/dialogue_acts.json.gz
diff --git a/data/MULTIWOZ2.1/test_dials.json.gz b/data/MULTIWOZ2.1_legacy/test_dials.json.gz
similarity index 100%
rename from data/MULTIWOZ2.1/test_dials.json.gz
rename to data/MULTIWOZ2.1_legacy/test_dials.json.gz
diff --git a/data/MULTIWOZ2.1/train_dials.json.gz b/data/MULTIWOZ2.1_legacy/train_dials.json.gz
similarity index 100%
rename from data/MULTIWOZ2.1/train_dials.json.gz
rename to data/MULTIWOZ2.1_legacy/train_dials.json.gz
diff --git a/data/MULTIWOZ2.1/val_dials.json.gz b/data/MULTIWOZ2.1_legacy/val_dials.json.gz
similarity index 100%
rename from data/MULTIWOZ2.1/val_dials.json.gz
rename to data/MULTIWOZ2.1_legacy/val_dials.json.gz
diff --git a/data/README.md b/data/README.md
new file mode 100644
index 0000000..1bdf345
--- /dev/null
+++ b/data/README.md
@@ -0,0 +1,50 @@
+## Supported datasets
+
+Datasets should go into the ```data/``` folder.
+
+### sim-M & sim-R:
+
+```
+git clone https://github.com/google-research-datasets/simulated-dialogue.git
+```
+
+### WOZ 2.0
+
+The original URL (http://mi.eng.cam.ac.uk/~nm480/woz_2.0.zip) is not active anymore.
+
+We provide the dataset in ```data/woz2```.
+
+### MultiWOZ 2.0, 2.1 & 2.2
+
+```
+git clone https://github.com/budzianowski/multiwoz.git
+unzip multiwoz/data/MultiWOZ_2.0.zip -d multiwoz/data/
+unzip multiwoz/data/MultiWOZ_2.1.zip -d multiwoz/data/
+mv multiwoz/data/MULTIWOZ2\ 2/ multiwoz/data/MultiWOZ_2.0
+python3 multiwoz/data/MultiWOZ_2.2/convert_to_multiwoz_format.py --multiwoz21_data_dir=multiwoz/data/MultiWOZ_2.1 --output_file=multiwoz/data/MultiWOZ_2.2/data.json
+cp multiwoz/data/MultiWOZ_2.1/valListFile.txt multiwoz/data/MultiWOZ_2.2/
+cp multiwoz/data/MultiWOZ_2.1/testListFile.txt multiwoz/data/MultiWOZ_2.2/
+python split_multiwoz_data.py --data_dir multiwoz/data/MultiWOZ_2.0
+python split_multiwoz_data.py --data_dir multiwoz/data/MultiWOZ_2.1
+python split_multiwoz_data.py --data_dir multiwoz/data/MultiWOZ_2.2
+```
+
+### MultiWOZ 2.1 legacy version
+
+With "legacy version" we refer to the mid 2019 version of MultiWOZ 2.1, which can be found at https://doi.org/10.17863/CAM.41572
+
+We used this version when we built TripPy. We provide the exact data that we used in ```data/MULTIWOZ2.1_legacy```.
+
+The dataset has since been updated and the most recent version of MultiWOZ 2.1 differs slightly from the version we used for the experiments that we report in [TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking](https://www.aclweb.org/anthology/2020.sigdial-1.4/). Our code supports both the new version as well as the legacy version of MultiWOZ.
+
+### MultiWOZ 2.3
+
+```
+git clone https://github.com/lexmen318/MultiWOZ-coref.git
+```
+
+### MultiWOZ 2.4
+
+```
+git clone https://github.com/smartyfh/MultiWOZ2.4.git
+```
diff --git a/data/split_multiwoz_data.py b/data/split_multiwoz_data.py
new file mode 100644
index 0000000..770da79
--- /dev/null
+++ b/data/split_multiwoz_data.py
@@ -0,0 +1,65 @@
+# coding=utf-8
+#
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import os
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--data_dir", default=None, type=str, required=True, help="Task database.")
+    args = parser.parse_args()
+
+    with open(os.path.join(args.data_dir, "data.json")) as f:
+        data = json.load(f)
+
+    val_list_file = os.path.join(args.data_dir, "valListFile.json")
+    if not os.path.isfile(val_list_file):
+        val_list_file = os.path.join(args.data_dir, "valListFile.txt")
+    with open(val_list_file) as f:
+        val_set = f.read().splitlines()
+
+    test_list_file = os.path.join(args.data_dir, "testListFile.json")
+    if not os.path.isfile(test_list_file):
+        test_list_file = os.path.join(args.data_dir, "testListFile.txt")
+    with open(test_list_file) as f:
+        test_set = f.read().splitlines()
+
+    val = {}
+    train = {}
+    test = {}
+
+    for k, v in data.items():
+        if k in val_set:
+            val[k] = v
+        elif k in test_set:
+            test[k] = v
+        else:
+            train[k] = v
+
+    print(len(data), len(train), len(val), len(test))
+
+    with open(os.path.join(args.data_dir, "train_dials.json"), "w+") as f:
+        f.write(json.dumps(train, indent = 4))
+
+    with open(os.path.join(args.data_dir, "val_dials.json"), "w+") as f:
+        f.write(json.dumps(val, indent = 4))
+
+    with open(os.path.join(args.data_dir, "test_dials.json"), "w+") as f:
+        f.write(json.dumps(test, indent = 4))
+
+if __name__ == "__main__":
+    main()
diff --git a/data_processors.py b/data_processors.py
index 88d94ae..671bd9a 100644
--- a/data_processors.py
+++ b/data_processors.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -23,24 +23,39 @@ import json
 import dataset_woz2
 import dataset_sim
 import dataset_multiwoz21
+import dataset_multiwoz21_legacy
 import dataset_aux_task
+import dataset_unified
 
 
 class DataProcessor(object):
+    dataset_name = ""
+    class_types = []
+    slot_list = []
+    label_maps = {}
+
     def __init__(self, dataset_config):
+        # Load dataset config file.
         with open(dataset_config, "r", encoding='utf-8') as f:
             raw_config = json.load(f)
-        self.class_types = raw_config['class_types']
-        self.slot_list = raw_config['slots']
-        self.label_maps = raw_config['label_maps']
+        self.class_types = raw_config['class_types'] # Required
+        self.slot_list = raw_config['slots'] if 'slots' in raw_config else []
+        self.label_maps = raw_config['label_maps'] if 'label_maps' in raw_config else {}
+        self.dataset_name = raw_config['dataset_name'] if 'dataset_name' in raw_config else ""
+        # If not slot list is provided, generate from data.
+        if len(self.slot_list) == 0:
+            self.slot_list = self._get_slot_list()
+
+    def _get_slot_list(self):
+        raise NotImplementedError()
 
-    def get_train_examples(self, data_dir, **args):
+    def get_train_examples(self):
         raise NotImplementedError()
 
-    def get_dev_examples(self, data_dir, **args):
+    def get_dev_examples(self):
         raise NotImplementedError()
 
-    def get_test_examples(self, data_dir, **args):
+    def get_test_examples(self):
         raise NotImplementedError()
 
 
@@ -61,16 +76,30 @@ class Woz2Processor(DataProcessor):
 class Multiwoz21Processor(DataProcessor):
     def get_train_examples(self, data_dir, args):
         return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'train_dials.json'),
+                                                  'train', self.class_types, self.slot_list, self.label_maps, **args)
+
+    def get_dev_examples(self, data_dir, args):
+        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'),
+                                                  'dev', self.class_types, self.slot_list, self.label_maps, **args)
+
+    def get_test_examples(self, data_dir, args):
+        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'),
+                                                  'test', self.class_types, self.slot_list, self.label_maps, **args)
+
+
+class Multiwoz21LegacyProcessor(DataProcessor):
+    def get_train_examples(self, data_dir, args):
+        return dataset_multiwoz21_legacy.create_examples(os.path.join(data_dir, 'train_dials.json'),
                                                   os.path.join(data_dir, 'dialogue_acts.json'),
                                                   'train', self.slot_list, self.label_maps, **args)
 
     def get_dev_examples(self, data_dir, args):
-        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'),
+        return dataset_multiwoz21_legacy.create_examples(os.path.join(data_dir, 'val_dials.json'),
                                                   os.path.join(data_dir, 'dialogue_acts.json'),
                                                   'dev', self.slot_list, self.label_maps, **args)
 
     def get_test_examples(self, data_dir, args):
-        return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'),
+        return dataset_multiwoz21_legacy.create_examples(os.path.join(data_dir, 'test_dials.json'),
                                                   os.path.join(data_dir, 'dialogue_acts.json'),
                                                   'test', self.slot_list, self.label_maps, **args)
 
@@ -89,6 +118,23 @@ class SimProcessor(DataProcessor):
                                            'test', self.slot_list, **args)
 
 
+class UnifiedDatasetProcessor(DataProcessor):
+    def _get_slot_list(self):
+        return dataset_unified.get_slot_list(self.dataset_name)
+        
+    def get_train_examples(self, data_dir, args):
+        return dataset_unified.create_examples('train', self.dataset_name, self.class_types,
+                                               self.slot_list, self.label_maps, **args)
+
+    def get_dev_examples(self, data_dir, args):
+        return dataset_unified.create_examples('validation', self.dataset_name, self.class_types,
+                                               self.slot_list, self.label_maps, **args)
+
+    def get_test_examples(self, data_dir, args):
+        return dataset_unified.create_examples('test', self.dataset_name, self.class_types,
+                                               self.slot_list, self.label_maps, **args)
+
+
 class AuxTaskProcessor(object):
     def get_aux_task_examples(self, data_dir, data_name, max_seq_length):
         file_path = os.path.join(data_dir, '{}_train.json'.format(data_name))
@@ -99,4 +145,6 @@ PROCESSORS = {"woz2": Woz2Processor,
               "sim-m": SimProcessor,
               "sim-r": SimProcessor,
               "multiwoz21": Multiwoz21Processor,
+              "multiwoz21_legacy": Multiwoz21LegacyProcessor,
+              "unified": UnifiedDatasetProcessor,
               "aux_task": AuxTaskProcessor}
diff --git a/dataset_aux_task.py b/dataset_aux_task.py
index 40a8667..78ba3a9 100644
--- a/dataset_aux_task.py
+++ b/dataset_aux_task.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/dataset_config/unified_multiwoz21.json b/dataset_config/unified_multiwoz21.json
new file mode 100644
index 0000000..103367a
--- /dev/null
+++ b/dataset_config/unified_multiwoz21.json
@@ -0,0 +1,1282 @@
+{
+  "dataset_name": "multiwoz21",
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "true",
+    "false",
+    "refer",
+    "inform",
+    "request"
+  ],
+  "slots": [],
+  "label_maps": {
+    "guest house": [
+      "guest houses"
+    ],
+    "hotel": [
+      "hotels"
+    ],
+    "centre": [
+      "center",
+      "downtown"
+    ],
+    "north": [
+      "northern",
+      "northside",
+      "northend"
+    ],
+    "east": [
+      "eastern",
+      "eastside",
+      "eastend"
+    ],
+    "west": [
+      "western",
+      "westside",
+      "westend"
+    ],
+    "south": [
+      "southern",
+      "southside",
+      "southend"
+    ],
+    "cheap": [
+      "inexpensive",
+      "lower price",
+      "lower range",
+      "cheaply",
+      "cheaper",
+      "cheapest",
+      "very affordable"
+    ],
+    "moderate": [
+      "moderately",
+      "reasonable",
+      "reasonably",
+      "affordable",
+      "mid range",
+      "mid-range",
+      "priced moderately",
+      "decently priced",
+      "mid price",
+      "mid-price",
+      "mid priced",
+      "mid-priced",
+      "middle price",
+      "medium price",
+      "medium priced",
+      "not too expensive",
+      "not too cheap"
+    ],
+    "expensive": [
+      "high end",
+      "high-end",
+      "high class",
+      "high-class",
+      "high scale",
+      "high-scale",
+      "high price",
+      "high priced",
+      "higher price",
+      "fancy",
+      "upscale",
+      "nice",
+      "expensively",
+      "luxury"
+    ],
+    "0": [
+      "zero"
+    ],
+    "1": [
+      "one",
+      "just me",
+      "for me",
+      "myself",
+      "alone",
+      "me"
+    ],
+    "2": [
+      "two"
+    ],
+    "3": [
+      "three"
+    ],
+    "4": [
+      "four"
+    ],
+    "5": [
+      "five"
+    ],
+    "6": [
+      "six"
+    ],
+    "7": [
+      "seven"
+    ],
+    "8": [
+      "eight"
+    ],
+    "9": [
+      "nine"
+    ],
+    "10": [
+      "ten"
+    ],
+    "11": [
+      "eleven"
+    ],
+    "12": [
+      "twelve"
+    ],
+    "architecture": [
+      "architectures",
+      "architectural",
+      "architecturally",
+      "architect"
+    ],
+    "boat": [
+      "boating",
+      "boats",
+      "camboats"
+    ],
+    "boating": [
+      "boat",
+      "boats",
+      "camboats"
+    ],
+    "camboats": [
+      "boating",
+      "boat",
+      "boats"
+    ],
+    "cinema": [
+      "cinemas",
+      "movie",
+      "films",
+      "film"
+    ],
+    "college": [
+      "colleges"
+    ],
+    "concert": [
+      "concert hall",
+      "concert halls",
+      "concerthall",
+      "concerthalls",
+      "concerts"
+    ],
+    "concerthall": [
+      "concert hall",
+      "concert halls",
+      "concerthalls",
+      "concerts",
+      "concert"
+    ],
+    "entertainment": [
+      "entertaining"
+    ],
+    "gallery": [
+      "museum",
+      "galleries"
+    ],
+    "gastropubs": [
+      "gastropub"
+    ],
+    "multiple sports": [
+      "multiple sport",
+      "multi sport",
+      "multi sports",
+      "sports",
+      "sporting"
+    ],
+    "museum": [
+      "museums",
+      "gallery",
+      "galleries"
+    ],
+    "night club": [
+      "night clubs",
+      "nightclub",
+      "nightclubs",
+      "club",
+      "clubs"
+    ],
+    "nightclub": [
+      "night club",
+      "night clubs",
+      "nightclubs",
+      "club",
+      "clubs"
+    ],
+    "park": [
+      "parks"
+    ],
+    "pool": [
+      "swimming pool",
+      "swimming pools",
+      "swimming",
+      "pools",
+      "swimmingpool",
+      "swimmingpools",
+      "water",
+      "swim"
+    ],
+    "sports": [
+      "multiple sport",
+      "multi sport",
+      "multi sports",
+      "multiple sports",
+      "sporting"
+    ],
+    "swimming pool": [
+      "swimming",
+      "pool",
+      "pools",
+      "swimmingpool",
+      "swimmingpools",
+      "water",
+      "swim"
+    ],
+    "theater": [
+      "theatre",
+      "theatres",
+      "theaters"
+    ],
+    "theatre": [
+      "theater",
+      "theatres",
+      "theaters"
+    ],
+    "abbey pool and astroturf pitch": [
+      "abbey pool and astroturf",
+      "abbey pool"
+    ],
+    "abbey pool and astroturf": [
+      "abbey pool and astroturf pitch",
+      "abbey pool"
+    ],
+    "abbey pool": [
+      "abbey pool and astroturf pitch",
+      "abbey pool and astroturf"
+    ],
+    "adc theatre": [
+      "adc theater",
+      "adc"
+    ],
+    "adc": [
+      "adc theatre",
+      "adc theater"
+    ],
+    "addenbrookes hospital": [
+      "addenbrooke's hospital"
+    ],
+    "cafe jello gallery": [
+      "cafe jello"
+    ],
+    "cambridge and county folk museum": [
+      "cambridge and country folk museum",
+      "county folk museum"
+    ],
+    "cambridge and country folk museum": [
+      "cambridge and county folk museum",
+      "county folk museum"
+    ],
+    "county folk museum": [
+      "cambridge and county folk museum",
+      "cambridge and country folk museum"
+    ],
+    "cambridge arts theatre": [
+      "cambridge arts theater"
+    ],
+    "cambridge book and print gallery": [
+      "book and print gallery"
+    ],
+    "cambridge contemporary art": [
+      "cambridge contemporary art museum",
+      "contemporary art museum"
+    ],
+    "cambridge contemporary art museum": [
+      "cambridge contemporary art",
+      "contemporary art museum"
+    ],
+    "cambridge corn exchange": [
+      "the cambridge corn exchange"
+    ],
+    "the cambridge corn exchange": [
+      "cambridge corn exchange"
+    ],
+    "cambridge museum of technology": [
+      "museum of technology"
+    ],
+    "cambridge punter": [
+      "the cambridge punter",
+      "cambridge punters"
+    ],
+    "cambridge punters": [
+      "the cambridge punter",
+      "cambridge punter"
+    ],
+    "the cambridge punter": [
+      "cambridge punter",
+      "cambridge punters"
+    ],
+    "cambridge university botanic gardens": [
+      "cambridge university botanical gardens",
+      "cambridge university botanical garden",
+      "cambridge university botanic garden",
+      "cambridge botanic gardens",
+      "cambridge botanical gardens",
+      "cambridge botanic garden",
+      "cambridge botanical garden",
+      "botanic gardens",
+      "botanical gardens",
+      "botanic garden",
+      "botanical garden"
+    ],
+    "cambridge botanic gardens": [
+      "cambridge university botanic gardens",
+      "cambridge university botanical gardens",
+      "cambridge university botanical garden",
+      "cambridge university botanic garden",
+      "cambridge botanical gardens",
+      "cambridge botanic garden",
+      "cambridge botanical garden",
+      "botanic gardens",
+      "botanical gardens",
+      "botanic garden",
+      "botanical garden"
+    ],
+    "botanic gardens": [
+      "cambridge university botanic gardens",
+      "cambridge university botanical gardens",
+      "cambridge university botanical garden",
+      "cambridge university botanic garden",
+      "cambridge botanic gardens",
+      "cambridge botanical gardens",
+      "cambridge botanic garden",
+      "cambridge botanical garden",
+      "botanical gardens",
+      "botanic garden",
+      "botanical garden"
+    ],
+    "cherry hinton village centre": [
+      "cherry hinton village center"
+    ],
+    "cherry hinton village center": [
+      "cherry hinton village centre"
+    ],
+    "cherry hinton hall and grounds": [
+      "cherry hinton hall"
+    ],
+    "cherry hinton hall": [
+      "cherry hinton hall and grounds"
+    ],
+    "cherry hinton water play": [
+      "cherry hinton water play park"
+    ],
+    "cherry hinton water play park": [
+      "cherry hinton water play"
+    ],
+    "christ college": [
+      "christ's college",
+      "christs college"
+    ],
+    "christs college": [
+      "christ college",
+      "christ's college"
+    ],
+    "churchills college": [
+      "churchill's college",
+      "churchill college"
+    ],
+    "cineworld cinema": [
+      "cineworld"
+    ],
+    "clair hall": [
+      "clare hall"
+    ],
+    "clare hall": [
+      "clair hall"
+    ],
+    "the fez club": [
+      "fez club"
+    ],
+    "great saint marys church": [
+      "great saint mary's church",
+      "great saint mary's",
+      "great saint marys"
+    ],
+    "jesus green outdoor pool": [
+      "jesus green"
+    ],
+    "jesus green": [
+      "jesus green outdoor pool"
+    ],
+    "kettles yard": [
+      "kettle's yard"
+    ],
+    "kings college": [
+      "king's college"
+    ],
+    "kings hedges learner pool": [
+      "king's hedges learner pool",
+      "king hedges learner pool"
+    ],
+    "king hedges learner pool": [
+      "king's hedges learner pool",
+      "kings hedges learner pool"
+    ],
+    "little saint marys church": [
+      "little saint mary's church",
+      "little saint mary's",
+      "little saint marys"
+    ],
+    "mumford theatre": [
+      "mumford theater"
+    ],
+    "museum of archaelogy": [
+      "museum of archaeology"
+    ],
+    "museum of archaelogy and anthropology": [
+      "museum of archaeology and anthropology"
+    ],
+    "peoples portraits exhibition": [
+      "people's portraits exhibition at girton college",
+      "peoples portraits exhibition at girton college",
+      "people's portraits exhibition"
+    ],
+    "peoples portraits exhibition at girton college": [
+      "people's portraits exhibition at girton college",
+      "people's portraits exhibition",
+      "peoples portraits exhibition"
+    ],
+    "queens college": [
+      "queens' college",
+      "queen's college"
+    ],
+    "riverboat georgina": [
+      "riverboat"
+    ],
+    "saint barnabas": [
+      "saint barbabas press gallery"
+    ],
+    "saint barnabas press gallery": [
+      "saint barbabas"
+    ],
+    "saint catharines college": [
+      "saint catharine's college",
+      "saint catharine's",
+      "saint catherine's college",
+      "saint catherine's"
+    ],
+    "saint johns college": [
+      "saint john's college",
+      "st john's college",
+      "st johns college"
+    ],
+    "scott polar": [
+      "scott polar museum"
+    ],
+    "scott polar museum": [
+      "scott polar"
+    ],
+    "scudamores punting co": [
+      "scudamore's punting co",
+      "scudamores punting",
+      "scudamore's punting",
+      "scudamores",
+      "scudamore's",
+      "scudamore"
+    ],
+    "scudamore": [
+      "scudamore's punting co",
+      "scudamores punting co",
+      "scudamores punting",
+      "scudamore's punting",
+      "scudamores",
+      "scudamore's"
+    ],
+    "sheeps green and lammas land park fen causeway": [
+      "sheep's green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "sheeps green and lammas land park",
+      "lammas land park",
+      "sheep's green",
+      "sheeps green"
+    ],
+    "sheeps green and lammas land park": [
+      "sheep's green and lammas land park fen causeway",
+      "sheeps green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "lammas land park",
+      "sheep's green",
+      "sheeps green"
+    ],
+    "lammas land park": [
+      "sheep's green and lammas land park fen causeway",
+      "sheeps green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "sheeps green and lammas land park",
+      "sheep's green",
+      "sheeps green"
+    ],
+    "sheeps green": [
+      "sheep's green and lammas land park fen causeway",
+      "sheeps green and lammas land park fen causeway",
+      "sheep's green and lammas land park",
+      "sheeps green and lammas land park",
+      "lammas land park",
+      "sheep's green"
+    ],
+    "soul tree nightclub": [
+      "soul tree night club",
+      "soul tree",
+      "soultree"
+    ],
+    "soultree": [
+      "soul tree nightclub",
+      "soul tree night club",
+      "soul tree"
+    ],
+    "the man on the moon": [
+      "man on the moon"
+    ],
+    "man on the moon": [
+      "the man on the moon"
+    ],
+    "the junction": [
+      "junction theatre",
+      "junction theater"
+    ],
+    "junction theatre": [
+      "the junction",
+      "junction theater"
+    ],
+    "old schools": [
+      "old school"
+    ],
+    "vue cinema": [
+      "vue"
+    ],
+    "wandlebury country park": [
+      "the wandlebury"
+    ],
+    "the wandlebury": [
+      "wandlebury country park"
+    ],
+    "whipple museum of the history of science": [
+      "whipple museum",
+      "history of science museum"
+    ],
+    "history of science museum": [
+      "whipple museum of the history of science",
+      "whipple museum"
+    ],
+    "williams art and antique": [
+      "william's art and antique"
+    ],
+    "alimentum": [
+      "restaurant alimentum"
+    ],
+    "restaurant alimentum": [
+      "alimentum"
+    ],
+    "bedouin": [
+      "the bedouin"
+    ],
+    "the bedouin": [
+      "bedouin"
+    ],
+    "bloomsbury restaurant": [
+      "bloomsbury"
+    ],
+    "cafe uno": [
+      "caffe uno",
+      "caffee uno"
+    ],
+    "caffe uno": [
+      "cafe uno",
+      "caffee uno"
+    ],
+    "caffee uno": [
+      "cafe uno",
+      "caffe uno"
+    ],
+    "cambridge lodge restaurant": [
+      "cambridge lodge"
+    ],
+    "chiquito": [
+      "chiquito restaurant bar",
+      "chiquito restaurant"
+    ],
+    "chiquito restaurant bar": [
+      "chiquito restaurant",
+      "chiquito"
+    ],
+    "city stop restaurant": [
+      "city stop"
+    ],
+    "cityr": [
+      "cityroomz"
+    ],
+    "citiroomz": [
+      "cityroomz"
+    ],
+    "clowns cafe": [
+      "clown's cafe"
+    ],
+    "cow pizza kitchen and bar": [
+      "the cow pizza kitchen and bar",
+      "cow pizza"
+    ],
+    "the cow pizza kitchen and bar": [
+      "cow pizza kitchen and bar",
+      "cow pizza"
+    ],
+    "darrys cookhouse and wine shop": [
+      "darry's cookhouse and wine shop",
+      "darry's cookhouse",
+      "darrys cookhouse"
+    ],
+    "de luca cucina and bar": [
+      "de luca cucina and bar riverside brasserie",
+      "luca cucina and bar",
+      "de luca cucina",
+      "luca cucina"
+    ],
+    "de luca cucina and bar riverside brasserie": [
+      "de luca cucina and bar",
+      "luca cucina and bar",
+      "de luca cucina",
+      "luca cucina"
+    ],
+    "da vinci pizzeria": [
+      "da vinci pizza",
+      "da vinci"
+    ],
+    "don pasquale pizzeria": [
+      "don pasquale pizza",
+      "don pasquale",
+      "pasquale pizzeria",
+      "pasquale pizza"
+    ],
+    "efes": [
+      "efes restaurant"
+    ],
+    "efes restaurant": [
+      "efes"
+    ],
+    "fitzbillies restaurant": [
+      "fitzbillies"
+    ],
+    "frankie and bennys": [
+      "frankie and benny's"
+    ],
+    "funky": [
+      "funky fun house"
+    ],
+    "funky fun house": [
+      "funky"
+    ],
+    "gardenia": [
+      "the gardenia"
+    ],
+    "the gardenia": [
+      "gardenia"
+    ],
+    "grafton hotel restaurant": [
+      "the grafton hotel",
+      "grafton hotel"
+    ],
+    "the grafton hotel": [
+      "grafton hotel restaurant",
+      "grafton hotel"
+    ],
+    "grafton hotel": [
+      "grafton hotel restaurant",
+      "the grafton hotel"
+    ],
+    "hotel du vin and bistro": [
+      "hotel du vin",
+      "du vin"
+    ],
+    "Kohinoor": [
+      "kohinoor",
+      "the kohinoor"
+    ],
+    "kohinoor": [
+      "the kohinoor"
+    ],
+    "the kohinoor": [
+      "kohinoor"
+    ],
+    "lan hong house": [
+      "lan hong",
+      "ian hong house",
+      "ian hong"
+    ],
+    "ian hong": [
+      "lan hong house",
+      "lan hong",
+      "ian hong house"
+    ],
+    "lovel": [
+      "the lovell lodge",
+      "lovell lodge"
+    ],
+    "lovell lodge": [
+      "lovell"
+    ],
+    "the lovell lodge": [
+      "lovell lodge",
+      "lovell"
+    ],
+    "mahal of cambridge": [
+      "mahal"
+    ],
+    "mahal": [
+      "mahal of cambridge"
+    ],
+    "maharajah tandoori restaurant": [
+      "maharajah tandoori"
+    ],
+    "the maharajah tandoor": [
+      "maharajah tandoori restaurant",
+      "maharajah tandoori"
+    ],
+    "meze bar": [
+      "meze bar restaurant",
+      "the meze bar"
+    ],
+    "meze bar restaurant": [
+      "the meze bar",
+      "meze bar"
+    ],
+    "the meze bar": [
+      "meze bar restaurant",
+      "meze bar"
+    ],
+    "michaelhouse cafe": [
+      "michael house cafe"
+    ],
+    "midsummer house restaurant": [
+      "midsummer house"
+    ],
+    "missing sock": [
+      "the missing sock"
+    ],
+    "the missing sock": [
+      "missing sock"
+    ],
+    "nandos": [
+      "nando's city centre",
+      "nando's city center",
+      "nandos city centre",
+      "nandos city center",
+      "nando's"
+    ],
+    "nandos city centre": [
+      "nando's city centre",
+      "nando's city center",
+      "nandos city center",
+      "nando's",
+      "nandos"
+    ],
+    "oak bistro": [
+      "the oak bistro"
+    ],
+    "the oak bistro": [
+      "oak bistro"
+    ],
+    "one seven": [
+      "restaurant one seven"
+    ],
+    "restaurant one seven": [
+      "one seven"
+    ],
+    "river bar steakhouse and grill": [
+      "the river bar steakhouse and grill",
+      "the river bar steakhouse",
+      "river bar steakhouse"
+    ],
+    "the river bar steakhouse and grill": [
+      "river bar steakhouse and grill",
+      "the river bar steakhouse",
+      "river bar steakhouse"
+    ],
+    "pipasha restaurant": [
+      "pipasha"
+    ],
+    "pizza hut city centre": [
+      "pizza hut city center"
+    ],
+    "pizza hut fenditton": [
+      "pizza hut fen ditton",
+      "pizza express fen ditton"
+    ],
+    "restaurant two two": [
+      "two two",
+      "restaurant 22"
+    ],
+    "saffron brasserie": [
+      "saffron"
+    ],
+    "saint johns chop house": [
+      "saint john's chop house",
+      "st john's chop house",
+      "st johns chop house"
+    ],
+    "sesame restaurant and bar": [
+      "sesame restaurant",
+      "sesame"
+    ],
+    "shanghai": [
+      "shanghai family restaurant"
+    ],
+    "shanghai family restaurant": [
+      "shanghai"
+    ],
+    "sitar": [
+      "sitar tandoori"
+    ],
+    "sitar tandoori": [
+      "sitar"
+    ],
+    "slug and lettuce": [
+      "the slug and lettuce"
+    ],
+    "the slug and lettuce": [
+      "slug and lettuce"
+    ],
+    "st johns chop house": [
+      "saint john's chop house",
+      "st john's chop house",
+      "saint johns chop house"
+    ],
+    "stazione restaurant and coffee bar": [
+      "stazione restaurant",
+      "stazione"
+    ],
+    "thanh binh": [
+      "thanh",
+      "binh"
+    ],
+    "thanh": [
+      "thanh binh",
+      "binh"
+    ],
+    "binh": [
+      "thanh binh",
+      "thanh"
+    ],
+    "the hotpot": [
+      "the hotspot",
+      "hotpot",
+      "hotspot"
+    ],
+    "hotpot": [
+      "the hotpot",
+      "the hotpot",
+      "hotspot"
+    ],
+    "the lucky star": [
+      "lucky star"
+    ],
+    "lucky star": [
+      "the lucky star"
+    ],
+    "the peking restaurant: ": [
+      "peking restaurant"
+    ],
+    "the varsity restaurant": [
+      "varsity restaurant",
+      "the varsity",
+      "varsity"
+    ],
+    "two two": [
+      "restaurant two two",
+      "restaurant 22"
+    ],
+    "restaurant 22": [
+      "restaurant two two",
+      "two two"
+    ],
+    "zizzi cambridge": [
+      "zizzi"
+    ],
+    "american": [
+      "americas"
+    ],
+    "asian oriental": [
+      "asian",
+      "oriental"
+    ],
+    "australian": [
+      "australasian"
+    ],
+    "barbeque": [
+      "barbecue",
+      "bbq"
+    ],
+    "corsica": [
+      "corsican"
+    ],
+    "indian": [
+      "tandoori"
+    ],
+    "italian": [
+      "pizza",
+      "pizzeria"
+    ],
+    "japanese": [
+      "sushi"
+    ],
+    "latin american": [
+      "latin-american",
+      "latin"
+    ],
+    "malaysian": [
+      "malay"
+    ],
+    "middle eastern": [
+      "middle-eastern"
+    ],
+    "traditional american": [
+      "american"
+    ],
+    "modern american": [
+      "american modern",
+      "american"
+    ],
+    "modern european": [
+      "european modern",
+      "european"
+    ],
+    "north american": [
+      "north-american",
+      "american"
+    ],
+    "portuguese": [
+      "portugese"
+    ],
+    "portugese": [
+      "portuguese"
+    ],
+    "seafood": [
+      "sea food"
+    ],
+    "singaporean": [
+      "singapore"
+    ],
+    "steakhouse": [
+      "steak house",
+      "steak"
+    ],
+    "the americas": [
+      "american",
+      "americas"
+    ],
+    "a and b guest house": [
+      "a & b guest house",
+      "a and b",
+      "a & b"
+    ],
+    "the acorn guest house": [
+      "acorn guest house",
+      "acorn"
+    ],
+    "acorn guest house": [
+      "the acorn guest house",
+      "acorn"
+    ],
+    "alexander bed and breakfast": [
+      "alexander"
+    ],
+    "allenbell": [
+      "the allenbell"
+    ],
+    "the allenbell": [
+      "allenbell"
+    ],
+    "alpha-milton guest house": [
+      "the alpha-milton",
+      "alpha-milton"
+    ],
+    "the alpha-milton": [
+      "alpha-milton guest house",
+      "alpha-milton"
+    ],
+    "arbury lodge guest house": [
+      "arbury lodge",
+      "arbury"
+    ],
+    "archway house": [
+      "archway"
+    ],
+    "ashley hotel": [
+      "the ashley hotel",
+      "ashley"
+    ],
+    "the ashley hotel": [
+      "ashley hotel",
+      "ashley"
+    ],
+    "aylesbray lodge guest house": [
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "aylesbray lodge guest": [
+      "aylesbray lodge guest house",
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "alesbray lodge guest house": [
+      "aylesbray lodge guest house",
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "alyesbray lodge hotel": [
+      "aylesbray lodge guest house",
+      "aylesbray lodge",
+      "aylesbray"
+    ],
+    "bridge guest house": [
+      "bridge house"
+    ],
+    "cambridge belfry": [
+      "the cambridge belfry",
+      "belfry hotel",
+      "belfry"
+    ],
+    "the cambridge belfry": [
+      "cambridge belfry",
+      "belfry hotel",
+      "belfry"
+    ],
+    "belfry hotel": [
+      "the cambridge belfry",
+      "cambridge belfry",
+      "belfry"
+    ],
+    "carolina bed and breakfast": [
+      "carolina"
+    ],
+    "city centre north": [
+      "city centre north bed and breakfast"
+    ],
+    "north b and b": [
+      "city centre north bed and breakfast"
+    ],
+    "city centre north b and b": [
+      "city centre north bed and breakfast"
+    ],
+    "el shaddia guest house": [
+      "el shaddai guest house",
+      "el shaddai",
+      "el shaddia"
+    ],
+    "el shaddai guest house": [
+      "el shaddia guest house",
+      "el shaddai",
+      "el shaddia"
+    ],
+    "express by holiday inn cambridge": [
+      "express by holiday inn",
+      "holiday inn cambridge",
+      "holiday inn"
+    ],
+    "holiday inn": [
+      "express by holiday inn cambridge",
+      "express by holiday inn",
+      "holiday inn cambridge"
+    ],
+    "finches bed and breakfast": [
+      "finches"
+    ],
+    "gonville hotel": [
+      "gonville"
+    ],
+    "hamilton lodge": [
+      "the hamilton lodge",
+      "hamilton"
+    ],
+    "the hamilton lodge": [
+      "hamilton lodge",
+      "hamilton"
+    ],
+    "hobsons house": [
+      "hobson's house",
+      "hobson's"
+    ],
+    "huntingdon marriott hotel": [
+      "huntington marriott hotel",
+      "huntington marriot hotel",
+      "huntingdon marriot hotel",
+      "huntington marriott",
+      "huntingdon marriott",
+      "huntington marriot",
+      "huntingdon marriot",
+      "huntington",
+      "huntingdon"
+    ],
+    "kirkwood": [
+      "kirkwood house"
+    ],
+    "kirkwood house": [
+      "kirkwood"
+    ],
+    "lensfield hotel": [
+      "the lensfield hotel",
+      "lensfield"
+    ],
+    "the lensfield hotel": [
+      "lensfield hotel",
+      "lensfield"
+    ],
+    "leverton house": [
+      "leverton"
+    ],
+    "marriot hotel": [
+      "marriott hotel",
+      "marriott"
+    ],
+    "rosas bed and breakfast": [
+      "rosa's bed and breakfast",
+      "rosa's",
+      "rosas"
+    ],
+    "university arms hotel": [
+      "university arms"
+    ],
+    "warkworth house": [
+      "warkworth hotel",
+      "warkworth"
+    ],
+    "warkworth hotel": [
+      "warkworth house",
+      "warkworth"
+    ],
+    "wartworth": [
+      "warkworth house",
+      "warkworth hotel",
+      "warkworth"
+    ],
+    "worth house": [
+      "the worth house"
+    ],
+    "the worth house": [
+      "worth house"
+    ],
+    "birmingham new street": [
+      "birmingham new street train station"
+    ],
+    "birmingham new street train station": [
+      "birmingham new street"
+    ],
+    "bishops stortford": [
+      "bishops stortford train station"
+    ],
+    "bishops stortford train station": [
+      "bishops stortford"
+    ],
+    "broxbourne": [
+      "broxbourne train station"
+    ],
+    "broxbourne train station": [
+      "broxbourne"
+    ],
+    "cambridge": [
+      "cambridge train station"
+    ],
+    "cambridge train station": [
+      "cambridge"
+    ],
+    "ely": [
+      "ely train station"
+    ],
+    "ely train station": [
+      "ely"
+    ],
+    "kings lynn": [
+      "king's lynn",
+      "king's lynn train station",
+      "kings lynn train station"
+    ],
+    "kings lynn train station": [
+      "kings lynn",
+      "king's lynn",
+      "king's lynn train station"
+    ],
+    "leicester": [
+      "leicester train station"
+    ],
+    "leicester train station": [
+      "leicester"
+    ],
+    "london kings cross": [
+      "kings cross",
+      "king's cross",
+      "london king's cross",
+      "kings cross train station",
+      "king's cross train station",
+      "london king's cross train station",
+      "london kings cross train station"
+    ],
+    "london kings cross train station": [
+      "kings cross",
+      "king's cross",
+      "london king's cross",
+      "london kings cross",
+      "kings cross train station",
+      "king's cross train station",
+      "london king's cross train station"
+    ],
+    "london liverpool": [
+      "liverpool street",
+      "london liverpool street",
+      "london liverpool train station",
+      "liverpool street train station",
+      "london liverpool street train station"
+    ],
+    "london liverpool street": [
+      "london liverpool",
+      "liverpool street",
+      "london liverpool train station",
+      "liverpool street train station",
+      "london liverpool street train station"
+    ],
+    "london liverpool street train station": [
+      "london liverpool",
+      "liverpool street",
+      "london liverpool street",
+      "london liverpool train station",
+      "liverpool street train station"
+    ],
+    "norwich": [
+      "norwich train station"
+    ],
+    "norwich train station": [
+      "norwich"
+    ],
+    "peterborough": [
+      "peterborough train station"
+    ],
+    "peterborough train station": [
+      "peterborough"
+    ],
+    "stansted airport": [
+      "stansted airport train station"
+    ],
+    "stansted airport train station": [
+      "stansted airport"
+    ],
+    "stevenage": [
+      "stevenage train station"
+    ],
+    "stevenage train station": [
+      "stevenage"
+    ]
+  }
+}
diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py
index 0f04afd..23207e0 100644
--- a/dataset_multiwoz21.py
+++ b/dataset_multiwoz21.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -19,6 +19,7 @@
 
 import json
 import re
+from tqdm import tqdm
 
 from utils_dst import (DSTExample, convert_to_unicode)
 
@@ -64,54 +65,6 @@ ACTS_DICT = {'taxi-depart': 'taxi-departure',
 }
 
 
-LABEL_MAPS = {} # Loaded from file
-
-
-# Loads the dialogue_acts.json and returns a list
-# of slot-value pairs.
-def load_acts(input_file):
-    with open(input_file) as f:
-        acts = json.load(f)
-    s_dict = {}
-    for d in acts:
-        for t in acts[d]:
-            # Only process, if turn has annotation
-            if isinstance(acts[d][t], dict):
-                is_22_format = False
-                if 'dialog_act' in acts[d][t]:
-                    is_22_format = True
-                    acts_list = acts[d][t]['dialog_act']
-                    if int(t) % 2 == 0:
-                        continue
-                else:
-                    acts_list = acts[d][t]
-                for a in acts_list:
-                    aa = a.lower().split('-')
-                    if aa[1] in ['inform', 'recommend', 'select', 'book']:
-                        for i in acts_list[a]:
-                            s = i[0].lower()
-                            v = i[1].lower().strip()
-                            if s == 'none' or v == '?' or v == 'none':
-                                continue
-                            slot = aa[0] + '-' + s
-                            if slot in ACTS_DICT:
-                                slot = ACTS_DICT[slot]
-                            if is_22_format:
-                                t_key = str(int(int(t) / 2 + 1))
-                                d_key = d
-                            else:
-                                t_key = t
-                                d_key = d + '.json'
-                            key = d_key, t_key, slot
-                            # In case of multiple mentioned values...
-                            # ... Option 1: Keep first informed value
-                            if key not in s_dict:
-                                s_dict[key] = list([v])
-                            # ... Option 2: Keep last informed value
-                            #s_dict[key] = list([v])
-    return s_dict
-
-
 def normalize_time(text):
     text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space
     text = re.sub("(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text) # am/pm short to long form
@@ -143,8 +96,7 @@ def normalize_text(text):
     return text
 
 
-# This should only contain label normalizations. All other mappings should
-# be defined in LABEL_MAPS.
+# This should only contain label normalizations, no label mappings.
 def normalize_label(slot, value_label):
     # Normalization of capitalization
     if isinstance(value_label, str):
@@ -166,7 +118,7 @@ def normalize_label(slot, value_label):
         return "dontcare"
 
     # Normalization of time slots
-    if "leaveAt" in slot or "arriveBy" in slot or slot == 'restaurant-book_time':
+    if "leave" in slot or "arrive" in slot or "time" in slot:
         return normalize_time(value_label)
 
     # Normalization
@@ -203,18 +155,18 @@ def get_token_pos(tok_list, value_label):
     return found, find_pos
 
 
-def check_label_existence(value_label, usr_utt_tok):
+def check_label_existence(value_label, usr_utt_tok, label_maps={}):
     in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label)
     # If no hit even though there should be one, check for value label variants
-    if not in_usr and value_label in LABEL_MAPS:
-        for value_label_variant in LABEL_MAPS[value_label]:
+    if not in_usr and value_label in label_maps:
+        for value_label_variant in label_maps[value_label]:
             in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant)
             if in_usr:
                 break
     return in_usr, usr_pos
 
 
-def check_slot_referral(value_label, slot, seen_slots):
+def check_slot_referral(value_label, slot, seen_slots, label_maps={}):
     referred_slot = 'none'
     if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking':
         return referred_slot
@@ -231,8 +183,8 @@ def check_slot_referral(value_label, slot, seen_slots):
             if seen_slots[s] == value_label:
                 referred_slot = s
                 break
-            elif value_label in LABEL_MAPS:
-                for value_label_variant in LABEL_MAPS[value_label]:
+            elif value_label in label_maps:
+                for value_label_variant in label_maps[value_label]:
                     if seen_slots[s] == value_label_variant:
                         referred_slot = s
                         break
@@ -266,7 +218,7 @@ def delex_utt(utt, values, unk_token="[UNK]"):
 
 
 # Fuzzy matching to label informed slot values
-def check_slot_inform(value_label, inform_label):
+def check_slot_inform(value_label, inform_label, label_maps={}):
     result = False
     informed_value = 'none'
     vl = ' '.join(tokenize(value_label))
@@ -277,8 +229,8 @@ def check_slot_inform(value_label, inform_label):
             result = True
         elif is_in_list(vl, il):
             result = True
-        elif il in LABEL_MAPS:
-            for il_variant in LABEL_MAPS[il]:
+        elif il in label_maps:
+            for il_variant in label_maps[il]:
                 if vl == il_variant:
                     result = True
                     break
@@ -288,8 +240,8 @@ def check_slot_inform(value_label, inform_label):
                 elif is_in_list(vl, il_variant):
                     result = True
                     break
-        elif vl in LABEL_MAPS:
-            for value_label_variant in LABEL_MAPS[vl]:
+        elif vl in label_maps:
+            for value_label_variant in label_maps[vl]:
                 if value_label_variant == il:
                     result = True
                     break
@@ -305,15 +257,15 @@ def check_slot_inform(value_label, inform_label):
     return result, informed_value
 
 
-def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence):
+def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence, label_maps={}):
     usr_utt_tok_label = [0 for _ in usr_utt_tok]
     informed_value = 'none'
     referred_slot = 'none'
     if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false':
         class_type = value_label
     else:
-        in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok)
-        is_informed, informed_value = check_slot_inform(value_label, inform_label)
+        in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok, label_maps)
+        is_informed, informed_value = check_slot_inform(value_label, inform_label, label_maps)
         if in_usr:
             class_type = 'copy_value'
             if slot_last_occurrence:
@@ -327,7 +279,7 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se
         elif is_informed:
             class_type = 'inform'
         else:
-            referred_slot = check_slot_referral(value_label, slot, seen_slots)
+            referred_slot = check_slot_referral(value_label, slot, seen_slots, label_maps)
             if referred_slot != 'none':
                 class_type = 'refer'
             else:
@@ -335,6 +287,21 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se
     return informed_value, referred_slot, usr_utt_tok_label, class_type
 
 
+# Requestable slots, general acts and domain indicator slots
+def is_request(slot, user_act, turn_domains):
+    if slot in user_act:
+        if isinstance(user_act[slot], list):
+            for act in user_act[slot]:
+                if act['intent'] in ['request', 'bye', 'thank', 'greet']:
+                    return True
+        elif user_act[slot]['intent'] in ['request', 'bye', 'thank', 'greet']:
+            return True
+    do, sl = slot.split('-')
+    if sl == 'none' and do in turn_domains:
+        return True
+    return False
+
+
 def tokenize(utt):
     utt_lower = convert_to_unicode(utt).lower()
     utt_lower = normalize_text(utt_lower)
@@ -346,27 +313,22 @@ def utt_to_token(utt):
     return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0]
 
 
-def create_examples(input_file, acts_file, set_type, slot_list,
+def create_examples(input_file, set_type, class_types, slot_list,
                     label_maps={},
-                    append_history=False,
-                    use_history_labels=False,
+                    no_append_history=False,
+                    no_use_history_labels=False,
+                    no_label_value_repetitions=False,
                     swap_utterances=False,
-                    label_value_repetitions=False,
                     delexicalize_sys_utts=False,
                     unk_token="[UNK]",
                     analyze=False):
     """Read a DST json file into a list of DSTExample."""
 
-    sys_inform_dict = load_acts(acts_file)
-
     with open(input_file, "r", encoding='utf-8') as reader:
         input_data = json.load(reader)
 
-    global LABEL_MAPS
-    LABEL_MAPS = label_maps
-
     examples = []
-    for dialog_id in input_data:
+    for d_itr, dialog_id in enumerate(tqdm(input_data)):
         entry = input_data[dialog_id]
         utterances = entry['log']
 
@@ -376,6 +338,9 @@ def create_examples(input_file, acts_file, set_type, slot_list,
         # First system utterance is empty, since multiwoz starts with user input
         utt_tok_list = [[]]
         mod_slots_list = [{}]
+        inform_dict_list = [{}]
+        user_act_dict_list = [{}]
+        mod_domains_list = [{}]
 
         # Collect all utterances and their metadata
         usr_sys_switch = True
@@ -391,17 +356,46 @@ def create_examples(input_file, acts_file, set_type, slot_list,
             if is_sys_utt:
                 turn_itr += 1
 
-            # Delexicalize sys utterance
-            if delexicalize_sys_utts and is_sys_utt:
-                inform_dict = {slot: 'none' for slot in slot_list}
-                for slot in slot_list:
-                    if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
-                        inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]
-                utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalize utterances
-            else:
-                utt_tok_list.append(tokenize(utt['text'])) # normalize utterances
-
+            # Extract dialog_act information for sys and usr utts.
+            inform_dict = {}
+            user_act_dict = {}
             modified_slots = {}
+            modified_domains = set()
+            if 'dialog_act' in utt:
+                for a in utt['dialog_act']:
+                    aa = a.lower().split('-')
+                    for i in utt['dialog_act'][a]:
+                        s = i[0].lower()
+                        # Some special intents are modeled as slots in TripPy
+                        if aa[0] == 'general':
+                            cs = "%s-%s" % (aa[0], aa[1])
+                        else:
+                            cs = "%s-%s" % (aa[0], s)
+                        if cs in ACTS_DICT:
+                            cs = ACTS_DICT[cs]
+                        v = normalize_label(cs, i[1].lower().strip())
+                        if cs in ['hotel-internet', 'hotel-parking']:
+                            v = 'true'
+                        modified_domains.add(aa[0]) # Remember domains
+                        if is_sys_utt and aa[1] in ['inform', 'recommend', 'select', 'book'] and v != 'none':
+                            if cs not in inform_dict:
+                                inform_dict[cs] = []
+                            inform_dict[cs].append(v)
+                        elif not is_sys_utt:
+                            if cs not in user_act_dict:
+                                user_act_dict[cs] = []
+                            user_act_dict[cs] = {'domain': aa[0], 'intent': aa[1], 'slot': s, 'value': v}
+                # INFO: Since the model has no mechanism to predict
+                # one among several informed value candidates, we
+                # keep only one informed value. For fairness, we
+                # apply a global rule:
+                for e in inform_dict:
+                    # ... Option 1: Always keep first informed value
+                    inform_dict[e] = list([inform_dict[e][0]])
+                    # ... Option 2: Always keep last informed value
+                    #inform_dict[e] = list([inform_dict[e][-1]])
+            else:
+                print("WARN: dialogue %s is missing dialog_act information." % dialog_id)
 
             # If sys utt, extract metadata (identify and collect modified slots)
             if is_sys_utt:
@@ -424,8 +418,20 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                             if cs in slot_list and cumulative_labels[cs] != value_label:
                                 modified_slots[cs] = value_label
                                 cumulative_labels[cs] = value_label
+                                modified_domains.add(cs.split("-")[0]) # Remember domains
 
+            # Delexicalize sys utterance
+            if delexicalize_sys_utts and is_sys_utt:
+                utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalizes utterances
+            else:
+                utt_tok_list.append(tokenize(utt['text'])) # normalizes utterances
+
+            inform_dict_list.append(inform_dict.copy())
+            user_act_dict_list.append(user_act_dict.copy())
             mod_slots_list.append(modified_slots.copy())
+            modified_domains = list(modified_domains)
+            modified_domains.sort()
+            mod_domains_list.append(modified_domains)
 
         # Form proper (usr, sys) turns
         turn_itr = 0
@@ -446,14 +452,17 @@ def create_examples(input_file, acts_file, set_type, slot_list,
             class_type_dict = {}
 
             # Collect turn data
-            if append_history:
-                if swap_utterances:
+            if not no_append_history:
+                if not swap_utterances:
                     hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
                 else:
                     hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
             sys_utt_tok = utt_tok_list[i - 1]
             usr_utt_tok = utt_tok_list[i]
             turn_slots = mod_slots_list[i + 1]
+            inform_mem = inform_dict_list[i - 1]
+            user_act = user_act_dict_list[i] 
+            turn_domains = mod_domains_list[i + 1]
 
             guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
 
@@ -472,17 +481,18 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                     # modify any of the original labels for test sets,
                     # since this would make comparison difficult.
                     value_dict[slot] = value_label
-                elif label_value_repetitions and slot in diag_seen_slots_dict:
+                elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
                     value_label = diag_seen_slots_value_dict[slot]
 
                 # Get dialog act annotations
                 inform_label = list(['none'])
                 inform_slot_dict[slot] = 0
-                if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
-                    inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
+                booking_slot = 'booking-' + slot.split('-')[1]
+                if slot in inform_mem:
+                    inform_label = inform_mem[slot]
                     inform_slot_dict[slot] = 1
-                elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
-                    inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
+                elif booking_slot in inform_mem:
+                    inform_label = inform_mem[booking_slot]
                     inform_slot_dict[slot] = 1
 
                 (informed_value,
@@ -494,17 +504,25 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                                               usr_utt_tok,
                                               slot,
                                               diag_seen_slots_value_dict,
-                                              slot_last_occurrence=True)
+                                              slot_last_occurrence=True,
+                                              label_maps=label_maps)
 
                 inform_dict[slot] = informed_value
 
+                # Requestable slots, domain indicator slots and general slots
+                # should have class_type 'request', if they ought to be predicted.
+                # Give other class_types preference.
+                if 'request' in class_types:
+                    if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains):
+                        class_type = 'request'
+
                 # Generally don't use span prediction on sys utterance (but inform prediction instead).
                 sys_utt_tok_label = [0 for _ in sys_utt_tok]
 
                 # Determine what to do with value repetitions.
                 # If value is unique in seen slots, then tag it, otherwise not,
                 # since correct slot assignment can not be guaranteed anymore.
-                if label_value_repetitions and slot in diag_seen_slots_dict:
+                if not no_label_value_repetitions and slot in diag_seen_slots_dict:
                     if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
                         class_type = 'none'
                         usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
@@ -512,9 +530,9 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                 sys_utt_tok_label_dict[slot] = sys_utt_tok_label
                 usr_utt_tok_label_dict[slot] = usr_utt_tok_label
 
-                if append_history:
-                    if use_history_labels:
-                        if swap_utterances:
+                if not no_append_history:
+                    if not no_use_history_labels:
+                        if not swap_utterances:
                             new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
                         else:
                             new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
@@ -556,7 +574,7 @@ def create_examples(input_file, acts_file, set_type, slot_list,
             if analyze:
                 print("]")
 
-            if swap_utterances:
+            if not swap_utterances:
                 txt_a = usr_utt_tok
                 txt_b = sys_utt_tok
                 txt_a_lbl = usr_utt_tok_label_dict
diff --git a/dataset_multiwoz21_legacy.py b/dataset_multiwoz21_legacy.py
new file mode 100644
index 0000000..b63e680
--- /dev/null
+++ b/dataset_multiwoz21_legacy.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+#
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from tqdm import tqdm
+
+from utils_dst import (DSTExample)
+
+from dataset_multiwoz21 import (ACTS_DICT, is_request,
+                                tokenize, normalize_label,
+                                get_turn_label, delex_utt)
+
+
+# Loads the dialogue_acts.json and returns a list
+# of slot-value pairs.
+def load_acts(input_file):
+    with open(input_file) as f:
+        acts = json.load(f)
+    s_dict = {}
+    for d in acts:
+        for t in acts[d]:
+            # Only process, if turn has annotation
+            if isinstance(acts[d][t], dict):
+                is_22_format = False
+                if 'dialog_act' in acts[d][t]:
+                    is_22_format = True
+                    acts_list = acts[d][t]['dialog_act']
+                    if int(t) % 2 == 0:
+                        continue
+                else:
+                    acts_list = acts[d][t]
+                for a in acts_list:
+                    aa = a.lower().split('-')
+                    if aa[1] in ['inform', 'recommend', 'select', 'book']:
+                        for i in acts_list[a]:
+                            s = i[0].lower()
+                            v = i[1].lower().strip()
+                            if s == 'none' or v == '?' or v == 'none':
+                                continue
+                            slot = aa[0] + '-' + s
+                            if slot in ACTS_DICT:
+                                slot = ACTS_DICT[slot]
+                            if is_22_format:
+                                t_key = str(int(int(t) / 2 + 1))
+                                d_key = d
+                            else:
+                                t_key = t
+                                d_key = d + '.json'
+                            key = d_key, t_key, slot
+                            # INFO: Since the model has no mechanism to predict
+                            # one among several informed value candidates, we
+                            # keep only one informed value. For fairness, we
+                            # apply a global rule:
+                            # ... Option 1: Keep first informed value
+                            if key not in s_dict:
+                                s_dict[key] = list([v])
+                            # ... Option 2: Keep last informed value
+                            #s_dict[key] = list([v])
+    return s_dict
+
+
+def create_examples(input_file, acts_file, set_type, slot_list,
+                    label_maps={},
+                    no_append_history=False,
+                    no_use_history_labels=False,
+                    no_label_value_repetitions=False,
+                    swap_utterances=False,
+                    delexicalize_sys_utts=False,
+                    unk_token="[UNK]",
+                    analyze=False):
+    """Read a DST json file into a list of DSTExample."""
+
+    sys_inform_dict = load_acts(acts_file)
+
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+
+    examples = []
+    for d_itr, dialog_id in enumerate(tqdm(input_data)):
+        entry = input_data[dialog_id]
+        utterances = entry['log']
+
+        # Collects all slot changes throughout the dialog
+        cumulative_labels = {slot: 'none' for slot in slot_list}
+
+        # First system utterance is empty, since multiwoz starts with user input
+        utt_tok_list = [[]]
+        mod_slots_list = [{}]
+
+        # Collect all utterances and their metadata
+        usr_sys_switch = True
+        turn_itr = 0
+        for utt in utterances:
+            # Assert that system and user utterances alternate
+            is_sys_utt = utt['metadata'] != {}
+            if usr_sys_switch == is_sys_utt:
+                print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
+                break
+            usr_sys_switch = is_sys_utt
+
+            if is_sys_utt:
+                turn_itr += 1
+
+            # Delexicalize sys utterance
+            if delexicalize_sys_utts and is_sys_utt:
+                inform_dict = {slot: 'none' for slot in slot_list}
+                for slot in slot_list:
+                    if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
+                        inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]
+                utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalize utterances
+            else:
+                utt_tok_list.append(tokenize(utt['text'])) # normalize utterances
+
+            modified_slots = {}
+
+            # If sys utt, extract metadata (identify and collect modified slots)
+            if is_sys_utt:
+                for d in utt['metadata']:
+                    booked = utt['metadata'][d]['book']['booked']
+                    booked_slots = {}
+                    # Check the booked section
+                    if booked != []:
+                        for s in booked[0]:
+                            booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s]) # normalize labels
+                    # Check the semi and the inform slots
+                    for category in ['book', 'semi']:
+                        for s in utt['metadata'][d][category]:
+                            cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s)
+                            value_label = normalize_label(cs, utt['metadata'][d][category][s]) # normalize labels
+                            # Prefer the slot value as stored in the booked section
+                            if s in booked_slots:
+                                value_label = booked_slots[s]
+                            # Remember modified slots and entire dialog state
+                            if cs in slot_list and cumulative_labels[cs] != value_label:
+                                modified_slots[cs] = value_label
+                                cumulative_labels[cs] = value_label
+
+            mod_slots_list.append(modified_slots.copy())
+
+        # Form proper (usr, sys) turns
+        turn_itr = 0
+        diag_seen_slots_dict = {}
+        diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
+        diag_state = {slot: 'none' for slot in slot_list}
+        sys_utt_tok = []
+        usr_utt_tok = []
+        hst_utt_tok = []
+        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
+        for i in range(1, len(utt_tok_list) - 1, 2):
+            sys_utt_tok_label_dict = {}
+            usr_utt_tok_label_dict = {}
+            value_dict = {}
+            inform_dict = {}
+            inform_slot_dict = {}
+            referral_dict = {}
+            class_type_dict = {}
+
+            # Collect turn data
+            if not no_append_history:
+                if not swap_utterances:
+                    hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
+                else:
+                    hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
+            sys_utt_tok = utt_tok_list[i - 1]
+            usr_utt_tok = utt_tok_list[i]
+            turn_slots = mod_slots_list[i + 1]
+
+            guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
+
+            if analyze:
+                print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
+                print("%15s %2s [" % (dialog_id, turn_itr), end='')
+
+            new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
+            new_diag_state = diag_state.copy()
+            for slot in slot_list:
+                value_label = 'none'
+                if slot in turn_slots:
+                    value_label = turn_slots[slot]
+                    # We keep the original labels so as to not
+                    # overlook unpointable values, as well as to not
+                    # modify any of the original labels for test sets,
+                    # since this would make comparison difficult.
+                    value_dict[slot] = value_label
+                elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
+                    value_label = diag_seen_slots_value_dict[slot]
+
+                # Get dialog act annotations
+                inform_label = list(['none'])
+                inform_slot_dict[slot] = 0
+                if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
+                    inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
+                    inform_slot_dict[slot] = 1
+                elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
+                    inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
+                    inform_slot_dict[slot] = 1
+
+                (informed_value,
+                 referred_slot,
+                 usr_utt_tok_label,
+                 class_type) = get_turn_label(value_label,
+                                              inform_label,
+                                              sys_utt_tok,
+                                              usr_utt_tok,
+                                              slot,
+                                              diag_seen_slots_value_dict,
+                                              slot_last_occurrence=True,
+                                              label_maps=label_maps)
+
+                inform_dict[slot] = informed_value
+
+                # Generally don't use span prediction on sys utterance (but inform prediction instead).
+                sys_utt_tok_label = [0 for _ in sys_utt_tok]
+
+                # Determine what to do with value repetitions.
+                # If value is unique in seen slots, then tag it, otherwise not,
+                # since correct slot assignment can not be guaranteed anymore.
+                if not no_label_value_repetitions and slot in diag_seen_slots_dict:
+                    if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
+                        class_type = 'none'
+                        usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
+
+                sys_utt_tok_label_dict[slot] = sys_utt_tok_label
+                usr_utt_tok_label_dict[slot] = usr_utt_tok_label
+
+                if not no_append_history:
+                    if not no_use_history_labels:
+                        if not swap_utterances:
+                            new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                        else:
+                            new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                    else:
+                        new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
+                    
+                # For now, we map all occurences of unpointable slot values
+                # to none. However, since the labels will still suggest
+                # a presence of unpointable slot values, the task of the
+                # DST is still to find those values. It is just not
+                # possible to do that via span prediction on the current input.
+                if class_type == 'unpointable':
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                    if analyze:
+                        if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
+                            print("(%s): %s, " % (slot, value_label), end='')
+                elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
+                    # If slot has seen before and its class type did not change, label this slot a not present,
+                    # assuming that the slot has not actually been mentioned in this turn.
+                    # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
+                    # this must mean there is evidence in the original labels, therefore consider
+                    # them as mentioned again.
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                else:
+                    class_type_dict[slot] = class_type
+                    referral_dict[slot] = referred_slot
+                # Remember that this slot was mentioned during this dialog already.
+                if class_type != 'none':
+                    diag_seen_slots_dict[slot] = class_type
+                    diag_seen_slots_value_dict[slot] = value_label
+                    new_diag_state[slot] = class_type
+                    # Unpointable is not a valid class, therefore replace with
+                    # some valid class for now...
+                    if class_type == 'unpointable':
+                        new_diag_state[slot] = 'copy_value'
+
+            if analyze:
+                print("]")
+
+            if not swap_utterances:
+                txt_a = usr_utt_tok
+                txt_b = sys_utt_tok
+                txt_a_lbl = usr_utt_tok_label_dict
+                txt_b_lbl = sys_utt_tok_label_dict
+            else:
+                txt_a = sys_utt_tok
+                txt_b = usr_utt_tok
+                txt_a_lbl = sys_utt_tok_label_dict
+                txt_b_lbl = usr_utt_tok_label_dict
+            examples.append(DSTExample(
+                guid=guid,
+                text_a=txt_a,
+                text_b=txt_b,
+                history=hst_utt_tok,
+                text_a_label=txt_a_lbl,
+                text_b_label=txt_b_lbl,
+                history_label=hst_utt_tok_label_dict,
+                values=diag_seen_slots_value_dict.copy(),
+                inform_label=inform_dict,
+                inform_slot_label=inform_slot_dict,
+                refer_label=referral_dict,
+                diag_state=diag_state,
+                class_label=class_type_dict))
+
+            # Update some variables.
+            hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
+            diag_state = new_diag_state.copy()
+
+            turn_itr += 1
+
+        if analyze:
+            print("----------------------------------------------------------------------")
+
+    return examples
diff --git a/dataset_sim.py b/dataset_sim.py
index 43fe298..7d3b047 100644
--- a/dataset_sim.py
+++ b/dataset_sim.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -166,11 +166,10 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_i
 
 
 def create_examples(input_file, set_type, slot_list,
-                    label_maps={},
-                    append_history=False,
-                    use_history_labels=False,
+                    no_append_history=False,
+                    no_use_history_labels=False,
+                    no_label_value_repetitions=False,
                     swap_utterances=False,
-                    label_value_repetitions=False,
                     delexicalize_sys_utts=False,
                     unk_token="[UNK]",
                     analyze=False):
@@ -211,7 +210,7 @@ def create_examples(input_file, set_type, slot_list,
                                            unk_token=unk_token,
                                            slot_last_occurrence=True)
 
-            if swap_utterances:
+            if not swap_utterances:
                 txt_a = text_b
                 txt_b = text_a
                 txt_a_lbl = text_b_label
@@ -230,8 +229,8 @@ def create_examples(input_file, set_type, slot_list,
                     value_dict[slot] = 'none'
                 if class_label[slot] != 'none':
                     ds_lbl_dict[slot] = class_label[slot]
-                if append_history:
-                    if use_history_labels:
+                if not no_append_history:
+                    if not no_use_history_labels:
                         hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]
                     else:
                         hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]]
@@ -255,7 +254,7 @@ def create_examples(input_file, set_type, slot_list,
             prev_ds_lbl_dict = ds_lbl_dict.copy()
             prev_hst_lbl_dict = hst_lbl_dict.copy()
 
-            if append_history:
+            if not no_append_history:
                 hst = txt_a + txt_b + hst
 
     return examples
diff --git a/dataset_unified.py b/dataset_unified.py
new file mode 100644
index 0000000..a0217cc
--- /dev/null
+++ b/dataset_unified.py
@@ -0,0 +1,347 @@
+# coding=utf-8
+#
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from tqdm import tqdm
+
+from utils_dst import (DSTExample)
+
+try:
+    from convlab.util import (load_dataset, load_ontology, load_dst_data)
+except ModuleNotFoundError as e:
+    print(e)
+    print("Ignore this error if you don't intend to use the data processor for ConvLab3's unified data format.")
+    print("Otherwise, make sure you have ConvLab3 installed and added to your PYTHONPATH.")
+
+
+def get_ontology_slots(ontology):
+    domains = [domain for domain in ontology['domains']]
+    ontology_slots = dict()
+    for domain in domains:
+        if domain not in ontology_slots:
+            ontology_slots[domain] = set()
+        for slot in ontology['domains'][domain]['slots']:
+            ontology_slots[domain].add(slot)
+        ontology_slots[domain] = list(ontology_slots[domain])
+        ontology_slots[domain].sort()
+    return ontology_slots
+
+    
+def get_slot_list(dataset_name):
+    slot_list = []
+    ontology = load_ontology(dataset_name)
+    dataset_slot_list = get_ontology_slots(ontology)
+    for domain in dataset_slot_list:
+        for slot in dataset_slot_list[domain]:
+            slot_list.append("%s-%s" % (domain, slot))
+        slot_list.append("%s-none" % (domain)) # none slot indicates domain activation in ConvLab3
+    # Some special intents are modeled as 'request' slots in TripPy
+    if 'bye' in ontology['intents']:
+        slot_list.append("general-bye")
+    if 'thank' in ontology['intents']:
+        slot_list.append("general-thank")
+    if 'greet' in ontology['intents']:
+        slot_list.append("general-greet")
+    return slot_list
+
+
+def create_examples(set_type, dataset_name="multiwoz21", class_types=[], slot_list=[], label_maps={},
+                    no_append_history=False,
+                    no_use_history_labels=False,
+                    no_label_value_repetitions=False,
+                    swap_utterances=False,
+                    delexicalize_sys_utts=False,
+                    unk_token="[UNK]",
+                    analyze=False):
+    """Read a DST json file into a list of DSTExample."""
+
+    # TODO: Make sure normalization etc. will be compatible with or suitable for SGD and
+    # other datasets as well.
+    if dataset_name == "multiwoz21":
+        from dataset_multiwoz21 import (tokenize, normalize_label,
+                                        get_turn_label, delex_utt,
+                                        is_request)
+    else:
+        raise ValueError("Unknown dataset_name.")
+
+    dataset_args = {"dataset_name": dataset_name}
+    dataset_dict = load_dataset(**dataset_args)
+
+    if slot_list == []:
+        slot_list = get_slot_list()
+
+    data = load_dst_data(dataset_dict, data_split=set_type, speaker='all', dialogue_acts=True, split_to_turn=False)
+
+    examples = []
+    for d_itr, entry in enumerate(tqdm(data[set_type])):
+        dialog_id = entry['dialogue_id']
+        #dialog_id = entry['original_id']
+        original_id = entry['original_id']
+        domains = entry['domains']
+        turns = entry['turns']
+
+        # Collects all slot changes throughout the dialog
+        cumulative_labels = {slot: 'none' for slot in slot_list}
+
+        # First system utterance is empty, since multiwoz starts with user input
+        utt_tok_list = [[]]
+        mod_slots_list = [{}]
+        inform_dict_list = [{}]
+        user_act_dict_list = [{}]
+        mod_domains_list = [{}]
+
+        # Collect all utterances and their metadata
+        usr_sys_switch = True
+        for turn in turns:
+            utterance = turn['utterance']
+            state = turn['state'] if 'state' in turn else {}
+            acts = [item for sublist in list(turn['dialogue_acts'].values()) for item in sublist] # flatten list
+
+            # Assert that system and user utterances alternate
+            is_sys_utt = turn['speaker'] in ['sys', 'system']
+            if usr_sys_switch == is_sys_utt:
+                print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
+                break
+            usr_sys_switch = is_sys_utt
+
+            # Extract metadata: identify modified slots and values informed by the system
+            inform_dict = {}
+            user_act_dict = {}
+            modified_slots = {}
+            modified_domains = set()
+            for act in acts:
+                slot = "%s-%s" % (act['domain'], act['slot'] if act['slot'] != '' else 'none')
+                if act['intent'] in ['bye', 'thank', 'hello']:
+                    slot = "general-%s" % (act['intent'])
+                value_label = act['value'] if 'value' in act else 'yes' if act['slot'] != '' else 'none'
+                value_label = normalize_label(slot, value_label)
+                modified_domains.add(act['domain']) # Remember domains
+                if is_sys_utt and act['intent'] in ['inform', 'recommend', 'select', 'book'] and value_label != 'none':
+                    if slot not in inform_dict:
+                        inform_dict[slot] = []
+                    inform_dict[slot].append(value_label)
+                elif not is_sys_utt:
+                    if slot not in user_act_dict:
+                        user_act_dict[slot] = []
+                    user_act_dict[slot].append(act)
+            # INFO: Since the model has no mechanism to predict
+            # one among several informed value candidates, we
+            # keep only one informed value. For fairness, we
+            # apply a global rule:
+            for e in inform_dict:
+                # ... Option 1: Always keep first informed value
+                inform_dict[e] = list([inform_dict[e][0]])
+                # ... Option 2: Always keep last informed value
+                #inform_dict[e] = list([inform_dict[e][-1]])
+            for d in state:
+                for s in state[d]:
+                    slot = "%s-%s" % (d, s)
+                    value_label = normalize_label(slot, state[d][s])
+                    # Remember modified slots and entire dialog state
+                    if slot in slot_list and cumulative_labels[slot] != value_label:
+                        modified_slots[slot] = value_label
+                        cumulative_labels[slot] = value_label
+                        modified_domains.add(d) # Remember domains
+
+            # Delexicalize sys utterance
+            if delexicalize_sys_utts and is_sys_utt:
+                utt_tok_list.append(delex_utt(utterance, inform_dict, unk_token)) # normalizes utterances
+            else:
+                utt_tok_list.append(tokenize(utterance)) # normalizes utterances
+
+            inform_dict_list.append(inform_dict.copy())
+            user_act_dict_list.append(user_act_dict.copy())
+            mod_slots_list.append(modified_slots.copy())
+            modified_domains = list(modified_domains)
+            modified_domains.sort()
+            mod_domains_list.append(modified_domains)
+
+        # Form proper (usr, sys) turns
+        turn_itr = 0
+        diag_seen_slots_dict = {}
+        diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
+        diag_state = {slot: 'none' for slot in slot_list}
+        sys_utt_tok = []
+        usr_utt_tok = []
+        hst_utt_tok = []
+        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
+        for i in range(1, len(utt_tok_list) - 1, 2):
+            sys_utt_tok_label_dict = {}
+            usr_utt_tok_label_dict = {}
+            value_dict = {}
+            inform_dict = {}
+            inform_slot_dict = {}
+            referral_dict = {}
+            class_type_dict = {}
+
+            # Collect turn data
+            if not no_append_history:
+                if not swap_utterances:
+                    hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
+                else:
+                    hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
+            sys_utt_tok = utt_tok_list[i - 1]
+            usr_utt_tok = utt_tok_list[i]
+            turn_slots = mod_slots_list[i]
+            inform_mem = inform_dict_list[i - 1]
+            user_act = user_act_dict_list[i] 
+            turn_domains = mod_domains_list[i]
+
+            guid = '%s-%s' % (dialog_id, turn_itr)
+
+            if analyze:
+                print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
+                print("%15s %2s [" % (dialog_id, turn_itr), end='')
+
+            new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
+            new_diag_state = diag_state.copy()
+            for slot in slot_list:
+                value_label = 'none'
+                if slot in turn_slots:
+                    value_label = turn_slots[slot]
+                    # We keep the original labels so as to not
+                    # overlook unpointable values, as well as to not
+                    # modify any of the original labels for test sets,
+                    # since this would make comparison difficult.
+                    value_dict[slot] = value_label
+                elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
+                    value_label = diag_seen_slots_value_dict[slot]
+
+                # Get dialog act annotations
+                inform_label = list(['none'])
+                inform_slot_dict[slot] = 0
+                if slot in inform_mem:
+                    inform_label = inform_mem[slot]
+                    inform_slot_dict[slot] = 1
+
+                (informed_value,
+                 referred_slot,
+                 usr_utt_tok_label,
+                 class_type) = get_turn_label(value_label,
+                                              inform_label,
+                                              sys_utt_tok,
+                                              usr_utt_tok,
+                                              slot,
+                                              diag_seen_slots_value_dict,
+                                              slot_last_occurrence=True,
+                                              label_maps=label_maps)
+
+                inform_dict[slot] = informed_value
+
+                # Requestable slots, domain indicator slots and general slots
+                # should have class_type 'request', if they ought to be predicted.
+                # Give other class_types preference.
+                if 'request' in class_types:
+                    if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains):
+                        class_type = 'request'
+
+                # Generally don't use span prediction on sys utterance (but inform prediction instead).
+                sys_utt_tok_label = [0 for _ in sys_utt_tok]
+
+                # Determine what to do with value repetitions.
+                # If value is unique in seen slots, then tag it, otherwise not,
+                # since correct slot assignment can not be guaranteed anymore.
+                if not no_label_value_repetitions and slot in diag_seen_slots_dict:
+                    if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
+                        class_type = 'none'
+                        usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
+
+                sys_utt_tok_label_dict[slot] = sys_utt_tok_label
+                usr_utt_tok_label_dict[slot] = usr_utt_tok_label
+
+                if not no_append_history:
+                    if not no_use_history_labels:
+                        if not swap_utterances:
+                            new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                        else:
+                            new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
+                    else:
+                        new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
+                    
+                # For now, we map all occurences of unpointable slot values
+                # to none. However, since the labels will still suggest
+                # a presence of unpointable slot values, the task of the
+                # DST is still to find those values. It is just not
+                # possible to do that via span prediction on the current input.
+                if class_type == 'unpointable':
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                    if analyze:
+                        if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
+                            print("(%s): %s, " % (slot, value_label), end='')
+                elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
+                    # If slot has seen before and its class type did not change, label this slot a not present,
+                    # assuming that the slot has not actually been mentioned in this turn.
+                    # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
+                    # this must mean there is evidence in the original labels, therefore consider
+                    # them as mentioned again.
+                    class_type_dict[slot] = 'none'
+                    referral_dict[slot] = 'none'
+                else:
+                    class_type_dict[slot] = class_type
+                    referral_dict[slot] = referred_slot
+                # Remember that this slot was mentioned during this dialog already.
+                if class_type != 'none':
+                    diag_seen_slots_dict[slot] = class_type
+                    diag_seen_slots_value_dict[slot] = value_label
+                    new_diag_state[slot] = class_type
+                    # Unpointable is not a valid class, therefore replace with
+                    # some valid class for now...
+                    if class_type == 'unpointable':
+                        new_diag_state[slot] = 'copy_value'
+
+            if analyze:
+                print("]")
+
+            if not swap_utterances:
+                txt_a = usr_utt_tok
+                txt_b = sys_utt_tok
+                txt_a_lbl = usr_utt_tok_label_dict
+                txt_b_lbl = sys_utt_tok_label_dict
+            else:
+                txt_a = sys_utt_tok
+                txt_b = usr_utt_tok
+                txt_a_lbl = sys_utt_tok_label_dict
+                txt_b_lbl = usr_utt_tok_label_dict
+            examples.append(DSTExample(
+                guid=guid,
+                text_a=txt_a,
+                text_b=txt_b,
+                history=hst_utt_tok,
+                text_a_label=txt_a_lbl,
+                text_b_label=txt_b_lbl,
+                history_label=hst_utt_tok_label_dict,
+                values=diag_seen_slots_value_dict.copy(),
+                inform_label=inform_dict,
+                inform_slot_label=inform_slot_dict,
+                refer_label=referral_dict,
+                diag_state=diag_state,
+                class_label=class_type_dict))
+
+            # Update some variables.
+            hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
+            diag_state = new_diag_state.copy()
+
+            turn_itr += 1
+
+        if analyze:
+            print("----------------------------------------------------------------------")
+
+    return examples
diff --git a/dataset_woz2.py b/dataset_woz2.py
index d13b317..6d88255 100644
--- a/dataset_woz2.py
+++ b/dataset_woz2.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -103,10 +103,11 @@ def utt_to_token(utt):
 
 def create_examples(input_file, set_type, slot_list,
                     label_maps={},
-                    append_history=False,
-                    use_history_labels=False,
+                    asr=False,
+                    no_append_history=False,
+                    no_use_history_labels=False,
+                    no_label_value_repetitions=False,
                     swap_utterances=False,
-                    label_value_repetitions=False,
                     delexicalize_sys_utts=False,
                     unk_token="[UNK]",
                     analyze=False):
@@ -136,8 +137,8 @@ def create_examples(input_file, set_type, slot_list,
             class_type_dict = {}
 
             # Collect turn data
-            if append_history:
-                if swap_utterances:
+            if not no_append_history:
+                if not swap_utterances:
                     if delexicalize_sys_utts:
                         hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok
                     else:
@@ -149,7 +150,10 @@ def create_examples(input_file, set_type, slot_list,
                         hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
 
             sys_utt_tok = tokenize(turn['system_transcript'])
-            usr_utt_tok = tokenize(turn['transcript'])
+            if asr:
+                usr_utt_tok = tokenize(turn['asr'][:1][0][0])
+            else:
+                usr_utt_tok = tokenize(turn['transcript'])
             turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']}
 
             guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx']))
@@ -162,7 +166,7 @@ def create_examples(input_file, set_type, slot_list,
                     label = 'none'
                     if slot in turn_label:
                         label = turn_label[slot]
-                    elif label_value_repetitions and slot in diag_seen_slots_dict:
+                    elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
                         label = diag_seen_slots_value_dict[slot]
                     if label != 'none' and label != 'dontcare':
                         _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
@@ -176,7 +180,7 @@ def create_examples(input_file, set_type, slot_list,
                 label = 'none'
                 if slot in turn_label:
                     label = turn_label[slot]
-                elif label_value_repetitions and slot in diag_seen_slots_dict:
+                elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
                     label = diag_seen_slots_value_dict[slot]
 
                 (usr_utt_tok_label,
@@ -202,7 +206,7 @@ def create_examples(input_file, set_type, slot_list,
                 # Determine what to do with value repetitions.
                 # If value is unique in seen slots, then tag it, otherwise not,
                 # since correct slot assignment can not be guaranteed anymore.
-                if label_value_repetitions and slot in diag_seen_slots_dict:
+                if not no_label_value_repetitions and slot in diag_seen_slots_dict:
                     if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1:
                         class_type = 'none'
                         usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
@@ -210,9 +214,9 @@ def create_examples(input_file, set_type, slot_list,
                 sys_utt_tok_label_dict[slot] = sys_utt_tok_label
                 usr_utt_tok_label_dict[slot] = usr_utt_tok_label
 
-                if append_history:
-                    if use_history_labels:
-                        if swap_utterances:
+                if not no_append_history:
+                    if not no_use_history_labels:
+                        if not swap_utterances:
                             new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
                         else:
                             new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
@@ -246,7 +250,7 @@ def create_examples(input_file, set_type, slot_list,
                     if class_type == 'unpointable':
                         new_diag_state[slot] = 'copy_value'
 
-            if swap_utterances:
+            if not swap_utterances:
                 txt_a = usr_utt_tok
                 if delexicalize_sys_utts:
                     txt_b = sys_utt_tok_delex
diff --git a/metric_bert_dst.py b/metric_dst.py
similarity index 94%
rename from metric_bert_dst.py
rename to metric_dst.py
index e7818c1..7c89db5 100644
--- a/metric_bert_dst.py
+++ b/metric_dst.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -22,6 +22,7 @@ import json
 import sys
 import numpy as np
 import re
+import argparse
 
 
 def load_dataset_config(dataset_config):
@@ -160,6 +161,8 @@ def get_joint_slot_correctness(fp, class_types, label_maps,
                 else:
                     print("ERROR: Unexpected slot value format. Aborting.")
                     exit()
+            elif 'request' in class_types and turn_pd_class == class_types.index('request'):
+                pass
             else:
                 print("ERROR: Unexpected class_type. Aborting.")
                 exit()
@@ -281,13 +284,17 @@ if __name__ == "__main__":
     key_slot_groundtruth = 'slot_groundtruth_%s'
     key_slot_prediction = 'slot_prediction_%s'
 
-    dataset = sys.argv[1].lower()
-    dataset_config = sys.argv[2].lower()
+    parser = argparse.ArgumentParser()
+
+    # Required parameters
+    parser.add_argument("--dataset_config", default=None, type=str, required=True,
+                        help="Dataset configuration file.")
+    parser.add_argument("--file_list", default=None, type=str, required=True,
+                        help="List of input files.")
 
-    if dataset not in ['woz2', 'sim-m', 'sim-r', 'multiwoz21']:
-        raise ValueError("Task not found: %s" % (dataset))
+    args = parser.parse_args()
 
-    class_types, slots, label_maps = load_dataset_config(dataset_config)
+    class_types, slots, label_maps = load_dataset_config(args.dataset_config)
 
     # Prepare label_maps
     label_maps_tmp = {}
@@ -295,7 +302,16 @@ if __name__ == "__main__":
         label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]]
     label_maps = label_maps_tmp
 
-    for fp in sorted(glob.glob(sys.argv[3])):
+    for fp in sorted(glob.glob(args.file_list)):
+        # Infer slot list from data if not provided.
+        if len(slots) == 0:
+            with open(fp) as f:
+                preds = json.load(f)
+                for e in preds[0]:
+                    slot = re.match("^slot_groundtruth_(.*)$", e)
+                    slot = slot[1] if slot else None
+                    if slot and slot not in slots:
+                        slots.append(slot)
         print(fp)
         goal_correctness = 1.0
         cls_acc = [[] for cl in range(len(class_types))]
diff --git a/modeling_bert_dst.py b/modeling_bert_dst.py
deleted file mode 100644
index e0ccc7d..0000000
--- a/modeling_bert_dst.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
-from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
-
-
-@add_start_docstrings(
-    """BERT Model with a classification heads for the DST task. """,
-    BERT_START_DOCSTRING,
-)
-class BertForDST(BertPreTrainedModel):
-    def __init__(self, config):
-        super(BertForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.bert = BertModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        # Head for aux task
-        if hasattr(config, "aux_task_def"):
-            self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
-
-        self.init_weights()
-
-    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None,
-                aux_task_def=None):
-        outputs = self.bert(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        if aux_task_def is not None:
-            if aux_task_def['task_type'] == "classification":
-                aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
-                aux_logits = self.dropout_heads(aux_logits)
-                aux_loss_fct = CrossEntropyLoss()
-                aux_loss = aux_loss_fct(aux_logits, class_label_id)
-                # add hidden states and attention if they are here
-                return (aux_loss,) + outputs[2:]
-            elif aux_task_def['task_type'] == "span":
-                aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
-                aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
-                aux_start_logits = self.dropout_heads(aux_start_logits)
-                aux_end_logits = self.dropout_heads(aux_end_logits)
-                aux_start_logits = aux_start_logits.squeeze(-1)
-                aux_end_logits = aux_end_logits.squeeze(-1)
-
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos.size()) > 1:
-                    start_pos = start_pos.squeeze(-1)
-                if len(end_pos.size()) > 1:
-                    end_pos = end_pos.squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = aux_start_logits.size(1) # This is a single index
-                start_pos.clamp_(0, ignored_index)
-                end_pos.clamp_(0, ignored_index)
-            
-                aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
-                aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
-                aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
-                aux_loss = (aux_start_loss + aux_end_loss) / 2.0
-                return (aux_loss,) + outputs[2:]
-            else:
-                raise Exception("Unknown task_type")
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                start_loss = token_loss_fct(start_logits, start_pos[slot])
-                end_loss = token_loss_fct(end_logits, end_pos[slot])
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
-
-        return outputs
diff --git a/modeling_dst.py b/modeling_dst.py
new file mode 100644
index 0000000..6e74fbc
--- /dev/null
+++ b/modeling_dst.py
@@ -0,0 +1,251 @@
+# coding=utf-8
+#
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import (BertModel, BertPreTrainedModel,
+                          RobertaModel, RobertaPreTrainedModel,
+                          ElectraModel, ElectraPreTrainedModel)
+
+PARENT_CLASSES = {
+    'bert': BertPreTrainedModel,
+    'roberta': RobertaPreTrainedModel,
+    'electra': ElectraPreTrainedModel
+}
+
+MODEL_CLASSES = {
+    BertPreTrainedModel: BertModel,
+    RobertaPreTrainedModel: RobertaModel,
+    ElectraPreTrainedModel: ElectraModel
+}
+
+
+class ElectraPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+        
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+def TransformerForDST(parent_name):
+    if parent_name not in PARENT_CLASSES:
+        raise ValueError("Unknown model %s" % (parent_name))
+
+    class TransformerForDST(PARENT_CLASSES[parent_name]):
+        def __init__(self, config):
+            assert config.model_type in PARENT_CLASSES
+            assert self.__class__.__bases__[0] in MODEL_CLASSES
+            super(TransformerForDST, self).__init__(config)
+            self.model_type = config.model_type
+            self.slot_list = config.dst_slot_list
+            self.class_types = config.dst_class_types
+            self.class_labels = config.dst_class_labels
+            self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
+            self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
+            self.class_aux_feats_inform = config.dst_class_aux_feats_inform
+            self.class_aux_feats_ds = config.dst_class_aux_feats_ds
+            self.class_loss_ratio = config.dst_class_loss_ratio
+
+            # Only use refer loss if refer class is present in dataset.
+            if 'refer' in self.class_types:
+                self.refer_index = self.class_types.index('refer')
+            else:
+                self.refer_index = -1
+
+            # Make sure this module has the same name as in the pretrained checkpoint you want to load!
+            self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config))
+            if self.model_type == "electra":
+                self.pooler = ElectraPooler(config)
+            
+            self.dropout = nn.Dropout(config.dst_dropout_rate)
+            self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
+
+            if self.class_aux_feats_inform:
+                self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+            if self.class_aux_feats_ds:
+                self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+
+            aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
+
+            for slot in self.slot_list:
+                self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
+                self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
+                self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
+
+            # Head for aux task
+            if hasattr(config, "aux_task_def") and config.aux_task_def is not None:
+                self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
+
+            self.init_weights()
+
+        def forward(self,
+                    input_ids,
+                    input_mask=None,
+                    segment_ids=None,
+                    position_ids=None,
+                    head_mask=None,
+                    start_pos=None,
+                    end_pos=None,
+                    inform_slot_id=None,
+                    refer_id=None,
+                    class_label_id=None,
+                    diag_state=None,
+                    aux_task_def=None):
+            outputs = getattr(self, self.model_type)(
+                input_ids,
+                attention_mask=input_mask,
+                token_type_ids=segment_ids,
+                position_ids=position_ids,
+                head_mask=head_mask
+            )
+
+            sequence_output = outputs[0]
+            if self.model_type == "electra":
+                pooled_output = self.pooler(sequence_output)
+            else:
+                pooled_output = outputs[1]
+
+            sequence_output = self.dropout(sequence_output)
+            pooled_output = self.dropout(pooled_output)
+
+            if aux_task_def is not None:
+                if aux_task_def['task_type'] == "classification":
+                    aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
+                    aux_logits = self.dropout_heads(aux_logits)
+                    aux_loss_fct = CrossEntropyLoss()
+                    aux_loss = aux_loss_fct(aux_logits, class_label_id)
+                    # add hidden states and attention if they are here
+                    return (aux_loss,) + outputs[2:]
+                elif aux_task_def['task_type'] == "span":
+                    aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
+                    aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
+                    aux_start_logits = self.dropout_heads(aux_start_logits)
+                    aux_end_logits = self.dropout_heads(aux_end_logits)
+                    aux_start_logits = aux_start_logits.squeeze(-1)
+                    aux_end_logits = aux_end_logits.squeeze(-1)
+
+                    # If we are on multi-GPU, split add a dimension
+                    if len(start_pos.size()) > 1:
+                        start_pos = start_pos.squeeze(-1)
+                    if len(end_pos.size()) > 1:
+                        end_pos = end_pos.squeeze(-1)
+                    # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                    ignored_index = aux_start_logits.size(1) # This is a single index
+                    start_pos.clamp_(0, ignored_index)
+                    end_pos.clamp_(0, ignored_index)
+
+                    aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+                    aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
+                    aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
+                    aux_loss = (aux_start_loss + aux_end_loss) / 2.0
+                    return (aux_loss,) + outputs[2:]
+                else:
+                    raise ValueError("Unknown task_type %s" % (aux_task_def['task_type'] if 'task_type' in aux_task_def else None))
+
+            if inform_slot_id is not None:
+                inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
+            if diag_state is not None:
+                diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
+
+            total_loss = 0
+            per_slot_per_example_loss = {}
+            per_slot_class_logits = {}
+            per_slot_start_logits = {}
+            per_slot_end_logits = {}
+            per_slot_refer_logits = {}
+            for slot in self.slot_list:
+                if self.class_aux_feats_inform and self.class_aux_feats_ds:
+                    pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
+                elif self.class_aux_feats_inform:
+                    pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
+                elif self.class_aux_feats_ds:
+                    pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
+                else:
+                    pooled_output_aux = pooled_output
+                class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
+
+                token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
+                start_logits, end_logits = token_logits.split(1, dim=-1)
+                start_logits = start_logits.squeeze(-1)
+                end_logits = end_logits.squeeze(-1)
+
+                refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
+
+                per_slot_class_logits[slot] = class_logits
+                per_slot_start_logits[slot] = start_logits
+                per_slot_end_logits[slot] = end_logits
+                per_slot_refer_logits[slot] = refer_logits
+
+                # If there are no labels, don't compute loss
+                if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
+                    # If we are on multi-GPU, split add a dimension
+                    if len(start_pos[slot].size()) > 1:
+                        start_pos[slot] = start_pos[slot].squeeze(-1)
+                    if len(end_pos[slot].size()) > 1:
+                        end_pos[slot] = end_pos[slot].squeeze(-1)
+                    # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                    ignored_index = start_logits.size(1) # This is a single index
+                    start_pos[slot].clamp_(0, ignored_index)
+                    end_pos[slot].clamp_(0, ignored_index)
+
+                    class_loss_fct = CrossEntropyLoss(reduction='none')
+                    token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
+                    refer_loss_fct = CrossEntropyLoss(reduction='none')
+
+                    start_loss = token_loss_fct(start_logits, start_pos[slot])
+                    end_loss = token_loss_fct(end_logits, end_pos[slot])
+                    token_loss = (start_loss + end_loss) / 2.0
+
+                    token_is_pointable = (start_pos[slot] > 0).float()
+                    if not self.token_loss_for_nonpointable:
+                        token_loss *= token_is_pointable
+
+                    refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
+                    token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
+                    if not self.refer_loss_for_nonpointable:
+                        refer_loss *= token_is_referrable
+
+                    class_loss = class_loss_fct(class_logits, class_label_id[slot])
+
+                    if self.refer_index > -1:
+                        per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
+                    else:
+                        per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
+
+                    total_loss += per_example_loss.sum()
+                    per_slot_per_example_loss[slot] = per_example_loss
+
+            # add hidden states and attention if they are here
+            outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
+
+            return outputs
+
+    return TransformerForDST
diff --git a/modeling_roberta_dst.py b/modeling_roberta_dst.py
deleted file mode 100644
index 062fccb..0000000
--- a/modeling_roberta_dst.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# coding=utf-8
-#
-# Copyright 2020 Heinrich Heine University Duesseldorf
-#
-# Part of this code is based on the source code of BERT-DST
-# (arXiv:1907.03040)
-# Part of this code is based on the source code of Transformers
-# (arXiv:1910.03771)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from torch import nn
-from torch.nn import CrossEntropyLoss
-
-from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
-from transformers.modeling_utils import (PreTrainedModel)
-from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
-                                           ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm)
-
-
-class RobertaPreTrainedModel(PreTrainedModel):
-    """ An abstract class to handle weights initialization and
-        a simple interface for dowloading and loading pretrained models.
-    """
-    config_class = RobertaConfig
-    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
-    base_model_prefix = "roberta"
-    
-    def _init_weights(self, module):
-        """ Initialize the weights """
-        if isinstance(module, (nn.Linear, nn.Embedding)):
-            # Slightly different from the TF version which uses truncated_normal for initialization
-            # cf https://github.com/pytorch/pytorch/pull/5617
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-        elif isinstance(module, BertLayerNorm):
-            module.bias.data.zero_()
-            module.weight.data.fill_(1.0)
-        if isinstance(module, nn.Linear) and module.bias is not None:
-            module.bias.data.zero_()
-
-
-@add_start_docstrings(
-    """RoBERTa Model with classification heads for the DST task. """,
-    ROBERTA_START_DOCSTRING,
-)
-class RobertaForDST(RobertaPreTrainedModel):
-    def __init__(self, config):
-        super(RobertaForDST, self).__init__(config)
-        self.slot_list = config.dst_slot_list
-        self.class_types = config.dst_class_types
-        self.class_labels = config.dst_class_labels
-        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
-        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
-        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
-        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
-        self.class_loss_ratio = config.dst_class_loss_ratio
-
-        # Only use refer loss if refer class is present in dataset.
-        if 'refer' in self.class_types:
-            self.refer_index = self.class_types.index('refer')
-        else:
-            self.refer_index = -1
-
-        self.roberta = RobertaModel(config)
-        self.dropout = nn.Dropout(config.dst_dropout_rate)
-        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
-
-        if self.class_aux_feats_inform:
-            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-        if self.class_aux_feats_ds:
-            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
-
-        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
-
-        for slot in self.slot_list:
-            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
-            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
-            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
-
-        # Head for aux task
-        if hasattr(config, "aux_task_def"):
-            self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
-
-        self.init_weights()
-
-    @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
-    def forward(self,
-                input_ids,
-                input_mask=None,
-                segment_ids=None,
-                position_ids=None,
-                head_mask=None,
-                start_pos=None,
-                end_pos=None,
-                inform_slot_id=None,
-                refer_id=None,
-                class_label_id=None,
-                diag_state=None,
-                aux_task_def=None):
-        outputs = self.roberta(
-            input_ids,
-            attention_mask=input_mask,
-            token_type_ids=segment_ids,
-            position_ids=position_ids,
-            head_mask=head_mask
-        )
-
-        sequence_output = outputs[0]
-        pooled_output = outputs[1]
-
-        sequence_output = self.dropout(sequence_output)
-        pooled_output = self.dropout(pooled_output)
-
-        if aux_task_def is not None:
-            if aux_task_def['task_type'] == "classification":
-                aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
-                aux_logits = self.dropout_heads(aux_logits)
-                aux_loss_fct = CrossEntropyLoss()
-                aux_loss = aux_loss_fct(aux_logits, class_label_id)
-                # add hidden states and attention if they are here
-                return (aux_loss,) + outputs[2:]
-            elif aux_task_def['task_type'] == "span":
-                aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
-                aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
-                aux_start_logits = self.dropout_heads(aux_start_logits)
-                aux_end_logits = self.dropout_heads(aux_end_logits)
-                aux_start_logits = aux_start_logits.squeeze(-1)
-                aux_end_logits = aux_end_logits.squeeze(-1)
-
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos.size()) > 1:
-                    start_pos = start_pos.squeeze(-1)
-                if len(end_pos.size()) > 1:
-                    end_pos = end_pos.squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = aux_start_logits.size(1) # This is a single index
-                start_pos.clamp_(0, ignored_index)
-                end_pos.clamp_(0, ignored_index)
-            
-                aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
-                aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
-                aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
-                aux_loss = (aux_start_loss + aux_end_loss) / 2.0
-                return (aux_loss,) + outputs[2:]
-            else:
-                raise Exception("Unknown task_type")
-
-        # TODO: establish proper format in labels already?
-        if inform_slot_id is not None:
-            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
-        if diag_state is not None:
-            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
-        
-        total_loss = 0
-        per_slot_per_example_loss = {}
-        per_slot_class_logits = {}
-        per_slot_start_logits = {}
-        per_slot_end_logits = {}
-        per_slot_refer_logits = {}
-        for slot in self.slot_list:
-            if self.class_aux_feats_inform and self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
-            elif self.class_aux_feats_inform:
-                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
-            elif self.class_aux_feats_ds:
-                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
-            else:
-                pooled_output_aux = pooled_output
-            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
-
-            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
-            start_logits, end_logits = token_logits.split(1, dim=-1)
-            start_logits = start_logits.squeeze(-1)
-            end_logits = end_logits.squeeze(-1)
-
-            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
-
-            per_slot_class_logits[slot] = class_logits
-            per_slot_start_logits[slot] = start_logits
-            per_slot_end_logits[slot] = end_logits
-            per_slot_refer_logits[slot] = refer_logits
-            
-            # If there are no labels, don't compute loss
-            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
-                # If we are on multi-GPU, split add a dimension
-                if len(start_pos[slot].size()) > 1:
-                    start_pos[slot] = start_pos[slot].squeeze(-1)
-                if len(end_pos[slot].size()) > 1:
-                    end_pos[slot] = end_pos[slot].squeeze(-1)
-                # sometimes the start/end positions are outside our model inputs, we ignore these terms
-                ignored_index = start_logits.size(1) # This is a single index
-                start_pos[slot].clamp_(0, ignored_index)
-                end_pos[slot].clamp_(0, ignored_index)
-
-                class_loss_fct = CrossEntropyLoss(reduction='none')
-                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
-                refer_loss_fct = CrossEntropyLoss(reduction='none')
-
-                start_loss = token_loss_fct(start_logits, start_pos[slot])
-                end_loss = token_loss_fct(end_logits, end_pos[slot])
-                token_loss = (start_loss + end_loss) / 2.0
-
-                token_is_pointable = (start_pos[slot] > 0).float()
-                if not self.token_loss_for_nonpointable:
-                    token_loss *= token_is_pointable
-
-                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
-                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
-                if not self.refer_loss_for_nonpointable:
-                    refer_loss *= token_is_referrable
-
-                class_loss = class_loss_fct(class_logits, class_label_id[slot])
-
-                if self.refer_index > -1:
-                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
-                else:
-                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
-
-                total_loss += per_example_loss.sum()
-                per_slot_per_example_loss[slot] = per_example_loss
-
-        # add hidden states and attention if they are here
-        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
-
-        return outputs
diff --git a/run_dst.py b/run_dst.py
index 7809bb1..c52cd64 100644
--- a/run_dst.py
+++ b/run_dst.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -32,28 +32,32 @@ import numpy as np
 import torch
 from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
 from torch.utils.data.distributed import DistributedSampler
+from torch.optim import (AdamW)
 from tqdm import tqdm, trange
 
 from tensorboardX import SummaryWriter
 
-from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer,
-                          RobertaConfig, RobertaTokenizer)
-from transformers import (AdamW, get_linear_schedule_with_warmup)
+from transformers import (WEIGHTS_NAME,
+                          BertConfig, BertTokenizer, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+                          RobertaConfig, RobertaTokenizer, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+                          ElectraConfig, ElectraTokenizer, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+                          get_linear_schedule_with_warmup)
 
-from modeling_bert_dst import (BertForDST)
-from modeling_roberta_dst import (RobertaForDST)
+from modeling_dst import (TransformerForDST)
 from data_processors import PROCESSORS
-from utils_dst import (convert_examples_to_features)
+from utils_dst import (print_header, convert_examples_to_features)
 from tensorlistdataset import (TensorListDataset)
 
 logger = logging.getLogger(__name__)
 
-ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys())
-ALL_MODELS += tuple(RobertaConfig.pretrained_config_archive_map.keys())
+ALL_MODELS = tuple(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
+ALL_MODELS += tuple(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP)
+ALL_MODELS += tuple(ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP)
 
 MODEL_CLASSES = {
-    'bert': (BertConfig, BertForDST, BertTokenizer),
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
+    'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer),
+    'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer),
+    'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer),
 }
 
 
@@ -64,10 +68,6 @@ def set_seed(args):
     if args.n_gpu > 0:
         torch.cuda.manual_seed_all(args.seed)
 
-        
-def to_list(tensor):
-    return tensor.detach().cpu().tolist()
-
 
 def batch_to_device(batch, device):
     batch_on_device = []
@@ -195,7 +195,7 @@ def train(args, train_dataset, features, model, tokenizer, processor, continue_f
 
                 # Log metrics
                 if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
-                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
+                    tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step)
                     tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
                     logging_loss = tr_loss
 
@@ -252,7 +252,7 @@ def evaluate(args, model, tokenizer, processor, prefix=""):
         batch = batch_to_device(batch, args.device)
 
         # Reset dialog state if turn is first in the dialog.
-        turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
+        turn_itrs = [features[i.item()].guid.split('-')[-1] for i in batch[9]]
         reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
         for slot in model.slot_list:
             for i in reset_diag_state:
@@ -365,18 +365,24 @@ def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_sl
     prediction_list = []
     dialog_state = ds
     for i in range(len(ids)):
-        if int(ids[i].split("-")[2]) == 0:
+        if int(ids[i].split("-")[-1]) == 0:
             dialog_state = {slot: 'none' for slot in model.slot_list}
 
+        input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i])
+
         prediction = {}
         prediction_addendum = {}
+
+        prediction['guid'] = ids[i].split("-")
+        input_ids = features['input_ids'][i].tolist()
+        prediction['input_ids'] = input_ids
+
         for slot in model.slot_list:
             class_logits = per_slot_class_logits[slot][i]
             start_logits = per_slot_start_logits[slot][i]
             end_logits = per_slot_end_logits[slot][i]
             refer_logits = per_slot_refer_logits[slot][i]
 
-            input_ids = features['input_ids'][i].tolist()
             class_label_id = int(features['class_label_id'][slot][i])
             start_pos = int(features['start_pos'][slot][i])
             end_pos = int(features['end_pos'][slot][i])
@@ -387,7 +393,6 @@ def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_sl
             end_prediction = int(end_logits.argmax())
             refer_prediction = int(refer_logits.argmax())
 
-            prediction['guid'] = ids[i].split("-")
             prediction['class_prediction_%s' % slot] = class_prediction
             prediction['class_label_id_%s' % slot] = class_label_id
             prediction['start_prediction_%s' % slot] = start_prediction
@@ -396,12 +401,10 @@ def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_sl
             prediction['end_pos_%s' % slot] = end_pos
             prediction['refer_prediction_%s' % slot] = refer_prediction
             prediction['refer_id_%s' % slot] = refer_id
-            prediction['input_ids_%s' % slot] = input_ids
 
             if class_prediction == model.class_types.index('dontcare'):
                 dialog_state[slot] = 'dontcare'
             elif class_prediction == model.class_types.index('copy_value'):
-                input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i])
                 dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
                 dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot])
             elif 'true' in model.class_types and class_prediction == model.class_types.index('true'):
@@ -410,6 +413,9 @@ def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_sl
                 dialog_state[slot] = 'false'
             elif class_prediction == model.class_types.index('inform'):
                 dialog_state[slot] = '§§' + inform[i][slot]
+            #elif 'request' in model.class_types and class_prediction == model.class_types.index('request'):
+                # Don't carry over requested slots
+                #pass
             # Referral case is handled below
 
             prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot]
@@ -449,11 +455,14 @@ def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
         logger.info("Loading features from cached file %s", cached_file)
         features = torch.load(cached_file)
     else:
-        logger.info("Creating features from dataset file at %s", args.data_dir)
-        processor_args = {'append_history': args.append_history,
-                          'use_history_labels': args.use_history_labels,
+        if args.task_name == "unified":
+            logger.info("Creating features from unified data format")
+        else:
+            logger.info("Creating features from dataset file at %s", args.data_dir)
+        processor_args = {'no_append_history': args.no_append_history,
+                          'no_use_history_labels': args.no_use_history_labels,
+                          'no_label_value_repetitions': args.no_label_value_repetitions,
                           'swap_utterances': args.swap_utterances,
-                          'label_value_repetitions': args.label_value_repetitions,
                           'delexicalize_sys_utts': args.delexicalize_sys_utts,
                           'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'}
         if evaluate and args.predict_type == "dev":
@@ -535,7 +544,7 @@ def main():
     parser.add_argument("--tokenizer_name", default="", type=str,
                         help="Pretrained tokenizer name or path if not the same as model_name")
 
-    parser.add_argument("--max_seq_length", default=384, type=int,
+    parser.add_argument("--max_seq_length", default=180, type=int,
                         help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.")
     parser.add_argument("--do_train", action='store_true',
                         help="Whether to run training.")
@@ -560,14 +569,14 @@ def main():
     parser.add_argument("--refer_loss_for_nonpointable", action='store_true',
                         help="Whether the refer loss for classes other than refer contribute towards total loss.")
 
-    parser.add_argument("--append_history", action='store_true',
+    parser.add_argument("--no_append_history", action='store_true',
                         help="Whether or not to append the dialog history to each turn.")
-    parser.add_argument("--use_history_labels", action='store_true',
+    parser.add_argument("--no_use_history_labels", action='store_true',
                         help="Whether or not to label the history as well.")
-    parser.add_argument("--swap_utterances", action='store_true',
-                        help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).")
-    parser.add_argument("--label_value_repetitions", action='store_true',
+    parser.add_argument("--no_label_value_repetitions", action='store_true',
                         help="Whether or not to label values that have been mentioned before.")
+    parser.add_argument("--swap_utterances", action='store_true',
+                        help="Whether or not to swap the turn utterances (default: usr|sys, swapped: sys|usr).")
     parser.add_argument("--delexicalize_sys_utts", action='store_true',
                         help="Whether or not to delexicalize the system utterances.")
     parser.add_argument("--class_aux_feats_inform", action='store_true',
@@ -592,13 +601,13 @@ def main():
     parser.add_argument("--num_train_epochs", default=3.0, type=float,
                         help="Total number of training epochs to perform.")
     parser.add_argument("--max_steps", default=-1, type=int,
-                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
+                        help="If > 0: set total number of training steps to perform. Overrides num_train_epochs.")
     parser.add_argument("--warmup_proportion", default=0.0, type=float,
                         help="Linear warmup over warmup_proportion * steps.")
     parser.add_argument("--svd", default=0.0, type=float,
                         help="Slot value dropout ratio (default: 0.0)")
 
-    parser.add_argument('--logging_steps', type=int, default=50,
+    parser.add_argument('--logging_steps', type=int, default=10,
                         help="Log every X updates steps.")
     parser.add_argument('--save_steps', type=int, default=0,
                         help="Save checkpoint every X updates steps. Overwritten by --save_epochs.")
@@ -622,14 +631,16 @@ def main():
     parser.add_argument('--fp16_opt_level', type=str, default='O1',
                         help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                              "See details at https://nvidia.github.io/apex/amp.html")
+    parser.add_argument('--local_files_only', action='store_true',
+                        help="Whether to only load local model files (useful when working offline).")
 
     args = parser.parse_args()
 
-    assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0)
-    assert(args.svd >= 0.0 and args.svd <= 1.0)
-    assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1)
-    assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1)
-    assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1)
+    assert args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0
+    assert args.svd >= 0.0 and args.svd <= 1.0
+    assert args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1
+    assert not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1
+    assert not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1
 
     task_name = args.task_name.lower()
     if task_name not in PROCESSORS:
@@ -651,10 +662,11 @@ def main():
         args.n_gpu = 1
     args.device = device
 
-    # Setup logging
+    # Setup logging, print header
     logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                         datefmt = '%m/%d/%Y %H:%M:%S',
                         level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
+    print_header()
     logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
 
@@ -667,9 +679,10 @@ def main():
 
     args.model_type = args.model_type.lower()
     config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
-    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
+    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, local_files_only=args.local_files_only)
 
     # Add DST specific parameters to config
+    config.dst_max_seq_length = args.max_seq_length
     config.dst_dropout_rate = args.dropout_rate
     config.dst_heads_dropout_rate = args.heads_dropout
     config.dst_class_loss_ratio = args.class_loss_ratio
@@ -681,8 +694,8 @@ def main():
     config.dst_class_types = dst_class_types
     config.dst_class_labels = dst_class_labels
 
-    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
-    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
+    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, local_files_only=args.local_files_only)
+    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, local_files_only=args.local_files_only)
 
     logger.info("Updated model config: %s" % config)
 
diff --git a/run_dst_mtl.py b/run_dst_mtl.py
index 8e50306..125d7db 100644
--- a/run_dst_mtl.py
+++ b/run_dst_mtl.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -33,56 +33,39 @@ import numpy as np
 import torch
 from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
 from torch.utils.data.distributed import DistributedSampler
+from torch.optim import (AdamW)
 from tqdm import tqdm, trange
 from collections import deque
 
 from tensorboardX import SummaryWriter
 
-from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer,
-                          RobertaConfig, RobertaTokenizer)
-from transformers import (AdamW, get_linear_schedule_with_warmup)
+from transformers import (WEIGHTS_NAME,
+                          BertConfig, BertTokenizer, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+                          RobertaConfig, RobertaTokenizer, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+                          ElectraConfig, ElectraTokenizer, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+                          get_linear_schedule_with_warmup)
 
-from modeling_bert_dst import (BertForDST)
-from modeling_roberta_dst import (RobertaForDST)
+from modeling_dst import (TransformerForDST)
 from data_processors import PROCESSORS
-from utils_dst import (convert_examples_to_features, convert_aux_examples_to_features)
+from run_dst import (evaluate, load_and_cache_examples, set_seed, batch_to_device)
+from utils_dst import (print_header, convert_examples_to_features, convert_aux_examples_to_features)
 from tensorlistdataset import (TensorListDataset)
 
 logger = logging.getLogger(__name__)
 
-ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys())
-ALL_MODELS += tuple(RobertaConfig.pretrained_config_archive_map.keys())
+ALL_MODELS = tuple(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
+ALL_MODELS += tuple(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP)
+ALL_MODELS += tuple(ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP)
 
 MODEL_CLASSES = {
-    'bert': (BertConfig, BertForDST, BertTokenizer),
-    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
+    'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer),
+    'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer),
+    'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer),
 }
 
 
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-        
-def to_list(tensor):
-    return tensor.detach().cpu().tolist()
-
-
-def batch_to_device(batch, device):
-    batch_on_device = []
-    for element in batch:
-        if isinstance(element, dict):
-            batch_on_device.append({k: v.to(device) for k, v in element.items()})
-        else:
-            batch_on_device.append(element.to(device))
-    return tuple(batch_on_device)
-
-
-def train(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step=0):
-    assert(not args.mtl_use or args.gradient_accumulation_steps == 1)
+def train_mtl(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step=0):
+    assert not args.mtl_use or args.gradient_accumulation_steps == 1
 
     """ Train the model """
     if args.local_rank in [-1, 0]:
@@ -278,7 +261,7 @@ def train(args, train_dataset, aux_dataset, aux_task_def, features, model, token
 
                 # Log metrics
                 if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
-                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
+                    tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step)
                     tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
                     logging_loss = tr_loss
 
@@ -317,288 +300,6 @@ def train(args, train_dataset, aux_dataset, aux_task_def, features, model, token
     return global_step, tr_loss / global_step, tr_aux_loss
 
 
-def evaluate(args, model, tokenizer, processor, prefix=""):
-    dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True)
-
-    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
-        os.makedirs(args.output_dir)
-
-    args.eval_batch_size = args.per_gpu_eval_batch_size
-    eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly
-    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
-
-    # Eval!
-    logger.info("***** Running evaluation {} *****".format(prefix))
-    logger.info("  Num examples = %d", len(dataset))
-    logger.info("  Batch size = %d", args.eval_batch_size)
-    all_results = []
-    all_preds = []
-    ds = {slot: 'none' for slot in model.slot_list}
-    with torch.no_grad():
-        diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list}
-    for batch in tqdm(eval_dataloader, desc="Evaluating"):
-        model.eval()
-        batch = batch_to_device(batch, args.device)
-
-        # Reset dialog state if turn is first in the dialog.
-        turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
-        reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
-        for slot in model.slot_list:
-            for i in reset_diag_state:
-                diag_state[slot][i] = 0
-
-        with torch.no_grad():
-            inputs = {'input_ids':       batch[0],
-                      'input_mask':      batch[1],
-                      'segment_ids':     batch[2],
-                      'start_pos':       batch[3],
-                      'end_pos':         batch[4],
-                      'inform_slot_id':  batch[5],
-                      'refer_id':        batch[6],
-                      'diag_state':      diag_state,
-                      'class_label_id':  batch[8]}
-            unique_ids = [features[i.item()].guid for i in batch[9]]
-            values = [features[i.item()].values for i in batch[9]]
-            input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]]
-            inform = [features[i.item()].inform for i in batch[9]]
-            outputs = model(**inputs)
-
-            # Update dialog state for next turn.
-            for slot in model.slot_list:
-                updates = outputs[2][slot].max(1)[1]
-                for i, u in enumerate(updates):
-                    if u != 0:
-                        diag_state[slot][i] = u
-
-        results = eval_metric(model, inputs, outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5])
-        preds, ds = predict_and_format(model, tokenizer, inputs, outputs[2], outputs[3], outputs[4], outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds)
-        all_results.append(results)
-        all_preds.append(preds)
-
-    all_preds = [item for sublist in all_preds for item in sublist] # Flatten list
-
-    # Generate final results
-    final_results = {}
-    for k in all_results[0].keys():
-        final_results[k] = torch.stack([r[k] for r in all_results]).mean()
-
-    # Write final predictions (for evaluation with external tool)
-    output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix))
-    with open(output_prediction_file, "w") as f:
-        json.dump(all_preds, f, indent=2)
-
-    return final_results
-
-
-def eval_metric(model, features, total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits):
-    metric_dict = {}
-    per_slot_correctness = {}
-    for slot in model.slot_list:
-        per_example_loss = per_slot_per_example_loss[slot]
-        class_logits = per_slot_class_logits[slot]
-        start_logits = per_slot_start_logits[slot]
-        end_logits = per_slot_end_logits[slot]
-        refer_logits = per_slot_refer_logits[slot]
-
-        class_label_id = features['class_label_id'][slot]
-        start_pos = features['start_pos'][slot]
-        end_pos = features['end_pos'][slot]
-        refer_id = features['refer_id'][slot]
-
-        _, class_prediction = class_logits.max(1)
-        class_correctness = torch.eq(class_prediction, class_label_id).float()
-        class_accuracy = class_correctness.mean()
-
-        # "is pointable" means whether class label is "copy_value",
-        # i.e., that there is a span to be detected.
-        token_is_pointable = torch.eq(class_label_id, model.class_types.index('copy_value')).float()
-        _, start_prediction = start_logits.max(1)
-        start_correctness = torch.eq(start_prediction, start_pos).float()
-        _, end_prediction = end_logits.max(1)
-        end_correctness = torch.eq(end_prediction, end_pos).float()
-        token_correctness = start_correctness * end_correctness
-        token_accuracy = (token_correctness * token_is_pointable).sum() / token_is_pointable.sum()
-        # NaNs mean that none of the examples in this batch contain spans. -> division by 0
-        # The accuracy therefore is 1 by default. -> replace NaNs
-        if math.isnan(token_accuracy):
-            token_accuracy = torch.tensor(1.0, device=token_accuracy.device)
-
-        token_is_referrable = torch.eq(class_label_id, model.class_types.index('refer') if 'refer' in model.class_types else -1).float()
-        _, refer_prediction = refer_logits.max(1)
-        refer_correctness = torch.eq(refer_prediction, refer_id).float()
-        refer_accuracy = refer_correctness.sum() / token_is_referrable.sum()
-        # NaNs mean that none of the examples in this batch contain referrals. -> division by 0
-        # The accuracy therefore is 1 by default. -> replace NaNs
-        if math.isnan(refer_accuracy) or math.isinf(refer_accuracy):
-            refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device)
-            
-        total_correctness = class_correctness * (token_is_pointable * token_correctness + (1 - token_is_pointable)) * (token_is_referrable * refer_correctness + (1 - token_is_referrable))
-        total_accuracy = total_correctness.mean()
-
-        loss = per_example_loss.mean()
-        metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy
-        metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy
-        metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy
-        metric_dict['eval_accuracy_%s' % slot] = total_accuracy
-        metric_dict['eval_loss_%s' % slot] = loss
-        per_slot_correctness[slot] = total_correctness
-
-    goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1)
-    goal_accuracy = goal_correctness.mean()
-    metric_dict['eval_accuracy_goal'] = goal_accuracy
-    metric_dict['loss'] = total_loss
-    return metric_dict
-
-
-def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, ids, input_ids_unmasked, values, inform, prefix, ds):
-    prediction_list = []
-    dialog_state = ds
-    for i in range(len(ids)):
-        if int(ids[i].split("-")[2]) == 0:
-            dialog_state = {slot: 'none' for slot in model.slot_list}
-
-        prediction = {}
-        prediction_addendum = {}
-        for slot in model.slot_list:
-            class_logits = per_slot_class_logits[slot][i]
-            start_logits = per_slot_start_logits[slot][i]
-            end_logits = per_slot_end_logits[slot][i]
-            refer_logits = per_slot_refer_logits[slot][i]
-
-            input_ids = features['input_ids'][i].tolist()
-            class_label_id = int(features['class_label_id'][slot][i])
-            start_pos = int(features['start_pos'][slot][i])
-            end_pos = int(features['end_pos'][slot][i])
-            refer_id = int(features['refer_id'][slot][i])
-            
-            class_prediction = int(class_logits.argmax())
-            start_prediction = int(start_logits.argmax())
-            end_prediction = int(end_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            prediction['guid'] = ids[i].split("-")
-            prediction['class_prediction_%s' % slot] = class_prediction
-            prediction['class_label_id_%s' % slot] = class_label_id
-            prediction['start_prediction_%s' % slot] = start_prediction
-            prediction['start_pos_%s' % slot] = start_pos
-            prediction['end_prediction_%s' % slot] = end_prediction
-            prediction['end_pos_%s' % slot] = end_pos
-            prediction['refer_prediction_%s' % slot] = refer_prediction
-            prediction['refer_id_%s' % slot] = refer_id
-            prediction['input_ids_%s' % slot] = input_ids
-
-            if class_prediction == model.class_types.index('dontcare'):
-                dialog_state[slot] = 'dontcare'
-            elif class_prediction == model.class_types.index('copy_value'):
-                input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i])
-                dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
-                dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot])
-            elif 'true' in model.class_types and class_prediction == model.class_types.index('true'):
-                dialog_state[slot] = 'true'
-            elif 'false' in model.class_types and class_prediction == model.class_types.index('false'):
-                dialog_state[slot] = 'false'
-            elif class_prediction == model.class_types.index('inform'):
-                dialog_state[slot] = '§§' + inform[i][slot]
-            # Referral case is handled below
-
-            prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot]
-            prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot]
-
-        # Referral case. All other slot values need to be seen first in order
-        # to be able to do this correctly.
-        for slot in model.slot_list:
-            class_logits = per_slot_class_logits[slot][i]
-            refer_logits = per_slot_refer_logits[slot][i]
-            
-            class_prediction = int(class_logits.argmax())
-            refer_prediction = int(refer_logits.argmax())
-
-            if 'refer' in model.class_types and class_prediction == model.class_types.index('refer'):
-                # Only slots that have been mentioned before can be referred to.
-                # One can think of a situation where one slot is referred to in the same utterance.
-                # This phenomenon is however currently not properly covered in the training data
-                # label generation process.
-                dialog_state[slot] = dialog_state[model.slot_list[refer_prediction - 1]]
-                prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update
-
-        prediction.update(prediction_addendum)
-        prediction_list.append(prediction)
-        
-    return prediction_list, dialog_state
-
-
-def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
-    if args.local_rank not in [-1, 0] and not evaluate:
-        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
-
-    # Load data features from cache or dataset file
-    cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format(
-        args.predict_type if evaluate else 'train'))
-    if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples:
-        logger.info("Loading features from cached file %s", cached_file)
-        features = torch.load(cached_file)
-    else:
-        logger.info("Creating features from dataset file at %s", args.data_dir)
-        processor_args = {'append_history': args.append_history,
-                          'use_history_labels': args.use_history_labels,
-                          'swap_utterances': args.swap_utterances,
-                          'label_value_repetitions': args.label_value_repetitions,
-                          'delexicalize_sys_utts': args.delexicalize_sys_utts,
-                          'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'}
-        if evaluate and args.predict_type == "dev":
-            examples = processor.get_dev_examples(args.data_dir, processor_args)
-        elif evaluate and args.predict_type == "test":
-            examples = processor.get_test_examples(args.data_dir, processor_args)
-        else:
-            examples = processor.get_train_examples(args.data_dir, processor_args)
-        features = convert_examples_to_features(examples=examples,
-                                                slot_list=model.slot_list,
-                                                class_types=model.class_types,
-                                                model_type=args.model_type,
-                                                tokenizer=tokenizer,
-                                                max_seq_length=args.max_seq_length,
-                                                slot_value_dropout=(0.0 if evaluate else args.svd))
-        if args.local_rank in [-1, 0]:
-            logger.info("Saving features into cached file %s", cached_file)
-            torch.save(features, cached_file)
-
-    if args.local_rank == 0 and not evaluate:
-        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
-
-    # Convert to Tensors and build dataset
-    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
-    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
-    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
-    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
-    f_start_pos = [f.start_pos for f in features]
-    f_end_pos = [f.end_pos for f in features]
-    f_inform_slot_ids = [f.inform_slot for f in features]
-    f_refer_ids = [f.refer_id for f in features]
-    f_diag_state = [f.diag_state for f in features]
-    f_class_label_ids = [f.class_label_id for f in features]
-    all_start_positions = {}
-    all_end_positions = {}
-    all_inform_slot_ids = {}
-    all_refer_ids = {}
-    all_diag_state = {}
-    all_class_label_ids = {}
-    for s in model.slot_list:
-        all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long)
-        all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long)
-        all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long)
-        all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long)
-        all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long)
-        all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long)
-    dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids,
-                                all_start_positions, all_end_positions,
-                                all_inform_slot_ids,
-                                all_refer_ids,
-                                all_diag_state,
-                                all_class_label_ids, all_example_index)
-
-    return dataset, features
-
-
 def load_and_cache_aux_examples(args, model, tokenizer, aux_task_def=None):
     if aux_task_def is None:
         return None, None
@@ -689,14 +390,14 @@ def main():
     parser.add_argument("--refer_loss_for_nonpointable", action='store_true',
                         help="Whether the refer loss for classes other than refer contribute towards total loss.")
 
-    parser.add_argument("--append_history", action='store_true',
+    parser.add_argument("--no_append_history", action='store_true',
                         help="Whether or not to append the dialog history to each turn.")
-    parser.add_argument("--use_history_labels", action='store_true',
+    parser.add_argument("--no_use_history_labels", action='store_true',
                         help="Whether or not to label the history as well.")
-    parser.add_argument("--swap_utterances", action='store_true',
-                        help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).")
-    parser.add_argument("--label_value_repetitions", action='store_true',
+    parser.add_argument("--no_label_value_repetitions", action='store_true',
                         help="Whether or not to label values that have been mentioned before.")
+    parser.add_argument("--swap_utterances", action='store_true',
+                        help="Whether or not to swap the turn utterances (default: usr|sys, swapped: sys|usr).")
     parser.add_argument("--delexicalize_sys_utts", action='store_true',
                         help="Whether or not to delexicalize the system utterances.")
     parser.add_argument("--class_aux_feats_inform", action='store_true',
@@ -727,7 +428,7 @@ def main():
     parser.add_argument("--svd", default=0.0, type=float,
                         help="Slot value dropout ratio (default: 0.0)")
 
-    parser.add_argument('--logging_steps', type=int, default=50,
+    parser.add_argument('--logging_steps', type=int, default=10,
                         help="Log every X updates steps.")
     parser.add_argument('--save_steps', type=int, default=0,
                         help="Save checkpoint every X updates steps. Overwritten by --save_epochs.")
@@ -751,6 +452,8 @@ def main():
     parser.add_argument('--fp16_opt_level', type=str, default='O1',
                         help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                              "See details at https://nvidia.github.io/apex/amp.html")
+    parser.add_argument('--local_files_only', action='store_true',
+                        help="Whether to only load local model files (useful when working offline).")
 
     parser.add_argument('--mtl_use', action='store_true', help="")
     parser.add_argument('--mtl_task_def', type=str, default="aux_task_def.json", help="")
@@ -762,14 +465,14 @@ def main():
 
     args = parser.parse_args()
 
-    assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0)
-    assert(args.svd >= 0.0 and args.svd <= 1.0)
-    assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1)
-    assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1)
-    assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1)
+    assert args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0
+    assert args.svd >= 0.0 and args.svd <= 1.0
+    assert args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1
+    assert not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1
+    assert not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1
 
-    assert(not args.mtl_use or args.gradient_accumulation_steps == 1)
-    assert(args.mtl_ratio >= 0.0 and args.mtl_ratio <= 1.0)
+    assert not args.mtl_use or args.gradient_accumulation_steps == 1
+    assert args.mtl_ratio >= 0.0 and args.mtl_ratio <= 1.0
 
     task_name = args.task_name.lower()
     if task_name not in PROCESSORS:
@@ -799,10 +502,11 @@ def main():
         args.n_gpu = 1
     args.device = device
 
-    # Setup logging
+    # Setup logging, print header
     logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                         datefmt = '%m/%d/%Y %H:%M:%S',
                         level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
+    print_header()
     logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
 
@@ -815,9 +519,10 @@ def main():
 
     args.model_type = args.model_type.lower()
     config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
-    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
+    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, local_files_only=args.local_files_only)
 
     # Add DST specific parameters to config
+    config.dst_max_seq_length = args.max_seq_length
     config.dst_dropout_rate = args.dropout_rate
     config.dst_heads_dropout_rate = args.heads_dropout
     config.dst_class_loss_ratio = args.class_loss_ratio
@@ -828,10 +533,11 @@ def main():
     config.dst_slot_list = dst_slot_list
     config.dst_class_types = dst_class_types
     config.dst_class_labels = dst_class_labels
-    config.aux_task_def = aux_task_def
+    if aux_task_def is not None:
+        config.aux_task_def = aux_task_def
 
-    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
-    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
+    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, local_files_only=args.local_files_only)
+    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, local_files_only=args.local_files_only)
 
     logger.info("Updated model config: %s" % config)
 
@@ -857,7 +563,7 @@ def main():
         
         train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False)
         aux_dataset, _ = load_and_cache_aux_examples(args, model, tokenizer, aux_task_def)
-        global_step, tr_loss, aux_loss = train(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step)
+        global_step, tr_loss, aux_loss = train_mtl(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step)
         logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
 
     # Save the trained model and the tokenizer
diff --git a/tensorlistdataset.py b/tensorlistdataset.py
index 5b27b3c..238a02e 100644
--- a/tensorlistdataset.py
+++ b/tensorlistdataset.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/utils_dst.py b/utils_dst.py
index 8f91df8..917b2fc 100644
--- a/utils_dst.py
+++ b/utils_dst.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 #
-# Copyright 2020 Heinrich Heine University Duesseldorf
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
 #
 # Part of this code is based on the source code of BERT-DST
 # (arXiv:1907.03040)
@@ -22,11 +22,22 @@
 import logging
 import six
 import numpy as np
-import json
 
 logger = logging.getLogger(__name__)
 
 
+def print_header():
+    logger.info(" _________  ________  ___  ________  ________  ___    ___ ")
+    logger.info("|\___   ___\\\   __  \|\  \|\   __  \|\   __  \|\  \  /  /|")
+    logger.info("\|___ \  \_\ \  \|\  \ \  \ \  \|\  \ \  \|\  \ \  \/  / /")
+    logger.info("     \ \  \ \ \   _  _\ \  \ \   ____\ \   ____\ \    / / ")
+    logger.info("      \ \  \ \ \  \\\  \\\ \  \ \  \___|\ \  \___|\/  /  /  ")
+    logger.info("       \ \__\ \ \__\\\ _\\\ \__\ \__\    \ \__\ __/  / /    ")
+    logger.info("        \|__|  \|__|\|__|\|__|\|__|     \|__||\___/ /     ")
+    logger.info("          (c) 2022 Heinrich Heine University \|___|/      ")
+    logger.info("")
+
+
 class DSTExample(object):
     """
     A single training/test example for the DST dataset.
@@ -70,23 +81,23 @@ class DSTExample(object):
         s += ", text_b: %s" % (self.text_b)
         s += ", history: %s" % (self.history)
         if self.text_a_label:
-            s += ", text_a_label: %d" % (self.text_a_label)
+            s += ", text_a_label: %s" % (self.text_a_label)
         if self.text_b_label:
-            s += ", text_b_label: %d" % (self.text_b_label)
+            s += ", text_b_label: %s" % (self.text_b_label)
         if self.history_label:
-            s += ", history_label: %d" % (self.history_label)
+            s += ", history_label: %s" % (self.history_label)
         if self.values:
-            s += ", values: %d" % (self.values)
+            s += ", values: %s" % (self.values)
         if self.inform_label:
-            s += ", inform_label: %d" % (self.inform_label)
+            s += ", inform_label: %s" % (self.inform_label)
         if self.inform_slot_label:
-            s += ", inform_slot_label: %d" % (self.inform_slot_label)
+            s += ", inform_slot_label: %s" % (self.inform_slot_label)
         if self.refer_label:
-            s += ", refer_label: %d" % (self.refer_label)
+            s += ", refer_label: %s" % (self.refer_label)
         if self.diag_state:
-            s += ", diag_state: %d" % (self.diag_state)
+            s += ", diag_state: %s" % (self.diag_state)
         if self.class_label:
-            s += ", class_label: %d" % (self.class_label)
+            s += ", class_label: %s" % (self.class_label)
         return s
 
 
@@ -147,50 +158,59 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
 
     if model_type == 'bert':
         model_specs = {'MODEL_TYPE': 'bert',
-                       'CLS_TOKEN': '[CLS]',
-                       'UNK_TOKEN': '[UNK]',
-                       'SEP_TOKEN': '[SEP]',
+                       'TOKEN_CORRECTION': 4}
+    elif model_type == 'electra':
+        model_specs = {'MODEL_TYPE': 'electra',
                        'TOKEN_CORRECTION': 4}
     elif model_type == 'roberta':
         model_specs = {'MODEL_TYPE': 'roberta',
-                       'CLS_TOKEN': '<s>',
-                       'UNK_TOKEN': '<unk>',
-                       'SEP_TOKEN': '</s>',
                        'TOKEN_CORRECTION': 6}
     else:
         logger.error("Unknown model type (%s). Aborting." % (model_type))
         exit(1)
 
-    def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, model_specs, slot_value_dropout):
-        joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label
+    def _tokenize_text(text, text_label_dict, tokenizer, model_specs, slot_value_dropout):
+        # Join text labels across slots (used for SVD)
+        joint_text_label = [0 for _ in range(len(text))]
         for slot_text_label in text_label_dict.values():
             for idx, label in enumerate(slot_text_label):
                 if label == 1:
                     joint_text_label[idx] = 1
 
-        text_label = text_label_dict[slot]
+        token_to_subtoken = []
         tokens = []
         tokens_unmasked = []
-        token_labels = []
-        for token, token_label, joint_label in zip(text, text_label, joint_text_label):
+        for token, joint_label in zip(text, joint_text_label):
             token = convert_to_unicode(token)
             if model_specs['MODEL_TYPE'] == 'roberta':
-                token = ' ' + token
+                # It seems the behaviour of the tokenizer changed in newer versions,
+                # which makes this case handling necessary.
+                if token != tokenizer.unk_token:
+                    token = ' ' + token
             sub_tokens = tokenizer.tokenize(token) # Most time intensive step
             tokens_unmasked.extend(sub_tokens)
             if slot_value_dropout == 0.0 or joint_label == 0:
+                token_to_subtoken.append([token, sub_tokens])
                 tokens.extend(sub_tokens)
             else:
                 rn_list = np.random.random_sample((len(sub_tokens),))
+                element = [token, []]
                 for rn, sub_token in zip(rn_list, sub_tokens):
                     if rn > slot_value_dropout:
+                        element[1].append(sub_token)
                         tokens.append(sub_token)
                     else:
-                        tokens.append(model_specs['UNK_TOKEN'])
+                        element[1].append(tokenizer.unk_token)
+                        tokens.append(tokenizer.unk_token)
+                token_to_subtoken.append(element)
+        return tokens, tokens_unmasked, token_to_subtoken
+
+    def _label_tokenized_text(tokens, text_label_dict, slot):
+        token_labels = []
+        for element, token_label in zip(tokens, text_label_dict[slot]):
+            token, sub_tokens = element
             token_labels.extend([token_label for _ in sub_tokens])
-        assert len(tokens) == len(token_labels)
-        assert len(tokens_unmasked) == len(token_labels)
-        return tokens, tokens_unmasked, token_labels
+        return token_labels
 
     def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
         """Truncates a sequence pair in place to the maximum length.
@@ -284,38 +304,36 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
         # the entire model is fine-tuned.
         tokens = []
         segment_ids = []
-        tokens.append(model_specs['CLS_TOKEN'])
+        tokens.append(tokenizer.cls_token)
         segment_ids.append(0)
         for token in tokens_a:
             tokens.append(token)
             segment_ids.append(0)
-        tokens.append(model_specs['SEP_TOKEN'])
+        tokens.append(tokenizer.sep_token)
         segment_ids.append(0)
         if model_specs['MODEL_TYPE'] == 'roberta':
-            tokens.append(model_specs['SEP_TOKEN'])
+            tokens.append(tokenizer.sep_token)
             segment_ids.append(0)
-        if model_specs['MODEL_TYPE'] != 'roberta':
             for token in tokens_b:
                 tokens.append(token)
-                segment_ids.append(1)
-            tokens.append(model_specs['SEP_TOKEN'])
-            segment_ids.append(1)
+                segment_ids.append(0)
+            tokens.append(tokenizer.sep_token)
+            segment_ids.append(0)
+            tokens.append(tokenizer.sep_token)
+            segment_ids.append(0)
         else:
             for token in tokens_b:
                 tokens.append(token)
-                segment_ids.append(0)
-            tokens.append(model_specs['SEP_TOKEN'])
-            segment_ids.append(0)
-            if model_specs['MODEL_TYPE'] == 'roberta':
-                tokens.append(model_specs['SEP_TOKEN'])
-                segment_ids.append(0)
+                segment_ids.append(1)
+            tokens.append(tokenizer.sep_token)
+            segment_ids.append(1)
         for token in history:
             tokens.append(token)
             if model_specs['MODEL_TYPE'] == 'roberta':
                 segment_ids.append(0)
             else:
                 segment_ids.append(1)
-        tokens.append(model_specs['SEP_TOKEN'])
+        tokens.append(tokenizer.sep_token)
         if model_specs['MODEL_TYPE'] == 'roberta':
             segment_ids.append(0)
         else:
@@ -336,7 +354,7 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
         assert len(input_mask) == max_seq_length
         assert len(segment_ids) == max_seq_length
         return tokens, input_ids, input_mask, segment_ids
-    
+
     total_cnt = 0
     too_long_cnt = 0
 
@@ -350,6 +368,49 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
 
         total_cnt += 1
 
+        tokens_a, tokens_a_unmasked, token_to_subtoken_a = _tokenize_text(
+            example.text_a, example.text_a_label, tokenizer, model_specs, slot_value_dropout)
+        tokens_b, tokens_b_unmasked, token_to_subtoken_b = _tokenize_text(
+            example.text_b, example.text_b_label, tokenizer, model_specs, slot_value_dropout)
+        tokens_history, tokens_history_unmasked, token_to_subtoken_history = _tokenize_text(
+            example.history, example.history_label, tokenizer, model_specs, slot_value_dropout)
+
+        token_labels_a_dict = {}
+        token_labels_b_dict = {}
+        token_labels_history_dict = {}
+        for slot in slot_list:
+            token_labels_a_dict[slot] = _label_tokenized_text(token_to_subtoken_a, example.text_a_label, slot)
+            token_labels_b_dict[slot] = _label_tokenized_text(token_to_subtoken_b, example.text_b_label, slot)
+            token_labels_history_dict[slot] = _label_tokenized_text(token_to_subtoken_history, example.history_label, slot)
+
+        input_text_too_long = _truncate_length_and_warn(
+            tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
+
+        tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)]
+        tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)]
+        tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)]
+
+        if input_text_too_long:
+            too_long_cnt += 1
+            
+        tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
+                                                                            tokens_b,
+                                                                            tokens_history,
+                                                                            max_seq_length,
+                                                                            tokenizer,
+                                                                            model_specs)
+        if slot_value_dropout > 0.0:
+            _, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked,
+                                                                 tokens_b_unmasked,
+                                                                 tokens_history_unmasked,
+                                                                 max_seq_length,
+                                                                 tokenizer,
+                                                                 model_specs)
+        else:
+            input_ids_unmasked = input_ids
+
+        assert(len(input_ids) == len(input_ids_unmasked))
+
         value_dict = {}
         inform_dict = {}
         inform_slot_dict = {}
@@ -359,37 +420,19 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
         start_pos_dict = {}
         end_pos_dict = {}
         for slot in slot_list:
-            tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label(
-                example.text_a, example.text_a_label, slot, tokenizer, model_specs, slot_value_dropout)
-            tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label(
-                example.text_b, example.text_b_label, slot, tokenizer, model_specs, slot_value_dropout)
-            tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label(
-                example.history, example.history_label, slot, tokenizer, model_specs, slot_value_dropout)
-
-            input_text_too_long = _truncate_length_and_warn(
-                tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
-
-            if input_text_too_long:
-                if example_index < 10:
-                    if len(token_labels_a) > len(tokens_a):
-                        logger.info('    tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
-                    if len(token_labels_b) > len(tokens_b):
-                        logger.info('    tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
-                    if len(token_labels_history) > len(tokens_history):
-                        logger.info('    tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):]))
-
-                token_labels_a = token_labels_a[:len(tokens_a)]
-                token_labels_b = token_labels_b[:len(tokens_b)]
-                token_labels_history = token_labels_history[:len(tokens_history)]
-                tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)]
-                tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)]
-                tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)]
+            token_labels_a = token_labels_a_dict[slot]
+            token_labels_b = token_labels_b_dict[slot]
+            token_labels_history = token_labels_history_dict[slot]
+
+            token_labels_a = token_labels_a[:len(tokens_a)]
+            token_labels_b = token_labels_b[:len(tokens_b)]
+            token_labels_history = token_labels_history[:len(tokens_history)]
 
             assert len(token_labels_a) == len(tokens_a)
             assert len(token_labels_b) == len(tokens_b)
-            assert len(token_labels_history) == len(tokens_history)
             assert len(token_labels_a) == len(tokens_a_unmasked)
             assert len(token_labels_b) == len(tokens_b_unmasked)
+            assert len(token_labels_history) == len(tokens_history)
             assert len(token_labels_history) == len(tokens_history_unmasked)
             token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs)
 
@@ -405,27 +448,6 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
             diag_state_dict[slot] = class_types.index(example.diag_state[slot])
             class_label_id_dict[slot] = class_types.index(example.class_label[slot])
 
-        if input_text_too_long:
-            too_long_cnt += 1
-            
-        tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
-                                                                            tokens_b,
-                                                                            tokens_history,
-                                                                            max_seq_length,
-                                                                            tokenizer,
-                                                                            model_specs)
-        if slot_value_dropout > 0.0:
-            _, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked,
-                                                                 tokens_b_unmasked,
-                                                                 tokens_history_unmasked,
-                                                                 max_seq_length,
-                                                                 tokenizer,
-                                                                 model_specs)
-        else:
-            input_ids_unmasked = input_ids
-
-        assert(len(input_ids) == len(input_ids_unmasked))
-
         if example_index < 10:
             logger.info("*** Example ***")
             logger.info("guid: %s" % (example.guid))
-- 
GitLab