diff --git a/DO.example.mtl b/DO.example.mtl
new file mode 100755
index 0000000000000000000000000000000000000000..7963eb61ad23d47e3a1a9edfabbfee4d59dc8e63
--- /dev/null
+++ b/DO.example.mtl
@@ -0,0 +1,74 @@
+#!/bin/bash
+
+# Parameters ------------------------------------------------------
+
+#TASK="sim-m"
+#DATA_DIR="data/simulated-dialogue/sim-M"
+#TASK="sim-r"
+#DATA_DIR="data/simulated-dialogue/sim-R"
+#TASK="woz2"
+#DATA_DIR="data/woz2"
+TASK="multiwoz21"
+DATA_DIR="data/MULTIWOZ2.1"
+
+AUX_TASK="cola"
+AUX_DATA_DIR="data/aux/roberta_base_cased_lower"
+
+# Project paths etc. ----------------------------------------------
+
+OUT_DIR=results
+mkdir -p ${OUT_DIR}
+
+# Main ------------------------------------------------------------
+
+for step in train dev test; do
+    args_add=""
+    if [ "$step" = "train" ]; then
+	args_add="--do_train --predict_type=dummy"
+    elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+	args_add="--do_eval --predict_type=${step}"
+    fi
+
+    python3 run_dst_mtl.py \
+	    --task_name=${TASK} \
+	    --data_dir=${DATA_DIR} \
+	    --dataset_config=dataset_config/${TASK}.json \
+	    --model_type="roberta" \
+	    --model_name_or_path="roberta-base" \
+	    --do_lower_case \
+	    --learning_rate=1e-4 \
+	    --num_train_epochs=10 \
+	    --max_seq_length=180 \
+	    --per_gpu_train_batch_size=48 \
+	    --per_gpu_eval_batch_size=1 \
+	    --output_dir=${OUT_DIR} \
+	    --save_epochs=2 \
+	    --logging_steps=10 \
+	    --warmup_proportion=0.1 \
+	    --eval_all_checkpoints \
+	    --adam_epsilon=1e-6 \
+            --weight_decay=0.01 \
+	    --heads_dropout=0.1 \
+	    --label_value_repetitions \
+            --swap_utterances \
+	    --append_history \
+	    --use_history_labels \
+	    --delexicalize_sys_utts \
+	    --class_aux_feats_inform \
+	    --class_aux_feats_ds \
+	    --mtl_use \
+            --mtl_task_def=dataset_config/aux_task_def.json \
+	    --mtl_train_dataset=${AUX_TASK} \
+	    --mtl_data_dir=${AUX_DATA_DIR} \
+            --mtl_ratio=0.7 \
+	    ${args_add} \
+	    2>&1 | tee ${OUT_DIR}/${step}.log
+    
+    if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+    	python3 metric_bert_dst.py \
+    		${TASK} \
+		dataset_config/${TASK}.json \
+    		"${OUT_DIR}/pred_res.${step}*json" \
+    		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
+    fi
+done
diff --git a/README.md b/README.md
index 04489ca48fcd2a68113dbca98b0ca86fe0c1797f..b90e4deda21bdced28dc5f97fe6c32340a92e2b2 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,10 @@
 ## *** For readers of "Out-of-Task Training for Dialog State Tracking Models" ***
 
-The code update to cover the content of the paper (https://arxiv.org/abs/2011.09379) will be available here after our presentation at COLING 2020
+The first version of the MTL code is available now. `DO.example.mtl` will train a model with MTL using an auxiliary task. As of now, pre-tokenized data is loaded for the auxiliary tasks. The next update will also include tokenization of the original data.
+
+The paper is available here:
+https://www.aclweb.org/anthology/2020.coling-main.596
+https://arxiv.org/abs/2011.09379
 
 ## Introduction
 
diff --git a/data/aux/bert_base_uncased_lower.tar.gz b/data/aux/bert_base_uncased_lower.tar.gz
new file mode 100644
index 0000000000000000000000000000000000000000..4f91ff72f72ca534f644e0dfda03f4fda09157d4
Binary files /dev/null and b/data/aux/bert_base_uncased_lower.tar.gz differ
diff --git a/data/aux/roberta_base_cased_lower.tar.gz b/data/aux/roberta_base_cased_lower.tar.gz
new file mode 100644
index 0000000000000000000000000000000000000000..c9270c264070c66e4e4153a2be30ba8ad55f39d3
Binary files /dev/null and b/data/aux/roberta_base_cased_lower.tar.gz differ
diff --git a/data_processors.py b/data_processors.py
index 6d7670df0b5a0734f05499fab24bf4a137edb652..88d94ae3c3b8687080571533e5f2b40f3386b3cd 100644
--- a/data_processors.py
+++ b/data_processors.py
@@ -23,6 +23,7 @@ import json
 import dataset_woz2
 import dataset_sim
 import dataset_multiwoz21
+import dataset_aux_task
 
 
 class DataProcessor(object):
@@ -88,7 +89,14 @@ class SimProcessor(DataProcessor):
                                            'test', self.slot_list, **args)
 
 
+class AuxTaskProcessor(object):
+    def get_aux_task_examples(self, data_dir, data_name, max_seq_length):
+        file_path = os.path.join(data_dir, '{}_train.json'.format(data_name))
+        return dataset_aux_task.create_examples(file_path, max_seq_length)
+
+
 PROCESSORS = {"woz2": Woz2Processor,
               "sim-m": SimProcessor,
               "sim-r": SimProcessor,
-              "multiwoz21": Multiwoz21Processor}
+              "multiwoz21": Multiwoz21Processor,
+              "aux_task": AuxTaskProcessor}
diff --git a/dataset_aux_task.py b/dataset_aux_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..40a8667640a30342c0498243730d8c8f000d5774
--- /dev/null
+++ b/dataset_aux_task.py
@@ -0,0 +1,31 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+
+def create_examples(path, maxlen=512):
+    examples = []
+    with open(path, 'r', encoding='utf-8') as reader:
+        cnt = 0
+        for line in reader:
+            sample = json.loads(line)
+            if len(sample['token_id']) > maxlen:
+                continue
+            cnt += 1
+            examples.append(sample)
+        print('Loaded {} samples out of {}'.format(len(examples), cnt))
+    return examples
diff --git a/dataset_config/aux_task_def.json b/dataset_config/aux_task_def.json
new file mode 100644
index 0000000000000000000000000000000000000000..dd92bce417d71a105bfbcece370d0cf45dfb6f42
--- /dev/null
+++ b/dataset_config/aux_task_def.json
@@ -0,0 +1,46 @@
+{
+  "cola": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "mnli": {
+    "n_class": 3,
+    "task_type": "classification"
+  },
+  "mrpc": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "qnli": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "qqp": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "rte": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "sst": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "stsb": {
+    "n_class": 1,
+    "task_type": "regression"
+  },
+  "wnli": {
+    "n_class": 2,
+    "task_type": "classification"
+  },
+  "squad": {
+    "n_class": 2,
+    "task_type": "span"
+  },
+  "squad-v2": {
+    "n_class": 2,
+    "task_type": "span"
+  }
+}
diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py
index e514638415db99af164190fa555e0fe993c69cab..92bdd2c57ad778134af66476dfac1d963df9c81b 100644
--- a/dataset_multiwoz21.py
+++ b/dataset_multiwoz21.py
@@ -203,7 +203,7 @@ def is_in_list(tok, value):
     return found
  
 
-def delex_utt(utt, values):
+def delex_utt(utt, values, unk_token="[UNK]"):
     utt_norm = tokenize(utt)
     for s, vals in values.items():
         for v in vals:
@@ -212,7 +212,7 @@ def delex_utt(utt, values):
                 v_len = len(v_norm)
                 for i in range(len(utt_norm) + 1 - v_len):
                     if utt_norm[i:i + v_len] == v_norm:
-                        utt_norm[i:i + v_len] = ['[UNK]'] * v_len
+                        utt_norm[i:i + v_len] = [unk_token] * v_len
     return utt_norm
 
 
@@ -300,6 +300,7 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                     swap_utterances=False,
                     label_value_repetitions=False,
                     delexicalize_sys_utts=False,
+                    unk_token="[UNK]",
                     analyze=False):
     """Read a DST json file into a list of DSTExample."""
 
