Skip to content
Snippets Groups Projects
Commit 316e99da authored by Christian's avatar Christian
Browse files

added remapping of span acts

parent a671faf3
No related branches found
No related tags found
No related merge requests found
......@@ -60,6 +60,15 @@ def flatten_acts(dialog_acts):
return flattened_acts
def flatten_span_acts(span_acts):
flattened_acts = []
for span_act in span_acts:
domain, intent = span_act[0].split("-")
flattened_acts.append((domain, intent, span_act[1], span_act[2:]))
return flattened_acts
def deflat_acts(flattened_acts):
dialog_acts = dict()
......@@ -74,6 +83,20 @@ def deflat_acts(flattened_acts):
return dialog_acts
def deflat_span_acts(flattened_acts):
dialog_span_acts = []
for act in flattened_acts:
domain, intent, slot, value = act
if value == 'none':
continue
new_act = [f"{domain}-{intent}", slot]
new_act.extend(value)
dialog_span_acts.append(new_act)
return dialog_span_acts
def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None,
keyword_domains_system=None, current_domain_system=None, next_user_domain=None):
......@@ -151,7 +174,6 @@ def preprocess():
copy2(f'{original_data_dir}/{filename}', new_data_dir)
original_data = json.load(open(f'{original_data_dir}/data.json'))
global init_ontology, cnt_domain_slot
val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split())
test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split())
......@@ -184,6 +206,7 @@ def preprocess():
else:
dialog_acts = turn.get('dialog_act', [])
span_acts = turn.get('span_info', [])
if dialog_acts:
# only need to go through that process if we have a dialogue act
......@@ -195,17 +218,26 @@ def preprocess():
next_user_domains = get_next_user_act_domains(ori_dialog, turn_id)
flattened_acts = flatten_acts(dialog_acts)
flattened_span_acts = flatten_span_acts(span_acts)
remapped_acts, error_local = remap_acts(flattened_acts, current_domains_user,
booked_domain_current, keyword_domains_user,
keyword_domains_system, current_domains_system,
next_user_domains)
remapped_span_acts, _ = remap_acts(flattened_span_acts, current_domains_user,
booked_domain_current, keyword_domains_user,
keyword_domains_system, current_domains_system,
next_user_domains)
errors += error_local
if error_local > 0:
print(ori_dialog_id)
deflattened_remapped_acts = deflat_acts(remapped_acts)
deflattened_remapped_span_acts = deflat_span_acts(remapped_span_acts)
turn['dialog_act'] = deflattened_remapped_acts
turn['span_info'] = deflattened_remapped_span_acts
print("Errors:", errors)
json.dump(original_data, open(f'{new_data_dir}/data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment