Skip to content
Snippets Groups Projects
Commit 0f4ecebf authored by Christian's avatar Christian
Browse files

can set an occurance bound when created the act set

parent 316e99da
Branches
No related tags found
No related merge requests found
......@@ -4,6 +4,40 @@ import json
import os
from tqdm import tqdm
MIN_OCCURENCE_ACT = 50
def write_system_act_set(new_sys_acts):
with open("sys_da_voc_remapped.txt", "w") as f:
new_sys_acts_list = []
for act in new_sys_acts:
if new_sys_acts[act] > MIN_OCCURENCE_ACT:
new_sys_acts_list.append(act)
new_sys_acts_list.sort()
for act in new_sys_acts_list:
f.write(act + "\n")
print("Saved new action dict.")
def delexicalize_da(da):
delexicalized_da = []
counter = {}
for domain, intent, slot, value in da:
if intent.lower() in ["request"]:
v = '?'
else:
if slot == 'none':
v = 'none'
else:
k = '-'.join([intent, domain, slot])
counter.setdefault(k, 0)
counter[k] += 1
v = str(counter[k])
delexicalized_da.append([domain, intent, slot, v])
return delexicalized_da
def get_keyword_domains(turn):
keyword_domains = []
......@@ -43,7 +77,7 @@ def check_domain_booked(turn, booked_domains):
booked_domain_current = None
for domain in turn['metadata']:
if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains:
booked_domain_current = domain
booked_domain_current = domain.capitalize()
booked_domains.append(domain)
return booked_domains, booked_domain_current
......@@ -178,6 +212,7 @@ def preprocess():
val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split())
test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split())
new_sys_acts = dict()
errors = 0
for ori_dialog_id, ori_dialog in tqdm(original_data.items()):
......@@ -224,6 +259,14 @@ def preprocess():
keyword_domains_system, current_domains_system,
next_user_domains)
delex_acts = delexicalize_da(remapped_acts)
for act in delex_acts:
act = "-".join(act)
if act not in new_sys_acts:
new_sys_acts[act] = 1
else:
new_sys_acts[act] += 1
remapped_span_acts, _ = remap_acts(flattened_span_acts, current_domains_user,
booked_domain_current, keyword_domains_user,
keyword_domains_system, current_domains_system,
......@@ -242,9 +285,12 @@ def preprocess():
print("Errors:", errors)
json.dump(original_data, open(f'{new_data_dir}/data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
write_system_act_set(new_sys_acts)
with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
for filename in os.listdir(new_data_dir):
zf.write(f'{new_data_dir}/{filename}')
print("Saved new data.")
rmtree(original_data_dir)
rmtree(new_data_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment