diff --git a/data/unified_datasets/multiwoz21/booking_remapper.py b/data/unified_datasets/multiwoz21/booking_remapper.py index b94a32434e362737eea396ab14dcbe6591feecef..8ac40b1caa03312f1130d68f029b3081aa531608 100644 --- a/data/unified_datasets/multiwoz21/booking_remapper.py +++ b/data/unified_datasets/multiwoz21/booking_remapper.py @@ -12,6 +12,7 @@ class BookingActRemapper: def retrieve_current_domain_from_user(self, turn_id, ori_dialog): prev_user_turn = ori_dialog[turn_id - 1] + dialog_acts = prev_user_turn.get('dialog_act', []) keyword_domains_user = get_keyword_domains(prev_user_turn) current_domains_temp = get_current_domains_from_act(dialog_acts) @@ -27,19 +28,19 @@ class BookingActRemapper: keyword_domains_system = get_keyword_domains(system_turn) current_domains_temp = get_current_domains_from_act(dialog_acts) self.current_domains_system = current_domains_temp if current_domains_temp else self.current_domains_system - self.booked_domains, booked_domain_current = check_domain_booked(system_turn, self.booked_domains) + booked_domain_current = self.check_domain_booked(system_turn) return keyword_domains_system, booked_domain_current def remap(self, turn_id, ori_dialog): + keyword_domains_user, next_user_domains = self.retrieve_current_domain_from_user(turn_id, ori_dialog) + keyword_domains_system, booked_domain_current = self.retrieve_current_domain_from_system(turn_id, ori_dialog) + # only need to remap if there is a dialog action labelled dialog_acts = ori_dialog[turn_id].get('dialog_act', []) spans = ori_dialog[turn_id].get('span_info', []) - if ori_dialog[turn_id].get('dialog_act', []): - - keyword_domains_user, next_user_domains = self.retrieve_current_domain_from_user(turn_id, ori_dialog) - keyword_domains_system, booked_domain_current = self.retrieve_current_domain_from_system(turn_id, ori_dialog) + if dialog_acts: flattened_acts = flatten_acts(dialog_acts) flattened_spans = flatten_span_acts(spans) @@ -60,6 +61,15 @@ class BookingActRemapper: else: return dialog_acts, spans + def check_domain_booked(self, turn): + + booked_domain_current = None + for domain in turn['metadata']: + if turn['metadata'][domain]["book"]["booked"] and domain not in self.booked_domains: + booked_domain_current = domain.capitalize() + self.booked_domains.append(domain) + return booked_domain_current + def get_keyword_domains(turn): keyword_domains = [] @@ -86,7 +96,7 @@ def get_current_domains_from_act(dialog_acts): def get_next_user_act_domains(ori_dialog, turn_id): domains = [] try: - next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act'] + next_user_act = ori_dialog[turn_id + 1]['dialog_act'] domains = get_current_domains_from_act(next_user_act) except: # will fail if system act is the last act of the dialogue @@ -94,16 +104,6 @@ def get_next_user_act_domains(ori_dialog, turn_id): return domains -def check_domain_booked(turn, booked_domains): - - booked_domain_current = None - for domain in turn['metadata']: - if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains: - booked_domain_current = domain.capitalize() - booked_domains.append(domain) - return booked_domains, booked_domain_current - - def flatten_acts(dialog_acts): flattened_acts = [] for dom_int in dialog_acts: diff --git a/data/unified_datasets/multiwoz21/preprocess.py b/data/unified_datasets/multiwoz21/preprocess.py index 7f0f5c6a8fd8c9f42919d9ded6d690813cb935aa..c354a8333dd9dc1ad93e3258034a09fe92c9a618 100644 --- a/data/unified_datasets/multiwoz21/preprocess.py +++ b/data/unified_datasets/multiwoz21/preprocess.py @@ -833,7 +833,6 @@ def preprocess(): if speaker == 'system': das, spans = booking_remapper.remap(turn_id, ori_dialog['log']) - print(ori_dialog['log'][turn_id]) da_dict = {} # transform DA @@ -841,18 +840,24 @@ def preprocess(): domain, intent = Domain_Intent.lower().split('-') assert intent in init_ontology['intents'], f'{ori_dialog_id}:{turn_id}:da\t{intent} not in ontology' for Slot, value in das[Domain_Intent]: - domain, slot, value = normalize_domain_slot_value(domain, Slot, value) - if domain not in cur_domains: - # update original cur_domains - cur_domains.append(domain) - da_dict[(intent, domain, slot, value,)] = [] + try: + domain, slot, value = normalize_domain_slot_value(domain, Slot, value) + if domain not in cur_domains: + # update original cur_domains + cur_domains.append(domain) + da_dict[(intent, domain, slot, value,)] = [] + except: + pass for span in spans: - Domain_Intent, Slot, value, start_word, end_word = span - domain, intent = Domain_Intent.lower().split('-') - domain, slot, value = normalize_domain_slot_value(domain, Slot, value) - assert (intent, domain, slot, value,) in da_dict - da_dict[(intent, domain, slot, value,)] = [start_word, end_word] + try: + Domain_Intent, Slot, value, start_word, end_word = span + domain, intent = Domain_Intent.lower().split('-') + domain, slot, value = normalize_domain_slot_value(domain, Slot, value) + assert (intent, domain, slot, value,) in da_dict + da_dict[(intent, domain, slot, value,)] = [start_word, end_word] + except: + pass dialogue_acts = convert_da(da_dict, utt, sent_tokenizer, word_tokenizer) # will also update ontology