From c854a517b012a5c204e1160e0abc551813a5f9da Mon Sep 17 00:00:00 2001 From: heck <heckmi@hhu.de> Date: Tue, 17 Nov 2020 09:46:30 +0000 Subject: [PATCH] fixes in dataset_* --- dataset_multiwoz21.py | 7 +++---- dataset_sim.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py index 269c2bb..e514638 100644 --- a/dataset_multiwoz21.py +++ b/dataset_multiwoz21.py @@ -423,10 +423,13 @@ def create_examples(input_file, acts_file, set_type, slot_list, # Get dialog act annotations inform_label = list(['none']) + inform_slot_dict[slot] = 0 if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]]) + inform_slot_dict[slot] = 1 elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict: inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]]) + inform_slot_dict[slot] = 1 (informed_value, referred_slot, @@ -440,10 +443,6 @@ def create_examples(input_file, acts_file, set_type, slot_list, slot_last_occurrence=True) inform_dict[slot] = informed_value - if informed_value != 'none': - inform_slot_dict[slot] = 1 - else: - inform_slot_dict[slot] = 0 # Generally don't use span prediction on sys utterance (but inform prediction instead). sys_utt_tok_label = [0 for _ in sys_utt_tok] diff --git a/dataset_sim.py b/dataset_sim.py index 6685f4e..575a29f 100644 --- a/dataset_sim.py +++ b/dataset_sim.py @@ -22,6 +22,29 @@ import json from utils_dst import (DSTExample) +# Loads the dialogue_acts.json and returns a list +# of slot-value pairs. +def load_acts(input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + d_id = d["dialogue_id"] + for t_id, t in enumerate(d["turns"]): + # Only process, if turn has annotation + if "system_acts" in t: + for a in t["system_acts"]: + if "value" in a: + key = d_id, t_id, a["slot"] + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = a["value"] + # ... Option 2: Keep last informed value + #s_dict[key] = a["value"] + return s_dict + + def dialogue_state_to_sv_dict(sv_list): sv_dict = {} for d in sv_list: @@ -98,7 +121,7 @@ def delex_utt(utt, values): return utt_delex -def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, +def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_inform_dict, delexicalize_sys_utts=False, slot_last_occurrence=True): """Make turn_label a dictionary of slot with value positions or being dontcare / none: Turn label contains: @@ -126,6 +149,7 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, slot_last_occurrence=slot_last_occurrence) if sum(sys_utt_tok_label) > 0: inform_label_dict[slot_type] = cur_ds_dict[slot_type] + if (dial_id, turn_id, slot_type) in sys_inform_dict: inform_slot_label_dict[slot_type] = 1 sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label @@ -150,6 +174,9 @@ def create_examples(input_file, set_type, slot_list, delexicalize_sys_utts=False, analyze=False): """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = load_acts(input_file) + with open(input_file, "r", encoding='utf-8') as reader: input_data = json.load(reader) @@ -178,6 +205,7 @@ def create_examples(input_file, set_type, slot_list, slot_list, dial_id, turn_id, + sys_inform_dict, delexicalize_sys_utts=delexicalize_sys_utts, slot_last_occurrence=True) -- GitLab