diff --git a/README.md b/README.md index e7ebcf205c94d13683decd2bf77bcf720a0eb615..be035831eb032a9c713974f96f1782b7e4094b68 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ Our approach combines the advantages of span-based slot filling methods with mem ## Recent updates +- 2024.09.17: Added SGD support + - 2022.12.19: Added support for ConvLab-3's unified data format. Added faster caching. Added transformers 4 support. - 2022.02.15: Added support for MultiWOZ versions 2.2, 2.3, 2.4 diff --git a/data_processors.py b/data_processors.py index 671bd9a7fc1caeea3cc41f22673a9a91579ae9c5..625453a85079e8b4829148c6b1515f87eac3d137 100644 --- a/data_processors.py +++ b/data_processors.py @@ -26,6 +26,7 @@ import dataset_multiwoz21 import dataset_multiwoz21_legacy import dataset_aux_task import dataset_unified +import dataset_sgd class DataProcessor(object): @@ -135,6 +136,24 @@ class UnifiedDatasetProcessor(DataProcessor): self.slot_list, self.label_maps, **args) +class SgdProcessor(DataProcessor): + def _get_slot_list(self): + data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO + return dataset_sgd.get_slot_list(os.path.join(data_dir, 'train', 'schema.json')) + + def get_train_examples(self, data_dir, args): + return dataset_sgd.create_examples(os.path.join(data_dir, 'train'), + 'train', self.class_types, self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, data_dir, args): + return dataset_sgd.create_examples(os.path.join(data_dir, 'dev'), + 'dev', self.class_types, self.slot_list, self.label_maps, **args) + + def get_test_examples(self, data_dir, args): + return dataset_sgd.create_examples(os.path.join(data_dir, 'test'), + 'test', self.class_types, self.slot_list, self.label_maps, **args) + + class AuxTaskProcessor(object): def get_aux_task_examples(self, data_dir, data_name, max_seq_length): file_path = os.path.join(data_dir, '{}_train.json'.format(data_name)) @@ -147,4 +166,5 @@ PROCESSORS = {"woz2": Woz2Processor, "multiwoz21": Multiwoz21Processor, "multiwoz21_legacy": Multiwoz21LegacyProcessor, "unified": UnifiedDatasetProcessor, + "sgd": SgdProcessor, "aux_task": AuxTaskProcessor} diff --git a/dataset_config/sgd.json b/dataset_config/sgd.json new file mode 100644 index 0000000000000000000000000000000000000000..0987656ada296d96691e59694e790c6bd894c162 --- /dev/null +++ b/dataset_config/sgd.json @@ -0,0 +1,168 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform" + ], + "slots": [], + "label_maps": { + "inexpensive": [ + "cheap", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable", + "low cost", + "low priced", + "low-cost", + "budget", + "bargain priced" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "afforadable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap", + "not too costly", + "not very costly", + "economical", + "intermediate", + "average" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "above average", + "fancy", + "upscale", + "expensively", + "luxury", + "pricey", + "costly" + ], + "very expensive": [ + "very fancy", + "lavish", + "extravagant" + ], + "0": [ + "zero" + ], + "1": [ + "one" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "kitchen speaker": [ + "kitchen" + ], + "bedroom speaker": [ + "bedroom" + ], + "music": [ + "concert" + ], + "sports": [ + "match", + "matches", + "game", + "games" + ], + "standard": [ + "medium-sized", + "intermediate" + ], + "compact": [ + "small" + ], + "tv": [ + "television", + "display" + ], + "full-size": [ + "large", + "spacious" + ], + "park": [ + "gardens", + "garden" + ], + "nature preserve": [ + "natural spot", + "wildlife spot" + ], + "historical landmark": [ + "historical spot" + ], + "tourist attraction": [ + "place of interest" + ], + "theme park": [ + "amusement park" + ], + "sports venue": [ + "playground" + ], + "place of worship": [ + "religious spot" + ], + "shopping area": [ + "mall" + ], + "performing arts venue": [ + "performance venue" + ] + } +} diff --git a/dataset_config/unified_sgd.json b/dataset_config/unified_sgd.json new file mode 100644 index 0000000000000000000000000000000000000000..3482a09a586bb33084f842cfd9c1353ecadce987 --- /dev/null +++ b/dataset_config/unified_sgd.json @@ -0,0 +1,170 @@ +{ + "dataset_name": "sgd", + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform", + "request" + ], + "slots": [], + "label_maps": { + "inexpensive": [ + "cheap", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable", + "low cost", + "low priced", + "low-cost", + "budget", + "bargain priced" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "afforadable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap", + "not too costly", + "not very costly", + "economical", + "intermediate", + "average" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "above average", + "fancy", + "upscale", + "expensively", + "luxury", + "pricey", + "costly" + ], + "very expensive": [ + "very fancy", + "lavish", + "extravagant" + ], + "0": [ + "zero" + ], + "1": [ + "one" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "kitchen speaker": [ + "kitchen" + ], + "bedroom speaker": [ + "bedroom" + ], + "music": [ + "concert" + ], + "sports": [ + "match", + "matches", + "game", + "games" + ], + "standard": [ + "medium-sized", + "intermediate" + ], + "compact": [ + "small" + ], + "tv": [ + "television", + "display" + ], + "full-size": [ + "large", + "spacious" + ], + "park": [ + "gardens", + "garden" + ], + "nature preserve": [ + "natural spot", + "wildlife spot" + ], + "historical landmark": [ + "historical spot" + ], + "tourist attraction": [ + "place of interest" + ], + "theme park": [ + "amusement park" + ], + "sports venue": [ + "playground" + ], + "place of worship": [ + "religious spot" + ], + "shopping area": [ + "mall" + ], + "performing arts venue": [ + "performance venue" + ] + } +} diff --git a/dataset_sgd.py b/dataset_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..1b337d66160b0b242433b2089031349f3fa5e942 --- /dev/null +++ b/dataset_sgd.py @@ -0,0 +1,532 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import os +import glob +from tqdm import tqdm + +from utils_dst import (DSTExample, convert_to_unicode) + + +def get_slot_list(input_file): + slot_list = [] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + for service in input_data: + service_name = service['service_name'] + for slot in service['slots']: + slot_name = slot['name'] + s = "%s-%s" % (service_name, slot_name) + slot_list.append(s) + return slot_list + + +# This should only contain label normalizations, no label mappings. +def normalize_label(slot, value_label): + if isinstance(value_label, list): + if len(value_label) > 1: + value_label = value_label[0] # TODO: Workaround. Note that Multiwoz 2.2 supports variants directly in the labels. + elif len(value_label) == 1: + value_label = value_label[0] + elif len(value_label) == 0: + value_label = "" + # Normalization of capitalization + if isinstance(value_label, str): + value_label = value_label.lower().strip() + + # Normalization of empty slots # TODO: needed? + if value_label == '': + return "none" + + # Normalization of 'dontcare' + if value_label == 'dont care': + return "dontcare" + + # Map to boolean slots + + return value_label + + +def get_token_pos(tok_list, value_label): + find_pos = [] + found = False + label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + +def check_label_existence(value_label, usr_utt_tok, label_maps={}): + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in label_maps: + for value_label_variant in label_maps[value_label]: + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + +def check_slot_referral(value_label, slot, seen_slots, label_maps={}): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay': + continue + if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in label_maps: + for value_label_variant in label_maps[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + +def is_in_list(tok, value): + found = False + tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] + value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + +def delex_utt(utt, values, unk_token="[UNK]"): + utt_norm = tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + +# Fuzzy matching to label informed slot values +def check_slot_inform(value_label, inform_label, label_maps={}): + result = False + informed_value = 'none' + vl = ' '.join(tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif is_in_list(il, vl): + result = True + elif is_in_list(vl, il): + result = True + elif il in label_maps: + for il_variant in label_maps[il]: + if vl == il_variant: + result = True + break + elif is_in_list(il_variant, vl): + result = True + break + elif is_in_list(vl, il_variant): + result = True + break + elif vl in label_maps: + for value_label_variant in label_maps[vl]: + if value_label_variant == il: + result = True + break + elif is_in_list(il, value_label_variant): + result = True + break + elif is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + +def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence, label_maps={}): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok, label_maps) + is_informed, informed_value = check_slot_inform(value_label, inform_label, label_maps) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = check_slot_referral(value_label, slot, seen_slots, label_maps) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + +# Requestable slots, general acts and domain indicator slots +def is_request(slot, user_act, turn_domains): + if slot in user_act: + if isinstance(user_act[slot], list): + for act in user_act[slot]: + if act['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']: + return True + elif user_act[slot]['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']: + return True + do, sl = slot.split('-') + if sl == 'none' and do in turn_domains: + return True + return False + + +def tokenize(utt): + utt_lower = convert_to_unicode(utt).lower() + utt_tok = utt_to_token(utt_lower) + return utt_tok + + +def utt_to_token(utt): + return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0] + + +def create_examples(input_file, set_type, class_types, slot_list, + label_maps={}, + no_append_history=False, + no_use_history_labels=False, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + analyze=True): + """Read a DST json file into a list of DSTExample.""" + + input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json')) + + examples = [] + for input_file in input_files: + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + for d_itr, dialog in enumerate(tqdm(input_data)): + dialog_id = dialog['dialogue_id'] + domains = dialog['services'] + utterances = dialog['turns'] + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [{}] + inform_dict_list = [{}] + user_act_dict_list = [{}] + mod_domains_list = [{}] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['speaker'] == "SYSTEM" + if usr_sys_switch == is_sys_utt: + print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Extract dialog_act information for sys and usr utts. + inform_dict = {} + user_act_dict = {} + modified_slots = {} + modified_domains = [] + for frame in utt['frames']: + actions = frame['actions'] + service = frame['service'] + spans = frame['slots'] + modified_domains.append(service) # Remember domains + for action in actions: + act = action['act'] + slot = action['slot'] + cs = "%s-%s" % (service, slot) + values = action['values'] # this is a list + #canonical_values = action['canonical_values'] + values = normalize_label(cs, values) + if is_sys_utt and act in ['INFORM', 'CONFIRM', 'OFFER']: + if cs not in inform_dict: + inform_dict[cs] = [] + inform_dict[cs].append(values) + elif not is_sys_utt: + if cs not in user_act_dict: + user_act_dict[cs] = [] + user_act_dict[cs].append({'domain': service, + 'intent': act, + 'slot': slot, + 'value': values}) + if act in ['INFORM']: + modified_slots[cs] = values + if not is_sys_utt: + state = frame['state'] + #active_intent = state['active_intent'] + #requested_slots = state['requested_slots'] + for slot in state['slot_values']: + cs = "%s-%s" % (service, slot) + values = frame['state']['slot_values'][slot] # this is a list + values = normalize_label(cs, values) + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[cs] != values: + modified_slots[cs] = values + cumulative_labels[cs] = values + # INFO: Since the model has no mechanism to predict + # one among several informed value candidates, we + # keep only one informed value. For fairness, we + # apply a global rule: + for e in inform_dict: + # ... Option 1: Always keep first informed value + inform_dict[e] = list([inform_dict[e][0]]) + # ... Option 2: Always keep last informed value + #inform_dict[e] = list([inform_dict[e][-1]]) + + utterance = utt['utterance'] + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + utt_tok_list.append(delex_utt(utterance, inform_dict, unk_token)) # normalizes utterances + else: + utt_tok_list.append(tokenize(utterance)) # normalizes utterances + + inform_dict_list.append(inform_dict.copy()) + user_act_dict_list.append(user_act_dict.copy()) + mod_slots_list.append(modified_slots.copy()) + modified_domains.sort() + mod_domains_list.append(modified_domains) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + for i in range(1, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if not no_append_history: + if not swap_utterances: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + sys_utt_tok = utt_tok_list[i - 1] + usr_utt_tok = utt_tok_list[i] + turn_slots = mod_slots_list[i] + inform_mem = {} + for inform_itr in range(0, i, 2): + inform_mem.update(inform_dict_list[inform_itr]) + #inform_mem = inform_dict_list[i - 1] + user_act = user_act_dict_list[i] + turn_domains = mod_domains_list[i] + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + #if analyze: + # print("%15s Inform: %s" % (dialog_id, inform_mem)) + # print("%15s Seen : %s" % (dialog_id, diag_seen_slots_value_dict)) + + if analyze: + print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok))) + print("%15s %2s [" % (dialog_id, turn_itr), end='') + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + booking_slot = 'booking-' + slot.split('-')[1] + if slot in inform_mem: + inform_label = inform_mem[slot] + inform_slot_dict[slot] = 1 + elif booking_slot in inform_mem: + inform_label = inform_mem[booking_slot] + inform_slot_dict[slot] = 1 + + (informed_value, + referred_slot, + usr_utt_tok_label, + class_type) = get_turn_label(value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True, + label_maps=label_maps) + + inform_dict[slot] = informed_value + + # Requestable slots, domain indicator slots and general slots + # should have class_type 'request', if they ought to be predicted. + # Give other class_types preference. + if 'request' in class_types: + if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains): + class_type = 'request' + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if not no_label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if not no_append_history: + if not no_use_history_labels: + if not swap_utterances: + new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot] + else: + new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot] + else: + new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]: + print("(%s): %s, " % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print("]") + + if not swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict)) + + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + turn_itr += 1 + + if analyze: + print("----------------------------------------------------------------------") + + return examples diff --git a/dataset_unified.py b/dataset_unified.py index a0217cc0c4f16175113e09f3874670b54af15973..566667cb963f24eed5f50cda9b214219a5fcdd5a 100644 --- a/dataset_unified.py +++ b/dataset_unified.py @@ -72,12 +72,14 @@ def create_examples(set_type, dataset_name="multiwoz21", class_types=[], slot_li analyze=False): """Read a DST json file into a list of DSTExample.""" - # TODO: Make sure normalization etc. will be compatible with or suitable for SGD and - # other datasets as well. if dataset_name == "multiwoz21": from dataset_multiwoz21 import (tokenize, normalize_label, get_turn_label, delex_utt, is_request) + elif dataset_name == "sgd": + from dataset_sgd import (tokenize, normalize_label, + get_turn_label, delex_utt, + is_request) else: raise ValueError("Unknown dataset_name.")