Skip to content
Snippets Groups Projects
Commit c854a517 authored by heck's avatar heck
Browse files

fixes in dataset_*

parent 732944e2
No related branches found
No related tags found
No related merge requests found
...@@ -423,10 +423,13 @@ def create_examples(input_file, acts_file, set_type, slot_list, ...@@ -423,10 +423,13 @@ def create_examples(input_file, acts_file, set_type, slot_list,
# Get dialog act annotations # Get dialog act annotations
inform_label = list(['none']) inform_label = list(['none'])
inform_slot_dict[slot] = 0
if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]]) inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
inform_slot_dict[slot] = 1
elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict: elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]]) inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
inform_slot_dict[slot] = 1
(informed_value, (informed_value,
referred_slot, referred_slot,
...@@ -440,10 +443,6 @@ def create_examples(input_file, acts_file, set_type, slot_list, ...@@ -440,10 +443,6 @@ def create_examples(input_file, acts_file, set_type, slot_list,
slot_last_occurrence=True) slot_last_occurrence=True)
inform_dict[slot] = informed_value inform_dict[slot] = informed_value
if informed_value != 'none':
inform_slot_dict[slot] = 1
else:
inform_slot_dict[slot] = 0
# Generally don't use span prediction on sys utterance (but inform prediction instead). # Generally don't use span prediction on sys utterance (but inform prediction instead).
sys_utt_tok_label = [0 for _ in sys_utt_tok] sys_utt_tok_label = [0 for _ in sys_utt_tok]
......
...@@ -22,6 +22,29 @@ import json ...@@ -22,6 +22,29 @@ import json
from utils_dst import (DSTExample) from utils_dst import (DSTExample)
# Loads the dialogue_acts.json and returns a list
# of slot-value pairs.
def load_acts(input_file):
with open(input_file) as f:
acts = json.load(f)
s_dict = {}
for d in acts:
d_id = d["dialogue_id"]
for t_id, t in enumerate(d["turns"]):
# Only process, if turn has annotation
if "system_acts" in t:
for a in t["system_acts"]:
if "value" in a:
key = d_id, t_id, a["slot"]
# In case of multiple mentioned values...
# ... Option 1: Keep first informed value
if key not in s_dict:
s_dict[key] = a["value"]
# ... Option 2: Keep last informed value
#s_dict[key] = a["value"]
return s_dict
def dialogue_state_to_sv_dict(sv_list): def dialogue_state_to_sv_dict(sv_list):
sv_dict = {} sv_dict = {}
for d in sv_list: for d in sv_list:
...@@ -98,7 +121,7 @@ def delex_utt(utt, values): ...@@ -98,7 +121,7 @@ def delex_utt(utt, values):
return utt_delex return utt_delex
def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_inform_dict,
delexicalize_sys_utts=False, slot_last_occurrence=True): delexicalize_sys_utts=False, slot_last_occurrence=True):
"""Make turn_label a dictionary of slot with value positions or being dontcare / none: """Make turn_label a dictionary of slot with value positions or being dontcare / none:
Turn label contains: Turn label contains:
...@@ -126,6 +149,7 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, ...@@ -126,6 +149,7 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
slot_last_occurrence=slot_last_occurrence) slot_last_occurrence=slot_last_occurrence)
if sum(sys_utt_tok_label) > 0: if sum(sys_utt_tok_label) > 0:
inform_label_dict[slot_type] = cur_ds_dict[slot_type] inform_label_dict[slot_type] = cur_ds_dict[slot_type]
if (dial_id, turn_id, slot_type) in sys_inform_dict:
inform_slot_label_dict[slot_type] = 1 inform_slot_label_dict[slot_type] = 1
sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt
sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label
...@@ -150,6 +174,9 @@ def create_examples(input_file, set_type, slot_list, ...@@ -150,6 +174,9 @@ def create_examples(input_file, set_type, slot_list,
delexicalize_sys_utts=False, delexicalize_sys_utts=False,
analyze=False): analyze=False):
"""Read a DST json file into a list of DSTExample.""" """Read a DST json file into a list of DSTExample."""
sys_inform_dict = load_acts(input_file)
with open(input_file, "r", encoding='utf-8') as reader: with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader) input_data = json.load(reader)
...@@ -178,6 +205,7 @@ def create_examples(input_file, set_type, slot_list, ...@@ -178,6 +205,7 @@ def create_examples(input_file, set_type, slot_list,
slot_list, slot_list,
dial_id, dial_id,
turn_id, turn_id,
sys_inform_dict,
delexicalize_sys_utts=delexicalize_sys_utts, delexicalize_sys_utts=delexicalize_sys_utts,
slot_last_occurrence=True) slot_last_occurrence=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment