diff --git a/data/multiwoz/remap_actions.py b/data/multiwoz/remap_actions.py index 658ab0ea84cef53d54d4e42dcdc0b481bb437777..ab7e48adc96090e56b371ee27a7ff7c0a621cdc9 100644 --- a/data/multiwoz/remap_actions.py +++ b/data/multiwoz/remap_actions.py @@ -4,6 +4,40 @@ import json import os from tqdm import tqdm +MIN_OCCURENCE_ACT = 50 + + +def write_system_act_set(new_sys_acts): + + with open("sys_da_voc_remapped.txt", "w") as f: + new_sys_acts_list = [] + for act in new_sys_acts: + if new_sys_acts[act] > MIN_OCCURENCE_ACT: + new_sys_acts_list.append(act) + + new_sys_acts_list.sort() + for act in new_sys_acts_list: + f.write(act + "\n") + print("Saved new action dict.") + + +def delexicalize_da(da): + delexicalized_da = [] + counter = {} + for domain, intent, slot, value in da: + if intent.lower() in ["request"]: + v = '?' + else: + if slot == 'none': + v = 'none' + else: + k = '-'.join([intent, domain, slot]) + counter.setdefault(k, 0) + counter[k] += 1 + v = str(counter[k]) + delexicalized_da.append([domain, intent, slot, v]) + return delexicalized_da + def get_keyword_domains(turn): keyword_domains = [] @@ -43,7 +77,7 @@ def check_domain_booked(turn, booked_domains): booked_domain_current = None for domain in turn['metadata']: if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains: - booked_domain_current = domain + booked_domain_current = domain.capitalize() booked_domains.append(domain) return booked_domains, booked_domain_current @@ -178,6 +212,7 @@ def preprocess(): val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split()) test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split()) + new_sys_acts = dict() errors = 0 for ori_dialog_id, ori_dialog in tqdm(original_data.items()): @@ -224,6 +259,14 @@ def preprocess(): keyword_domains_system, current_domains_system, next_user_domains) + delex_acts = delexicalize_da(remapped_acts) + for act in delex_acts: + act = "-".join(act) + if act not in new_sys_acts: + new_sys_acts[act] = 1 + else: + new_sys_acts[act] += 1 + remapped_span_acts, _ = remap_acts(flattened_span_acts, current_domains_user, booked_domain_current, keyword_domains_user, keyword_domains_system, current_domains_system, @@ -242,9 +285,12 @@ def preprocess(): print("Errors:", errors) json.dump(original_data, open(f'{new_data_dir}/data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + write_system_act_set(new_sys_acts) + with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: for filename in os.listdir(new_data_dir): zf.write(f'{new_data_dir}/{filename}') + print("Saved new data.") rmtree(original_data_dir) rmtree(new_data_dir)