diff --git a/data/multiwoz/remap_actions.py b/data/multiwoz/remap_actions.py index bbcf1a2b71c3c39461b3b87d3cf508b35d3b1d29..ffac523b64e9c5a41cf2086d2e39fb54c848b838 100644 --- a/data/multiwoz/remap_actions.py +++ b/data/multiwoz/remap_actions.py @@ -5,19 +5,59 @@ import os from tqdm import tqdm +def get_keyword_domains(turn): + keyword_domains = [] + text = turn['text'] + for d in ["Hotel", "Restaurant", "Train"]: + if d.lower() in text.lower(): + keyword_domains.append(d) + return keyword_domains + + +def get_current_domains_from_act(dialog_acts): + + current_domains_temp = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + if domain in ["general", "Booking"]: + continue + if domain not in current_domains_temp: + current_domains_temp.append(domain) + + return current_domains_temp + + +def get_next_user_act_domains(ori_dialog, turn_id): + domains = [] + try: + next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act'] + domains = get_current_domains_from_act(next_user_act) + except: + # will fail if system act is the last act of the dialogue + pass + return domains + + +def check_domain_booked(turn, booked_domains): + + booked_domain_current = None + for domain in turn['metadata']: + if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains: + booked_domain_current = domain + booked_domains.append(domain) + return booked_domains, booked_domain_current + + def flatten_acts(dialog_acts): flattened_acts = [] - contains_booking = False for dom_int in dialog_acts: domain, intent = dom_int.split('-') for slot_value in dialog_acts[dom_int]: slot = slot_value[0] value = slot_value[1] flattened_acts.append((domain, intent, slot, value)) - if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref": - contains_booking = True - return flattened_acts, contains_booking + return flattened_acts def deflat_acts(flattened_acts): @@ -41,7 +81,7 @@ def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_doma error = 0 remapped_acts = [] - # if there are more current domains, we try to get booked domain from booked_domain + # if there is more than one current domain or none at all, we try to get booked domain differently if len(current_domains) != 1 and booked_domain: current_domains = [booked_domain] elif len(current_domains) != 1 and len(keyword_domains_user) == 1: @@ -74,7 +114,7 @@ def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_doma remapped_acts.append((domain, "Book", "none", "none")) remapped_acts.append((domain, intent, slot, value)) elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]: - # train-inform-ref actually triggers the booking and informs on the reference number + # train-inform/offerbooked-ref actually triggers the booking and informs on the reference number remapped_acts.append((domain, "Book", "none", "none")) remapped_acts.append((domain, "Inform", slot, value)) elif domain == "Train" and intent == "OfferBook": @@ -116,7 +156,6 @@ def preprocess(): val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split()) test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split()) - booking_counter = 0 errors = 0 for ori_dialog_id, ori_dialog in tqdm(original_data.items()): @@ -129,84 +168,37 @@ def preprocess(): # add information to which split the dialogue belongs ori_dialog['split'] = split - current_domains_after = [] + current_domains_user = [] + current_domains_system = [] booked_domains = [] - current_domain_system = [] for turn_id, turn in enumerate(ori_dialog['log']): # if it is a user turn, try to extract the current domain if turn_id % 2 == 0: dialog_acts = turn.get('dialog_act', []) - keyword_domains_user = [] - text = turn['text'] - - for d in ["hotel", "restaurant", "train"]: - if d in text.lower(): - keyword_domains_user.append(d) - - if len(dialog_acts) != 0: - current_domains_after_ = [] - for dom_int in dialog_acts.keys(): - domain, intent = dom_int.split('-') - if domain == "general": - continue - if domain not in current_domains_after_: - current_domains_after_.append(domain) - if len(current_domains_after_) > 0: - current_domains_after = current_domains_after_ + keyword_domains_user = get_keyword_domains(turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + current_domains_user = current_domains_temp if current_domains_temp else current_domains_user else: - # get next user act - next_user_domain = [] - try: - next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act'] - for dom_int in next_user_act.keys(): - domain, intent = dom_int.split('-') - if domain == "general": - continue - if domain not in next_user_domain: - next_user_domain.append(domain) - except: - # will fail if system act is the last act of the dialogue - pass - - # it is a system turn, we now want to map the actions - keyword_domains_system = [] - text = turn['text'] - for d in ["hotel", "restaurant", "train"]: - if d in text.lower(): - keyword_domains_system.append(d) - - booked_domain_current = None - for domain in turn['metadata']: - if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains: - booked_domain_current = domain - booked_domains.append(domain) - dialog_acts = turn.get('dialog_act', []) if dialog_acts: + # only need to go through that process if we have a dialogue act - spans = turn.get('span_info', []) - - flattened_acts, contains_booking = flatten_acts(dialog_acts) - - current_domain_temp = [] - for a in flattened_acts: - d = a[0] - if d not in ["general", "Booking"] and d not in current_domain_temp: - current_domain_temp.append(d) - if len(current_domain_temp) != 0: - current_domain_system = current_domain_temp + keyword_domains_system = get_keyword_domains(turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + current_domains_system = current_domains_temp if current_domains_temp else current_domains_system - if contains_booking: - booking_counter += 1 + booked_domains, booked_domain_current = check_domain_booked(turn, booked_domains) + next_user_domains = get_next_user_act_domains(ori_dialog, turn_id) - remapped_acts, error_local = remap_acts(flattened_acts, current_domains_after, + flattened_acts = flatten_acts(dialog_acts) + remapped_acts, error_local = remap_acts(flattened_acts, current_domains_user, booked_domain_current, keyword_domains_user, - keyword_domains_system, current_domain_system, - next_user_domain) + keyword_domains_system, current_domains_system, + next_user_domains) errors += error_local if error_local > 0: @@ -226,4 +218,4 @@ def preprocess(): if __name__ == '__main__': - preprocess() \ No newline at end of file + preprocess()