diff --git a/__pycache__/data_processors.cpython-38.pyc b/__pycache__/data_processors.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..965930dfd1d0fa5fcb8c51646579fc7b0b344fa7
Binary files /dev/null and b/__pycache__/data_processors.cpython-38.pyc differ
diff --git a/__pycache__/dataset_multiwoz21.cpython-38.pyc b/__pycache__/dataset_multiwoz21.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..980eabfb6f59e73f1065b31d423fd7644c34d330
Binary files /dev/null and b/__pycache__/dataset_multiwoz21.cpython-38.pyc differ
diff --git a/__pycache__/dataset_multiwoz21_legacy.cpython-38.pyc b/__pycache__/dataset_multiwoz21_legacy.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f11d7d988363ef007d86568e72aec2474e27e8a5
Binary files /dev/null and b/__pycache__/dataset_multiwoz21_legacy.cpython-38.pyc differ
diff --git a/__pycache__/dataset_sgd.cpython-38.pyc b/__pycache__/dataset_sgd.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6dc85b8dc8629a4945ffdf35959ed4d04d91fd02
Binary files /dev/null and b/__pycache__/dataset_sgd.cpython-38.pyc differ
diff --git a/__pycache__/dataset_sim.cpython-38.pyc b/__pycache__/dataset_sim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58d3015587c0f4326f1b7532092dabecbd948479
Binary files /dev/null and b/__pycache__/dataset_sim.cpython-38.pyc differ
diff --git a/__pycache__/dataset_unified.cpython-38.pyc b/__pycache__/dataset_unified.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af53c02f892715fad740ee3bfac22a9e4fbc8b66
Binary files /dev/null and b/__pycache__/dataset_unified.cpython-38.pyc differ
diff --git a/__pycache__/dataset_woz2.cpython-38.pyc b/__pycache__/dataset_woz2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e24f6b1ca4ef2fdf77d0b8df059c9638568b846
Binary files /dev/null and b/__pycache__/dataset_woz2.cpython-38.pyc differ
diff --git a/__pycache__/dst_proto.cpython-38.pyc b/__pycache__/dst_proto.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e03b1a8d2cc009eb0d6f0d4d81b9c2d92ce01536
Binary files /dev/null and b/__pycache__/dst_proto.cpython-38.pyc differ
diff --git a/__pycache__/dst_tag.cpython-38.pyc b/__pycache__/dst_tag.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73b5fc4fe4bad0c18f1c6f749970053aee655da0
Binary files /dev/null and b/__pycache__/dst_tag.cpython-38.pyc differ
diff --git a/__pycache__/dst_train.cpython-38.pyc b/__pycache__/dst_train.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb32ecf9e04480a6b5534f4260c6f5b27ffea231
Binary files /dev/null and b/__pycache__/dst_train.cpython-38.pyc differ
diff --git a/__pycache__/modeling_dst.cpython-38.pyc b/__pycache__/modeling_dst.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59ba86a208d015f46de1f22e6b9e166e50be9aae
Binary files /dev/null and b/__pycache__/modeling_dst.cpython-38.pyc differ
diff --git a/__pycache__/utils_dst.cpython-38.pyc b/__pycache__/utils_dst.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d27297bbac053cc58585f302cfc9728ad9d6fea
Binary files /dev/null and b/__pycache__/utils_dst.cpython-38.pyc differ
diff --git a/__pycache__/utils_run.cpython-38.pyc b/__pycache__/utils_run.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cd992376294711f78e0d3dcffa6fced62d631bf
Binary files /dev/null and b/__pycache__/utils_run.cpython-38.pyc differ
diff --git a/data_processors.py b/data_processors.py
index 58cf927f25b37829f0fec05db9e6d3e8ed8787db..a842f4b44ed1a56caebfb61a1cc5315578510921 100644
--- a/data_processors.py
+++ b/data_processors.py
@@ -25,6 +25,7 @@ import dataset_sim
 import dataset_multiwoz21
 import dataset_multiwoz21_legacy
 import dataset_unified
+import dataset_sgd
 
 
 class DataProcessor(object):
@@ -37,20 +38,24 @@ class DataProcessor(object):
     label_maps = {}
     value_list = {'train': {}, 'dev': {}, 'test': {}}
 
-    def __init__(self, dataset_config, data_dir):
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
         self.data_dir = data_dir
         # Load dataset config file.
         with open(dataset_config, "r", encoding='utf-8') as f:
             raw_config = json.load(f)
         self.dataset_name = raw_config['dataset_name'] if 'dataset_name' in raw_config else ""
         self.class_types = raw_config['class_types'] # Required
-        self.slot_list = raw_config['slots'] if 'slots' in raw_config else {}
-        self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else []
-        self.boolean = raw_config['boolean'] if 'boolean' in raw_config else []
+        self.slot_list = raw_config['slots'] if 'slots' in raw_config else None
+        self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else None
+        self.boolean = raw_config['boolean'] if 'boolean' in raw_config else None
         self.label_maps = raw_config['label_maps'] if 'label_maps' in raw_config else {}
         # If not slot list is provided, generate from data.
-        if len(self.slot_list) == 0:
-            self.slot_list = self._get_slot_list()
+        if self.slot_list is None:
+            self.slot_list = self._get_slot_list(predict_type)
+        if self.noncategorical is None:
+            self.noncategorical = self._get_noncategorical(predict_type)
+        if self.boolean is None:
+            self.noncategorical = self._get_boolean(predict_type)
 
     def _add_dummy_value_to_value_list(self):
         for dset in self.value_list:
@@ -77,7 +82,13 @@ class DataProcessor(object):
                         self.value_list['train'][s][v] += new_value_list[s][v]
         self._add_dummy_value_to_value_list()
 
-    def _get_slot_list(self):
+    def _get_slot_list(self, predict_type):
+        raise NotImplementedError()
+
+    def _get_noncategorical(self, predict_type):
+        raise NotImplementedError()
+
+    def _get_boolean(self, predict_type):
         raise NotImplementedError()
 
     def prediction_normalization(self, slot, value):
@@ -94,8 +105,8 @@ class DataProcessor(object):
 
 
 class Woz2Processor(DataProcessor):
-    def __init__(self, dataset_config, data_dir):
-        super(Woz2Processor, self).__init__(dataset_config, data_dir)
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
+        super(Woz2Processor, self).__init__(dataset_config, data_dir, predict_type)
         self.value_list['train'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_train_en.json'),
                                                                self.slot_list)
         self.value_list['dev'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_validate_en.json'),
@@ -117,8 +128,8 @@ class Woz2Processor(DataProcessor):
 
 
 class Multiwoz21Processor(DataProcessor):
-    def __init__(self, dataset_config, data_dir):
-        super(Multiwoz21Processor, self).__init__(dataset_config, data_dir)
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
+        super(Multiwoz21Processor, self).__init__(dataset_config, data_dir, predict_type)
         self.value_list['train'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'train_dials.json'),
                                                                      self.slot_list)
         self.value_list['dev'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'val_dials.json'),
@@ -144,8 +155,8 @@ class Multiwoz21Processor(DataProcessor):
 
     
 class Multiwoz21LegacyProcessor(DataProcessor):
-    def __init__(self, dataset_config, data_dir):
-        super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir)
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
+        super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir, predict_type)
         self.value_list['train'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'train_dials.json'),
                                                                             self.slot_list)
         self.value_list['dev'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'val_dials.json'),
@@ -174,8 +185,8 @@ class Multiwoz21LegacyProcessor(DataProcessor):
 
 
 class SimProcessor(DataProcessor):
-    def __init__(self, dataset_config, data_dir):
-        super(SimProcessor, self).__init__(dataset_config, data_dir)
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
+        super(SimProcessor, self).__init__(dataset_config, data_dir, predict_type)
         self.value_list['train'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'train.json'),
                                                               self.slot_list)
         self.value_list['dev'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'dev.json'),
@@ -197,8 +208,8 @@ class SimProcessor(DataProcessor):
 
 
 class UnifiedDatasetProcessor(DataProcessor):
-    def __init__(self, dataset_config, data_dir):
-        super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir)
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
+        super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir, predict_type)
         self.value_list['train'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
         self.value_list['dev'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
         self.value_list['test'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
@@ -207,7 +218,7 @@ class UnifiedDatasetProcessor(DataProcessor):
     def prediction_normalization(self, slot, value):
         return dataset_unified.prediction_normalization(self.dataset_name, slot, value)
 
-    def _get_slot_list(self):
+    def _get_slot_list(self, predict_type):
         return dataset_unified.get_slot_list(self.dataset_name)
         
     def get_train_examples(self, args):
@@ -222,10 +233,47 @@ class UnifiedDatasetProcessor(DataProcessor):
         return dataset_unified.create_examples('test', self.dataset_name, self.class_types,
                                                self.slot_list, self.label_maps, **args)
 
-    
+
+class SgdProcessor(DataProcessor):
+    def __init__(self, dataset_config, data_dir, predict_type='train'):
+        super(SgdProcessor, self).__init__(dataset_config, data_dir, predict_type)
+        self.value_list['train'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'train'), self.slot_list)
+        self.value_list['dev'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'dev'), self.slot_list)
+        self.value_list['test'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'test'), self.slot_list)
+        self._add_dummy_value_to_value_list()
+
+    def prediction_normalization(self, slot, value):
+        return dataset_sgd.prediction_normalization(slot, value)
+
+    def _get_slot_list(self, predict_type='train'):
+        data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO
+        return dataset_sgd.get_slot_list(os.path.join(data_dir, predict_type, 'schema.json'))
+
+    def _get_noncategorical(self, predict_type='train'):
+        data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO
+        return dataset_sgd.get_noncategorical(os.path.join(data_dir, predict_type, 'schema.json'))
+
+    def _get_boolean(self, predict_type='train'):
+        data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO
+        return dataset_sgd.get_boolean(os.path.join(data_dir, predict_type, 'schema.json'))
+
+    def get_train_examples(self, args):
+        return dataset_sgd.create_examples(os.path.join(self.data_dir, 'train'),
+                                           'train', self.class_types, self.slot_list, self.label_maps, **args)
+
+    def get_dev_examples(self, args):
+        return dataset_sgd.create_examples(os.path.join(self.data_dir, 'dev'),
+                                           'dev', self.class_types, self.slot_list, self.label_maps, **args)
+
+    def get_test_examples(self, args):
+        return dataset_sgd.create_examples(os.path.join(self.data_dir, 'test'),
+                                           'test', self.class_types, self.slot_list, self.label_maps, **args)
+
+
 PROCESSORS = {"woz2": Woz2Processor,
               "sim-m": SimProcessor,
               "sim-r": SimProcessor,
               "multiwoz21": Multiwoz21Processor,
               "multiwoz21_legacy": Multiwoz21LegacyProcessor,
-              "unified": UnifiedDatasetProcessor}
+              "unified": UnifiedDatasetProcessor,
+              "sgd": SgdProcessor}
diff --git a/dataset_config/sgd.json b/dataset_config/sgd.json
new file mode 100644
index 0000000000000000000000000000000000000000..88a99677e8d9b722b8efd5d7166675208a38bf5a
--- /dev/null
+++ b/dataset_config/sgd.json
@@ -0,0 +1,167 @@
+{
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "true",
+    "false",
+    "refer",
+    "inform"
+  ],
+  "label_maps": {
+    "inexpensive": [
+      "cheap",
+      "lower price",
+      "lower range",
+      "cheaply",
+      "cheaper",
+      "cheapest",
+      "very affordable",
+      "low cost",
+      "low priced",
+      "low-cost",
+      "budget",
+      "bargain priced"
+    ],
+    "moderate": [
+      "moderately",
+      "reasonable",
+      "reasonably",
+      "affordable",
+      "afforadable",
+      "mid range",
+      "mid-range",
+      "priced moderately",
+      "decently priced",
+      "mid price",
+      "mid-price",
+      "mid priced",
+      "mid-priced",
+      "middle price",
+      "medium price",
+      "medium priced",
+      "not too expensive",
+      "not too cheap",
+      "not too costly",
+      "not very costly",
+      "economical",
+      "intermediate",
+      "average"
+    ],
+    "expensive": [
+      "high end",
+      "high-end",
+      "high class",
+      "high-class",
+      "high scale",
+      "high-scale",
+      "high price",
+      "high priced",
+      "higher price",
+      "above average",
+      "fancy",
+      "upscale",
+      "expensively",
+      "luxury",
+      "pricey",
+      "costly"
+    ],
+    "very expensive": [
+      "very fancy",
+      "lavish",
+      "extravagant"
+    ],
+    "0": [
+      "zero"
+    ],
+    "1": [
+      "one"
+    ],
+    "2": [
+      "two"
+    ],
+    "3": [
+      "three"
+    ],
+    "4": [
+      "four"
+    ],
+    "5": [
+      "five"
+    ],
+    "6": [
+      "six"
+    ],
+    "7": [
+      "seven"
+    ],
+    "8": [
+      "eight"
+    ],
+    "9": [
+      "nine"
+    ],
+    "10": [
+      "ten"
+    ],
+    "kitchen speaker": [
+      "kitchen"
+    ],
+    "bedroom speaker": [
+      "bedroom"
+    ],
+    "music": [
+      "concert"
+    ],
+    "sports": [
+      "match",
+      "matches",
+      "game",
+      "games"
+    ],
+    "standard": [
+      "medium-sized",
+      "intermediate"
+    ],
+    "compact": [
+      "small"
+    ],
+    "tv": [
+      "television",
+      "display"
+    ],
+    "full-size": [
+      "large",
+      "spacious"
+    ],
+    "park": [
+      "gardens",
+      "garden"
+    ],
+    "nature preserve": [
+      "natural spot",
+      "wildlife spot"
+    ],
+    "historical landmark": [
+      "historical spot"
+    ],
+    "tourist attraction": [
+      "place of interest"
+    ],
+    "theme park": [
+      "amusement park"
+    ],
+    "sports venue": [
+      "playground"
+    ],
+    "place of worship": [
+      "religious spot"
+    ],
+    "shopping area": [
+      "mall"
+    ],
+    "performing arts venue": [
+      "performance venue"
+    ]
+  }
+}
diff --git a/dataset_config/sim-m.json b/dataset_config/sim-m.json
index c11444504aadd5d5517e43cac22c69211566b674..fd9f553c61aa81a63582de26d4d17416458b228f 100644
--- a/dataset_config/sim-m.json
+++ b/dataset_config/sim-m.json
@@ -14,7 +14,5 @@
   },
   "noncategorical": [
     "movie"
-  ],
-  "boolean": [],
-  "label_maps": {}
+  ]
 }
diff --git a/dataset_config/sim-r.json b/dataset_config/sim-r.json
index d7400e51382aef920b3b6ec44ff8da31b37ee6f5..1e41a9692e99197efd55bd6613f79ec615b68a35 100644
--- a/dataset_config/sim-r.json
+++ b/dataset_config/sim-r.json
@@ -18,7 +18,5 @@
   },
   "noncategorical": [
     "restaurant_name"
-  ],
-  "boolean": [],
-  "label_maps": {}
+  ]
 }
diff --git a/dataset_config/unified_multiwoz21.json b/dataset_config/unified_multiwoz21.json
index 77140bcee127fbb3d3e43a0c078ffd066af59728..3267be101f41407ca05e1dfbd71bf752c0b9c3f0 100644
--- a/dataset_config/unified_multiwoz21.json
+++ b/dataset_config/unified_multiwoz21.json
@@ -10,7 +10,6 @@
     "inform",
     "request"
   ],
-  "slots": [],
   "noncategorical": [
     "taxi-leaveAt",
     "taxi-destination",
diff --git a/dataset_config/unified_sgd.json b/dataset_config/unified_sgd.json
new file mode 100644
index 0000000000000000000000000000000000000000..1747d95ff6826b2eb8e39c13300e3b6aa7643fe5
--- /dev/null
+++ b/dataset_config/unified_sgd.json
@@ -0,0 +1,169 @@
+{
+  "dataset_name": "sgd",
+  "class_types": [
+    "none",
+    "dontcare",
+    "copy_value",
+    "true",
+    "false",
+    "refer",
+    "inform",
+    "request"
+  ],
+  "label_maps": {
+    "inexpensive": [
+      "cheap",
+      "lower price",
+      "lower range",
+      "cheaply",
+      "cheaper",
+      "cheapest",
+      "very affordable",
+      "low cost",
+      "low priced",
+      "low-cost",
+      "budget",
+      "bargain priced"
+    ],
+    "moderate": [
+      "moderately",
+      "reasonable",
+      "reasonably",
+      "affordable",
+      "afforadable",
+      "mid range",
+      "mid-range",
+      "priced moderately",
+      "decently priced",
+      "mid price",
+      "mid-price",
+      "mid priced",
+      "mid-priced",
+      "middle price",
+      "medium price",
+      "medium priced",
+      "not too expensive",
+      "not too cheap",
+      "not too costly",
+      "not very costly",
+      "economical",
+      "intermediate",
+      "average"
+    ],
+    "expensive": [
+      "high end",
+      "high-end",
+      "high class",
+      "high-class",
+      "high scale",
+      "high-scale",
+      "high price",
+      "high priced",
+      "higher price",
+      "above average",
+      "fancy",
+      "upscale",
+      "expensively",
+      "luxury",
+      "pricey",
+      "costly"
+    ],
+    "very expensive": [
+      "very fancy",
+      "lavish",
+      "extravagant"
+    ],
+    "0": [
+      "zero"
+    ],
+    "1": [
+      "one"
+    ],
+    "2": [
+      "two"
+    ],
+    "3": [
+      "three"
+    ],
+    "4": [
+      "four"
+    ],
+    "5": [
+      "five"
+    ],
+    "6": [
+      "six"
+    ],
+    "7": [
+      "seven"
+    ],
+    "8": [
+      "eight"
+    ],
+    "9": [
+      "nine"
+    ],
+    "10": [
+      "ten"
+    ],
+    "kitchen speaker": [
+      "kitchen"
+    ],
+    "bedroom speaker": [
+      "bedroom"
+    ],
+    "music": [
+      "concert"
+    ],
+    "sports": [
+      "match",
+      "matches",
+      "game",
+      "games"
+    ],
+    "standard": [
+      "medium-sized",
+      "intermediate"
+    ],
+    "compact": [
+      "small"
+    ],
+    "tv": [
+      "television",
+      "display"
+    ],
+    "full-size": [
+      "large",
+      "spacious"
+    ],
+    "park": [
+      "gardens",
+      "garden"
+    ],
+    "nature preserve": [
+      "natural spot",
+      "wildlife spot"
+    ],
+    "historical landmark": [
+      "historical spot"
+    ],
+    "tourist attraction": [
+      "place of interest"
+    ],
+    "theme park": [
+      "amusement park"
+    ],
+    "sports venue": [
+      "playground"
+    ],
+    "place of worship": [
+      "religious spot"
+    ],
+    "shopping area": [
+      "mall"
+    ],
+    "performing arts venue": [
+      "performance venue"
+    ]
+  }
+}
diff --git a/dataset_sgd.py b/dataset_sgd.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea70b598b03648a7b9f72181d6136d0d834ea662
--- /dev/null
+++ b/dataset_sgd.py
@@ -0,0 +1,587 @@
+# coding=utf-8
+#
+# Copyright 2020-2022 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+import os
+import glob
+from tqdm import tqdm
+
+from utils_dst import (DSTExample, convert_to_unicode)
+
+
+# TODO: check what's actually needed
+def prediction_normalization(slot, value):
+    #def _normalize_value(text):
+    #    text = re.sub(" ?' ?s", "s", text)
+    #    return text
+
+    #value = _normalize_value(value)
+
+    return value
+
+
+def get_slot_list(input_file):
+    slot_list = {}
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+    for service in input_data:
+        for slot in service['slots']:
+            s = "%s-%s" % (service['service_name'], slot['name'])
+            slot_list[s] = slot['description'].lower()
+    return slot_list
+
+
+def get_noncategorical(input_file):
+    noncategorical = []
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+    for service in input_data:
+        for slot in service['slots']:
+            s = "%s-%s" % (service['service_name'], slot['name'])
+            if not slot['is_categorical']:
+                noncategorical.append(s)
+    return noncategorical
+
+
+def get_boolean(input_file):
+    boolean = []
+    with open(input_file, "r", encoding='utf-8') as reader:
+        input_data = json.load(reader)
+    for service in input_data:
+        for slot in service['slots']:
+            s = "%s-%s" % (service['service_name'], slot['name'])
+            if len(slot['possible_values']) == 2 and \
+               "True" in slot['possible_values'] and \
+               "False" in slot['possible_values']:
+                boolean.append(s)
+    return boolean
+
+
+# This should only contain label normalizations, no label mappings.
+def normalize_label(slot, value_label):
+    if isinstance(value_label, list):
+        if len(value_label) > 1:
+            value_label = value_label[0] # TODO: Workaround. Note that Multiwoz 2.2 supports variants directly in the labels.
+        elif len(value_label) == 1:
+            value_label = value_label[0]
+        elif len(value_label) == 0:
+            value_label = ""
+    # Normalization of capitalization
+    if isinstance(value_label, str):
+        value_label = value_label.lower().strip()
+
+    # Normalization of empty slots # TODO: needed?
+    if value_label == '':
+        return "none"
+
+    # Normalization of 'dontcare'
+    if value_label == 'dont care':
+        return "dontcare"
+
+    # Map to boolean slots
+
+    return value_label
+
+
+def get_token_pos(tok_list, value_label):
+    find_pos = []
+    found = False
+    label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0]
+    len_label = len(label_list)
+    for i in range(len(tok_list) + 1 - len_label):
+        if tok_list[i:i + len_label] == label_list:
+            find_pos.append((i, i + len_label)) # start, exclusive_end
+            found = True
+    return found, find_pos
+
+
+def check_label_existence(value_label, usr_utt_tok, label_maps={}):
+    in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label)
+    # If no hit even though there should be one, check for value label variants
+    if not in_usr and value_label in label_maps:
+        for value_label_variant in label_maps[value_label]:
+            in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant)
+            if in_usr:
+                break
+    return in_usr, usr_pos
+
+
+def check_slot_referral(value_label, slot, seen_slots, label_maps={}):
+    referred_slot = 'none'
+    if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking':
+        return referred_slot
+    for s in seen_slots:
+        # Avoid matches for slots that share values with different meaning.
+        # hotel-internet and -parking are handled separately as Boolean slots.
+        if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking':
+            continue
+        if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay':
+            continue
+        if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay':
+            continue
+        if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label):
+            if seen_slots[s] == value_label:
+                referred_slot = s
+                break
+            elif value_label in label_maps:
+                for value_label_variant in label_maps[value_label]:
+                    if seen_slots[s] == value_label_variant:
+                        referred_slot = s
+                        break
+    return referred_slot
+
+
+def is_in_list(tok, value):
+    found = False
+    tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
+    value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
+    tok_len = len(tok_list)
+    value_len = len(value_list)
+    for i in range(tok_len + 1 - value_len):
+        if tok_list[i:i + value_len] == value_list:
+            found = True
+            break
+    return found
+ 
+
+def delex_utt(utt, values, unk_token="[UNK]"):
+    utt_norm = tokenize(utt)
+    for s, vals in values.items():
+        for v in vals:
+            if v != 'none':
+                v_norm = tokenize(v)
+                v_len = len(v_norm)
+                for i in range(len(utt_norm) + 1 - v_len):
+                    if utt_norm[i:i + v_len] == v_norm:
+                        utt_norm[i:i + v_len] = [unk_token] * v_len
+    return utt_norm
+
+
+# Fuzzy matching to label informed slot values
+def check_slot_inform(value_label, inform_label, label_maps={}):
+    result = False
+    informed_value = 'none'
+    vl = ' '.join(tokenize(value_label))
+    for il in inform_label:
+        if vl == il:
+            result = True
+        elif is_in_list(il, vl):
+            result = True
+        elif is_in_list(vl, il):
+            result = True
+        elif il in label_maps:
+            for il_variant in label_maps[il]:
+                if vl == il_variant:
+                    result = True
+                    break
+                elif is_in_list(il_variant, vl):
+                    result = True
+                    break
+                elif is_in_list(vl, il_variant):
+                    result = True
+                    break
+        elif vl in label_maps:
+            for value_label_variant in label_maps[vl]:
+                if value_label_variant == il:
+                    result = True
+                    break
+                elif is_in_list(il, value_label_variant):
+                    result = True
+                    break
+                elif is_in_list(value_label_variant, il):
+                    result = True
+                    break
+        if result:
+            informed_value = il
+            break
+    return result, informed_value
+
+
+def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence, label_maps={}):
+    usr_utt_tok_label = [0 for _ in usr_utt_tok]
+    informed_value = 'none'
+    referred_slot = 'none'
+    if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false':
+        class_type = value_label
+    else:
+        in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok, label_maps)
+        is_informed, informed_value = check_slot_inform(value_label, inform_label, label_maps)
+        if in_usr:
+            class_type = 'copy_value'
+            if slot_last_occurrence:
+                (s, e) = usr_pos[-1]
+                for i in range(s, e):
+                    usr_utt_tok_label[i] = 1
+            else:
+                for (s, e) in usr_pos:
+                    for i in range(s, e):
+                        usr_utt_tok_label[i] = 1
+        elif is_informed:
+            class_type = 'inform'
+        else:
+            referred_slot = check_slot_referral(value_label, slot, seen_slots, label_maps)
+            if referred_slot != 'none':
+                class_type = 'refer'
+            else:
+                class_type = 'unpointable'
+    return informed_value, referred_slot, usr_utt_tok_label, class_type
+
+
+# Requestable slots, general acts and domain indicator slots
+def is_request(slot, user_act, turn_domains):
+    if slot in user_act:
+        if isinstance(user_act[slot], list):
+            for act in user_act[slot]:
+                if act['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']:
+                    return True
+        elif user_act[slot]['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']:
+            return True
+    do, sl = slot.split('-')
+    if sl == 'none' and do in turn_domains:
+        return True
+    return False
+
+
+def tokenize(utt):
+    utt_lower = convert_to_unicode(utt).lower()
+    utt_tok = utt_to_token(utt_lower)
+    return utt_tok
+
+
+def utt_to_token(utt):
+    return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0]
+
+
+def create_examples(input_file, set_type, class_types, slot_list,
+                    label_maps={},
+                    no_label_value_repetitions=False,
+                    swap_utterances=False,
+                    delexicalize_sys_utts=False,
+                    unk_token="[UNK]",
+                    boolean_slots=True,
+                    analyze=False):
+    """Read a DST json file into a list of DSTExample."""
+
+    input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json'))
+
+    examples = []
+    for input_file in input_files:
+        with open(input_file, "r", encoding='utf-8') as reader:
+            input_data = json.load(reader)
+
+        for d_itr, dialog in enumerate(tqdm(input_data)):
+            dialog_id = dialog['dialogue_id']
+            domains = dialog['services']
+            utterances = dialog['turns']
+
+            # Collects all slot changes throughout the dialog
+            cumulative_labels = {slot: 'none' for slot in slot_list}
+
+            # First system utterance is empty, since multiwoz starts with user input
+            utt_tok_list = [[]]
+            mod_slots_list = [{}]
+            inform_dict_list = [{}]
+            user_act_dict_list = [{}]
+            mod_domains_list = [{}]
+
+            # Collect all utterances and their metadata
+            usr_sys_switch = True
+            turn_itr = 0
+            for utt in utterances:
+                # Assert that system and user utterances alternate
+                is_sys_utt = utt['speaker'] == "SYSTEM"
+                if usr_sys_switch == is_sys_utt:
+                    print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
+                    break
+                usr_sys_switch = is_sys_utt
+
+                if is_sys_utt:
+                    turn_itr += 1
+
+                # Extract dialog_act information for sys and usr utts.
+                inform_dict = {}
+                user_act_dict = {}
+                modified_slots = {}
+                modified_domains = []
+                for frame in utt['frames']:
+                    actions = frame['actions']
+                    service = frame['service']
+                    spans = frame['slots']
+                    modified_domains.append(service) # Remember domains
+                    for action in actions:
+                        act = action['act']
+                        slot = action['slot']
+                        cs = "%s-%s" % (service, slot)
+                        values = action['values'] # this is a list
+                        #canonical_values = action['canonical_values']
+                        values = normalize_label(cs, values)
+                        if is_sys_utt and act in ['INFORM', 'CONFIRM', 'OFFER']:
+                            if cs not in inform_dict:
+                                inform_dict[cs] = []
+                            inform_dict[cs].append(values)
+                        elif not is_sys_utt:
+                            if cs not in user_act_dict:
+                                user_act_dict[cs] = []
+                            user_act_dict[cs].append({'domain': service,
+                                                      'intent': act,
+                                                      'slot': slot,
+                                                      'value': values})
+                            if act in ['INFORM']:
+                                modified_slots[cs] = values
+                    if not is_sys_utt:
+                        state = frame['state']
+                        #active_intent = state['active_intent']
+                        #requested_slots = state['requested_slots']
+                        for slot in state['slot_values']:
+                            cs = "%s-%s" % (service, slot)
+                            values = frame['state']['slot_values'][slot] # this is a list
+                            values = normalize_label(cs, values)
+                            # Remember modified slots and entire dialog state
+                            if cs in slot_list and cumulative_labels[cs] != values:
+                                modified_slots[cs] = values
+                                cumulative_labels[cs] = values
+                # INFO: Since the model has no mechanism to predict
+                # one among several informed value candidates, we
+                # keep only one informed value. For fairness, we
+                # apply a global rule:
+                for e in inform_dict:
+                    # ... Option 1: Always keep first informed value
+                    inform_dict[e] = list([inform_dict[e][0]])
+                    # ... Option 2: Always keep last informed value
+                    #inform_dict[e] = list([inform_dict[e][-1]])
+
+                utterance = utt['utterance']
+
+                # Delexicalize sys utterance
+                if delexicalize_sys_utts and is_sys_utt:
+                    utt_tok_list.append(delex_utt(utterance, inform_dict, unk_token)) # normalizes utterances
+                else:
+                    utt_tok_list.append(tokenize(utterance)) # normalizes utterances
+
+                inform_dict_list.append(inform_dict.copy())
+                user_act_dict_list.append(user_act_dict.copy())
+                mod_slots_list.append(modified_slots.copy())
+                modified_domains.sort()
+                mod_domains_list.append(modified_domains)
+
+            # Form proper (usr, sys) turns
+            turn_itr = 0
+            diag_seen_slots_dict = {}
+            diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
+            diag_state = {slot: 'none' for slot in slot_list}
+            sys_utt_tok = []
+            usr_utt_tok = []
+            for i in range(1, len(utt_tok_list) - 1, 2):
+                sys_utt_tok_label_dict = {}
+                usr_utt_tok_label_dict = {}
+                value_dict = {}
+                inform_dict = {}
+                inform_slot_dict = {}
+                referral_dict = {}
+                class_type_dict = {}
+                updated_slots = {slot: 0 for slot in slot_list}
+
+                # Collect turn data
+                sys_utt_tok = utt_tok_list[i - 1]
+                usr_utt_tok = utt_tok_list[i]
+                turn_slots = mod_slots_list[i]
+                inform_mem = {}
+                for inform_itr in range(0, i, 2):
+                    inform_mem.update(inform_dict_list[inform_itr])
+                #inform_mem = inform_dict_list[i - 1]
+                user_act = user_act_dict_list[i] 
+                turn_domains = mod_domains_list[i]
+
+                guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
+
+                if analyze:
+                    print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
+                    print("%15s %2s [" % (dialog_id, turn_itr), end='')
+
+                new_diag_state = diag_state.copy()
+
+                for slot in slot_list:
+                    value_label = 'none'
+                    if slot in turn_slots:
+                        value_label = turn_slots[slot]
+                        # We keep the original labels so as to not
+                        # overlook unpointable values, as well as to not
+                        # modify any of the original labels for test sets,
+                        # since this would make comparison difficult.
+                        value_dict[slot] = value_label
+                    elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
+                        value_label = diag_seen_slots_value_dict[slot]
+
+                    # Get dialog act annotations
+                    inform_label = list(['none'])
+                    inform_slot_dict[slot] = 0
+                    booking_slot = 'booking-' + slot.split('-')[1]
+                    if slot in inform_mem:
+                        inform_label = inform_mem[slot]
+                        inform_slot_dict[slot] = 1
+                    elif booking_slot in inform_mem:
+                        inform_label = inform_mem[booking_slot]
+                        inform_slot_dict[slot] = 1
+
+                    (informed_value,
+                     referred_slot,
+                     usr_utt_tok_label,
+                     class_type) = get_turn_label(value_label,
+                                                  inform_label,
+                                                  sys_utt_tok,
+                                                  usr_utt_tok,
+                                                  slot,
+                                                  diag_seen_slots_value_dict,
+                                                  slot_last_occurrence=True,
+                                                  label_maps=label_maps)
+
+                    inform_dict[slot] = informed_value
+
+                    # Requestable slots, domain indicator slots and general slots
+                    # should have class_type 'request', if they ought to be predicted.
+                    # Give other class_types preference.
+                    if 'request' in class_types:
+                        if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains):
+                            class_type = 'request'
+
+                    # Generally don't use span prediction on sys utterance (but inform prediction instead).
+                    sys_utt_tok_label = [0 for _ in sys_utt_tok]
+
+                    # Determine what to do with value repetitions.
+                    # If value is unique in seen slots, then tag it, otherwise not,
+                    # since correct slot assignment can not be guaranteed anymore.
+                    if not no_label_value_repetitions and slot in diag_seen_slots_dict:
+                        if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
+                            class_type = 'none'
+                            usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
+
+                    sys_utt_tok_label_dict[slot] = sys_utt_tok_label
+                    usr_utt_tok_label_dict[slot] = usr_utt_tok_label
+
+                    if diag_seen_slots_value_dict[slot] != value_label:
+                        updated_slots[slot] = 1
+
+                    # For now, we map all occurences of unpointable slot values
+                    # to none. However, since the labels will still suggest
+                    # a presence of unpointable slot values, the task of the
+                    # DST is still to find those values. It is just not
+                    # possible to do that via span prediction on the current input.
+                    if class_type == 'unpointable':
+                        class_type_dict[slot] = 'none'
+                        referral_dict[slot] = 'none'
+                        if analyze:
+                            if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
+                                print("(%s): %s, " % (slot, value_label), end='')
+                    elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
+                        # If slot has seen before and its class type did not change, label this slot a not present,
+                        # assuming that the slot has not actually been mentioned in this turn.
+                        # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
+                        # this must mean there is evidence in the original labels, therefore consider
+                        # them as mentioned again.
+                        class_type_dict[slot] = 'none'
+                        referral_dict[slot] = 'none'
+                    else:
+                        class_type_dict[slot] = class_type
+                        referral_dict[slot] = referred_slot
+                    # Remember that this slot was mentioned during this dialog already.
+                    if class_type != 'none':
+                        diag_seen_slots_dict[slot] = class_type
+                        diag_seen_slots_value_dict[slot] = value_label
+                        new_diag_state[slot] = class_type
+                        # Unpointable is not a valid class, therefore replace with
+                        # some valid class for now...
+                        if class_type == 'unpointable':
+                            new_diag_state[slot] = 'copy_value'
+
+                if analyze:
+                    print("]")
+
+                if not swap_utterances:
+                    txt_a = usr_utt_tok
+                    txt_b = sys_utt_tok
+                    txt_a_lbl = usr_utt_tok_label_dict
+                    txt_b_lbl = sys_utt_tok_label_dict
+                else:
+                    txt_a = sys_utt_tok
+                    txt_b = usr_utt_tok
+                    txt_a_lbl = sys_utt_tok_label_dict
+                    txt_b_lbl = usr_utt_tok_label_dict
+                examples.append(DSTExample(
+                    guid=guid,
+                    text_a=txt_a,
+                    text_b=txt_b,
+                    text_a_label=txt_a_lbl,
+                    text_b_label=txt_b_lbl,
+                    values=diag_seen_slots_value_dict.copy(),
+                    inform_label=inform_dict,
+                    inform_slot_label=inform_slot_dict,
+                    refer_label=referral_dict,
+                    diag_state=diag_state,
+                    slot_update=updated_slots,
+                    class_label=class_type_dict))
+
+                # Update some variables.
+                diag_state = new_diag_state.copy()
+
+                turn_itr += 1
+
+            if analyze:
+                print("----------------------------------------------------------------------")
+
+    return examples
+
+
+def get_value_list(input_file, slot_list, boolean_slots=True):
+    def add_to_list(value_label, cs, value_list, slot_list, exclude):
+        if cs in slot_list and value_label not in exclude:
+            if value_label not in value_list[cs]:
+                value_list[cs][value_label] = 0
+            value_list[cs][value_label] += 1
+        
+    exclude = ['none', 'dontcare']
+    if not boolean_slots:
+        exclude += ['true', 'false']
+    value_list = {slot: {} for slot in slot_list}
+    input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json'))
+    for input_file in input_files:
+        with open(input_file, "r", encoding='utf-8') as reader:
+            input_data = json.load(reader)
+        for dialog in input_data:
+            usr_sys_switch = True
+            for utt in dialog['turns']:
+                is_sys_utt = utt['speaker'] == "SYSTEM"
+                usr_sys_switch = is_sys_utt
+                inform_dict = {}
+                user_act_dict = {}
+                for frame in utt['frames']:
+                    for action in frame['actions']:
+                        cs = "%s-%s" % (frame['service'], action['slot'])
+                        value_label = normalize_label(cs, action['values'])
+                        if is_sys_utt and action['act'] in ['INFORM', 'CONFIRM', 'OFFER']:
+                            add_to_list(value_label, cs, value_list, slot_list, exclude)
+                        elif not is_sys_utt and action['act'] in ['INFORM']:
+                            add_to_list(value_label, cs, value_list, slot_list, exclude)
+                    if not is_sys_utt:
+                        for slot in frame['state']['slot_values']:
+                            cs = "%s-%s" % (frame['service'], slot)
+                            value_label = normalize_label(cs, frame['state']['slot_values'][slot])
+                            add_to_list(value_label, cs, value_list, slot_list, exclude)
+    return value_list
+
diff --git a/dst_train.py b/dst_train.py
index b01ba7bd773b6dc718cefe968be1d4539e357059..e485a965124e49b168c7bac610171a7c16a535c7 100644
--- a/dst_train.py
+++ b/dst_train.py
@@ -41,7 +41,7 @@ from utils_run import (set_seed, to_device, from_device,
 logger = logging.getLogger(__name__)
 
 
-def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor):
+def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor, continue_from_global_step=0):
     """ Train the model """
     if args.local_rank in [-1, 0]:
         tb_writer = SummaryWriter()
@@ -103,6 +103,9 @@ def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer,
     logger.info("  Total optimization steps = %d", t_total)
     logger.info("  Warmup steps = %d", num_warmup_steps)
 
+    if continue_from_global_step > 0:
+        logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step)
+
     global_step = 0
     tr_loss, logging_loss = 0.0, 0.0
     model.zero_grad()
@@ -111,11 +114,20 @@ def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer,
 
     for e_itr, _ in enumerate(train_iterator):
         epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
-        train_dataset.dropout_input()
-        train_dataset.encode_slots()
-        train_dataset.encode_slot_values()
+        if global_step >= continue_from_global_step:
+            train_dataset.dropout_input()
+            train_dataset.encode_slots()
+            train_dataset.encode_slot_values()
 
         for step, batch in enumerate(epoch_iterator):
+            # If training is continued from a checkpoint, fast forward
+            # to the state of that checkpoint.
+            if global_step < continue_from_global_step:
+                if (step + 1) % args.gradient_accumulation_steps == 0:
+                    scheduler.step()  # Update learning rate schedule
+                    global_step += 1
+                continue
+
             model.train()
 
             # Add tokenized or encoded slot descriptions and encoded values to batch.
@@ -327,7 +339,7 @@ def eval_metric(args, model, tokenizer, batch, outputs, threshold=0.0, dae=False
         per_example_tp_loss = per_slot_per_example_tp_loss[slot]
         class_logits = per_slot_class_logits[slot]
         start_logits = per_slot_start_logits[slot]
-        value_logits = per_slot_value_logits[slot]
+        #value_logits = per_slot_value_logits[slot]
         refer_logits = per_slot_refer_logits[slot]
 
         mean = []
diff --git a/modeling_dst.py b/modeling_dst.py
index 72bc8d4d2685737a9d7dd7736da390a42d6e7e12..de78b016cbe7566b5e591b216e9718f0d18d471e 100644
--- a/modeling_dst.py
+++ b/modeling_dst.py
@@ -328,8 +328,9 @@ def TransformerForDST(parent_name):
 
                 per_slot_class_logits[slot] = class_logits
                 per_slot_token_weights[slot] = token_weights
-                per_slot_value_weights[slot] = value_weights
                 per_slot_refer_logits[slot] = refer_logits
+                if self.value_matching_weight > 0.0:
+                    per_slot_value_weights[slot] = value_weights
 
                 # If there are no labels, don't compute loss
                 if class_label_id is not None and token_pos is not None and refer_id is not None:
diff --git a/run_dst.py b/run_dst.py
index 99eacc722602cae18161b46adb841085dbebe773..c9c1310d68ae7a9cf1777a1def924b0688b73c95 100644
--- a/run_dst.py
+++ b/run_dst.py
@@ -217,7 +217,7 @@ def main():
     if task_name not in PROCESSORS:
         raise ValueError("Task not found: %s" % (task_name))
 
-    processor = PROCESSORS[task_name](args.dataset_config, args.data_dir)
+    processor = PROCESSORS[task_name](args.dataset_config, args.data_dir, 'train' if not args.do_eval else args.predict_type)
     slot_list = processor.slot_list
     noncategorical = processor.noncategorical
     class_types = processor.class_types
@@ -300,7 +300,7 @@ def main():
         if args.training_phase in [-1, 0, 1]:
             train_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset="train", evaluate=False)
             dev_dataset = None
-            if args.do_train and args.evaluate_during_training:
+            if args.evaluate_during_training:
                 dev_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True)
 
         # Step 1: Pretrain attention layer for random sequence tagging.
@@ -319,7 +319,8 @@ def main():
                 model = model_class.from_pretrained(proto_checkpoint)
                 model.to(args.device)
                 train_dataset.update_model(model)
-                dev_dataset.update_model(model)
+                if dev_dataset is not None:
+                    dev_dataset.update_model(model)
                 max_tag_goal = 0.0
                 max_tag_thresh = 0.0
                 max_dae = True # default should be true
@@ -356,7 +357,8 @@ def main():
                 if args.do_train and args.evaluate_during_training:
                     dev_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True)
                 train_dataset.compute_vectors()
-                dev_dataset.compute_vectors()
+                if dev_dataset is not None:
+                    dev_dataset.compute_vectors()
                 global_step, tr_loss = train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor)
                 logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
             else:
@@ -364,13 +366,23 @@ def main():
 
         # Train full model with original training.
         if args.training_phase == -1:
-            if len(checkpoints) == 0:
-                train_dataset.compute_vectors()
+            # If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set)
+            continue_from_global_step = 0
+            if len(checkpoints) > 0:
+                with open(os.path.join(args.output_dir, 'last_checkpoint.txt'), 'r') as f:
+                    continue_from_global_step = int((f.readline()).split('-')[-1])
+                checkpoint = os.path.join(args.output_dir, 'checkpoint-%s' % continue_from_global_step)
+                logger.warning(" Resuming training from the latest checkpoint: %s", checkpoint)
+                model = model_class.from_pretrained(checkpoint)
+                model.to(args.device)
+                train_dataset.update_model(model)
+                if dev_dataset is not None:
+                    dev_dataset.update_model(model)
+            train_dataset.compute_vectors()
+            if dev_dataset is not None:
                 dev_dataset.compute_vectors()
-                global_step, tr_loss = train(args, train_dataset, dev_dataset, None, model, tokenizer, processor)
-                logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
-            else:
-                logger.warning(" Preconditions for training not fulfilled! Skipping.")
+            global_step, tr_loss = train(args, train_dataset, dev_dataset, None, model, tokenizer, processor, continue_from_global_step)
+            logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
 
     # Save the trained model and the tokenizer
     if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
diff --git a/utils_dst.py b/utils_dst.py
index 00728dd0488a6f22dc4e315eb6bbabcc4a0804b4..5d6f4129b5b6e818cfe7b7cef845d57dba6cd58d 100644
--- a/utils_dst.py
+++ b/utils_dst.py
@@ -582,7 +582,7 @@ class TrippyDataset(Dataset):
 
     def __getitem__(self, index):
         result = {}
-        # Static elements. Copy, because they will be modified below.
+        # Static elements. Copy, because they will be modified below. # TODO: make more efficient
         for key, element in self.data.items():
             if isinstance(element, dict):
                 result[key] = {k: v[index].detach().clone() for k, v in element.items()}
@@ -648,16 +648,18 @@ class TrippyDataset(Dataset):
                 # TODO: Test subsampling negative samples.
 
                 # For attention based value matching
-                result['value_labels'][slot] = torch.zeros((len(self.encoded_slot_values[slot])), dtype=torch.float)
-                result['dropout_value_feat'][slot] = torch.zeros((1, self.model.config.hidden_size), dtype=torch.float)
-                # Only train value matching, if value is extractable
-                if token_is_pointable and pos_value in self.encoded_slot_values[slot]:
-                    result['value_labels'][slot][list(self.encoded_slot_values[slot].keys()).index(pos_value)] = 1.0
-                    # In case of dropout, forward new representation as target for value matching instead.
-                    if self.dropout_value_seq is not None:
-                        if result['example_id'].item() in self.dropout_value_seq[slot]:
-                            dropout_value_seq = tuple(self.dropout_value_seq[slot][result['example_id'].item()])
-                            result['dropout_value_feat'][slot] = self.encoded_dropout_slot_values[slot][dropout_value_seq]
+                if self.encoded_slot_values is not None:
+                    result['value_labels'][slot] = torch.zeros((len(self.encoded_slot_values[slot])), dtype=torch.float)
+                    result['dropout_value_feat'][slot] = torch.zeros((1, self.model.config.hidden_size), dtype=torch.float)
+                    # Only train value matching, if value is extractable
+                    if token_is_pointable and pos_value in self.encoded_slot_values[slot]:
+                        result['value_labels'][slot][list(self.encoded_slot_values[slot].keys()).index(pos_value)] = 1.0
+                        # In case of dropout, forward new representation as target for value matching instead.
+                        if self.dropout_value_seq is not None:
+                            if result['example_id'].item() in self.dropout_value_seq[slot]:
+                                dropout_value_seq = tuple(self.dropout_value_seq[slot][result['example_id'].item()])
+                                result['dropout_value_feat'][slot] = self.encoded_dropout_slot_values[slot][dropout_value_seq]
+
         return result
 
     def _encode_text(self, text, input_ids, input_mask, mode="represent", train=False):