@@ -343,7 +344,7 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                 for slot in slot_list:
                     if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
                         inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]
-                utt_tok_list.append(delex_utt(utt['text'], inform_dict)) # normalize utterances
+                utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalize utterances
             else:
                 utt_tok_list.append(tokenize(utt['text'])) # normalize utterances
 
diff --git a/dataset_sim.py b/dataset_sim.py
index 575a29f0da6746efb963283de479e8119de39cba..e8caa3b1e486c581e3bfcfc205be4b384d71a305 100644
--- a/dataset_sim.py
+++ b/dataset_sim.py
@@ -114,15 +114,15 @@ def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok,
     return sys_utt_tok_label, usr_utt_tok_label, class_type
 
 
-def delex_utt(utt, values):
+def delex_utt(utt, values, unk_token="[UNK]"):
     utt_delex = utt.copy()
     for v in values:
-        utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start'])
+        utt_delex[v['start']:v['exclusive_end']] = [unk_token] * (v['exclusive_end'] - v['start'])
     return utt_delex
     
 
 def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_inform_dict,
-                   delexicalize_sys_utts=False, slot_last_occurrence=True):
+                   delexicalize_sys_utts=False, unk_token="[UNK]", slot_last_occurrence=True):
     """Make turn_label a dictionary of slot with value positions or being dontcare / none:
     Turn label contains:
       (1) the updates from previous to current dialogue state,
@@ -157,7 +157,7 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_i
         class_type_dict[slot_type] = class_type
 
     if delexicalize_sys_utts:
-        sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label)
+        sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label, unk_token)
 
     return (sys_utt_tok, sys_utt_tok_label_dict,
             usr_utt_tok, usr_utt_tok_label_dict,
@@ -172,6 +172,7 @@ def create_examples(input_file, set_type, slot_list,
                     swap_utterances=False,
                     label_value_repetitions=False,
                     delexicalize_sys_utts=False,
+                    unk_token="[UNK]",
                     analyze=False):
     """Read a DST json file into a list of DSTExample."""
 
@@ -207,6 +208,7 @@ def create_examples(input_file, set_type, slot_list,
                                            turn_id,
                                            sys_inform_dict,
                                            delexicalize_sys_utts=delexicalize_sys_utts,
+                                           unk_token=unk_token,
                                            slot_last_occurrence=True)
 
             if swap_utterances:
diff --git a/dataset_woz2.py b/dataset_woz2.py
index ef0a1efbaf98888f3103602e9210148ab835b03c..3c605d8bd7d4c3a95db1965c846e5ea11bbdfc78 100644
--- a/dataset_woz2.py
+++ b/dataset_woz2.py
@@ -27,7 +27,7 @@ LABEL_MAPS = {} # Loaded from file
 LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'}
 
 
-def delex_utt(utt, values):
+def delex_utt(utt, values, unk_token="[UNK]"):
     utt_norm = utt.copy()
     for s, v in values.items():
         if v != 'none':
@@ -35,7 +35,7 @@ def delex_utt(utt, values):
             v_len = len(v_norm)
             for i in range(len(utt_norm) + 1 - v_len):
                 if utt_norm[i:i + v_len] == v_norm:
-                    utt_norm[i:i + v_len] = ['[UNK]'] * v_len
+                    utt_norm[i:i + v_len] = [unk_token] * v_len
     return utt_norm
 
 
@@ -104,6 +104,7 @@ def create_examples(input_file, set_type, slot_list,
                     swap_utterances=False,
                     label_value_repetitions=False,
                     delexicalize_sys_utts=False,
+                    unk_token="[UNK]",
                     analyze=False):
     """Read a DST json file into a list of DSTExample."""
     with open(input_file, "r", encoding='utf-8') as reader:
@@ -163,7 +164,7 @@ def create_examples(input_file, set_type, slot_list,
                         _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
                         if in_sys:
                             delex_dict[slot] = label
-                sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict)
+                sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict, unk_token)
 
             new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
             new_diag_state = diag_state.copy()
@@ -177,9 +178,9 @@ def create_examples(input_file, set_type, slot_list,
                 (usr_utt_tok_label,
                  class_type,
                  is_informed) = get_turn_label(label,
-                                              sys_utt_tok,
-                                              usr_utt_tok,
-                                              slot_last_occurrence=True)
+                                               sys_utt_tok,
+                                               usr_utt_tok,
+                                               slot_last_occurrence=True)
 
                 if class_type == 'inform':
                     inform_dict[slot] = label
diff --git a/modeling_bert_dst.py b/modeling_bert_dst.py
index dc7206c3f0380417292289f2fe1909777c6cb418..e0ccc7d0539e63fe6397ad2c5fa41ac7401a4332 100644
--- a/modeling_bert_dst.py
+++ b/modeling_bert_dst.py
@@ -65,6 +65,10 @@ class BertForDST(BertPreTrainedModel):
             self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
             self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
 
+        # Head for aux task
+        if hasattr(config, "aux_task_def"):
+            self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
+
         self.init_weights()
 
     @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@@ -79,7 +83,8 @@ class BertForDST(BertPreTrainedModel):
                 inform_slot_id=None,
                 refer_id=None,
                 class_label_id=None,
-                diag_state=None):
+                diag_state=None,
+                aux_task_def=None):
         outputs = self.bert(
             input_ids,
             attention_mask=input_mask,
@@ -94,6 +99,40 @@ class BertForDST(BertPreTrainedModel):
         sequence_output = self.dropout(sequence_output)
         pooled_output = self.dropout(pooled_output)
 
+        if aux_task_def is not None:
+            if aux_task_def['task_type'] == "classification":
+                aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
+                aux_logits = self.dropout_heads(aux_logits)
+                aux_loss_fct = CrossEntropyLoss()
+                aux_loss = aux_loss_fct(aux_logits, class_label_id)
+                # add hidden states and attention if they are here
+                return (aux_loss,) + outputs[2:]
+            elif aux_task_def['task_type'] == "span":
+                aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
+                aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
+                aux_start_logits = self.dropout_heads(aux_start_logits)
+                aux_end_logits = self.dropout_heads(aux_end_logits)
+                aux_start_logits = aux_start_logits.squeeze(-1)
+                aux_end_logits = aux_end_logits.squeeze(-1)
+
+                # If we are on multi-GPU, split add a dimension
+                if len(start_pos.size()) > 1:
+                    start_pos = start_pos.squeeze(-1)
+                if len(end_pos.size()) > 1:
+                    end_pos = end_pos.squeeze(-1)
+                # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                ignored_index = aux_start_logits.size(1) # This is a single index
+                start_pos.clamp_(0, ignored_index)
+                end_pos.clamp_(0, ignored_index)
+            
+                aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+                aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
+                aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
+                aux_loss = (aux_start_loss + aux_end_loss) / 2.0
+                return (aux_loss,) + outputs[2:]
+            else:
+                raise Exception("Unknown task_type")
+
         # TODO: establish proper format in labels already?
         if inform_slot_id is not None:
             inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
diff --git a/modeling_roberta_dst.py b/modeling_roberta_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..062fccb97831f65f67a4c3a4cdeb3685688fcd51
--- /dev/null
+++ b/modeling_roberta_dst.py
@@ -0,0 +1,236 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
+from transformers.modeling_utils import (PreTrainedModel)
+from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
+                                           ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm)
+
+
+class RobertaPreTrainedModel(PreTrainedModel):
+    """ An abstract class to handle weights initialization and
+        a simple interface for dowloading and loading pretrained models.
+    """
+    config_class = RobertaConfig
+    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
+    base_model_prefix = "roberta"
+    
+    def _init_weights(self, module):
+        """ Initialize the weights """
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+        elif isinstance(module, BertLayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+
+@add_start_docstrings(
+    """RoBERTa Model with classification heads for the DST task. """,
+    ROBERTA_START_DOCSTRING,
+)
+class RobertaForDST(RobertaPreTrainedModel):
+    def __init__(self, config):
+        super(RobertaForDST, self).__init__(config)
+        self.slot_list = config.dst_slot_list
+        self.class_types = config.dst_class_types
+        self.class_labels = config.dst_class_labels
+        self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
+        self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
+        self.class_aux_feats_inform = config.dst_class_aux_feats_inform
+        self.class_aux_feats_ds = config.dst_class_aux_feats_ds
+        self.class_loss_ratio = config.dst_class_loss_ratio
+
+        # Only use refer loss if refer class is present in dataset.
+        if 'refer' in self.class_types:
+            self.refer_index = self.class_types.index('refer')
+        else:
+            self.refer_index = -1
+
+        self.roberta = RobertaModel(config)
+        self.dropout = nn.Dropout(config.dst_dropout_rate)
+        self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
+
+        if self.class_aux_feats_inform:
+            self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+        if self.class_aux_feats_ds:
+            self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+
+        aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
+
+        for slot in self.slot_list:
+            self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
+            self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
+            self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
+
+        # Head for aux task
+        if hasattr(config, "aux_task_def"):
+            self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class'])))
+
+        self.init_weights()
+
+    @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
+    def forward(self,
+                input_ids,
+                input_mask=None,
+                segment_ids=None,
+                position_ids=None,
+                head_mask=None,
+                start_pos=None,
+                end_pos=None,
+                inform_slot_id=None,
+                refer_id=None,
+                class_label_id=None,
+                diag_state=None,
+                aux_task_def=None):
+        outputs = self.roberta(
+            input_ids,
+            attention_mask=input_mask,
+            token_type_ids=segment_ids,
+            position_ids=position_ids,
+            head_mask=head_mask
+        )
+
+        sequence_output = outputs[0]
+        pooled_output = outputs[1]
+
+        sequence_output = self.dropout(sequence_output)
+        pooled_output = self.dropout(pooled_output)
+
+        if aux_task_def is not None:
+            if aux_task_def['task_type'] == "classification":
+                aux_logits = getattr(self, 'aux_out_projection')(pooled_output)
+                aux_logits = self.dropout_heads(aux_logits)
+                aux_loss_fct = CrossEntropyLoss()
+                aux_loss = aux_loss_fct(aux_logits, class_label_id)
+                # add hidden states and attention if they are here
+                return (aux_loss,) + outputs[2:]
+            elif aux_task_def['task_type'] == "span":
+                aux_logits = getattr(self, 'aux_out_projection')(sequence_output)
+                aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1)
+                aux_start_logits = self.dropout_heads(aux_start_logits)
+                aux_end_logits = self.dropout_heads(aux_end_logits)
+                aux_start_logits = aux_start_logits.squeeze(-1)
+                aux_end_logits = aux_end_logits.squeeze(-1)
+
+                # If we are on multi-GPU, split add a dimension
+                if len(start_pos.size()) > 1:
+                    start_pos = start_pos.squeeze(-1)
+                if len(end_pos.size()) > 1:
+                    end_pos = end_pos.squeeze(-1)
+                # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                ignored_index = aux_start_logits.size(1) # This is a single index
+                start_pos.clamp_(0, ignored_index)
+                end_pos.clamp_(0, ignored_index)
+            
+                aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+                aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos)
+                aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos)
+                aux_loss = (aux_start_loss + aux_end_loss) / 2.0
+                return (aux_loss,) + outputs[2:]
+            else:
+                raise Exception("Unknown task_type")
+
+        # TODO: establish proper format in labels already?
+        if inform_slot_id is not None:
+            inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
+        if diag_state is not None:
+            diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
+        
+        total_loss = 0
+        per_slot_per_example_loss = {}
+        per_slot_class_logits = {}
+        per_slot_start_logits = {}
+        per_slot_end_logits = {}
+        per_slot_refer_logits = {}
+        for slot in self.slot_list:
+            if self.class_aux_feats_inform and self.class_aux_feats_ds:
+                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
+            elif self.class_aux_feats_inform:
+                pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
+            elif self.class_aux_feats_ds:
+                pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
+            else:
+                pooled_output_aux = pooled_output
+            class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
+
+            token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
+            start_logits, end_logits = token_logits.split(1, dim=-1)
+            start_logits = start_logits.squeeze(-1)
+            end_logits = end_logits.squeeze(-1)
+
+            refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
+
+            per_slot_class_logits[slot] = class_logits
+            per_slot_start_logits[slot] = start_logits
+            per_slot_end_logits[slot] = end_logits
+            per_slot_refer_logits[slot] = refer_logits
+            
+            # If there are no labels, don't compute loss
+            if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
+                # If we are on multi-GPU, split add a dimension
+                if len(start_pos[slot].size()) > 1:
+                    start_pos[slot] = start_pos[slot].squeeze(-1)
+                if len(end_pos[slot].size()) > 1:
+                    end_pos[slot] = end_pos[slot].squeeze(-1)
+                # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                ignored_index = start_logits.size(1) # This is a single index
+                start_pos[slot].clamp_(0, ignored_index)
+                end_pos[slot].clamp_(0, ignored_index)
+
+                class_loss_fct = CrossEntropyLoss(reduction='none')
+                token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
+                refer_loss_fct = CrossEntropyLoss(reduction='none')
+
+                start_loss = token_loss_fct(start_logits, start_pos[slot])
+                end_loss = token_loss_fct(end_logits, end_pos[slot])
+                token_loss = (start_loss + end_loss) / 2.0
+
+                token_is_pointable = (start_pos[slot] > 0).float()
+                if not self.token_loss_for_nonpointable:
+                    token_loss *= token_is_pointable
+
+                refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
+                token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
+                if not self.refer_loss_for_nonpointable:
+                    refer_loss *= token_is_referrable
+
+                class_loss = class_loss_fct(class_logits, class_label_id[slot])
+
+                if self.refer_index > -1:
+                    per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
+                else:
+                    per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
+
+                total_loss += per_example_loss.sum()
+                per_slot_per_example_loss[slot] = per_example_loss
+
+        # add hidden states and attention if they are here
+        outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
+
+        return outputs
diff --git a/run_dst.py b/run_dst.py
index bde263007dfaa6bf071a3c9547f7732f5c38183b..7809bb1192db14b9cc25f04bd13b028a61e878f7 100644
--- a/run_dst.py
+++ b/run_dst.py
@@ -36,10 +36,12 @@ from tqdm import tqdm, trange
 
 from tensorboardX import SummaryWriter
 
-from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer)
+from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer,
+                          RobertaConfig, RobertaTokenizer)
 from transformers import (AdamW, get_linear_schedule_with_warmup)
 
 from modeling_bert_dst import (BertForDST)
+from modeling_roberta_dst import (RobertaForDST)
 from data_processors import PROCESSORS
 from utils_dst import (convert_examples_to_features)
 from tensorlistdataset import (TensorListDataset)
@@ -47,9 +49,11 @@ from tensorlistdataset import (TensorListDataset)
 logger = logging.getLogger(__name__)
 
 ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys())
+ALL_MODELS += tuple(RobertaConfig.pretrained_config_archive_map.keys())
 
 MODEL_CLASSES = {
     'bert': (BertConfig, BertForDST, BertTokenizer),
+    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
 }
 
 
@@ -450,7 +454,8 @@ def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
                           'use_history_labels': args.use_history_labels,
                           'swap_utterances': args.swap_utterances,
                           'label_value_repetitions': args.label_value_repetitions,
-                          'delexicalize_sys_utts': args.delexicalize_sys_utts}
+                          'delexicalize_sys_utts': args.delexicalize_sys_utts,
+                          'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'}
         if evaluate and args.predict_type == "dev":
             examples = processor.get_dev_examples(args.data_dir, processor_args)
         elif evaluate and args.predict_type == "test":
diff --git a/run_dst_mtl.py b/run_dst_mtl.py
new file mode 100644
index 0000000000000000000000000000000000000000..c077c7bd253c7de5f36117e10fe36dca6f064b5e
--- /dev/null
+++ b/run_dst_mtl.py
@@ -0,0 +1,920 @@
+# coding=utf-8
+#
+# Copyright 2020 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+import random
+import glob
+import json
+import math
+import re
+import gc
+
+import numpy as np
+import torch
+from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
+from torch.utils.data.distributed import DistributedSampler
+from tqdm import tqdm, trange
+from collections import deque
+
+from tensorboardX import SummaryWriter
+
+from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer,
+                          RobertaConfig, RobertaTokenizer)
+from transformers import (AdamW, get_linear_schedule_with_warmup)
+
+from modeling_bert_dst import (BertForDST)
+from modeling_roberta_dst import (RobertaForDST)
+from data_processors import PROCESSORS
+from utils_dst import (convert_examples_to_features, convert_aux_examples_to_features)
+from tensorlistdataset import (TensorListDataset)
+
+logger = logging.getLogger(__name__)
+
+ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys())
+ALL_MODELS += tuple(RobertaConfig.pretrained_config_archive_map.keys())
+
+MODEL_CLASSES = {
+    'bert': (BertConfig, BertForDST, BertTokenizer),
+    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
+}
+
+
+def set_seed(args):
+    random.seed(args.seed)
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+        
+def to_list(tensor):
+    return tensor.detach().cpu().tolist()
+
+
+def batch_to_device(batch, device):
+    batch_on_device = []
+    for element in batch:
+        if isinstance(element, dict):
+            batch_on_device.append({k: v.to(device) for k, v in element.items()})
+        else:
+            batch_on_device.append(element.to(device))
+    return tuple(batch_on_device)
+
+
+def train(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step=0):
+    assert(not args.mtl_use or args.gradient_accumulation_steps == 1)
+
+    """ Train the model """
+    if args.local_rank in [-1, 0]:
+        tb_writer = SummaryWriter()
+
+    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
+    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
+    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True)
+    aux_sampler = RandomSampler(aux_dataset) if args.local_rank == -1 else DistributedSampler(aux_dataset)
+    aux_dataloader = DataLoader(aux_dataset, sampler=aux_sampler, batch_size=args.train_batch_size, pin_memory=True)
+
+    if args.max_steps > 0:
+        t_total = args.max_steps
+        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
+    else:
+        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
+
+    if args.save_epochs > 0:
+        args.save_steps = t_total // args.num_train_epochs * args.save_epochs
+
+    num_warmup_steps = int(t_total * args.warmup_proportion)
+    mtl_last_step = int(t_total * args.mtl_ratio)
+
+    # Prepare optimizer and schedule (linear warmup and decay)
+    no_decay = ['bias', 'LayerNorm.weight']
+    optimizer_grouped_parameters = [
+        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
+        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
+    ]
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
+    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
+    if args.fp16:
+        try:
+            from apex import amp
+        except ImportError:
+            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
+        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
+
+    # multi-gpu training (should be after apex fp16 initialization)
+    model_single_gpu = model
+    if args.n_gpu > 1:
+        model = torch.nn.DataParallel(model_single_gpu)
+
+    # Distributed training (should be after apex fp16 initialization)
+    if args.local_rank != -1:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
+                                                          output_device=args.local_rank,
+                                                          find_unused_parameters=True)
+
+    # Train!
+    logger.info("***** Running training *****")
+    logger.info("  Num examples = %d", len(train_dataset))
+    logger.info("  Num Epochs = %d", args.num_train_epochs)
+    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
+    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
+                args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
+    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
+    logger.info("  Total optimization steps = %d", t_total)
+    logger.info("  Warmup steps = %d", num_warmup_steps)
+
+    if continue_from_global_step > 0:
+        logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step)
+    
+    global_step = 0
+    tr_loss, logging_loss = 0.0, 0.0
+    model.zero_grad()
+    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
+    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
+
+    # Iterators for aux tasks
+    aux_loss_diff = 0.0
+    aux_logged_steps = 0
+    tr_aux_loss = 0.0
+    logging_aux_loss = 0.0
+    loss_diff_queue = deque([], args.mtl_diff_window)
+    aux_iterator_dict = iter(aux_dataloader)
+
+    for _ in train_iterator:
+        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
+
+        for step, batch in enumerate(epoch_iterator):
+            # If training is continued from a checkpoint, fast forward
+            # to the state of that checkpoint.
+            if global_step < continue_from_global_step:
+                if (step + 1) % args.gradient_accumulation_steps == 0:
+                    scheduler.step()  # Update learning rate schedule
+                    global_step += 1
+                continue
+
+            #model.train()
+            batch = batch_to_device(batch, args.device)
+
+            # This is what is forwarded to the "forward" def.
+            inputs = {'input_ids':       batch[0],
+                      'input_mask':      batch[1], 
+                      'segment_ids':     batch[2],
+                      'start_pos':       batch[3],
+                      'end_pos':         batch[4],
+                      'inform_slot_id':  batch[5],
+                      'refer_id':        batch[6],
+                      'diag_state':      batch[7],
+                      'class_label_id':  batch[8]}
+
+            # MTL (optional)
+            if args.mtl_use and global_step < mtl_last_step:
+                if args.mtl_print_loss_diff:
+                    model.eval()
+                    with torch.no_grad():
+                        pre_aux_loss = model(**inputs)[0]
+
+                try:
+                    aux_batch = batch_to_device(next(aux_iterator_dict), args.device)
+                except StopIteration:
+                    logger.info("Resetting iterator for aux task")
+                    aux_iterator_dict = iter(aux_dataloader)
+                    aux_batch = batch_to_device(next(aux_iterator_dict), args.device)
+
+                aux_inputs = {'input_ids':       aux_batch[0],
+                              'input_mask':      aux_batch[1],
+                              'segment_ids':     aux_batch[2],
+                              'start_pos':       aux_batch[3],
+                              'end_pos':         aux_batch[4],
+                              'class_label_id':  aux_batch[5],
+                              'aux_task_def':    aux_task_def}
+                model.train()
+                aux_outputs = model(**aux_inputs)
+                aux_loss = aux_outputs[0]
+
+                if args.n_gpu > 1:
+                    aux_loss = aux_loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
+
+                if args.fp16:
+                    with amp.scale_loss(aux_loss, optimizer) as scaled_loss:
+                        scaled_loss.backward()
+                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
+                else:
+                    aux_loss.backward()
+                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+                tr_aux_loss += aux_loss.item()
+                aux_logged_steps += 1
+
+                optimizer.step()
+
+                model.zero_grad()
+                if args.mtl_print_loss_diff:
+                    model.eval()
+                    with torch.no_grad():
+                        post_aux_loss = model(**inputs)[0]
+                    aux_loss_diff = pre_aux_loss - post_aux_loss
+                else:
+                    post_aux_loss = 0.0 # TODO: move somewhere else...
+
+                pre_aux_loss = post_aux_loss
+
+                loss_diff_queue.append(aux_loss_diff)
+
+            # Normal training
+            model.train()
+
+            outputs = model(**inputs)
+            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
+
+            if args.n_gpu > 1:
+                loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
+            if args.gradient_accumulation_steps > 1:
+                loss = loss / args.gradient_accumulation_steps
+
+            if args.fp16:
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
+            else:
+                loss.backward()
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+            tr_loss += loss.item()
+            if (step + 1) % args.gradient_accumulation_steps == 0:
+                optimizer.step()
+                scheduler.step()  # Update learning rate schedule
+                model.zero_grad()
+                global_step += 1
+
+                # Aux losses
+                if args.mtl_use and args.local_rank in [-1, 0] and args.logging_steps > 0:
+                    tb_writer.add_scalar('loss_diff', aux_loss_diff, global_step)
+                    # TODO: make nicer
+                    if len(loss_diff_queue) > 0:
+                        tb_writer.add_scalar('loss_diff_mean', sum(loss_diff_queue) / len(loss_diff_queue), global_step)
+                    if aux_logged_steps > 0 and tr_aux_loss != logging_aux_loss:
+                        tb_writer.add_scalar('aux_loss', tr_aux_loss - logging_aux_loss, global_step)
+                        logging_aux_loss = tr_aux_loss
+
+                # Log metrics
+                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
+                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
+                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
+                    logging_loss = tr_loss
+
+                # Save model checkpoint
+                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
+                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
+                    if not os.path.exists(output_dir):
+                        os.makedirs(output_dir)
+                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
+                    model_to_save.save_pretrained(output_dir)
+                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
+                    logger.info("Saving model checkpoint to %s", output_dir)
+
+            if args.max_steps > 0 and global_step > args.max_steps:
+                epoch_iterator.close()
+                break
+
+        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
+            results = evaluate(args, model_single_gpu, tokenizer, processor, prefix=global_step)
+            for key, value in results.items():
+                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
+
+        # To prevent GPU memory to overflow
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        if args.max_steps > 0 and global_step > args.max_steps:
+            train_iterator.close()
+            break
+
+    if args.local_rank in [-1, 0]:
+        tb_writer.close()
+
+    tr_aux_loss /= aux_logged_steps
+
+    return global_step, tr_loss / global_step, tr_aux_loss
+
+
+def evaluate(args, model, tokenizer, processor, prefix=""):
+    dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True)
+
+    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
+        os.makedirs(args.output_dir)
+
+    args.eval_batch_size = args.per_gpu_eval_batch_size
+    eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly
+    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
+
+    # Eval!
+    logger.info("***** Running evaluation {} *****".format(prefix))
+    logger.info("  Num examples = %d", len(dataset))
+    logger.info("  Batch size = %d", args.eval_batch_size)
+    all_results = []
+    all_preds = []
+    ds = {slot: 'none' for slot in model.slot_list}
+    with torch.no_grad():
+        diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list}
+    for batch in tqdm(eval_dataloader, desc="Evaluating"):
+        model.eval()
+        batch = batch_to_device(batch, args.device)
+
+        # Reset dialog state if turn is first in the dialog.
+        turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
+        reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
+        for slot in model.slot_list:
+            for i in reset_diag_state:
+                diag_state[slot][i] = 0
+
+        with torch.no_grad():
+            inputs = {'input_ids':       batch[0],
+                      'input_mask':      batch[1],
+                      'segment_ids':     batch[2],
+                      'start_pos':       batch[3],
+                      'end_pos':         batch[4],
+                      'inform_slot_id':  batch[5],
+                      'refer_id':        batch[6],
+                      'diag_state':      diag_state,
+                      'class_label_id':  batch[8]}
+            unique_ids = [features[i.item()].guid for i in batch[9]]
+            values = [features[i.item()].values for i in batch[9]]
+            input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]]
+            inform = [features[i.item()].inform for i in batch[9]]
+            outputs = model(**inputs)
+
+            # Update dialog state for next turn.
+            for slot in model.slot_list:
+                updates = outputs[2][slot].max(1)[1]
+                for i, u in enumerate(updates):
+                    if u != 0:
+                        diag_state[slot][i] = u
+
+        results = eval_metric(model, inputs, outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5])
+        preds, ds = predict_and_format(model, tokenizer, inputs, outputs[2], outputs[3], outputs[4], outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds)
+        all_results.append(results)
+        all_preds.append(preds)
+
+    all_preds = [item for sublist in all_preds for item in sublist] # Flatten list
+
+    # Generate final results
+    final_results = {}
+    for k in all_results[0].keys():
+        final_results[k] = torch.stack([r[k] for r in all_results]).mean()
+
+    # Write final predictions (for evaluation with external tool)
+    output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix))
+    with open(output_prediction_file, "w") as f:
+        json.dump(all_preds, f, indent=2)
+
+    return final_results
+
+
+def eval_metric(model, features, total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits):
+    metric_dict = {}
+    per_slot_correctness = {}
+    for slot in model.slot_list:
+        per_example_loss = per_slot_per_example_loss[slot]
+        class_logits = per_slot_class_logits[slot]
+        start_logits = per_slot_start_logits[slot]
+        end_logits = per_slot_end_logits[slot]
+        refer_logits = per_slot_refer_logits[slot]
+
+        class_label_id = features['class_label_id'][slot]
+        start_pos = features['start_pos'][slot]
+        end_pos = features['end_pos'][slot]
+        refer_id = features['refer_id'][slot]
+
+        _, class_prediction = class_logits.max(1)
+        class_correctness = torch.eq(class_prediction, class_label_id).float()
+        class_accuracy = class_correctness.mean()
+
+        # "is pointable" means whether class label is "copy_value",
+        # i.e., that there is a span to be detected.
+        token_is_pointable = torch.eq(class_label_id, model.class_types.index('copy_value')).float()
+        _, start_prediction = start_logits.max(1)
+        start_correctness = torch.eq(start_prediction, start_pos).float()
+        _, end_prediction = end_logits.max(1)
+        end_correctness = torch.eq(end_prediction, end_pos).float()
+        token_correctness = start_correctness * end_correctness
+        token_accuracy = (token_correctness * token_is_pointable).sum() / token_is_pointable.sum()
+        # NaNs mean that none of the examples in this batch contain spans. -> division by 0
+        # The accuracy therefore is 1 by default. -> replace NaNs
+        if math.isnan(token_accuracy):
+            token_accuracy = torch.tensor(1.0, device=token_accuracy.device)
+
+        token_is_referrable = torch.eq(class_label_id, model.class_types.index('refer') if 'refer' in model.class_types else -1).float()
+        _, refer_prediction = refer_logits.max(1)
+        refer_correctness = torch.eq(refer_prediction, refer_id).float()
+        refer_accuracy = refer_correctness.sum() / token_is_referrable.sum()
+        # NaNs mean that none of the examples in this batch contain referrals. -> division by 0
+        # The accuracy therefore is 1 by default. -> replace NaNs
+        if math.isnan(refer_accuracy) or math.isinf(refer_accuracy):
+            refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device)
+            
+        total_correctness = class_correctness * (token_is_pointable * token_correctness + (1 - token_is_pointable)) * (token_is_referrable * refer_correctness + (1 - token_is_referrable))
+        total_accuracy = total_correctness.mean()
+
+        loss = per_example_loss.mean()
+        metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy
+        metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy
+        metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy
+        metric_dict['eval_accuracy_%s' % slot] = total_accuracy
+        metric_dict['eval_loss_%s' % slot] = loss
+        per_slot_correctness[slot] = total_correctness
+
+    goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1)
+    goal_accuracy = goal_correctness.mean()
+    metric_dict['eval_accuracy_goal'] = goal_accuracy
+    metric_dict['loss'] = total_loss
+    return metric_dict
+
+
+def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, ids, input_ids_unmasked, values, inform, prefix, ds):
+    prediction_list = []
+    dialog_state = ds
+    for i in range(len(ids)):
+        if int(ids[i].split("-")[2]) == 0:
+            dialog_state = {slot: 'none' for slot in model.slot_list}
+
+        prediction = {}
+        prediction_addendum = {}
+        for slot in model.slot_list:
+            class_logits = per_slot_class_logits[slot][i]
+            start_logits = per_slot_start_logits[slot][i]
+            end_logits = per_slot_end_logits[slot][i]
+            refer_logits = per_slot_refer_logits[slot][i]
+
+            input_ids = features['input_ids'][i].tolist()
+            class_label_id = int(features['class_label_id'][slot][i])
+            start_pos = int(features['start_pos'][slot][i])
+            end_pos = int(features['end_pos'][slot][i])
+            refer_id = int(features['refer_id'][slot][i])
+            
+            class_prediction = int(class_logits.argmax())
+            start_prediction = int(start_logits.argmax())
+            end_prediction = int(end_logits.argmax())
+            refer_prediction = int(refer_logits.argmax())
+
+            prediction['guid'] = ids[i].split("-")
+            prediction['class_prediction_%s' % slot] = class_prediction
+            prediction['class_label_id_%s' % slot] = class_label_id
+            prediction['start_prediction_%s' % slot] = start_prediction
+            prediction['start_pos_%s' % slot] = start_pos
+            prediction['end_prediction_%s' % slot] = end_prediction
+            prediction['end_pos_%s' % slot] = end_pos
+            prediction['refer_prediction_%s' % slot] = refer_prediction
+            prediction['refer_id_%s' % slot] = refer_id
+            prediction['input_ids_%s' % slot] = input_ids
+
+            if class_prediction == model.class_types.index('dontcare'):
+                dialog_state[slot] = 'dontcare'
+            elif class_prediction == model.class_types.index('copy_value'):
+                input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i])
+                dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
+                dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot])
+            elif 'true' in model.class_types and class_prediction == model.class_types.index('true'):
+                dialog_state[slot] = 'true'
+            elif 'false' in model.class_types and class_prediction == model.class_types.index('false'):
+                dialog_state[slot] = 'false'
+            elif class_prediction == model.class_types.index('inform'):
+                dialog_state[slot] = 'ยงยง' + inform[i][slot]
+            # Referral case is handled below
+
+            prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot]
+            prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot]
+
+        # Referral case. All other slot values need to be seen first in order
+        # to be able to do this correctly.
+        for slot in model.slot_list:
+            class_logits = per_slot_class_logits[slot][i]
+            refer_logits = per_slot_refer_logits[slot][i]
+            
+            class_prediction = int(class_logits.argmax())
+            refer_prediction = int(refer_logits.argmax())
+
+            if 'refer' in model.class_types and class_prediction == model.class_types.index('refer'):
+                # Only slots that have been mentioned before can be referred to.
+                # One can think of a situation where one slot is referred to in the same utterance.
+                # This phenomenon is however currently not properly covered in the training data
+                # label generation process.
+                dialog_state[slot] = dialog_state[model.slot_list[refer_prediction - 1]]
+                prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update
+
+        prediction.update(prediction_addendum)
+        prediction_list.append(prediction)
+        
+    return prediction_list, dialog_state
+
+
+def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
+    if args.local_rank not in [-1, 0] and not evaluate:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+    # Load data features from cache or dataset file
+    cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format(
+        args.predict_type if evaluate else 'train'))
+    if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples:
+        logger.info("Loading features from cached file %s", cached_file)
+        features = torch.load(cached_file)
+    else:
+        logger.info("Creating features from dataset file at %s", args.data_dir)
+        processor_args = {'append_history': args.append_history,
+                          'use_history_labels': args.use_history_labels,
+                          'swap_utterances': args.swap_utterances,
+                          'label_value_repetitions': args.label_value_repetitions,
+                          'delexicalize_sys_utts': args.delexicalize_sys_utts,
+                          'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'}
+        if evaluate and args.predict_type == "dev":
+            examples = processor.get_dev_examples(args.data_dir, processor_args)
+        elif evaluate and args.predict_type == "test":
+            examples = processor.get_test_examples(args.data_dir, processor_args)
+        else:
+            examples = processor.get_train_examples(args.data_dir, processor_args)
+        features = convert_examples_to_features(examples=examples,
+                                                slot_list=model.slot_list,
+                                                class_types=model.class_types,
+                                                model_type=args.model_type,
+                                                tokenizer=tokenizer,
+                                                max_seq_length=args.max_seq_length,
+                                                slot_value_dropout=(0.0 if evaluate else args.svd))
+        if args.local_rank in [-1, 0]:
+            logger.info("Saving features into cached file %s", cached_file)
+            torch.save(features, cached_file)
+
+    if args.local_rank == 0 and not evaluate:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+    # Convert to Tensors and build dataset
+    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
+    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
+    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
+    f_start_pos = [f.start_pos for f in features]
+    f_end_pos = [f.end_pos for f in features]
+    f_inform_slot_ids = [f.inform_slot for f in features]
+    f_refer_ids = [f.refer_id for f in features]
+    f_diag_state = [f.diag_state for f in features]
+    f_class_label_ids = [f.class_label_id for f in features]
+    all_start_positions = {}
+    all_end_positions = {}
+    all_inform_slot_ids = {}
+    all_refer_ids = {}
+    all_diag_state = {}
+    all_class_label_ids = {}
+    for s in model.slot_list:
+        all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long)
+        all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long)
+        all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long)
+        all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long)
+        all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long)
+        all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long)
+    dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids,
+                                all_start_positions, all_end_positions,
+                                all_inform_slot_ids,
+                                all_refer_ids,
+                                all_diag_state,
+                                all_class_label_ids, all_example_index)
+
+    return dataset, features
+
+
+def load_and_cache_aux_examples(args, model, tokenizer, aux_task_def=None):
+    if aux_task_def is None:
+        return None, None
+
+    if args.local_rank not in [-1, 0]:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+    # Load data features from cache or dataset file
+    cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_aux_features')
+    if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples:
+        logger.info("Loading features from cached file %s", cached_file)
+        features = torch.load(cached_file)
+    else:
+        processor = PROCESSORS['aux_task']()
+        logger.info("Creating features from aux task dataset file at %s", args.mtl_data_dir)
+        examples = processor.get_aux_task_examples(data_dir=args.mtl_data_dir,
+                                                   data_name=args.mtl_train_dataset,
+                                                   max_seq_length=args.max_seq_length)
+
+        features = convert_aux_examples_to_features(examples=examples, aux_task_def=aux_task_def, max_seq_length=args.max_seq_length)
+
+        if args.local_rank in [-1, 0]:
+            logger.info("Saving features into cached file %s", cached_file)
+            torch.save(features, cached_file)
+
+    if args.local_rank == 0:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+    # Convert to Tensors and build dataset
+    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
+    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
+    all_start_positions = torch.tensor([f.start_pos for f in features], dtype=torch.long)
+    all_end_positions = torch.tensor([f.end_pos for f in features], dtype=torch.long)
+    all_label = torch.tensor([f.label for f in features], dtype=torch.long)
+    dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions, all_label)
+
+    return dataset, features
+
+
+def main():
+    parser = argparse.ArgumentParser()
+
+    # Required parameters
+    parser.add_argument("--task_name", default=None, type=str, required=True,
+                        help="Name of the task (e.g., multiwoz21).")
+    parser.add_argument("--data_dir", default=None, type=str, required=True,
+                        help="Task database.")
+    parser.add_argument("--dataset_config", default=None, type=str, required=True,
+                        help="Dataset configuration file.")
+    parser.add_argument("--predict_type", default=None, type=str, required=True,
+                        help="Portion of the data to perform prediction on (e.g., dev, test).")
+    parser.add_argument("--model_type", default=None, type=str, required=True,
+                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
+    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
+                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
+    parser.add_argument("--output_dir", default=None, type=str, required=True,
+                        help="The output directory where the model checkpoints and predictions will be written.")
+
+    # Other parameters
+    parser.add_argument("--config_name", default="", type=str,
+                        help="Pretrained config name or path if not the same as model_name")
+    parser.add_argument("--tokenizer_name", default="", type=str,
+                        help="Pretrained tokenizer name or path if not the same as model_name")
+
+    parser.add_argument("--max_seq_length", default=384, type=int,
+                        help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.")
+    parser.add_argument("--do_train", action='store_true',
+                        help="Whether to run training.")
+    parser.add_argument("--do_eval", action='store_true',
+                        help="Whether to run eval on the <predict_type> set.")
+    parser.add_argument("--evaluate_during_training", action='store_true',
+                        help="Rul evaluation during training at each logging step.")
+    parser.add_argument("--do_lower_case", action='store_true',
+                        help="Set this flag if you are using an uncased model.")
+
+    parser.add_argument("--dropout_rate", default=0.3, type=float,
+                        help="Dropout rate for BERT representations.")
+    parser.add_argument("--heads_dropout", default=0.0, type=float,
+                        help="Dropout rate for classification heads.")
+    parser.add_argument("--class_loss_ratio", default=0.8, type=float,
+                        help="The ratio applied on class loss in total loss calculation. "
+                             "Should be a value in [0.0, 1.0]. "
+                             "The ratio applied on token loss is (1-class_loss_ratio)/2. "
+                             "The ratio applied on refer loss is (1-class_loss_ratio)/2.")
+    parser.add_argument("--token_loss_for_nonpointable", action='store_true',
+                        help="Whether the token loss for classes other than copy_value contribute towards total loss.")
+    parser.add_argument("--refer_loss_for_nonpointable", action='store_true',
+                        help="Whether the refer loss for classes other than refer contribute towards total loss.")
+
+    parser.add_argument("--append_history", action='store_true',
+                        help="Whether or not to append the dialog history to each turn.")
+    parser.add_argument("--use_history_labels", action='store_true',
+                        help="Whether or not to label the history as well.")
+    parser.add_argument("--swap_utterances", action='store_true',
+                        help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).")
+    parser.add_argument("--label_value_repetitions", action='store_true',
+                        help="Whether or not to label values that have been mentioned before.")
+    parser.add_argument("--delexicalize_sys_utts", action='store_true',
+                        help="Whether or not to delexicalize the system utterances.")
+    parser.add_argument("--class_aux_feats_inform", action='store_true',
+                        help="Whether or not to use the identity of informed slots as auxiliary featurs for class prediction.")
+    parser.add_argument("--class_aux_feats_ds", action='store_true',
+                        help="Whether or not to use the identity of slots in the current dialog state as auxiliary featurs for class prediction.")
+
+    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
+                        help="Batch size per GPU/CPU for training.")
+    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
+                        help="Batch size per GPU/CPU for evaluation.")
+    parser.add_argument("--learning_rate", default=5e-5, type=float,
+                        help="The initial learning rate for Adam.")
+    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
+                        help="Number of updates steps to accumulate before performing a backward/update pass.")
+    parser.add_argument("--weight_decay", default=0.0, type=float,
+                        help="Weight deay if we apply some.")
+    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
+                        help="Epsilon for Adam optimizer.")
+    parser.add_argument("--max_grad_norm", default=1.0, type=float,
+                        help="Max gradient norm.")
+    parser.add_argument("--num_train_epochs", default=3.0, type=float,
+                        help="Total number of training epochs to perform.")
+    parser.add_argument("--max_steps", default=-1, type=int,
+                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
+    parser.add_argument("--warmup_proportion", default=0.0, type=float,
+                        help="Linear warmup over warmup_proportion * steps.")
+    parser.add_argument("--svd", default=0.0, type=float,
+                        help="Slot value dropout ratio (default: 0.0)")
+
+    parser.add_argument('--logging_steps', type=int, default=50,
+                        help="Log every X updates steps.")
+    parser.add_argument('--save_steps', type=int, default=0,
+                        help="Save checkpoint every X updates steps. Overwritten by --save_epochs.")
+    parser.add_argument('--save_epochs', type=int, default=0,
+                        help="Save checkpoint every X epochs. Overrides --save_steps.")
+    parser.add_argument("--eval_all_checkpoints", action='store_true',
+                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
+    parser.add_argument("--no_cuda", action='store_true',
+                        help="Whether not to use CUDA when available")
+    parser.add_argument('--overwrite_output_dir', action='store_true',
+                        help="Overwrite the content of the output directory")
+    parser.add_argument('--overwrite_cache', action='store_true',
+                        help="Overwrite the cached training and evaluation sets")
+    parser.add_argument('--seed', type=int, default=42,
+                        help="random seed for initialization")
+
+    parser.add_argument("--local_rank", type=int, default=-1,
+                        help="local_rank for distributed training on gpus")
+    parser.add_argument('--fp16', action='store_true',
+                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
+    parser.add_argument('--fp16_opt_level', type=str, default='O1',
+                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+                             "See details at https://nvidia.github.io/apex/amp.html")
+
+    # TODO
+    parser.add_argument('--mtl_use', action='store_true', help="")
+    parser.add_argument('--mtl_task_def', type=str, default="/home/heckmi_hhu/tools/trippy/aux_task_def.json", help="") # TODO
+    parser.add_argument('--mtl_train_dataset', type=str, default="", help="cola|mnli|mrpc|qnli|qqp|rte|sst|wnli|squad|squad-v2")
+    parser.add_argument("--mtl_data_dir", type=str, default="/home/heckmi_hhu/data/glue/canonical_data/bert_base_uncased_lower", help="") # TODO
+    parser.add_argument("--mtl_ratio", type=float, default=1.0, help="") # TODO
+    parser.add_argument("--mtl_diff_window", type=int, default=10)
+    parser.add_argument('--mtl_print_loss_diff', action='store_true', help="")
+
+    args = parser.parse_args()
+
+    assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0)
+    assert(args.svd >= 0.0 and args.svd <= 1.0)
+    assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1)
+    assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1)
+    assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1)
+
+    assert(not args.mtl_use or args.gradient_accumulation_steps == 1)
+    assert(args.mtl_ratio >= 0.0 and args.mtl_ratio <= 1.0)
+
+    task_name = args.task_name.lower()
+    if task_name not in PROCESSORS:
+        raise ValueError("Task not found: %s" % (task_name))
+
+    # Load the MTL task definitions
+    if args.mtl_use:
+        with open(args.mtl_task_def, "r", encoding='utf-8') as reader:
+            aux_task_defs = json.load(reader)
+        aux_task_def = aux_task_defs[args.mtl_train_dataset]
+    else:
+        aux_task_def = None
+
+    processor = PROCESSORS[task_name](args.dataset_config)
+    dst_slot_list = processor.slot_list
+    dst_class_types = processor.class_types
+    dst_class_labels = len(dst_class_types)
+
+    # Setup CUDA, GPU & distributed training
+    if args.local_rank == -1 or args.no_cuda:
+        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
+        args.n_gpu = torch.cuda.device_count()
+    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+        torch.cuda.set_device(args.local_rank)
+        device = torch.device("cuda", args.local_rank)
+        torch.distributed.init_process_group(backend='nccl')
+        args.n_gpu = 1
+    args.device = device
+
+    # Setup logging
+    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
+                        datefmt = '%m/%d/%Y %H:%M:%S',
+                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
+    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+                   args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
+
+    # Set seed
+    set_seed(args)
+
+    # Load pretrained model and tokenizer
+    if args.local_rank not in [-1, 0]:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
+
+    args.model_type = args.model_type.lower()
+    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
+
+    # Add DST specific parameters to config
+    config.dst_dropout_rate = args.dropout_rate
+    config.dst_heads_dropout_rate = args.heads_dropout
+    config.dst_class_loss_ratio = args.class_loss_ratio
+    config.dst_token_loss_for_nonpointable = args.token_loss_for_nonpointable
+    config.dst_refer_loss_for_nonpointable = args.refer_loss_for_nonpointable
+    config.dst_class_aux_feats_inform = args.class_aux_feats_inform
+    config.dst_class_aux_feats_ds = args.class_aux_feats_ds
+    config.dst_slot_list = dst_slot_list
+    config.dst_class_types = dst_class_types
+    config.dst_class_labels = dst_class_labels
+    config.aux_task_def = aux_task_def
+
+    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
+    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
+
+    logger.info("Updated model config: %s" % config)
+
+    if args.local_rank == 0:
+        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
+
+    model.to(args.device)
+
+    logger.info("Training/evaluation parameters %s", args)
+
+    # Training
+    if args.do_train:
+        # If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set)
+        continue_from_global_step = 0 # If set to 0, start training from the beginning
+        if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
+            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/*/' + WEIGHTS_NAME, recursive=True)))
+            if len(checkpoints) > 0:
+                checkpoint = checkpoints[-1]
+                logger.info("Resuming training from the latest checkpoint: %s", checkpoint)
+                continue_from_global_step = int(checkpoint.split('-')[-1])
+                model = model_class.from_pretrained(checkpoint)
+                model.to(args.device)
+        
+        train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False)
+        aux_dataset, _ = load_and_cache_aux_examples(args, model, tokenizer, aux_task_def)
+        global_step, tr_loss, aux_loss = train(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step)
+        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
+
+    # Save the trained model and the tokenizer
+    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
+        # Create output directory if needed
+        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
+            os.makedirs(args.output_dir)
+
+        logger.info("Saving model checkpoint to %s", args.output_dir)
+        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
+        # They can then be reloaded using `from_pretrained()`
+        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
+        model_to_save.save_pretrained(args.output_dir)
+        tokenizer.save_pretrained(args.output_dir)
+
+        # Good practice: save your training arguments together with the trained model
+        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
+
+        # Load a trained model and vocabulary that you have fine-tuned
+        model = model_class.from_pretrained(args.output_dir)
+        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
+        model.to(args.device)
+
+    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
+    results = []
+    if args.do_eval and args.local_rank in [-1, 0]:
+        output_eval_file = os.path.join(args.output_dir, "eval_res.%s.json" % (args.predict_type))
+        checkpoints = [args.output_dir]
+        if args.eval_all_checkpoints:
+            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
+            logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs
+
+        logger.info("Evaluate the following checkpoints: %s", checkpoints)
+
+        for cItr, checkpoint in enumerate(checkpoints):
+            # Reload the model
+            global_step = checkpoint.split('-')[-1]
+            if cItr == len(checkpoints) - 1:
+                global_step = "final"
+            model = model_class.from_pretrained(checkpoint)
+            model.to(args.device)
+
+            # Evaluate
+            result = evaluate(args, model, tokenizer, processor, prefix=global_step)
+            result_dict = {k: float(v) for k, v in result.items()}
+            result_dict["global_step"] = global_step
+            results.append(result_dict)
+
+            for key in sorted(result_dict.keys()):
+                logger.info("%s = %s", key, str(result_dict[key]))
+
+        with open(output_eval_file, "w") as f:
+            json.dump(results, f, indent=2)
+
+    return results
+
+
+if __name__ == "__main__":
+    main()
diff --git a/utils_dst.py b/utils_dst.py
index e91971da86f4f9185f3428e52a09f48b5247343d..8f91df8c8a824b13cadf87bb4fe551f1b7ec026d 100644
--- a/utils_dst.py
+++ b/utils_dst.py
@@ -122,6 +122,26 @@ class InputFeatures(object):
         self.class_label_id = class_label_id
 
 
+class AuxInputFeatures(object):
+    """A single set of features of data."""
+
+    def __init__(self,
+                 input_ids,
+                 input_mask,
+                 segment_ids,
+                 start_pos=None,
+                 end_pos=None,
+                 label=None,
+                 uid="NONE"):
+        self.uid = uid
+        self.input_ids = input_ids
+        self.input_mask = input_mask
+        self.segment_ids = segment_ids
+        self.start_pos = start_pos
+        self.end_pos = end_pos
+        self.label = label
+
+
 def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0):
     """Loads a data file into a list of `InputBatch`s."""
 
@@ -131,6 +151,12 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
                        'UNK_TOKEN': '[UNK]',
                        'SEP_TOKEN': '[SEP]',
                        'TOKEN_CORRECTION': 4}
+    elif model_type == 'roberta':
+        model_specs = {'MODEL_TYPE': 'roberta',
+                       'CLS_TOKEN': '<s>',
+                       'UNK_TOKEN': '<unk>',
+                       'SEP_TOKEN': '</s>',
+                       'TOKEN_CORRECTION': 6}
     else:
         logger.error("Unknown model type (%s). Aborting." % (model_type))
         exit(1)
@@ -148,6 +174,8 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
         token_labels = []
         for token, token_label, joint_label in zip(text, text_label, joint_text_label):
             token = convert_to_unicode(token)
+            if model_specs['MODEL_TYPE'] == 'roberta':
+                token = ' ' + token
             sub_tokens = tokenizer.tokenize(token) # Most time intensive step
             tokens_unmasked.extend(sub_tokens)
             if slot_value_dropout == 0.0 or joint_label == 0:
@@ -187,6 +215,7 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
         # Modifies `tokens_a` and `tokens_b` in place so that the total
         # length is less than the specified length.
         # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT)
+        # Account for <s>, </s></s>, </s></s>, </s> with "- 6" (RoBERTa)
         if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
             logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
             input_text_too_long = True
@@ -197,16 +226,20 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
 
     def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
         token_label_ids = []
-        token_label_ids.append(0) # [CLS]
+        token_label_ids.append(0) # [CLS]/<s>
         for token_label in token_labels_a:
             token_label_ids.append(token_label)
-        token_label_ids.append(0) # [SEP]
+        token_label_ids.append(0) # [SEP]/</s></s>
+        if model_specs['MODEL_TYPE'] == 'roberta':
+            token_label_ids.append(0)
         for token_label in token_labels_b:
             token_label_ids.append(token_label)
-        token_label_ids.append(0) # [SEP]
+        token_label_ids.append(0) # [SEP]/</s></s>
+        if model_specs['MODEL_TYPE'] == 'roberta':
+            token_label_ids.append(0)
         for token_label in token_labels_history:
             token_label_ids.append(token_label)
-        token_label_ids.append(0) # [SEP]
+        token_label_ids.append(0) # [SEP]/</s>
         while len(token_label_ids) < max_seq_length:
             token_label_ids.append(0) # padding
         assert len(token_label_ids) == max_seq_length
@@ -258,23 +291,45 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
             segment_ids.append(0)
         tokens.append(model_specs['SEP_TOKEN'])
         segment_ids.append(0)
-        for token in tokens_b:
-            tokens.append(token)
+        if model_specs['MODEL_TYPE'] == 'roberta':
+            tokens.append(model_specs['SEP_TOKEN'])
+            segment_ids.append(0)
+        if model_specs['MODEL_TYPE'] != 'roberta':
+            for token in tokens_b:
+                tokens.append(token)
+                segment_ids.append(1)
+            tokens.append(model_specs['SEP_TOKEN'])
             segment_ids.append(1)
-        tokens.append(model_specs['SEP_TOKEN'])
-        segment_ids.append(1)
+        else:
+            for token in tokens_b:
+                tokens.append(token)
+                segment_ids.append(0)
+            tokens.append(model_specs['SEP_TOKEN'])
+            segment_ids.append(0)
+            if model_specs['MODEL_TYPE'] == 'roberta':
+                tokens.append(model_specs['SEP_TOKEN'])
+                segment_ids.append(0)
         for token in history:
             tokens.append(token)
-            segment_ids.append(1)
+            if model_specs['MODEL_TYPE'] == 'roberta':
+                segment_ids.append(0)
+            else:
+                segment_ids.append(1)
         tokens.append(model_specs['SEP_TOKEN'])
-        segment_ids.append(1)
+        if model_specs['MODEL_TYPE'] == 'roberta':
+            segment_ids.append(0)
+        else:
+            segment_ids.append(1)
         input_ids = tokenizer.convert_tokens_to_ids(tokens)
         # The mask has 1 for real tokens and 0 for padding tokens. Only real
         # tokens are attended to.
         input_mask = [1] * len(input_ids)
         # Zero-pad up to the sequence length.
         while len(input_ids) < max_seq_length:
-            input_ids.append(0)
+            if model_specs['MODEL_TYPE'] == 'roberta':
+                input_ids.append(1)
+            else:
+                input_ids.append(0)
             input_mask.append(0)
             segment_ids.append(0)
         assert len(input_ids) == max_seq_length
@@ -408,6 +463,71 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t
     return features
 
 
+# TODO: Don't start with pre-tokenized data, instead do tokenization here
+def convert_aux_examples_to_features(examples, aux_task_def, max_seq_length):
+    """Loads a data file into a list of AuxInputFeatures."""
+
+    def _get_transformer_input(tokens, type_id, max_seq_length):
+        assert len(tokens) == len(type_id)
+        input_ids = tokens
+        segment_ids = type_id
+        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
+        input_mask = [1] * len(type_id)
+        # Zero-pad up to the sequence length.
+        while len(input_ids) < max_seq_length:
+            input_ids.append(0)
+            input_mask.append(0)
+            segment_ids.append(0)
+        assert len(input_ids) == max_seq_length
+        assert len(input_mask) == max_seq_length
+        assert len(segment_ids) == max_seq_length
+        return input_ids, input_mask, segment_ids
+    
+    features = []
+    # Convert single example
+    for (example_index, example) in enumerate(examples):
+        if example_index % 1000 == 0:
+            logger.info("Writing example %d of %d" % (example_index, len(examples)))
+
+        uid = example['uid']
+        label = example['label']
+        tokens = example['token_id']
+        type_id = example['type_id']
+
+        start = 0
+        end = 0
+        if aux_task_def['task_type'] == "span":
+            start = example['start_position']
+            end = example['end_position']
+
+        # TODO: implement truncation
+        assert len(tokens) <= max_seq_length
+
+        input_ids, input_mask, segment_ids = _get_transformer_input(tokens, type_id, max_seq_length)
+
+        if example_index < 10:
+            logger.info("*** Example ***")
+            logger.info("uid: %s" % (uid))
+            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
+            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
+            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
+            logger.info("start_pos: %s" % str(start))
+            logger.info("end_pos: %s" % str(end))
+            logger.info("label: %s" % str(label))
+
+        features.append(
+            AuxInputFeatures(
+                uid=uid,
+                input_ids=input_ids,
+                input_mask=input_mask,
+                segment_ids=segment_ids,
+                start_pos=start,
+                end_pos=end,
+                label=label))
+
+    return features
+
+
 # From bert.tokenization (TF code)
 def convert_to_unicode(text):
     """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""