diff --git a/__pycache__/data_processors.cpython-38.pyc b/__pycache__/data_processors.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..965930dfd1d0fa5fcb8c51646579fc7b0b344fa7 Binary files /dev/null and b/__pycache__/data_processors.cpython-38.pyc differ diff --git a/__pycache__/dataset_multiwoz21.cpython-38.pyc b/__pycache__/dataset_multiwoz21.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..980eabfb6f59e73f1065b31d423fd7644c34d330 Binary files /dev/null and b/__pycache__/dataset_multiwoz21.cpython-38.pyc differ diff --git a/__pycache__/dataset_multiwoz21_legacy.cpython-38.pyc b/__pycache__/dataset_multiwoz21_legacy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f11d7d988363ef007d86568e72aec2474e27e8a5 Binary files /dev/null and b/__pycache__/dataset_multiwoz21_legacy.cpython-38.pyc differ diff --git a/__pycache__/dataset_sgd.cpython-38.pyc b/__pycache__/dataset_sgd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc85b8dc8629a4945ffdf35959ed4d04d91fd02 Binary files /dev/null and b/__pycache__/dataset_sgd.cpython-38.pyc differ diff --git a/__pycache__/dataset_sim.cpython-38.pyc b/__pycache__/dataset_sim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58d3015587c0f4326f1b7532092dabecbd948479 Binary files /dev/null and b/__pycache__/dataset_sim.cpython-38.pyc differ diff --git a/__pycache__/dataset_unified.cpython-38.pyc b/__pycache__/dataset_unified.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af53c02f892715fad740ee3bfac22a9e4fbc8b66 Binary files /dev/null and b/__pycache__/dataset_unified.cpython-38.pyc differ diff --git a/__pycache__/dataset_woz2.cpython-38.pyc b/__pycache__/dataset_woz2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e24f6b1ca4ef2fdf77d0b8df059c9638568b846 Binary files /dev/null and b/__pycache__/dataset_woz2.cpython-38.pyc differ diff --git a/__pycache__/dst_proto.cpython-38.pyc b/__pycache__/dst_proto.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e03b1a8d2cc009eb0d6f0d4d81b9c2d92ce01536 Binary files /dev/null and b/__pycache__/dst_proto.cpython-38.pyc differ diff --git a/__pycache__/dst_tag.cpython-38.pyc b/__pycache__/dst_tag.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73b5fc4fe4bad0c18f1c6f749970053aee655da0 Binary files /dev/null and b/__pycache__/dst_tag.cpython-38.pyc differ diff --git a/__pycache__/dst_train.cpython-38.pyc b/__pycache__/dst_train.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb32ecf9e04480a6b5534f4260c6f5b27ffea231 Binary files /dev/null and b/__pycache__/dst_train.cpython-38.pyc differ diff --git a/__pycache__/modeling_dst.cpython-38.pyc b/__pycache__/modeling_dst.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59ba86a208d015f46de1f22e6b9e166e50be9aae Binary files /dev/null and b/__pycache__/modeling_dst.cpython-38.pyc differ diff --git a/__pycache__/utils_dst.cpython-38.pyc b/__pycache__/utils_dst.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d27297bbac053cc58585f302cfc9728ad9d6fea Binary files /dev/null and b/__pycache__/utils_dst.cpython-38.pyc differ diff --git a/__pycache__/utils_run.cpython-38.pyc b/__pycache__/utils_run.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cd992376294711f78e0d3dcffa6fced62d631bf Binary files /dev/null and b/__pycache__/utils_run.cpython-38.pyc differ diff --git a/data_processors.py b/data_processors.py index 58cf927f25b37829f0fec05db9e6d3e8ed8787db..a842f4b44ed1a56caebfb61a1cc5315578510921 100644 --- a/data_processors.py +++ b/data_processors.py @@ -25,6 +25,7 @@ import dataset_sim import dataset_multiwoz21 import dataset_multiwoz21_legacy import dataset_unified +import dataset_sgd class DataProcessor(object): @@ -37,20 +38,24 @@ class DataProcessor(object): label_maps = {} value_list = {'train': {}, 'dev': {}, 'test': {}} - def __init__(self, dataset_config, data_dir): + def __init__(self, dataset_config, data_dir, predict_type='train'): self.data_dir = data_dir # Load dataset config file. with open(dataset_config, "r", encoding='utf-8') as f: raw_config = json.load(f) self.dataset_name = raw_config['dataset_name'] if 'dataset_name' in raw_config else "" self.class_types = raw_config['class_types'] # Required - self.slot_list = raw_config['slots'] if 'slots' in raw_config else {} - self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else [] - self.boolean = raw_config['boolean'] if 'boolean' in raw_config else [] + self.slot_list = raw_config['slots'] if 'slots' in raw_config else None + self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else None + self.boolean = raw_config['boolean'] if 'boolean' in raw_config else None self.label_maps = raw_config['label_maps'] if 'label_maps' in raw_config else {} # If not slot list is provided, generate from data. - if len(self.slot_list) == 0: - self.slot_list = self._get_slot_list() + if self.slot_list is None: + self.slot_list = self._get_slot_list(predict_type) + if self.noncategorical is None: + self.noncategorical = self._get_noncategorical(predict_type) + if self.boolean is None: + self.noncategorical = self._get_boolean(predict_type) def _add_dummy_value_to_value_list(self): for dset in self.value_list: @@ -77,7 +82,13 @@ class DataProcessor(object): self.value_list['train'][s][v] += new_value_list[s][v] self._add_dummy_value_to_value_list() - def _get_slot_list(self): + def _get_slot_list(self, predict_type): + raise NotImplementedError() + + def _get_noncategorical(self, predict_type): + raise NotImplementedError() + + def _get_boolean(self, predict_type): raise NotImplementedError() def prediction_normalization(self, slot, value): @@ -94,8 +105,8 @@ class DataProcessor(object): class Woz2Processor(DataProcessor): - def __init__(self, dataset_config, data_dir): - super(Woz2Processor, self).__init__(dataset_config, data_dir) + def __init__(self, dataset_config, data_dir, predict_type='train'): + super(Woz2Processor, self).__init__(dataset_config, data_dir, predict_type) self.value_list['train'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_train_en.json'), self.slot_list) self.value_list['dev'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_validate_en.json'), @@ -117,8 +128,8 @@ class Woz2Processor(DataProcessor): class Multiwoz21Processor(DataProcessor): - def __init__(self, dataset_config, data_dir): - super(Multiwoz21Processor, self).__init__(dataset_config, data_dir) + def __init__(self, dataset_config, data_dir, predict_type='train'): + super(Multiwoz21Processor, self).__init__(dataset_config, data_dir, predict_type) self.value_list['train'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'train_dials.json'), self.slot_list) self.value_list['dev'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'val_dials.json'), @@ -144,8 +155,8 @@ class Multiwoz21Processor(DataProcessor): class Multiwoz21LegacyProcessor(DataProcessor): - def __init__(self, dataset_config, data_dir): - super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir) + def __init__(self, dataset_config, data_dir, predict_type='train'): + super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir, predict_type) self.value_list['train'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'train_dials.json'), self.slot_list) self.value_list['dev'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'val_dials.json'), @@ -174,8 +185,8 @@ class Multiwoz21LegacyProcessor(DataProcessor): class SimProcessor(DataProcessor): - def __init__(self, dataset_config, data_dir): - super(SimProcessor, self).__init__(dataset_config, data_dir) + def __init__(self, dataset_config, data_dir, predict_type='train'): + super(SimProcessor, self).__init__(dataset_config, data_dir, predict_type) self.value_list['train'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'train.json'), self.slot_list) self.value_list['dev'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'dev.json'), @@ -197,8 +208,8 @@ class SimProcessor(DataProcessor): class UnifiedDatasetProcessor(DataProcessor): - def __init__(self, dataset_config, data_dir): - super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir) + def __init__(self, dataset_config, data_dir, predict_type='train'): + super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir, predict_type) self.value_list['train'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list) self.value_list['dev'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list) self.value_list['test'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list) @@ -207,7 +218,7 @@ class UnifiedDatasetProcessor(DataProcessor): def prediction_normalization(self, slot, value): return dataset_unified.prediction_normalization(self.dataset_name, slot, value) - def _get_slot_list(self): + def _get_slot_list(self, predict_type): return dataset_unified.get_slot_list(self.dataset_name) def get_train_examples(self, args): @@ -222,10 +233,47 @@ class UnifiedDatasetProcessor(DataProcessor): return dataset_unified.create_examples('test', self.dataset_name, self.class_types, self.slot_list, self.label_maps, **args) - + +class SgdProcessor(DataProcessor): + def __init__(self, dataset_config, data_dir, predict_type='train'): + super(SgdProcessor, self).__init__(dataset_config, data_dir, predict_type) + self.value_list['train'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'train'), self.slot_list) + self.value_list['dev'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'dev'), self.slot_list) + self.value_list['test'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'test'), self.slot_list) + self._add_dummy_value_to_value_list() + + def prediction_normalization(self, slot, value): + return dataset_sgd.prediction_normalization(slot, value) + + def _get_slot_list(self, predict_type='train'): + data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO + return dataset_sgd.get_slot_list(os.path.join(data_dir, predict_type, 'schema.json')) + + def _get_noncategorical(self, predict_type='train'): + data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO + return dataset_sgd.get_noncategorical(os.path.join(data_dir, predict_type, 'schema.json')) + + def _get_boolean(self, predict_type='train'): + data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO + return dataset_sgd.get_boolean(os.path.join(data_dir, predict_type, 'schema.json')) + + def get_train_examples(self, args): + return dataset_sgd.create_examples(os.path.join(self.data_dir, 'train'), + 'train', self.class_types, self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, args): + return dataset_sgd.create_examples(os.path.join(self.data_dir, 'dev'), + 'dev', self.class_types, self.slot_list, self.label_maps, **args) + + def get_test_examples(self, args): + return dataset_sgd.create_examples(os.path.join(self.data_dir, 'test'), + 'test', self.class_types, self.slot_list, self.label_maps, **args) + + PROCESSORS = {"woz2": Woz2Processor, "sim-m": SimProcessor, "sim-r": SimProcessor, "multiwoz21": Multiwoz21Processor, "multiwoz21_legacy": Multiwoz21LegacyProcessor, - "unified": UnifiedDatasetProcessor} + "unified": UnifiedDatasetProcessor, + "sgd": SgdProcessor} diff --git a/dataset_config/sgd.json b/dataset_config/sgd.json new file mode 100644 index 0000000000000000000000000000000000000000..88a99677e8d9b722b8efd5d7166675208a38bf5a --- /dev/null +++ b/dataset_config/sgd.json @@ -0,0 +1,167 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform" + ], + "label_maps": { + "inexpensive": [ + "cheap", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable", + "low cost", + "low priced", + "low-cost", + "budget", + "bargain priced" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "afforadable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap", + "not too costly", + "not very costly", + "economical", + "intermediate", + "average" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "above average", + "fancy", + "upscale", + "expensively", + "luxury", + "pricey", + "costly" + ], + "very expensive": [ + "very fancy", + "lavish", + "extravagant" + ], + "0": [ + "zero" + ], + "1": [ + "one" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "kitchen speaker": [ + "kitchen" + ], + "bedroom speaker": [ + "bedroom" + ], + "music": [ + "concert" + ], + "sports": [ + "match", + "matches", + "game", + "games" + ], + "standard": [ + "medium-sized", + "intermediate" + ], + "compact": [ + "small" + ], + "tv": [ + "television", + "display" + ], + "full-size": [ + "large", + "spacious" + ], + "park": [ + "gardens", + "garden" + ], + "nature preserve": [ + "natural spot", + "wildlife spot" + ], + "historical landmark": [ + "historical spot" + ], + "tourist attraction": [ + "place of interest" + ], + "theme park": [ + "amusement park" + ], + "sports venue": [ + "playground" + ], + "place of worship": [ + "religious spot" + ], + "shopping area": [ + "mall" + ], + "performing arts venue": [ + "performance venue" + ] + } +} diff --git a/dataset_config/sim-m.json b/dataset_config/sim-m.json index c11444504aadd5d5517e43cac22c69211566b674..fd9f553c61aa81a63582de26d4d17416458b228f 100644 --- a/dataset_config/sim-m.json +++ b/dataset_config/sim-m.json @@ -14,7 +14,5 @@ }, "noncategorical": [ "movie" - ], - "boolean": [], - "label_maps": {} + ] } diff --git a/dataset_config/sim-r.json b/dataset_config/sim-r.json index d7400e51382aef920b3b6ec44ff8da31b37ee6f5..1e41a9692e99197efd55bd6613f79ec615b68a35 100644 --- a/dataset_config/sim-r.json +++ b/dataset_config/sim-r.json @@ -18,7 +18,5 @@ }, "noncategorical": [ "restaurant_name" - ], - "boolean": [], - "label_maps": {} + ] } diff --git a/dataset_config/unified_multiwoz21.json b/dataset_config/unified_multiwoz21.json index 77140bcee127fbb3d3e43a0c078ffd066af59728..3267be101f41407ca05e1dfbd71bf752c0b9c3f0 100644 --- a/dataset_config/unified_multiwoz21.json +++ b/dataset_config/unified_multiwoz21.json @@ -10,7 +10,6 @@ "inform", "request" ], - "slots": [], "noncategorical": [ "taxi-leaveAt", "taxi-destination", diff --git a/dataset_config/unified_sgd.json b/dataset_config/unified_sgd.json new file mode 100644 index 0000000000000000000000000000000000000000..1747d95ff6826b2eb8e39c13300e3b6aa7643fe5 --- /dev/null +++ b/dataset_config/unified_sgd.json @@ -0,0 +1,169 @@ +{ + "dataset_name": "sgd", + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform", + "request" + ], + "label_maps": { + "inexpensive": [ + "cheap", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable", + "low cost", + "low priced", + "low-cost", + "budget", + "bargain priced" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "afforadable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap", + "not too costly", + "not very costly", + "economical", + "intermediate", + "average" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "above average", + "fancy", + "upscale", + "expensively", + "luxury", + "pricey", + "costly" + ], + "very expensive": [ + "very fancy", + "lavish", + "extravagant" + ], + "0": [ + "zero" + ], + "1": [ + "one" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "kitchen speaker": [ + "kitchen" + ], + "bedroom speaker": [ + "bedroom" + ], + "music": [ + "concert" + ], + "sports": [ + "match", + "matches", + "game", + "games" + ], + "standard": [ + "medium-sized", + "intermediate" + ], + "compact": [ + "small" + ], + "tv": [ + "television", + "display" + ], + "full-size": [ + "large", + "spacious" + ], + "park": [ + "gardens", + "garden" + ], + "nature preserve": [ + "natural spot", + "wildlife spot" + ], + "historical landmark": [ + "historical spot" + ], + "tourist attraction": [ + "place of interest" + ], + "theme park": [ + "amusement park" + ], + "sports venue": [ + "playground" + ], + "place of worship": [ + "religious spot" + ], + "shopping area": [ + "mall" + ], + "performing arts venue": [ + "performance venue" + ] + } +} diff --git a/dataset_sgd.py b/dataset_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..ea70b598b03648a7b9f72181d6136d0d834ea662 --- /dev/null +++ b/dataset_sgd.py @@ -0,0 +1,587 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import os +import glob +from tqdm import tqdm + +from utils_dst import (DSTExample, convert_to_unicode) + + +# TODO: check what's actually needed +def prediction_normalization(slot, value): + #def _normalize_value(text): + # text = re.sub(" ?' ?s", "s", text) + # return text + + #value = _normalize_value(value) + + return value + + +def get_slot_list(input_file): + slot_list = {} + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + for service in input_data: + for slot in service['slots']: + s = "%s-%s" % (service['service_name'], slot['name']) + slot_list[s] = slot['description'].lower() + return slot_list + + +def get_noncategorical(input_file): + noncategorical = [] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + for service in input_data: + for slot in service['slots']: + s = "%s-%s" % (service['service_name'], slot['name']) + if not slot['is_categorical']: + noncategorical.append(s) + return noncategorical + + +def get_boolean(input_file): + boolean = [] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + for service in input_data: + for slot in service['slots']: + s = "%s-%s" % (service['service_name'], slot['name']) + if len(slot['possible_values']) == 2 and \ + "True" in slot['possible_values'] and \ + "False" in slot['possible_values']: + boolean.append(s) + return boolean + + +# This should only contain label normalizations, no label mappings. +def normalize_label(slot, value_label): + if isinstance(value_label, list): + if len(value_label) > 1: + value_label = value_label[0] # TODO: Workaround. Note that Multiwoz 2.2 supports variants directly in the labels. + elif len(value_label) == 1: + value_label = value_label[0] + elif len(value_label) == 0: + value_label = "" + # Normalization of capitalization + if isinstance(value_label, str): + value_label = value_label.lower().strip() + + # Normalization of empty slots # TODO: needed? + if value_label == '': + return "none" + + # Normalization of 'dontcare' + if value_label == 'dont care': + return "dontcare" + + # Map to boolean slots + + return value_label + + +def get_token_pos(tok_list, value_label): + find_pos = [] + found = False + label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + +def check_label_existence(value_label, usr_utt_tok, label_maps={}): + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in label_maps: + for value_label_variant in label_maps[value_label]: + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + +def check_slot_referral(value_label, slot, seen_slots, label_maps={}): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay': + continue + if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in label_maps: + for value_label_variant in label_maps[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + +def is_in_list(tok, value): + found = False + tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] + value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + +def delex_utt(utt, values, unk_token="[UNK]"): + utt_norm = tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + +# Fuzzy matching to label informed slot values +def check_slot_inform(value_label, inform_label, label_maps={}): + result = False + informed_value = 'none' + vl = ' '.join(tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif is_in_list(il, vl): + result = True + elif is_in_list(vl, il): + result = True + elif il in label_maps: + for il_variant in label_maps[il]: + if vl == il_variant: + result = True + break + elif is_in_list(il_variant, vl): + result = True + break + elif is_in_list(vl, il_variant): + result = True + break + elif vl in label_maps: + for value_label_variant in label_maps[vl]: + if value_label_variant == il: + result = True + break + elif is_in_list(il, value_label_variant): + result = True + break + elif is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + +def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence, label_maps={}): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok, label_maps) + is_informed, informed_value = check_slot_inform(value_label, inform_label, label_maps) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = check_slot_referral(value_label, slot, seen_slots, label_maps) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + +# Requestable slots, general acts and domain indicator slots +def is_request(slot, user_act, turn_domains): + if slot in user_act: + if isinstance(user_act[slot], list): + for act in user_act[slot]: + if act['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']: + return True + elif user_act[slot]['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']: + return True + do, sl = slot.split('-') + if sl == 'none' and do in turn_domains: + return True + return False + + +def tokenize(utt): + utt_lower = convert_to_unicode(utt).lower() + utt_tok = utt_to_token(utt_lower) + return utt_tok + + +def utt_to_token(utt): + return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0] + + +def create_examples(input_file, set_type, class_types, slot_list, + label_maps={}, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + boolean_slots=True, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json')) + + examples = [] + for input_file in input_files: + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + for d_itr, dialog in enumerate(tqdm(input_data)): + dialog_id = dialog['dialogue_id'] + domains = dialog['services'] + utterances = dialog['turns'] + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [{}] + inform_dict_list = [{}] + user_act_dict_list = [{}] + mod_domains_list = [{}] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['speaker'] == "SYSTEM" + if usr_sys_switch == is_sys_utt: + print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Extract dialog_act information for sys and usr utts. + inform_dict = {} + user_act_dict = {} + modified_slots = {} + modified_domains = [] + for frame in utt['frames']: + actions = frame['actions'] + service = frame['service'] + spans = frame['slots'] + modified_domains.append(service) # Remember domains + for action in actions: + act = action['act'] + slot = action['slot'] + cs = "%s-%s" % (service, slot) + values = action['values'] # this is a list + #canonical_values = action['canonical_values'] + values = normalize_label(cs, values) + if is_sys_utt and act in ['INFORM', 'CONFIRM', 'OFFER']: + if cs not in inform_dict: + inform_dict[cs] = [] + inform_dict[cs].append(values) + elif not is_sys_utt: + if cs not in user_act_dict: + user_act_dict[cs] = [] + user_act_dict[cs].append({'domain': service, + 'intent': act, + 'slot': slot, + 'value': values}) + if act in ['INFORM']: + modified_slots[cs] = values + if not is_sys_utt: + state = frame['state'] + #active_intent = state['active_intent'] + #requested_slots = state['requested_slots'] + for slot in state['slot_values']: + cs = "%s-%s" % (service, slot) + values = frame['state']['slot_values'][slot] # this is a list + values = normalize_label(cs, values) + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[cs] != values: + modified_slots[cs] = values + cumulative_labels[cs] = values + # INFO: Since the model has no mechanism to predict + # one among several informed value candidates, we + # keep only one informed value. For fairness, we + # apply a global rule: + for e in inform_dict: + # ... Option 1: Always keep first informed value + inform_dict[e] = list([inform_dict[e][0]]) + # ... Option 2: Always keep last informed value + #inform_dict[e] = list([inform_dict[e][-1]]) + + utterance = utt['utterance'] + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + utt_tok_list.append(delex_utt(utterance, inform_dict, unk_token)) # normalizes utterances + else: + utt_tok_list.append(tokenize(utterance)) # normalizes utterances + + inform_dict_list.append(inform_dict.copy()) + user_act_dict_list.append(user_act_dict.copy()) + mod_slots_list.append(modified_slots.copy()) + modified_domains.sort() + mod_domains_list.append(modified_domains) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + for i in range(1, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + updated_slots = {slot: 0 for slot in slot_list} + + # Collect turn data + sys_utt_tok = utt_tok_list[i - 1] + usr_utt_tok = utt_tok_list[i] + turn_slots = mod_slots_list[i] + inform_mem = {} + for inform_itr in range(0, i, 2): + inform_mem.update(inform_dict_list[inform_itr]) + #inform_mem = inform_dict_list[i - 1] + user_act = user_act_dict_list[i] + turn_domains = mod_domains_list[i] + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok))) + print("%15s %2s [" % (dialog_id, turn_itr), end='') + + new_diag_state = diag_state.copy() + + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + booking_slot = 'booking-' + slot.split('-')[1] + if slot in inform_mem: + inform_label = inform_mem[slot] + inform_slot_dict[slot] = 1 + elif booking_slot in inform_mem: + inform_label = inform_mem[booking_slot] + inform_slot_dict[slot] = 1 + + (informed_value, + referred_slot, + usr_utt_tok_label, + class_type) = get_turn_label(value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True, + label_maps=label_maps) + + inform_dict[slot] = informed_value + + # Requestable slots, domain indicator slots and general slots + # should have class_type 'request', if they ought to be predicted. + # Give other class_types preference. + if 'request' in class_types: + if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains): + class_type = 'request' + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if not no_label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if diag_seen_slots_value_dict[slot] != value_label: + updated_slots[slot] = 1 + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]: + print("(%s): %s, " % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print("]") + + if not swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + slot_update=updated_slots, + class_label=class_type_dict)) + + # Update some variables. + diag_state = new_diag_state.copy() + + turn_itr += 1 + + if analyze: + print("----------------------------------------------------------------------") + + return examples + + +def get_value_list(input_file, slot_list, boolean_slots=True): + def add_to_list(value_label, cs, value_list, slot_list, exclude): + if cs in slot_list and value_label not in exclude: + if value_label not in value_list[cs]: + value_list[cs][value_label] = 0 + value_list[cs][value_label] += 1 + + exclude = ['none', 'dontcare'] + if not boolean_slots: + exclude += ['true', 'false'] + value_list = {slot: {} for slot in slot_list} + input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json')) + for input_file in input_files: + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + for dialog in input_data: + usr_sys_switch = True + for utt in dialog['turns']: + is_sys_utt = utt['speaker'] == "SYSTEM" + usr_sys_switch = is_sys_utt + inform_dict = {} + user_act_dict = {} + for frame in utt['frames']: + for action in frame['actions']: + cs = "%s-%s" % (frame['service'], action['slot']) + value_label = normalize_label(cs, action['values']) + if is_sys_utt and action['act'] in ['INFORM', 'CONFIRM', 'OFFER']: + add_to_list(value_label, cs, value_list, slot_list, exclude) + elif not is_sys_utt and action['act'] in ['INFORM']: + add_to_list(value_label, cs, value_list, slot_list, exclude) + if not is_sys_utt: + for slot in frame['state']['slot_values']: + cs = "%s-%s" % (frame['service'], slot) + value_label = normalize_label(cs, frame['state']['slot_values'][slot]) + add_to_list(value_label, cs, value_list, slot_list, exclude) + return value_list + diff --git a/dst_train.py b/dst_train.py index b01ba7bd773b6dc718cefe968be1d4539e357059..e485a965124e49b168c7bac610171a7c16a535c7 100644 --- a/dst_train.py +++ b/dst_train.py @@ -41,7 +41,7 @@ from utils_run import (set_seed, to_device, from_device, logger = logging.getLogger(__name__) -def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor): +def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor, continue_from_global_step=0): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() @@ -103,6 +103,9 @@ def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, logger.info(" Total optimization steps = %d", t_total) logger.info(" Warmup steps = %d", num_warmup_steps) + if continue_from_global_step > 0: + logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step) + global_step = 0 tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() @@ -111,11 +114,20 @@ def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, for e_itr, _ in enumerate(train_iterator): epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) - train_dataset.dropout_input() - train_dataset.encode_slots() - train_dataset.encode_slot_values() + if global_step >= continue_from_global_step: + train_dataset.dropout_input() + train_dataset.encode_slots() + train_dataset.encode_slot_values() for step, batch in enumerate(epoch_iterator): + # If training is continued from a checkpoint, fast forward + # to the state of that checkpoint. + if global_step < continue_from_global_step: + if (step + 1) % args.gradient_accumulation_steps == 0: + scheduler.step() # Update learning rate schedule + global_step += 1 + continue + model.train() # Add tokenized or encoded slot descriptions and encoded values to batch. @@ -327,7 +339,7 @@ def eval_metric(args, model, tokenizer, batch, outputs, threshold=0.0, dae=False per_example_tp_loss = per_slot_per_example_tp_loss[slot] class_logits = per_slot_class_logits[slot] start_logits = per_slot_start_logits[slot] - value_logits = per_slot_value_logits[slot] + #value_logits = per_slot_value_logits[slot] refer_logits = per_slot_refer_logits[slot] mean = [] diff --git a/modeling_dst.py b/modeling_dst.py index 72bc8d4d2685737a9d7dd7736da390a42d6e7e12..de78b016cbe7566b5e591b216e9718f0d18d471e 100644 --- a/modeling_dst.py +++ b/modeling_dst.py @@ -328,8 +328,9 @@ def TransformerForDST(parent_name): per_slot_class_logits[slot] = class_logits per_slot_token_weights[slot] = token_weights - per_slot_value_weights[slot] = value_weights per_slot_refer_logits[slot] = refer_logits + if self.value_matching_weight > 0.0: + per_slot_value_weights[slot] = value_weights # If there are no labels, don't compute loss if class_label_id is not None and token_pos is not None and refer_id is not None: diff --git a/run_dst.py b/run_dst.py index 99eacc722602cae18161b46adb841085dbebe773..c9c1310d68ae7a9cf1777a1def924b0688b73c95 100644 --- a/run_dst.py +++ b/run_dst.py @@ -217,7 +217,7 @@ def main(): if task_name not in PROCESSORS: raise ValueError("Task not found: %s" % (task_name)) - processor = PROCESSORS[task_name](args.dataset_config, args.data_dir) + processor = PROCESSORS[task_name](args.dataset_config, args.data_dir, 'train' if not args.do_eval else args.predict_type) slot_list = processor.slot_list noncategorical = processor.noncategorical class_types = processor.class_types @@ -300,7 +300,7 @@ def main(): if args.training_phase in [-1, 0, 1]: train_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset="train", evaluate=False) dev_dataset = None - if args.do_train and args.evaluate_during_training: + if args.evaluate_during_training: dev_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True) # Step 1: Pretrain attention layer for random sequence tagging. @@ -319,7 +319,8 @@ def main(): model = model_class.from_pretrained(proto_checkpoint) model.to(args.device) train_dataset.update_model(model) - dev_dataset.update_model(model) + if dev_dataset is not None: + dev_dataset.update_model(model) max_tag_goal = 0.0 max_tag_thresh = 0.0 max_dae = True # default should be true @@ -356,7 +357,8 @@ def main(): if args.do_train and args.evaluate_during_training: dev_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True) train_dataset.compute_vectors() - dev_dataset.compute_vectors() + if dev_dataset is not None: + dev_dataset.compute_vectors() global_step, tr_loss = train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) else: @@ -364,13 +366,23 @@ def main(): # Train full model with original training. if args.training_phase == -1: - if len(checkpoints) == 0: - train_dataset.compute_vectors() + # If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set) + continue_from_global_step = 0 + if len(checkpoints) > 0: + with open(os.path.join(args.output_dir, 'last_checkpoint.txt'), 'r') as f: + continue_from_global_step = int((f.readline()).split('-')[-1]) + checkpoint = os.path.join(args.output_dir, 'checkpoint-%s' % continue_from_global_step) + logger.warning(" Resuming training from the latest checkpoint: %s", checkpoint) + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + train_dataset.update_model(model) + if dev_dataset is not None: + dev_dataset.update_model(model) + train_dataset.compute_vectors() + if dev_dataset is not None: dev_dataset.compute_vectors() - global_step, tr_loss = train(args, train_dataset, dev_dataset, None, model, tokenizer, processor) - logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - else: - logger.warning(" Preconditions for training not fulfilled! Skipping.") + global_step, tr_loss = train(args, train_dataset, dev_dataset, None, model, tokenizer, processor, continue_from_global_step) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Save the trained model and the tokenizer if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): diff --git a/utils_dst.py b/utils_dst.py index 00728dd0488a6f22dc4e315eb6bbabcc4a0804b4..5d6f4129b5b6e818cfe7b7cef845d57dba6cd58d 100644 --- a/utils_dst.py +++ b/utils_dst.py @@ -582,7 +582,7 @@ class TrippyDataset(Dataset): def __getitem__(self, index): result = {} - # Static elements. Copy, because they will be modified below. + # Static elements. Copy, because they will be modified below. # TODO: make more efficient for key, element in self.data.items(): if isinstance(element, dict): result[key] = {k: v[index].detach().clone() for k, v in element.items()} @@ -648,16 +648,18 @@ class TrippyDataset(Dataset): # TODO: Test subsampling negative samples. # For attention based value matching - result['value_labels'][slot] = torch.zeros((len(self.encoded_slot_values[slot])), dtype=torch.float) - result['dropout_value_feat'][slot] = torch.zeros((1, self.model.config.hidden_size), dtype=torch.float) - # Only train value matching, if value is extractable - if token_is_pointable and pos_value in self.encoded_slot_values[slot]: - result['value_labels'][slot][list(self.encoded_slot_values[slot].keys()).index(pos_value)] = 1.0 - # In case of dropout, forward new representation as target for value matching instead. - if self.dropout_value_seq is not None: - if result['example_id'].item() in self.dropout_value_seq[slot]: - dropout_value_seq = tuple(self.dropout_value_seq[slot][result['example_id'].item()]) - result['dropout_value_feat'][slot] = self.encoded_dropout_slot_values[slot][dropout_value_seq] + if self.encoded_slot_values is not None: + result['value_labels'][slot] = torch.zeros((len(self.encoded_slot_values[slot])), dtype=torch.float) + result['dropout_value_feat'][slot] = torch.zeros((1, self.model.config.hidden_size), dtype=torch.float) + # Only train value matching, if value is extractable + if token_is_pointable and pos_value in self.encoded_slot_values[slot]: + result['value_labels'][slot][list(self.encoded_slot_values[slot].keys()).index(pos_value)] = 1.0 + # In case of dropout, forward new representation as target for value matching instead. + if self.dropout_value_seq is not None: + if result['example_id'].item() in self.dropout_value_seq[slot]: + dropout_value_seq = tuple(self.dropout_value_seq[slot][result['example_id'].item()]) + result['dropout_value_feat'][slot] = self.encoded_dropout_slot_values[slot][dropout_value_seq] + return result def _encode_text(self, text, input_ids, input_mask, mode="represent", train=False):