# coding=utf-8 # # Copyright 2020 Heinrich Heine University Duesseldorf # # Part of this code is based on the source code of BERT-DST # (arXiv:1907.03040) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from utils_dst import (DSTExample) def dialogue_state_to_sv_dict(sv_list): sv_dict = {} for d in sv_list: sv_dict[d['slot']] = d['value'] return sv_dict def get_token_and_slot_label(turn): if 'system_utterance' in turn: sys_utt_tok = turn['system_utterance']['tokens'] sys_slot_label = turn['system_utterance']['slots'] else: sys_utt_tok = [] sys_slot_label = [] usr_utt_tok = turn['user_utterance']['tokens'] usr_slot_label = turn['user_utterance']['slots'] return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label, dial_id, turn_id, slot_last_occurrence=True): """The position of the last occurrence of the slot value will be used.""" sys_utt_tok_label = [0 for _ in sys_utt_tok] usr_utt_tok_label = [0 for _ in usr_utt_tok] if slot_type not in cur_ds_dict: class_type = 'none' else: value = cur_ds_dict[slot_type] if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'): # Only label dontcare at its first occurrence in the dialog class_type = 'dontcare' else: # If not none or dontcare, we have to identify whether # there is a span, or if the value is informed in_usr = False in_sys = False for label_d in usr_slot_label: if label_d['slot'] == slot_type and value == ' '.join( usr_utt_tok[label_d['start']:label_d['exclusive_end']]): for idx in range(label_d['start'], label_d['exclusive_end']): usr_utt_tok_label[idx] = 1 in_usr = True class_type = 'copy_value' if slot_last_occurrence: break if not in_usr or not slot_last_occurrence: for label_d in sys_slot_label: if label_d['slot'] == slot_type and value == ' '.join( sys_utt_tok[label_d['start']:label_d['exclusive_end']]): for idx in range(label_d['start'], label_d['exclusive_end']): sys_utt_tok_label[idx] = 1 in_sys = True class_type = 'inform' if slot_last_occurrence: break if not in_usr and not in_sys: assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0 if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]): raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id))) else: class_type = 'none' else: assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0 return sys_utt_tok_label, usr_utt_tok_label, class_type def delex_utt(utt, values): utt_delex = utt.copy() for v in values: utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start']) return utt_delex def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, delexicalize_sys_utts=False, slot_last_occurrence=True): """Make turn_label a dictionary of slot with value positions or being dontcare / none: Turn label contains: (1) the updates from previous to current dialogue state, (2) values in current dialogue state explicitly mentioned in system or user utterance.""" prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state) cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state']) (sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn) sys_utt_tok_label_dict = {} usr_utt_tok_label_dict = {} inform_label_dict = {} inform_slot_label_dict = {} referral_label_dict = {} class_type_dict = {} for slot_type in slot_list: inform_label_dict[slot_type] = 'none' inform_slot_label_dict[slot_type] = 0 referral_label_dict[slot_type] = 'none' # Referral is not present in sim data sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label( prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label, dial_id, turn_id, slot_last_occurrence=slot_last_occurrence) if sum(sys_utt_tok_label) > 0: inform_label_dict[slot_type] = cur_ds_dict[slot_type] inform_slot_label_dict[slot_type] = 1 sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label class_type_dict[slot_type] = class_type if delexicalize_sys_utts: sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label) return (sys_utt_tok, sys_utt_tok_label_dict, usr_utt_tok, usr_utt_tok_label_dict, inform_label_dict, inform_slot_label_dict, referral_label_dict, cur_ds_dict, class_type_dict) def create_examples(input_file, set_type, slot_list, label_maps={}, append_history=False, use_history_labels=False, swap_utterances=False, label_value_repetitions=False, delexicalize_sys_utts=False, analyze=False): """Read a DST json file into a list of DSTExample.""" with open(input_file, "r", encoding='utf-8') as reader: input_data = json.load(reader) examples = [] for entry in input_data: dial_id = entry['dialogue_id'] prev_ds = [] hst = [] prev_hst_lbl_dict = {slot: [] for slot in slot_list} prev_ds_lbl_dict = {slot: 'none' for slot in slot_list} for turn_id, turn in enumerate(entry['turns']): guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id)) ds_lbl_dict = prev_ds_lbl_dict.copy() hst_lbl_dict = prev_hst_lbl_dict.copy() (text_a, text_a_label, text_b, text_b_label, inform_label, inform_slot_label, referral_label, cur_ds_dict, class_label) = get_turn_label(turn, prev_ds, slot_list, dial_id, turn_id, delexicalize_sys_utts=delexicalize_sys_utts, slot_last_occurrence=True) if swap_utterances: txt_a = text_b txt_b = text_a txt_a_lbl = text_b_label txt_b_lbl = text_a_label else: txt_a = text_a txt_b = text_b txt_a_lbl = text_a_label txt_b_lbl = text_b_label value_dict = {} for slot in slot_list: if slot in cur_ds_dict: value_dict[slot] = cur_ds_dict[slot] else: value_dict[slot] = 'none' if class_label[slot] != 'none': ds_lbl_dict[slot] = class_label[slot] if append_history: if use_history_labels: hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot] else: hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]] examples.append(DSTExample( guid=guid, text_a=txt_a, text_b=txt_b, history=hst, text_a_label=txt_a_lbl, text_b_label=txt_b_lbl, history_label=prev_hst_lbl_dict, values=value_dict, inform_label=inform_label, inform_slot_label=inform_slot_label, refer_label=referral_label, diag_state=prev_ds_lbl_dict, class_label=class_label)) prev_ds = turn['dialogue_state'] prev_ds_lbl_dict = ds_lbl_dict.copy() prev_hst_lbl_dict = hst_lbl_dict.copy() if append_history: hst = txt_a + txt_b + hst return examples