From 732944e24eb53c997f9961960fff4fc286f9bf37 Mon Sep 17 00:00:00 2001 From: heck <heckmi@hhu.de> Date: Fri, 13 Nov 2020 10:01:40 +0000 Subject: [PATCH] fixes for dataset_* --- dataset_multiwoz21.py | 15 +++++++-------- dataset_sim.py | 21 +++++++++++---------- dataset_woz2.py | 9 +++++---- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py index 9751f5c..269c2bb 100644 --- a/dataset_multiwoz21.py +++ b/dataset_multiwoz21.py @@ -264,6 +264,7 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se class_type = value_label else: in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok) + is_informed, informed_value = check_slot_inform(value_label, inform_label) if in_usr: class_type = 'copy_value' if slot_last_occurrence: @@ -274,16 +275,14 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se for (s, e) in usr_pos: for i in range(s, e): usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' else: - is_informed, informed_value = check_slot_inform(value_label, inform_label) - if is_informed: - class_type = 'inform' + referred_slot = check_slot_referral(value_label, slot, seen_slots) + if referred_slot != 'none': + class_type = 'refer' else: - referred_slot = check_slot_referral(value_label, slot, seen_slots) - if referred_slot != 'none': - class_type = 'refer' - else: - class_type = 'unpointable' + class_type = 'unpointable' return informed_value, referred_slot, usr_utt_tok_label, class_type diff --git a/dataset_sim.py b/dataset_sim.py index e30ed09..6685f4e 100644 --- a/dataset_sim.py +++ b/dataset_sim.py @@ -62,23 +62,24 @@ def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, for label_d in usr_slot_label: if label_d['slot'] == slot_type and value == ' '.join( usr_utt_tok[label_d['start']:label_d['exclusive_end']]): - for idx in range(label_d['start'], label_d['exclusive_end']): usr_utt_tok_label[idx] = 1 in_usr = True class_type = 'copy_value' if slot_last_occurrence: break - if not in_usr or not slot_last_occurrence: - for label_d in sys_slot_label: - if label_d['slot'] == slot_type and value == ' '.join( - sys_utt_tok[label_d['start']:label_d['exclusive_end']]): - for idx in range(label_d['start'], label_d['exclusive_end']): - sys_utt_tok_label[idx] = 1 - in_sys = True + + for label_d in sys_slot_label: + if label_d['slot'] == slot_type and value == ' '.join( + sys_utt_tok[label_d['start']:label_d['exclusive_end']]): + for idx in range(label_d['start'], label_d['exclusive_end']): + sys_utt_tok_label[idx] = 1 + in_sys = True + if not in_usr or not slot_last_occurrence: class_type = 'inform' - if slot_last_occurrence: - break + if slot_last_occurrence: + break + if not in_usr and not in_sys: assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0 if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]): diff --git a/dataset_woz2.py b/dataset_woz2.py index c32af05..fcb802f 100644 --- a/dataset_woz2.py +++ b/dataset_woz2.py @@ -87,7 +87,7 @@ def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence): class_type = 'inform' else: class_type = 'unpointable' - return usr_utt_tok_label, class_type + return usr_utt_tok_label, class_type, in_sys def tokenize(utt): @@ -174,15 +174,16 @@ def create_examples(input_file, set_type, slot_list, label = diag_seen_slots_value_dict[slot] (usr_utt_tok_label, - class_type) = get_turn_label(label, + class_type, + is_informed) = get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence=True) if class_type == 'inform': inform_dict[slot] = label - if label != 'none': - inform_slot_dict[slot] = 1 + if is_informed and label != 'none': + inform_slot_dict[slot] = 1 referral_dict[slot] = 'none' # Referral is not present in woz2 data -- GitLab