Skip to content
Snippets Groups Projects
Commit fcee8dd3 authored by Michael Heck's avatar Michael Heck
Browse files

Support for SGD, resume training

parent 9cc92266
No related branches found
No related tags found
No related merge requests found
Showing
with 993 additions and 27 deletions
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
......@@ -25,6 +25,7 @@ import dataset_sim
import dataset_multiwoz21
import dataset_multiwoz21_legacy
import dataset_unified
import dataset_sgd
class DataProcessor(object):
......@@ -37,20 +38,24 @@ class DataProcessor(object):
label_maps = {}
value_list = {'train': {}, 'dev': {}, 'test': {}}
def __init__(self, dataset_config, data_dir):
def __init__(self, dataset_config, data_dir, predict_type='train'):
self.data_dir = data_dir
# Load dataset config file.
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
self.dataset_name = raw_config['dataset_name'] if 'dataset_name' in raw_config else ""
self.class_types = raw_config['class_types'] # Required
self.slot_list = raw_config['slots'] if 'slots' in raw_config else {}
self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else []
self.boolean = raw_config['boolean'] if 'boolean' in raw_config else []
self.slot_list = raw_config['slots'] if 'slots' in raw_config else None
self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else None
self.boolean = raw_config['boolean'] if 'boolean' in raw_config else None
self.label_maps = raw_config['label_maps'] if 'label_maps' in raw_config else {}
# If not slot list is provided, generate from data.
if len(self.slot_list) == 0:
self.slot_list = self._get_slot_list()
if self.slot_list is None:
self.slot_list = self._get_slot_list(predict_type)
if self.noncategorical is None:
self.noncategorical = self._get_noncategorical(predict_type)
if self.boolean is None:
self.noncategorical = self._get_boolean(predict_type)
def _add_dummy_value_to_value_list(self):
for dset in self.value_list:
......@@ -77,7 +82,13 @@ class DataProcessor(object):
self.value_list['train'][s][v] += new_value_list[s][v]
self._add_dummy_value_to_value_list()
def _get_slot_list(self):
def _get_slot_list(self, predict_type):
raise NotImplementedError()
def _get_noncategorical(self, predict_type):
raise NotImplementedError()
def _get_boolean(self, predict_type):
raise NotImplementedError()
def prediction_normalization(self, slot, value):
......@@ -94,8 +105,8 @@ class DataProcessor(object):
class Woz2Processor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(Woz2Processor, self).__init__(dataset_config, data_dir)
def __init__(self, dataset_config, data_dir, predict_type='train'):
super(Woz2Processor, self).__init__(dataset_config, data_dir, predict_type)
self.value_list['train'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_train_en.json'),
self.slot_list)
self.value_list['dev'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_validate_en.json'),
......@@ -117,8 +128,8 @@ class Woz2Processor(DataProcessor):
class Multiwoz21Processor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(Multiwoz21Processor, self).__init__(dataset_config, data_dir)
def __init__(self, dataset_config, data_dir, predict_type='train'):
super(Multiwoz21Processor, self).__init__(dataset_config, data_dir, predict_type)
self.value_list['train'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'train_dials.json'),
self.slot_list)
self.value_list['dev'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'val_dials.json'),
......@@ -144,8 +155,8 @@ class Multiwoz21Processor(DataProcessor):
class Multiwoz21LegacyProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir)
def __init__(self, dataset_config, data_dir, predict_type='train'):
super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir, predict_type)
self.value_list['train'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'train_dials.json'),
self.slot_list)
self.value_list['dev'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'val_dials.json'),
......@@ -174,8 +185,8 @@ class Multiwoz21LegacyProcessor(DataProcessor):
class SimProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(SimProcessor, self).__init__(dataset_config, data_dir)
def __init__(self, dataset_config, data_dir, predict_type='train'):
super(SimProcessor, self).__init__(dataset_config, data_dir, predict_type)
self.value_list['train'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'train.json'),
self.slot_list)
self.value_list['dev'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'dev.json'),
......@@ -197,8 +208,8 @@ class SimProcessor(DataProcessor):
class UnifiedDatasetProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir)
def __init__(self, dataset_config, data_dir, predict_type='train'):
super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir, predict_type)
self.value_list['train'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
self.value_list['dev'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
self.value_list['test'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
......@@ -207,7 +218,7 @@ class UnifiedDatasetProcessor(DataProcessor):
def prediction_normalization(self, slot, value):
return dataset_unified.prediction_normalization(self.dataset_name, slot, value)
def _get_slot_list(self):
def _get_slot_list(self, predict_type):
return dataset_unified.get_slot_list(self.dataset_name)
def get_train_examples(self, args):
......@@ -223,9 +234,46 @@ class UnifiedDatasetProcessor(DataProcessor):
self.slot_list, self.label_maps, **args)
class SgdProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir, predict_type='train'):
super(SgdProcessor, self).__init__(dataset_config, data_dir, predict_type)
self.value_list['train'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'train'), self.slot_list)
self.value_list['dev'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'dev'), self.slot_list)
self.value_list['test'] = dataset_sgd.get_value_list(os.path.join(self.data_dir, 'test'), self.slot_list)
self._add_dummy_value_to_value_list()
def prediction_normalization(self, slot, value):
return dataset_sgd.prediction_normalization(slot, value)
def _get_slot_list(self, predict_type='train'):
data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO
return dataset_sgd.get_slot_list(os.path.join(data_dir, predict_type, 'schema.json'))
def _get_noncategorical(self, predict_type='train'):
data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO
return dataset_sgd.get_noncategorical(os.path.join(data_dir, predict_type, 'schema.json'))
def _get_boolean(self, predict_type='train'):
data_dir = "/gpfs/project/heckmi/data/dstc8-schema-guided-dialogue" # TODO
return dataset_sgd.get_boolean(os.path.join(data_dir, predict_type, 'schema.json'))
def get_train_examples(self, args):
return dataset_sgd.create_examples(os.path.join(self.data_dir, 'train'),
'train', self.class_types, self.slot_list, self.label_maps, **args)
def get_dev_examples(self, args):
return dataset_sgd.create_examples(os.path.join(self.data_dir, 'dev'),
'dev', self.class_types, self.slot_list, self.label_maps, **args)
def get_test_examples(self, args):
return dataset_sgd.create_examples(os.path.join(self.data_dir, 'test'),
'test', self.class_types, self.slot_list, self.label_maps, **args)
PROCESSORS = {"woz2": Woz2Processor,
"sim-m": SimProcessor,
"sim-r": SimProcessor,
"multiwoz21": Multiwoz21Processor,
"multiwoz21_legacy": Multiwoz21LegacyProcessor,
"unified": UnifiedDatasetProcessor}
"unified": UnifiedDatasetProcessor,
"sgd": SgdProcessor}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"true",
"false",
"refer",
"inform"
],
"label_maps": {
"inexpensive": [
"cheap",
"lower price",
"lower range",
"cheaply",
"cheaper",
"cheapest",
"very affordable",
"low cost",
"low priced",
"low-cost",
"budget",
"bargain priced"
],
"moderate": [
"moderately",
"reasonable",
"reasonably",
"affordable",
"afforadable",
"mid range",
"mid-range",
"priced moderately",
"decently priced",
"mid price",
"mid-price",
"mid priced",
"mid-priced",
"middle price",
"medium price",
"medium priced",
"not too expensive",
"not too cheap",
"not too costly",
"not very costly",
"economical",
"intermediate",
"average"
],
"expensive": [
"high end",
"high-end",
"high class",
"high-class",
"high scale",
"high-scale",
"high price",
"high priced",
"higher price",
"above average",
"fancy",
"upscale",
"expensively",
"luxury",
"pricey",
"costly"
],
"very expensive": [
"very fancy",
"lavish",
"extravagant"
],
"0": [
"zero"
],
"1": [
"one"
],
"2": [
"two"
],
"3": [
"three"
],
"4": [
"four"
],
"5": [
"five"
],
"6": [
"six"
],
"7": [
"seven"
],
"8": [
"eight"
],
"9": [
"nine"
],
"10": [
"ten"
],
"kitchen speaker": [
"kitchen"
],
"bedroom speaker": [
"bedroom"
],
"music": [
"concert"
],
"sports": [
"match",
"matches",
"game",
"games"
],
"standard": [
"medium-sized",
"intermediate"
],
"compact": [
"small"
],
"tv": [
"television",
"display"
],
"full-size": [
"large",
"spacious"
],
"park": [
"gardens",
"garden"
],
"nature preserve": [
"natural spot",
"wildlife spot"
],
"historical landmark": [
"historical spot"
],
"tourist attraction": [
"place of interest"
],
"theme park": [
"amusement park"
],
"sports venue": [
"playground"
],
"place of worship": [
"religious spot"
],
"shopping area": [
"mall"
],
"performing arts venue": [
"performance venue"
]
}
}
......@@ -14,7 +14,5 @@
},
"noncategorical": [
"movie"
],
"boolean": [],
"label_maps": {}
]
}
......@@ -18,7 +18,5 @@
},
"noncategorical": [
"restaurant_name"
],
"boolean": [],
"label_maps": {}
]
}
......@@ -10,7 +10,6 @@
"inform",
"request"
],
"slots": [],
"noncategorical": [
"taxi-leaveAt",
"taxi-destination",
......
{
"dataset_name": "sgd",
"class_types": [
"none",
"dontcare",
"copy_value",
"true",
"false",
"refer",
"inform",
"request"
],
"label_maps": {
"inexpensive": [
"cheap",
"lower price",
"lower range",
"cheaply",
"cheaper",
"cheapest",
"very affordable",
"low cost",
"low priced",
"low-cost",
"budget",
"bargain priced"
],
"moderate": [
"moderately",
"reasonable",
"reasonably",
"affordable",
"afforadable",
"mid range",
"mid-range",
"priced moderately",
"decently priced",
"mid price",
"mid-price",
"mid priced",
"mid-priced",
"middle price",
"medium price",
"medium priced",
"not too expensive",
"not too cheap",
"not too costly",
"not very costly",
"economical",
"intermediate",
"average"
],
"expensive": [
"high end",
"high-end",
"high class",
"high-class",
"high scale",
"high-scale",
"high price",
"high priced",
"higher price",
"above average",
"fancy",
"upscale",
"expensively",
"luxury",
"pricey",
"costly"
],
"very expensive": [
"very fancy",
"lavish",
"extravagant"
],
"0": [
"zero"
],
"1": [
"one"
],
"2": [
"two"
],
"3": [
"three"
],
"4": [
"four"
],
"5": [
"five"
],
"6": [
"six"
],
"7": [
"seven"
],
"8": [
"eight"
],
"9": [
"nine"
],
"10": [
"ten"
],
"kitchen speaker": [
"kitchen"
],
"bedroom speaker": [
"bedroom"
],
"music": [
"concert"
],
"sports": [
"match",
"matches",
"game",
"games"
],
"standard": [
"medium-sized",
"intermediate"
],
"compact": [
"small"
],
"tv": [
"television",
"display"
],
"full-size": [
"large",
"spacious"
],
"park": [
"gardens",
"garden"
],
"nature preserve": [
"natural spot",
"wildlife spot"
],
"historical landmark": [
"historical spot"
],
"tourist attraction": [
"place of interest"
],
"theme park": [
"amusement park"
],
"sports venue": [
"playground"
],
"place of worship": [
"religious spot"
],
"shopping area": [
"mall"
],
"performing arts venue": [
"performance venue"
]
}
}
# coding=utf-8
#
# Copyright 2020-2022 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import os
import glob
from tqdm import tqdm
from utils_dst import (DSTExample, convert_to_unicode)
# TODO: check what's actually needed
def prediction_normalization(slot, value):
#def _normalize_value(text):
# text = re.sub(" ?' ?s", "s", text)
# return text
#value = _normalize_value(value)
return value
def get_slot_list(input_file):
slot_list = {}
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
for service in input_data:
for slot in service['slots']:
s = "%s-%s" % (service['service_name'], slot['name'])
slot_list[s] = slot['description'].lower()
return slot_list
def get_noncategorical(input_file):
noncategorical = []
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
for service in input_data:
for slot in service['slots']:
s = "%s-%s" % (service['service_name'], slot['name'])
if not slot['is_categorical']:
noncategorical.append(s)
return noncategorical
def get_boolean(input_file):
boolean = []
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
for service in input_data:
for slot in service['slots']:
s = "%s-%s" % (service['service_name'], slot['name'])
if len(slot['possible_values']) == 2 and \
"True" in slot['possible_values'] and \
"False" in slot['possible_values']:
boolean.append(s)
return boolean
# This should only contain label normalizations, no label mappings.
def normalize_label(slot, value_label):
if isinstance(value_label, list):
if len(value_label) > 1:
value_label = value_label[0] # TODO: Workaround. Note that Multiwoz 2.2 supports variants directly in the labels.
elif len(value_label) == 1:
value_label = value_label[0]
elif len(value_label) == 0:
value_label = ""
# Normalization of capitalization
if isinstance(value_label, str):
value_label = value_label.lower().strip()
# Normalization of empty slots # TODO: needed?
if value_label == '':
return "none"
# Normalization of 'dontcare'
if value_label == 'dont care':
return "dontcare"
# Map to boolean slots
return value_label
def get_token_pos(tok_list, value_label):
find_pos = []
found = False
label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0]
len_label = len(label_list)
for i in range(len(tok_list) + 1 - len_label):
if tok_list[i:i + len_label] == label_list:
find_pos.append((i, i + len_label)) # start, exclusive_end
found = True
return found, find_pos
def check_label_existence(value_label, usr_utt_tok, label_maps={}):
in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label)
# If no hit even though there should be one, check for value label variants
if not in_usr and value_label in label_maps:
for value_label_variant in label_maps[value_label]:
in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant)
if in_usr:
break
return in_usr, usr_pos
def check_slot_referral(value_label, slot, seen_slots, label_maps={}):
referred_slot = 'none'
if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking':
return referred_slot
for s in seen_slots:
# Avoid matches for slots that share values with different meaning.
# hotel-internet and -parking are handled separately as Boolean slots.
if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking':
continue
if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay':
continue
if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay':
continue
if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label):
if seen_slots[s] == value_label:
referred_slot = s
break
elif value_label in label_maps:
for value_label_variant in label_maps[value_label]:
if seen_slots[s] == value_label_variant:
referred_slot = s
break
return referred_slot
def is_in_list(tok, value):
found = False
tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
tok_len = len(tok_list)
value_len = len(value_list)
for i in range(tok_len + 1 - value_len):
if tok_list[i:i + value_len] == value_list:
found = True
break
return found
def delex_utt(utt, values, unk_token="[UNK]"):
utt_norm = tokenize(utt)
for s, vals in values.items():
for v in vals:
if v != 'none':
v_norm = tokenize(v)
v_len = len(v_norm)
for i in range(len(utt_norm) + 1 - v_len):
if utt_norm[i:i + v_len] == v_norm:
utt_norm[i:i + v_len] = [unk_token] * v_len
return utt_norm
# Fuzzy matching to label informed slot values
def check_slot_inform(value_label, inform_label, label_maps={}):
result = False
informed_value = 'none'
vl = ' '.join(tokenize(value_label))
for il in inform_label:
if vl == il:
result = True
elif is_in_list(il, vl):
result = True
elif is_in_list(vl, il):
result = True
elif il in label_maps:
for il_variant in label_maps[il]:
if vl == il_variant:
result = True
break
elif is_in_list(il_variant, vl):
result = True
break
elif is_in_list(vl, il_variant):
result = True
break
elif vl in label_maps:
for value_label_variant in label_maps[vl]:
if value_label_variant == il:
result = True
break
elif is_in_list(il, value_label_variant):
result = True
break
elif is_in_list(value_label_variant, il):
result = True
break
if result:
informed_value = il
break
return result, informed_value
def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence, label_maps={}):
usr_utt_tok_label = [0 for _ in usr_utt_tok]
informed_value = 'none'
referred_slot = 'none'
if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false':
class_type = value_label
else:
in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok, label_maps)
is_informed, informed_value = check_slot_inform(value_label, inform_label, label_maps)
if in_usr:
class_type = 'copy_value'
if slot_last_occurrence:
(s, e) = usr_pos[-1]
for i in range(s, e):
usr_utt_tok_label[i] = 1
else:
for (s, e) in usr_pos:
for i in range(s, e):
usr_utt_tok_label[i] = 1
elif is_informed:
class_type = 'inform'
else:
referred_slot = check_slot_referral(value_label, slot, seen_slots, label_maps)
if referred_slot != 'none':
class_type = 'refer'
else:
class_type = 'unpointable'
return informed_value, referred_slot, usr_utt_tok_label, class_type
# Requestable slots, general acts and domain indicator slots
def is_request(slot, user_act, turn_domains):
if slot in user_act:
if isinstance(user_act[slot], list):
for act in user_act[slot]:
if act['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']:
return True
elif user_act[slot]['intent'] in ['REQUEST', 'GOODBYE', 'THANK_YOU']:
return True
do, sl = slot.split('-')
if sl == 'none' and do in turn_domains:
return True
return False
def tokenize(utt):
utt_lower = convert_to_unicode(utt).lower()
utt_tok = utt_to_token(utt_lower)
return utt_tok
def utt_to_token(utt):
return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0]
def create_examples(input_file, set_type, class_types, slot_list,
label_maps={},
no_label_value_repetitions=False,
swap_utterances=False,
delexicalize_sys_utts=False,
unk_token="[UNK]",
boolean_slots=True,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json'))
examples = []
for input_file in input_files:
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
for d_itr, dialog in enumerate(tqdm(input_data)):
dialog_id = dialog['dialogue_id']
domains = dialog['services']
utterances = dialog['turns']
# Collects all slot changes throughout the dialog
cumulative_labels = {slot: 'none' for slot in slot_list}
# First system utterance is empty, since multiwoz starts with user input
utt_tok_list = [[]]
mod_slots_list = [{}]
inform_dict_list = [{}]
user_act_dict_list = [{}]
mod_domains_list = [{}]
# Collect all utterances and their metadata
usr_sys_switch = True
turn_itr = 0
for utt in utterances:
# Assert that system and user utterances alternate
is_sys_utt = utt['speaker'] == "SYSTEM"
if usr_sys_switch == is_sys_utt:
print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
break
usr_sys_switch = is_sys_utt
if is_sys_utt:
turn_itr += 1
# Extract dialog_act information for sys and usr utts.
inform_dict = {}
user_act_dict = {}
modified_slots = {}
modified_domains = []
for frame in utt['frames']:
actions = frame['actions']
service = frame['service']
spans = frame['slots']
modified_domains.append(service) # Remember domains
for action in actions:
act = action['act']
slot = action['slot']
cs = "%s-%s" % (service, slot)
values = action['values'] # this is a list
#canonical_values = action['canonical_values']
values = normalize_label(cs, values)
if is_sys_utt and act in ['INFORM', 'CONFIRM', 'OFFER']:
if cs not in inform_dict:
inform_dict[cs] = []
inform_dict[cs].append(values)
elif not is_sys_utt:
if cs not in user_act_dict:
user_act_dict[cs] = []
user_act_dict[cs].append({'domain': service,
'intent': act,
'slot': slot,
'value': values})
if act in ['INFORM']:
modified_slots[cs] = values
if not is_sys_utt:
state = frame['state']
#active_intent = state['active_intent']
#requested_slots = state['requested_slots']
for slot in state['slot_values']:
cs = "%s-%s" % (service, slot)
values = frame['state']['slot_values'][slot] # this is a list
values = normalize_label(cs, values)
# Remember modified slots and entire dialog state
if cs in slot_list and cumulative_labels[cs] != values:
modified_slots[cs] = values
cumulative_labels[cs] = values
# INFO: Since the model has no mechanism to predict
# one among several informed value candidates, we
# keep only one informed value. For fairness, we
# apply a global rule:
for e in inform_dict:
# ... Option 1: Always keep first informed value
inform_dict[e] = list([inform_dict[e][0]])
# ... Option 2: Always keep last informed value
#inform_dict[e] = list([inform_dict[e][-1]])
utterance = utt['utterance']
# Delexicalize sys utterance
if delexicalize_sys_utts and is_sys_utt:
utt_tok_list.append(delex_utt(utterance, inform_dict, unk_token)) # normalizes utterances
else:
utt_tok_list.append(tokenize(utterance)) # normalizes utterances
inform_dict_list.append(inform_dict.copy())
user_act_dict_list.append(user_act_dict.copy())
mod_slots_list.append(modified_slots.copy())
modified_domains.sort()
mod_domains_list.append(modified_domains)
# Form proper (usr, sys) turns
turn_itr = 0
diag_seen_slots_dict = {}
diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
diag_state = {slot: 'none' for slot in slot_list}
sys_utt_tok = []
usr_utt_tok = []
for i in range(1, len(utt_tok_list) - 1, 2):
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
value_dict = {}
inform_dict = {}
inform_slot_dict = {}
referral_dict = {}
class_type_dict = {}
updated_slots = {slot: 0 for slot in slot_list}
# Collect turn data
sys_utt_tok = utt_tok_list[i - 1]
usr_utt_tok = utt_tok_list[i]
turn_slots = mod_slots_list[i]
inform_mem = {}
for inform_itr in range(0, i, 2):
inform_mem.update(inform_dict_list[inform_itr])
#inform_mem = inform_dict_list[i - 1]
user_act = user_act_dict_list[i]
turn_domains = mod_domains_list[i]
guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
if analyze:
print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
print("%15s %2s [" % (dialog_id, turn_itr), end='')
new_diag_state = diag_state.copy()
for slot in slot_list:
value_label = 'none'
if slot in turn_slots:
value_label = turn_slots[slot]
# We keep the original labels so as to not
# overlook unpointable values, as well as to not
# modify any of the original labels for test sets,
# since this would make comparison difficult.
value_dict[slot] = value_label
elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
value_label = diag_seen_slots_value_dict[slot]
# Get dialog act annotations
inform_label = list(['none'])
inform_slot_dict[slot] = 0
booking_slot = 'booking-' + slot.split('-')[1]
if slot in inform_mem:
inform_label = inform_mem[slot]
inform_slot_dict[slot] = 1
elif booking_slot in inform_mem:
inform_label = inform_mem[booking_slot]
inform_slot_dict[slot] = 1
(informed_value,
referred_slot,
usr_utt_tok_label,
class_type) = get_turn_label(value_label,
inform_label,
sys_utt_tok,
usr_utt_tok,
slot,
diag_seen_slots_value_dict,
slot_last_occurrence=True,
label_maps=label_maps)
inform_dict[slot] = informed_value
# Requestable slots, domain indicator slots and general slots
# should have class_type 'request', if they ought to be predicted.
# Give other class_types preference.
if 'request' in class_types:
if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains):
class_type = 'request'
# Generally don't use span prediction on sys utterance (but inform prediction instead).
sys_utt_tok_label = [0 for _ in sys_utt_tok]
# Determine what to do with value repetitions.
# If value is unique in seen slots, then tag it, otherwise not,
# since correct slot assignment can not be guaranteed anymore.
if not no_label_value_repetitions and slot in diag_seen_slots_dict:
if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
class_type = 'none'
usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
sys_utt_tok_label_dict[slot] = sys_utt_tok_label
usr_utt_tok_label_dict[slot] = usr_utt_tok_label
if diag_seen_slots_value_dict[slot] != value_label:
updated_slots[slot] = 1
# For now, we map all occurences of unpointable slot values
# to none. However, since the labels will still suggest
# a presence of unpointable slot values, the task of the
# DST is still to find those values. It is just not
# possible to do that via span prediction on the current input.
if class_type == 'unpointable':
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
if analyze:
if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
print("(%s): %s, " % (slot, value_label), end='')
elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
# If slot has seen before and its class type did not change, label this slot a not present,
# assuming that the slot has not actually been mentioned in this turn.
# Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
# this must mean there is evidence in the original labels, therefore consider
# them as mentioned again.
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
else:
class_type_dict[slot] = class_type
referral_dict[slot] = referred_slot
# Remember that this slot was mentioned during this dialog already.
if class_type != 'none':
diag_seen_slots_dict[slot] = class_type
diag_seen_slots_value_dict[slot] = value_label
new_diag_state[slot] = class_type
# Unpointable is not a valid class, therefore replace with
# some valid class for now...
if class_type == 'unpointable':
new_diag_state[slot] = 'copy_value'
if analyze:
print("]")
if not swap_utterances:
txt_a = usr_utt_tok
txt_b = sys_utt_tok
txt_a_lbl = usr_utt_tok_label_dict
txt_b_lbl = sys_utt_tok_label_dict
else:
txt_a = sys_utt_tok
txt_b = usr_utt_tok
txt_a_lbl = sys_utt_tok_label_dict
txt_b_lbl = usr_utt_tok_label_dict
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
values=diag_seen_slots_value_dict.copy(),
inform_label=inform_dict,
inform_slot_label=inform_slot_dict,
refer_label=referral_dict,
diag_state=diag_state,
slot_update=updated_slots,
class_label=class_type_dict))
# Update some variables.
diag_state = new_diag_state.copy()
turn_itr += 1
if analyze:
print("----------------------------------------------------------------------")
return examples
def get_value_list(input_file, slot_list, boolean_slots=True):
def add_to_list(value_label, cs, value_list, slot_list, exclude):
if cs in slot_list and value_label not in exclude:
if value_label not in value_list[cs]:
value_list[cs][value_label] = 0
value_list[cs][value_label] += 1
exclude = ['none', 'dontcare']
if not boolean_slots:
exclude += ['true', 'false']
value_list = {slot: {} for slot in slot_list}
input_files = glob.glob(os.path.join(input_file, 'dialogues_*.json'))
for input_file in input_files:
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
for dialog in input_data:
usr_sys_switch = True
for utt in dialog['turns']:
is_sys_utt = utt['speaker'] == "SYSTEM"
usr_sys_switch = is_sys_utt
inform_dict = {}
user_act_dict = {}
for frame in utt['frames']:
for action in frame['actions']:
cs = "%s-%s" % (frame['service'], action['slot'])
value_label = normalize_label(cs, action['values'])
if is_sys_utt and action['act'] in ['INFORM', 'CONFIRM', 'OFFER']:
add_to_list(value_label, cs, value_list, slot_list, exclude)
elif not is_sys_utt and action['act'] in ['INFORM']:
add_to_list(value_label, cs, value_list, slot_list, exclude)
if not is_sys_utt:
for slot in frame['state']['slot_values']:
cs = "%s-%s" % (frame['service'], slot)
value_label = normalize_label(cs, frame['state']['slot_values'][slot])
add_to_list(value_label, cs, value_list, slot_list, exclude)
return value_list
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment