Skip to content
Snippets Groups Projects
Commit 8b96d4e1 authored by Christian's avatar Christian
Browse files

first version of booking_remapper, there are some assertion errors due to...

first version of booking_remapper, there are some assertion errors due to incorrect labelling of data
parent 027f3425
Branches
No related tags found
No related merge requests found
class BookingActRemapper:
def __init__(self):
self.reset()
def reset(self):
self.current_domains_user = []
self.current_domains_system = []
self.booked_domains = []
def retrieve_current_domain_from_user(self, turn_id, ori_dialog):
prev_user_turn = ori_dialog[turn_id - 1]
dialog_acts = prev_user_turn.get('dialog_act', [])
keyword_domains_user = get_keyword_domains(prev_user_turn)
current_domains_temp = get_current_domains_from_act(dialog_acts)
self.current_domains_user = current_domains_temp if current_domains_temp else self.current_domains_user
next_user_domains = get_next_user_act_domains(ori_dialog, turn_id)
return keyword_domains_user, next_user_domains
def retrieve_current_domain_from_system(self, turn_id, ori_dialog):
system_turn = ori_dialog[turn_id]
dialog_acts = system_turn.get('dialog_act', [])
keyword_domains_system = get_keyword_domains(system_turn)
current_domains_temp = get_current_domains_from_act(dialog_acts)
self.current_domains_system = current_domains_temp if current_domains_temp else self.current_domains_system
self.booked_domains, booked_domain_current = check_domain_booked(system_turn, self.booked_domains)
return keyword_domains_system, booked_domain_current
def remap(self, turn_id, ori_dialog):
# only need to remap if there is a dialog action labelled
dialog_acts = ori_dialog[turn_id].get('dialog_act', [])
spans = ori_dialog[turn_id].get('span_info', [])
if ori_dialog[turn_id].get('dialog_act', []):
keyword_domains_user, next_user_domains = self.retrieve_current_domain_from_user(turn_id, ori_dialog)
keyword_domains_system, booked_domain_current = self.retrieve_current_domain_from_system(turn_id, ori_dialog)
flattened_acts = flatten_acts(dialog_acts)
flattened_spans = flatten_span_acts(spans)
remapped_acts, error_local = remap_acts(flattened_acts, self.current_domains_user,
booked_domain_current, keyword_domains_user,
keyword_domains_system, self.current_domains_system,
next_user_domains)
remapped_spans, _ = remap_acts(flattened_spans, self.current_domains_user,
booked_domain_current, keyword_domains_user,
keyword_domains_system, self.current_domains_system,
next_user_domains)
deflattened_remapped_acts = deflat_acts(remapped_acts)
deflattened_remapped_spans = deflat_span_acts(remapped_spans)
return deflattened_remapped_acts, deflattened_remapped_spans
else:
return dialog_acts, spans
def get_keyword_domains(turn):
keyword_domains = []
text = turn['text']
for d in ["Hotel", "Restaurant", "Train"]:
if d.lower() in text.lower():
keyword_domains.append(d)
return keyword_domains
def get_current_domains_from_act(dialog_acts):
current_domains_temp = []
for dom_int in dialog_acts:
domain, intent = dom_int.split('-')
if domain in ["general", "Booking"]:
continue
if domain not in current_domains_temp:
current_domains_temp.append(domain)
return current_domains_temp
def get_next_user_act_domains(ori_dialog, turn_id):
domains = []
try:
next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act']
domains = get_current_domains_from_act(next_user_act)
except:
# will fail if system act is the last act of the dialogue
pass
return domains
def check_domain_booked(turn, booked_domains):
booked_domain_current = None
for domain in turn['metadata']:
if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains:
booked_domain_current = domain.capitalize()
booked_domains.append(domain)
return booked_domains, booked_domain_current
def flatten_acts(dialog_acts):
flattened_acts = []
for dom_int in dialog_acts:
domain, intent = dom_int.split('-')
for slot_value in dialog_acts[dom_int]:
slot = slot_value[0]
value = slot_value[1]
flattened_acts.append((domain, intent, slot, value))
return flattened_acts
def flatten_span_acts(span_acts):
flattened_acts = []
for span_act in span_acts:
domain, intent = span_act[0].split("-")
flattened_acts.append((domain, intent, span_act[1], span_act[2:]))
return flattened_acts
def deflat_acts(flattened_acts):
dialog_acts = dict()
for act in flattened_acts:
domain, intent, slot, value = act
if f"{domain}-{intent}" not in dialog_acts.keys():
dialog_acts[f"{domain}-{intent}"] = [[slot, value]]
else:
dialog_acts[f"{domain}-{intent}"].append([slot, value])
return dialog_acts
def deflat_span_acts(flattened_acts):
dialog_span_acts = []
for act in flattened_acts:
domain, intent, slot, value = act
if value == 'none':
continue
new_act = [f"{domain}-{intent}", slot]
new_act.extend(value)
dialog_span_acts.append(new_act)
return dialog_span_acts
def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None,
keyword_domains_system=None, current_domain_system=None, next_user_domain=None):
# We now look for all cases that can happen: Booking domain, Booking within a domain or taxi-inform-car for booking
error = 0
remapped_acts = []
# if there is more than one current domain or none at all, we try to get booked domain differently
if len(current_domains) != 1 and booked_domain:
current_domains = [booked_domain]
elif len(current_domains) != 1 and len(keyword_domains_user) == 1:
current_domains = keyword_domains_user
elif len(current_domains) != 1 and len(keyword_domains_system) == 1:
current_domains = keyword_domains_system
elif len(current_domains) != 1 and len(current_domain_system) == 1:
current_domains = current_domain_system
elif len(current_domains) != 1 and len(next_user_domain) == 1:
current_domains = next_user_domain
for act in flattened_acts:
try:
domain, intent, slot, value = act
if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref":
# We need to remap that booking act now
assert len(current_domains) == 1, "Can not resolve booking-book act because there are more current domains"
remapped_acts.append((current_domains[0], "Book", "none", "none"))
remapped_acts.append((current_domains[0], "Inform", "Ref", value))
elif domain == "Booking" and intent == "Book" and slot != "Ref":
# the book intent is here actually an inform intent according to the data
remapped_acts.append((current_domains[0], "Inform", slot, value))
elif domain == "Booking" and intent == "Inform":
# the inform intent is here actually a request intent according to the data
remapped_acts.append((current_domains[0], "OfferBook", slot, value))
elif domain == "Booking" and intent in ["NoBook", "Request"]:
remapped_acts.append((current_domains[0], intent, slot, value))
elif f"{domain}-{intent}-{slot}" == "Taxi-Inform-Car":
# taxi-inform-car actually triggers the booking and informs on a car
remapped_acts.append((domain, "Book", "none", "none"))
remapped_acts.append((domain, intent, slot, value))
elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]:
# train-inform/offerbooked-ref actually triggers the booking and informs on the reference number
remapped_acts.append((domain, "Book", "none", "none"))
remapped_acts.append((domain, "Inform", slot, value))
elif domain == "Train" and intent == "OfferBooked" and slot != "Ref":
# this is actually an inform act
remapped_acts.append((domain, "Inform", slot, value))
else:
remapped_acts.append(act)
except Exception as e:
print("Error detected:", e)
error += 1
return remapped_acts, error
\ No newline at end of file
...@@ -8,6 +8,7 @@ from tqdm import tqdm ...@@ -8,6 +8,7 @@ from tqdm import tqdm
from collections import Counter from collections import Counter
from pprint import pprint from pprint import pprint
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from data.unified_datasets.multiwoz21.booking_remapper import BookingActRemapper
init_ontology = { init_ontology = {
"domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different "domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different
...@@ -765,6 +766,7 @@ def preprocess(): ...@@ -765,6 +766,7 @@ def preprocess():
dialogues_by_split = {split:[] for split in splits} dialogues_by_split = {split:[] for split in splits}
sent_tokenizer = PunktSentenceTokenizer() sent_tokenizer = PunktSentenceTokenizer()
word_tokenizer = TreebankWordTokenizer() word_tokenizer = TreebankWordTokenizer()
booking_remapper = BookingActRemapper()
for ori_dialog_id, ori_dialog in tqdm(original_data.items()): for ori_dialog_id, ori_dialog in tqdm(original_data.items()):
if ori_dialog_id in val_list: if ori_dialog_id in val_list:
split = 'validation' split = 'validation'
...@@ -811,6 +813,7 @@ def preprocess(): ...@@ -811,6 +813,7 @@ def preprocess():
'turns': [] 'turns': []
} }
booking_remapper.reset()
for turn_id, turn in enumerate(ori_dialog['log']): for turn_id, turn in enumerate(ori_dialog['log']):
# correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1 # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1
text = turn['text'] text = turn['text']
...@@ -827,6 +830,11 @@ def preprocess(): ...@@ -827,6 +830,11 @@ def preprocess():
das = turn.get('dialog_act', []) das = turn.get('dialog_act', [])
spans = turn.get('span_info', []) spans = turn.get('span_info', [])
if speaker == 'system':
das, spans = booking_remapper.remap(turn_id, ori_dialog['log'])
print(ori_dialog['log'][turn_id])
da_dict = {} da_dict = {}
# transform DA # transform DA
for Domain_Intent in das: for Domain_Intent in das:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment