Skip to content
Snippets Groups Projects
Commit a671faf3 authored by Christian's avatar Christian
Browse files

refactored remap_actions

parent b33beb54
No related branches found
No related tags found
No related merge requests found
......@@ -5,19 +5,59 @@ import os
from tqdm import tqdm
def get_keyword_domains(turn):
keyword_domains = []
text = turn['text']
for d in ["Hotel", "Restaurant", "Train"]:
if d.lower() in text.lower():
keyword_domains.append(d)
return keyword_domains
def get_current_domains_from_act(dialog_acts):
current_domains_temp = []
for dom_int in dialog_acts:
domain, intent = dom_int.split('-')
if domain in ["general", "Booking"]:
continue
if domain not in current_domains_temp:
current_domains_temp.append(domain)
return current_domains_temp
def get_next_user_act_domains(ori_dialog, turn_id):
domains = []
try:
next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act']
domains = get_current_domains_from_act(next_user_act)
except:
# will fail if system act is the last act of the dialogue
pass
return domains
def check_domain_booked(turn, booked_domains):
booked_domain_current = None
for domain in turn['metadata']:
if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains:
booked_domain_current = domain
booked_domains.append(domain)
return booked_domains, booked_domain_current
def flatten_acts(dialog_acts):
flattened_acts = []
contains_booking = False
for dom_int in dialog_acts:
domain, intent = dom_int.split('-')
for slot_value in dialog_acts[dom_int]:
slot = slot_value[0]
value = slot_value[1]
flattened_acts.append((domain, intent, slot, value))
if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref":
contains_booking = True
return flattened_acts, contains_booking
return flattened_acts
def deflat_acts(flattened_acts):
......@@ -41,7 +81,7 @@ def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_doma
error = 0
remapped_acts = []
# if there are more current domains, we try to get booked domain from booked_domain
# if there is more than one current domain or none at all, we try to get booked domain differently
if len(current_domains) != 1 and booked_domain:
current_domains = [booked_domain]
elif len(current_domains) != 1 and len(keyword_domains_user) == 1:
......@@ -74,7 +114,7 @@ def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_doma
remapped_acts.append((domain, "Book", "none", "none"))
remapped_acts.append((domain, intent, slot, value))
elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]:
# train-inform-ref actually triggers the booking and informs on the reference number
# train-inform/offerbooked-ref actually triggers the booking and informs on the reference number
remapped_acts.append((domain, "Book", "none", "none"))
remapped_acts.append((domain, "Inform", slot, value))
elif domain == "Train" and intent == "OfferBook":
......@@ -116,7 +156,6 @@ def preprocess():
val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split())
test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split())
booking_counter = 0
errors = 0
for ori_dialog_id, ori_dialog in tqdm(original_data.items()):
......@@ -129,84 +168,37 @@ def preprocess():
# add information to which split the dialogue belongs
ori_dialog['split'] = split
current_domains_after = []
current_domains_user = []
current_domains_system = []
booked_domains = []
current_domain_system = []
for turn_id, turn in enumerate(ori_dialog['log']):
# if it is a user turn, try to extract the current domain
if turn_id % 2 == 0:
dialog_acts = turn.get('dialog_act', [])
keyword_domains_user = []
text = turn['text']
for d in ["hotel", "restaurant", "train"]:
if d in text.lower():
keyword_domains_user.append(d)
if len(dialog_acts) != 0:
current_domains_after_ = []
for dom_int in dialog_acts.keys():
domain, intent = dom_int.split('-')
if domain == "general":
continue
if domain not in current_domains_after_:
current_domains_after_.append(domain)
if len(current_domains_after_) > 0:
current_domains_after = current_domains_after_
keyword_domains_user = get_keyword_domains(turn)
current_domains_temp = get_current_domains_from_act(dialog_acts)
current_domains_user = current_domains_temp if current_domains_temp else current_domains_user
else:
# get next user act
next_user_domain = []
try:
next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act']
for dom_int in next_user_act.keys():
domain, intent = dom_int.split('-')
if domain == "general":
continue
if domain not in next_user_domain:
next_user_domain.append(domain)
except:
# will fail if system act is the last act of the dialogue
pass
# it is a system turn, we now want to map the actions
keyword_domains_system = []
text = turn['text']
for d in ["hotel", "restaurant", "train"]:
if d in text.lower():
keyword_domains_system.append(d)
booked_domain_current = None
for domain in turn['metadata']:
if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains:
booked_domain_current = domain
booked_domains.append(domain)
dialog_acts = turn.get('dialog_act', [])
if dialog_acts:
# only need to go through that process if we have a dialogue act
spans = turn.get('span_info', [])
flattened_acts, contains_booking = flatten_acts(dialog_acts)
current_domain_temp = []
for a in flattened_acts:
d = a[0]
if d not in ["general", "Booking"] and d not in current_domain_temp:
current_domain_temp.append(d)
if len(current_domain_temp) != 0:
current_domain_system = current_domain_temp
keyword_domains_system = get_keyword_domains(turn)
current_domains_temp = get_current_domains_from_act(dialog_acts)
current_domains_system = current_domains_temp if current_domains_temp else current_domains_system
if contains_booking:
booking_counter += 1
booked_domains, booked_domain_current = check_domain_booked(turn, booked_domains)
next_user_domains = get_next_user_act_domains(ori_dialog, turn_id)
remapped_acts, error_local = remap_acts(flattened_acts, current_domains_after,
flattened_acts = flatten_acts(dialog_acts)
remapped_acts, error_local = remap_acts(flattened_acts, current_domains_user,
booked_domain_current, keyword_domains_user,
keyword_domains_system, current_domain_system,
next_user_domain)
keyword_domains_system, current_domains_system,
next_user_domains)
errors += error_local
if error_local > 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment