diff --git a/data/multiwoz/remap_actions.py b/data/multiwoz/remap_actions.py index ffac523b64e9c5a41cf2086d2e39fb54c848b838..658ab0ea84cef53d54d4e42dcdc0b481bb437777 100644 --- a/data/multiwoz/remap_actions.py +++ b/data/multiwoz/remap_actions.py @@ -60,6 +60,15 @@ def flatten_acts(dialog_acts): return flattened_acts +def flatten_span_acts(span_acts): + + flattened_acts = [] + for span_act in span_acts: + domain, intent = span_act[0].split("-") + flattened_acts.append((domain, intent, span_act[1], span_act[2:])) + return flattened_acts + + def deflat_acts(flattened_acts): dialog_acts = dict() @@ -74,6 +83,20 @@ def deflat_acts(flattened_acts): return dialog_acts +def deflat_span_acts(flattened_acts): + + dialog_span_acts = [] + for act in flattened_acts: + domain, intent, slot, value = act + if value == 'none': + continue + new_act = [f"{domain}-{intent}", slot] + new_act.extend(value) + dialog_span_acts.append(new_act) + + return dialog_span_acts + + def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None, keyword_domains_system=None, current_domain_system=None, next_user_domain=None): @@ -151,7 +174,6 @@ def preprocess(): copy2(f'{original_data_dir}/{filename}', new_data_dir) original_data = json.load(open(f'{original_data_dir}/data.json')) - global init_ontology, cnt_domain_slot val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split()) test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split()) @@ -184,6 +206,7 @@ def preprocess(): else: dialog_acts = turn.get('dialog_act', []) + span_acts = turn.get('span_info', []) if dialog_acts: # only need to go through that process if we have a dialogue act @@ -195,17 +218,26 @@ def preprocess(): next_user_domains = get_next_user_act_domains(ori_dialog, turn_id) flattened_acts = flatten_acts(dialog_acts) + flattened_span_acts = flatten_span_acts(span_acts) remapped_acts, error_local = remap_acts(flattened_acts, current_domains_user, booked_domain_current, keyword_domains_user, keyword_domains_system, current_domains_system, next_user_domains) + + remapped_span_acts, _ = remap_acts(flattened_span_acts, current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, current_domains_system, + next_user_domains) + errors += error_local if error_local > 0: print(ori_dialog_id) deflattened_remapped_acts = deflat_acts(remapped_acts) + deflattened_remapped_span_acts = deflat_span_acts(remapped_span_acts) turn['dialog_act'] = deflattened_remapped_acts + turn['span_info'] = deflattened_remapped_span_acts print("Errors:", errors) json.dump(original_data, open(f'{new_data_dir}/data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)