diff --git a/.gitignore b/.gitignore index a2820f1eab2c138e99371195495312a75e1bcb0e..4e467920e8bcde078b70436a7297f84669d54c9a 100644 --- a/.gitignore +++ b/.gitignore @@ -69,7 +69,8 @@ convlab2.egg-info # configs - +*experiment* +*pretrained_models* .ipynb_checkpoints ## dst files diff --git a/convlab2/dialog_agent/agent.py b/convlab2/dialog_agent/agent.py index 1afbc936c036e877b1f7ca21ea585a73883d613c..2feed5ad6e0aebe0fea218d3801c9b1a150aea6f 100755 --- a/convlab2/dialog_agent/agent.py +++ b/convlab2/dialog_agent/agent.py @@ -196,14 +196,8 @@ class PipelineAgent(Agent): for intent, domain, slot, value in self.output_action: if domain.lower() not in ['general', 'booking']: self.cur_domain = domain - dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}' - if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']: - if self.cur_domain: - self.dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref': - self.dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'taxi-inform-car': - self.dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}] + if intent == "book": + self.dst.state['belief_state'][domain.lower()]['book']['booked'] = [{slot.lower(): value}] else: self.dst.state['user_action'] = self.output_action # user dst is also updated by itself diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index 202e248f217cf7f42f281a398c5f45c32068299a..0ec4aeedee5b9bdec736bc05cf490898959cff59 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -111,21 +111,18 @@ class MultiWozEvaluator(Evaluator): value = str(value) self.sys_da_array.append(da + '-' + value) - if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']: - if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \ - len(self.dbs[self.cur_domain]) > int(value): - self.booked[self.cur_domain] = self.dbs[self.cur_domain][int( - value)].copy() - self.booked[self.cur_domain]['Ref'] = value - self.booked_states[self.cur_domain] = belief_state[self.cur_domain] - elif da == 'train-offerbooked-ref' or da == 'train-inform-ref': - if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value): - self.booked['train'] = self.dbs['train'][int(value)].copy() - self.booked['train']['Ref'] = value - self.booked_states[self.cur_domain] = belief_state[self.cur_domain] - elif da == 'taxi-inform-car': - if not self.booked['taxi']: - self.booked['taxi'] = 'booked' + # new booking actions make life easier + if intent.lower() == "book": + # taxi has no DB queries + if domain.lower() == "taxi": + if not self.booked['taxi']: + self.booked['taxi'] = 'booked' + else: + if not self.booked[domain] and re.match(r'^\d{8}$', value) and \ + len(self.dbs[domain]) > int(value): + self.booked[domain] = self.dbs[domain][int(value)].copy() + self.booked[domain]['Ref'] = value + self.booked_states[domain] = belief_state[domain] def add_usr_da(self, da_turn): """add usr_da into array diff --git a/convlab2/policy/evaluate.py b/convlab2/policy/evaluate.py index 471e6c8ea487e1d7df5a3f906e8adad5f4a6c468..da5d184f044cf5e7dcf11b15d3ebd61b078fb557 100755 --- a/convlab2/policy/evaluate.py +++ b/convlab2/policy/evaluate.py @@ -9,6 +9,7 @@ import json import logging import os import random +from convlab2.policy.vector.vector_multiwoz import MultiWozVector import numpy as np import torch @@ -167,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v if model_name == "PPO": from convlab2.policy.ppo import PPO if load_path: - policy_sys = PPO(False) + policy_sys = PPO(False, vectorizer=MultiWozVector()) policy_sys.load(load_path) else: policy_sys = PPO.from_pretrained() diff --git a/convlab2/policy/mle/loader.py b/convlab2/policy/mle/loader.py index 8bea164bb00fc5827153ba0c50b61f963257b57e..349783d8262d7dc137b3541023f069bba2a6772a 100755 --- a/convlab2/policy/mle/loader.py +++ b/convlab2/policy/mle/loader.py @@ -16,10 +16,12 @@ class ActMLEPolicyDataLoader: def _build_data(self, root_dir, processed_dir): self.data = {} + print("Initialise DataLoader") data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader()) + raw_data_all = data_loader.load_data(data_key='all', role='sys') for part in ['train', 'val', 'test']: self.data[part] = [] - raw_data = data_loader.load_data(data_key=part, role='sys')[part] + raw_data = raw_data_all[part] for belief_state, context_dialog_act, terminated, dialog_act, goal in \ zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], diff --git a/convlab2/policy/mle/multiwoz/loader.py b/convlab2/policy/mle/multiwoz/loader.py index bc9467613bb49685af612699a096f99e9d8d9d28..13ea8f35dcd10eb2b970487c00eee22bb2408a01 100755 --- a/convlab2/policy/mle/multiwoz/loader.py +++ b/convlab2/policy/mle/multiwoz/loader.py @@ -7,13 +7,11 @@ class ActMLEPolicyDataLoaderMultiWoz(ActMLEPolicyDataLoader): def __init__(self, vectoriser=None): root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') - voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt') if vectoriser: self.vector = vectoriser else: print("We use vanilla Vectoriser") - self.vector = MultiWozVector(voc_file, voc_opp_file) + self.vector = MultiWozVector() processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data') diff --git a/convlab2/policy/mle/multiwoz/mle.py b/convlab2/policy/mle/multiwoz/mle.py index f0377524be1dfadcd56ba1bee2a5c81e0bd1a841..b614b55e733c6b74565a2cdf1079bf3857869dd5 100755 --- a/convlab2/policy/mle/multiwoz/mle.py +++ b/convlab2/policy/mle/multiwoz/mle.py @@ -11,6 +11,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_multiwoz.zip") + class MLE(MLEAbstract): def __init__(self): diff --git a/convlab2/policy/ppo/semantic_level_config.json b/convlab2/policy/ppo/semantic_level_config.json index 095a21b5bc88ee4dcc515b0ec63b35d0ab651624..24cfeb577eed03c4e808a8343c49539760b9ad9e 100644 --- a/convlab2/policy/ppo/semantic_level_config.json +++ b/convlab2/policy/ppo/semantic_level_config.json @@ -1,7 +1,7 @@ { "model": { - "load_path": "convlab2/policy/mle/multiwoz/experiment_2021-12-15-11-12-07/save/supervised", - "use_pretrained_initialisation": true, + "load_path": "convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/save/supervised", + "use_pretrained_initialisation": false, "pretrained_load_path": "", "batchsz": 1000, "seed": 0, diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index f5c5aa58d1d9d8c769ee0445dfff833e42c96cb7..b2778a31925aaee64374b17975899d54a1490a62 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -486,7 +486,7 @@ class Agenda(object): continue slot_vals = sys_action[diaact] - if 'booking' in diaact: + if 'book' in diaact: if self.update_booking(diaact, slot_vals, goal): return elif 'general' in diaact: @@ -502,7 +502,7 @@ class Agenda(object): if slot == 'name': self._remove_item(diaact.split( '-')[0]+'-inform', 'choice') - if 'booking' in diaact and self.cur_domain: + if 'book' in diaact and self.cur_domain: g_book = self._get_goal_infos(self.cur_domain, goal)[-2] if len(g_book) == 0: self._push_item(self.cur_domain + @@ -533,8 +533,7 @@ class Agenda(object): :param goal: Goal :return: True:user want to close the session. False:session is continue """ - _, intent = diaact.split('-') - domain = self.cur_domain + domain, intent = diaact.split('-') self.domains['update_booking'] = domain isover = False if domain not in goal.domains: @@ -682,7 +681,7 @@ class Agenda(object): self._push_item(domain + '-inform', slot, g_book[slot]) info_right = False - if intent in ['book', 'offerbooked'] and info_right: + if intent in ['book'] and info_right: # booked ok if 'booked' in goal.domain_goals[domain]: goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED diff --git a/convlab2/policy/vector/vector_base.py b/convlab2/policy/vector/vector_base.py index 48e0529e9868a6e9baaeb84103c8d09422a69175..d040e6e5db3be49e9d4b54500d73b403f23d0082 100644 --- a/convlab2/policy/vector/vector_base.py +++ b/convlab2/policy/vector/vector_base.py @@ -57,7 +57,7 @@ class MultiWozVectorBase(Vector): if not voc_file or not voc_opp_file: voc_file = os.path.join( - root_dir, 'data/multiwoz/sys_da_voc.txt') + root_dir, 'data/multiwoz/sys_da_voc_remapped.txt') voc_opp_file = os.path.join( root_dir, 'data/multiwoz/usr_da_voc.txt') diff --git a/convlab2/util/dataloader/dataset_dataloader.py b/convlab2/util/dataloader/dataset_dataloader.py index 12c06b1446618421aec865e15335cb33a39f8114..5387320f7040731638d10c66c0ab242c34712a17 100755 --- a/convlab2/util/dataloader/dataset_dataloader.py +++ b/convlab2/util/dataloader/dataset_dataloader.py @@ -3,6 +3,7 @@ Dataloader base class. Every dataset should inherit this class and implement its """ from abc import ABC, abstractmethod import os +from zipfile import ZipFile import json import sys import zipfile @@ -69,57 +70,57 @@ class MultiWOZDataloader(DatasetDataloader): 'terminated', 'goal'])) self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role, 'human_val': {}} - if data_key == 'all': - data_key_list = ['train', 'val', 'test'] - else: - data_key_list = [data_key] - for data_key in data_key_list: - data = read_zipped_json(os.path.join(data_dir, '{}.json.zip'.format(data_key)), '{}.json'.format(data_key)) - print('loaded {}, size {}'.format(data_key, len(data))) + + archive = ZipFile(os.path.join(data_dir, 'data.zip')) + archive.extractall() + data = json.load(open(os.path.join(data_dir, 'data/data.json'))) + + for k in ['train', 'test', 'val']: for x in info_list: - self.data[data_key][x] = [] - for sess_id, sess in data.items(): - cur_context = [] - cur_context_dialog_act = [] - entity_booked_dict = dict((domain, False) for domain in belief_domains) - for i, turn in enumerate(sess['log']): - text = turn['text'] - da = da2tuples(turn['dialog_act']) - if role == 'sys' and i % 2 == 0: - cur_context.append(text) - cur_context_dialog_act.append(da) - continue - elif role == 'usr' and i % 2 == 1: - cur_context.append(text) - cur_context_dialog_act.append(da) - continue - if utterance: - self.data[data_key]['utterance'].append(text) - if dialog_act: - self.data[data_key]['dialog_act'].append(da) - if context: - self.data[data_key]['context'].append(cur_context[-context_window_size:]) - if context_dialog_act: - self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:]) - if belief_state: - entity_booked_dict, fixed_bs = self.fix_entity_booked_info(entity_booked_dict, turn['metadata']) - self.data[data_key]['belief_state'].append(fixed_bs) - if last_opponent_utterance: - self.data[data_key]['last_opponent_utterance'].append( - cur_context[-1] if len(cur_context) >= 1 else '') - if last_self_utterance: - self.data[data_key]['last_self_utterance'].append( - cur_context[-2] if len(cur_context) >= 2 else '') - if session_id: - self.data[data_key]['session_id'].append(sess_id) - if span_info: - self.data[data_key]['span_info'].append(turn['span_info']) - if terminated: - self.data[data_key]['terminated'].append(i + 2 >= len(sess['log'])) - if goal: - self.data[data_key]['goal'].append(sess['goal']) + self.data[k][x] = [] + for sess_id, sess in data.items(): + data_key = sess['split'] + cur_context = [] + cur_context_dialog_act = [] + entity_booked_dict = dict((domain, False) for domain in belief_domains) + for i, turn in enumerate(sess['log']): + text = turn['text'] + da = da2tuples(turn.get('dialog_act', {})) + if role == 'sys' and i % 2 == 0: + cur_context.append(text) + cur_context_dialog_act.append(da) + continue + elif role == 'usr' and i % 2 == 1: cur_context.append(text) cur_context_dialog_act.append(da) + continue + if utterance: + self.data[data_key]['utterance'].append(text) + if dialog_act: + self.data[data_key]['dialog_act'].append(da) + if context: + self.data[data_key]['context'].append(cur_context[-context_window_size:]) + if context_dialog_act: + self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:]) + if belief_state: + entity_booked_dict, fixed_bs = self.fix_entity_booked_info(entity_booked_dict, turn['metadata']) + self.data[data_key]['belief_state'].append(fixed_bs) + if last_opponent_utterance: + self.data[data_key]['last_opponent_utterance'].append( + cur_context[-1] if len(cur_context) >= 1 else '') + if last_self_utterance: + self.data[data_key]['last_self_utterance'].append( + cur_context[-2] if len(cur_context) >= 2 else '') + if session_id: + self.data[data_key]['session_id'].append(sess_id) + if span_info: + self.data[data_key]['span_info'].append(turn['span_info']) + if terminated: + self.data[data_key]['terminated'].append(i + 2 >= len(sess['log'])) + if goal: + self.data[data_key]['goal'].append(sess['goal']) + cur_context.append(text) + cur_context_dialog_act.append(da) if ontology: ontology_path = os.path.join(data_dir, 'ontology.json') self.data['ontology'] = json.load(open(ontology_path)) diff --git a/convlab2/util/multiwoz/lexicalize.py b/convlab2/util/multiwoz/lexicalize.py index 427a54ababc12c9d60a49e8f6a7d265ae902ba87..3f798e46a8e70a38469047cde07ba4a1525b0074 100755 --- a/convlab2/util/multiwoz/lexicalize.py +++ b/convlab2/util/multiwoz/lexicalize.py @@ -76,6 +76,14 @@ def lexicalize_da(meta, entities, state, requestable, cur_domain=None): else: pair[1] = 'none' else: + if intent.lower() == "book": + for pair in v: + if len(entities[domain]) > 0: + slot = REF_SYS_DA[domain].get('Ref', 'Ref') + if slot in entities[domain][0]: + pair[1] = entities[domain][0][slot] + continue + if domain.lower() in ['booking']: if cur_domain and cur_domain in entities: domain = cur_domain diff --git a/data/multiwoz/remap_actions.py b/data/multiwoz/remap_actions.py new file mode 100644 index 0000000000000000000000000000000000000000..097ebce640435df04a57bc7757b7dbfff47f702e --- /dev/null +++ b/data/multiwoz/remap_actions.py @@ -0,0 +1,299 @@ +from zipfile import ZipFile, ZIP_DEFLATED +from shutil import copy2, rmtree +import json +import os +from tqdm import tqdm + +MIN_OCCURENCE_ACT = 50 + + +def write_system_act_set(new_sys_acts): + + with open("sys_da_voc_remapped.txt", "w") as f: + new_sys_acts_list = [] + for act in new_sys_acts: + if new_sys_acts[act] > MIN_OCCURENCE_ACT: + new_sys_acts_list.append(act) + + new_sys_acts_list.sort() + for act in new_sys_acts_list: + f.write(act + "\n") + print("Saved new action dict.") + + +def delexicalize_da(da): + delexicalized_da = [] + counter = {} + for domain, intent, slot, value in da: + if intent.lower() in ["request"]: + v = '?' + else: + if slot == 'none': + v = 'none' + else: + k = '-'.join([intent, domain, slot]) + counter.setdefault(k, 0) + counter[k] += 1 + v = str(counter[k]) + delexicalized_da.append([domain, intent, slot, v]) + return delexicalized_da + + +def get_keyword_domains(turn): + keyword_domains = [] + text = turn['text'] + for d in ["Hotel", "Restaurant", "Train"]: + if d.lower() in text.lower(): + keyword_domains.append(d) + return keyword_domains + + +def get_current_domains_from_act(dialog_acts): + + current_domains_temp = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + if domain in ["general", "Booking"]: + continue + if domain not in current_domains_temp: + current_domains_temp.append(domain) + + return current_domains_temp + + +def get_next_user_act_domains(ori_dialog, turn_id): + domains = [] + try: + next_user_act = ori_dialog['log'][turn_id + 1]['dialog_act'] + domains = get_current_domains_from_act(next_user_act) + except: + # will fail if system act is the last act of the dialogue + pass + return domains + + +def check_domain_booked(turn, booked_domains): + + booked_domain_current = None + for domain in turn['metadata']: + if turn['metadata'][domain]["book"]["booked"] and domain not in booked_domains: + booked_domain_current = domain.capitalize() + booked_domains.append(domain) + return booked_domains, booked_domain_current + + +def flatten_acts(dialog_acts): + flattened_acts = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + for slot_value in dialog_acts[dom_int]: + slot = slot_value[0] + value = slot_value[1] + flattened_acts.append((domain, intent, slot, value)) + + return flattened_acts + + +def flatten_span_acts(span_acts): + + flattened_acts = [] + for span_act in span_acts: + domain, intent = span_act[0].split("-") + flattened_acts.append((domain, intent, span_act[1], span_act[2:])) + return flattened_acts + + +def deflat_acts(flattened_acts): + + dialog_acts = dict() + + for act in flattened_acts: + domain, intent, slot, value = act + if f"{domain}-{intent}" not in dialog_acts.keys(): + dialog_acts[f"{domain}-{intent}"] = [[slot, value]] + else: + dialog_acts[f"{domain}-{intent}"].append([slot, value]) + + return dialog_acts + + +def deflat_span_acts(flattened_acts): + + dialog_span_acts = [] + for act in flattened_acts: + domain, intent, slot, value = act + if value == 'none': + continue + new_act = [f"{domain}-{intent}", slot] + new_act.extend(value) + dialog_span_acts.append(new_act) + + return dialog_span_acts + + +def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None, + keyword_domains_system=None, current_domain_system=None, next_user_domain=None): + + # We now look for all cases that can happen: Booking domain, Booking within a domain or taxi-inform-car for booking + error = 0 + remapped_acts = [] + + # if there is more than one current domain or none at all, we try to get booked domain differently + if len(current_domains) != 1 and booked_domain: + current_domains = [booked_domain] + elif len(current_domains) != 1 and len(keyword_domains_user) == 1: + current_domains = keyword_domains_user + elif len(current_domains) != 1 and len(keyword_domains_system) == 1: + current_domains = keyword_domains_system + elif len(current_domains) != 1 and len(current_domain_system) == 1: + current_domains = current_domain_system + elif len(current_domains) != 1 and len(next_user_domain) == 1: + current_domains = next_user_domain + + for act in flattened_acts: + try: + domain, intent, slot, value = act + if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref": + # We need to remap that booking act now + assert len(current_domains) == 1, "Can not resolve booking-book act because there are more current domains" + remapped_acts.append((current_domains[0], "Book", "none", "none")) + remapped_acts.append((current_domains[0], "Inform", "Ref", value)) + elif domain == "Booking" and intent == "Book" and slot != "Ref": + # the book intent is here actually an inform intent according to the data + remapped_acts.append((current_domains[0], "Inform", slot, value)) + elif domain == "Booking" and intent == "Inform": + # the inform intent is here actually a request intent according to the data + remapped_acts.append((current_domains[0], "RequestBook", slot, value)) + elif domain == "Booking" and intent in ["NoBook", "Request"]: + remapped_acts.append((current_domains[0], intent, slot, value)) + elif f"{domain}-{intent}-{slot}" == "Taxi-Inform-Car": + # taxi-inform-car actually triggers the booking and informs on a car + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]: + # train-inform/offerbooked-ref actually triggers the booking and informs on the reference number + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, "Inform", slot, value)) + elif domain == "Train" and intent == "OfferBook": + # make offerbook consistent with RequestBook above + remapped_acts.append(("Train", "RequestBook", slot, value)) + elif domain == "Train" and intent == "OfferBooked" and slot != "Ref": + # this is actually an inform act + remapped_acts.append((domain, "Inform", slot, value)) + else: + remapped_acts.append(act) + except Exception as e: + print("Error detected:", e) + error += 1 + + return remapped_acts, error + + +def preprocess(): + original_data_dir = 'MultiWOZ_2.1' + new_data_dir = 'data' + + if not os.path.exists(original_data_dir): + original_data_zip = 'MultiWOZ_2.1.zip' + if not os.path.exists(original_data_zip): + raise FileNotFoundError( + f'cannot find original data {original_data_zip} in multiwoz21/, should manually download MultiWOZ_2.1.zip from https://github.com/budzianowski/multiwoz/blob/master/data/MultiWOZ_2.1.zip') + else: + archive = ZipFile(original_data_zip) + archive.extractall() + + os.makedirs(new_data_dir, exist_ok=True) + for filename in os.listdir(original_data_dir): + if 'db' in filename: + copy2(f'{original_data_dir}/{filename}', new_data_dir) + + original_data = json.load(open(f'{original_data_dir}/data.json')) + + val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split()) + test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split()) + + new_sys_acts = dict() + errors = 0 + + for ori_dialog_id, ori_dialog in tqdm(original_data.items()): + if ori_dialog_id in val_list: + split = 'val' + elif ori_dialog_id in test_list: + split = 'test' + else: + split = 'train' + + # add information to which split the dialogue belongs + ori_dialog['split'] = split + current_domains_user = [] + current_domains_system = [] + booked_domains = [] + + for turn_id, turn in enumerate(ori_dialog['log']): + + # if it is a user turn, try to extract the current domain + if turn_id % 2 == 0: + dialog_acts = turn.get('dialog_act', []) + + keyword_domains_user = get_keyword_domains(turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + current_domains_user = current_domains_temp if current_domains_temp else current_domains_user + else: + + dialog_acts = turn.get('dialog_act', []) + span_acts = turn.get('span_info', []) + if dialog_acts: + # only need to go through that process if we have a dialogue act + + keyword_domains_system = get_keyword_domains(turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + current_domains_system = current_domains_temp if current_domains_temp else current_domains_system + + booked_domains, booked_domain_current = check_domain_booked(turn, booked_domains) + next_user_domains = get_next_user_act_domains(ori_dialog, turn_id) + + flattened_acts = flatten_acts(dialog_acts) + flattened_span_acts = flatten_span_acts(span_acts) + remapped_acts, error_local = remap_acts(flattened_acts, current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, current_domains_system, + next_user_domains) + + delex_acts = delexicalize_da(remapped_acts) + for act in delex_acts: + act = "-".join(act) + if act not in new_sys_acts: + new_sys_acts[act] = 1 + else: + new_sys_acts[act] += 1 + + remapped_span_acts, _ = remap_acts(flattened_span_acts, current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, current_domains_system, + next_user_domains) + + errors += error_local + + if error_local > 0: + print(ori_dialog_id) + + deflattened_remapped_acts = deflat_acts(remapped_acts) + deflattened_remapped_span_acts = deflat_span_acts(remapped_span_acts) + turn['dialog_act'] = deflattened_remapped_acts + turn['span_info'] = deflattened_remapped_span_acts + + print("Errors:", errors) + json.dump(original_data, open(f'{new_data_dir}/data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + + write_system_act_set(new_sys_acts) + + with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: + for filename in os.listdir(new_data_dir): + zf.write(f'{new_data_dir}/{filename}') + print("Saved new data.") + rmtree(original_data_dir) + rmtree(new_data_dir) + + +if __name__ == '__main__': + preprocess() diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index 52ecc087c09cde18e0b2929f839b011f6209ee17..77ffb19169b8fd55678ac55fe3d9e2100ecfe13c 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -79,10 +79,10 @@ We first introduce the unified format of `ontology` and `dialogues`. To transfor - `intents`: (*dict*) descriptions for intents. - `$intent_name`: (*dict*) - `description`: (*str*) description for this intent. -- `dialogue_acts`: (*dict*) dialogue act dictionaries extracted from the data, separated by their types. Each dialogue act is a *str* converted by a *dict* like `"{'speaker': 'system', 'intent': 'inform', 'domain': 'attraction', 'slot': 'area'}"` that includes speaker, intent, domain, slot (and value for binary dialogue acts). +- `dialogue_acts`: (*dict*) dialogue act dictionaries extracted from the data, separated by their types. Each dialogue act is a *str* converted by a *dict* like `"{'user': True, 'system': True, 'intent': 'inform', 'domain': 'attraction', 'slot': 'area'}"` that includes intent, domain, slot, and whether the speakers use this dialogue act. - `categorical`: (*list* of *str*) dictionary for categorical dialogue acts. - `non-categorical`: (*list* of *str*) dictionary for non-categorical dialogue acts. - - `binary`: (*list* of *str*) dictionary for binary dialogue acts that are more detailed intents where the values are not extracted from dialogues, e.g. request the address of a hotel. + - `binary`: (*list* of *str*) dictionary for binary dialogue acts that are more detailed intents without values, e.g. request the address of a hotel. Note that the `slot` in a binary dialogue act may not be an actual slot that presents in `ontology['domains'][domain]['slots']`. - `state`: (*dict*) dialogue state of all domains. - `$domain_name`: (*dict*) - `$slot_name: ""`: slot with empty value. Note that the slot set are the subset of the slot set in Part 1 definition. @@ -109,7 +109,7 @@ We first introduce the unified format of `ontology` and `dialogues`. To transfor - `non-categorical` (*list* of *dict*, could be empty) for non-categorical slots. - `{"intent": (str), "domain": (str), "slot": (str), "value": (str), "start": (int), "end": (int)}`. `start` and `end` are character indexes for the value span in the utterance and can be absent. - `binary` (*list* of *dict*, could be empty) for binary dialogue acts in ontology. - - `{"intent": (str), "domain": (str), "slot": (str), "value": (str)}`. Possible dialogue acts are listed in the `ontology['binary_dialogue_acts']`. + - `{"intent": (str), "domain": (str), "slot": (str)}`. Binary dialogue acts are more detailed intents without values, e.g. request the address of a hotel. - `state`: (*dict*, user side, could be empty) dialogue state of involved domains. full state is shown in `ontology['state']`. - `$domain_name`: (*dict*) contains all slots in this domain. - `$slot_name`: (*str*) value for this slot. diff --git a/data/unified_datasets/check.py b/data/unified_datasets/check.py index ec88a1f38ff818231997de8661132c4338cde776..098fac3f121e8a7943313e93e8fecb01f8d6569c 100644 --- a/data/unified_datasets/check.py +++ b/data/unified_datasets/check.py @@ -27,21 +27,18 @@ def check_ontology(ontology): intent name: { "description": intent description } - }, - "binary_dialogue_acts": { - [ - { - "intent": intent name, - "domain": domain name, - "slot": slot name, - "value": some value - } - ] } "state": { domain name: { slot name: "" } + }, + "dialogue_acts": { + "categorical": [ + "{'user': True/False, 'system': True/False, 'intent': intent, 'domain': domain, 'slot': slot}", + ], + "non-categorical": {}, + "binary": {} } } """ @@ -77,6 +74,13 @@ def check_ontology(ontology): assert slot_name in ontology['domains'][domain_name]['slots'] assert value == "", "should set value in state to \"\"" + ontology['da_dict'] = {} + for da_type in ontology['dialogue_acts']: + ontology['da_dict'][da_type] = {} + for da_str in ontology['dialogue_acts'][da_type]: + da = eval(da_str) + ontology["da_dict"][da_type][(da['intent'], da['domain'], da['slot'])] = {'user': da['user'], 'system': da['system']} + # print('description existence:', descriptions, '\n') for description, value in descriptions.items(): if not value: @@ -207,12 +211,10 @@ def check_dialogues(name, dialogues, ontology): stat[split][f'non-cat slot span(dialogue act)'][0] += 1 for da_type in dialogue_acts: - if da_type == 'binary': - for da in dialogue_acts[da_type]: - assert str({'speaker': turn['speaker'], 'intent': da['intent'], 'domain': da['domain'], 'slot': da['slot'], 'value': da['value']}) in ontology['dialogue_acts'][da_type] - else: - for da in dialogue_acts[da_type]: - assert str({'speaker': turn['speaker'], 'intent': da['intent'], 'domain': da['domain'], 'slot': da['slot']}) in ontology['dialogue_acts'][da_type] + for da in dialogue_acts[da_type]: + assert ontology['da_dict'][da_type][(da['intent'], da['domain'], da['slot'])][turn['speaker']] == True + if da_type == 'binary': + assert 'value' not in da, f'{dialogue_id}:{turn_id}\tbinary dialogue act should not have value' if turn['speaker'] == 'user': assert 'db_results' not in turn diff --git a/data/unified_datasets/multiwoz21/booking_remapper.py b/data/unified_datasets/multiwoz21/booking_remapper.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0f38fb48f2387b5fd3016c40375908f4b27b32 --- /dev/null +++ b/data/unified_datasets/multiwoz21/booking_remapper.py @@ -0,0 +1,266 @@ + +slot_name_map = { + 'addr': "address", + 'post': "postcode", + 'pricerange': "price range", + 'arrive': "arrive by", + 'arriveby': "arrive by", + 'leave': "leave at", + 'leaveat': "leave at", + 'depart': "departure", + 'dest': "destination", + 'fee': "entrance fee", + 'open': 'open hours', + 'car': "type", + 'car type': "type", + 'ticket': 'price', + 'trainid': 'train id', + 'id': 'train id', + 'people': 'book people', + 'stay': 'book stay', + 'none': '', + 'attraction': { + 'price': 'entrance fee' + }, + 'hospital': {}, + 'hotel': { + 'day': 'book day', 'price': "price range" + }, + 'restaurant': { + 'day': 'book day', 'time': 'book time', 'price': "price range" + }, + 'taxi': {}, + 'train': { + 'day': 'day', 'time': "duration" + }, + 'police': {}, + 'booking': {} +} + + +class BookingActRemapper: + + def __init__(self, ontology): + self.ontology = ontology + self.reset() + + def reset(self): + self.current_domains_user = [] + self.current_domains_system = [] + self.booked_domains = [] + + def retrieve_current_domain_from_user(self, turn_id, ori_dialog): + prev_user_turn = ori_dialog[turn_id - 1] + + dialog_acts = prev_user_turn.get('dialog_act', []) + keyword_domains_user = get_keyword_domains(prev_user_turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + self.current_domains_user = current_domains_temp if current_domains_temp else self.current_domains_user + next_user_domains = get_next_user_act_domains(ori_dialog, turn_id) + + return keyword_domains_user, next_user_domains + + def retrieve_current_domain_from_system(self, turn_id, ori_dialog): + + system_turn = ori_dialog[turn_id] + dialog_acts = system_turn.get('dialog_act', []) + keyword_domains_system = get_keyword_domains(system_turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + self.current_domains_system = current_domains_temp if current_domains_temp else self.current_domains_system + booked_domain_current = self.check_domain_booked(system_turn) + + return keyword_domains_system, booked_domain_current + + def remap(self, turn_id, ori_dialog): + + keyword_domains_user, next_user_domains = self.retrieve_current_domain_from_user(turn_id, ori_dialog) + keyword_domains_system, booked_domain_current = self.retrieve_current_domain_from_system(turn_id, ori_dialog) + + # only need to remap if there is a dialog action labelled + dialog_acts = ori_dialog[turn_id].get('dialog_act', []) + spans = ori_dialog[turn_id].get('span_info', []) + if dialog_acts: + + flattened_acts = flatten_acts(dialog_acts) + flattened_spans = flatten_span_acts(spans) + remapped_acts, error_local = remap_acts(flattened_acts, self.current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, self.current_domains_system, + next_user_domains, self.ontology) + + remapped_spans, _ = remap_acts(flattened_spans, self.current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, self.current_domains_system, + next_user_domains, self.ontology) + + deflattened_remapped_acts = deflat_acts(remapped_acts) + deflattened_remapped_spans = deflat_span_acts(remapped_spans) + + return deflattened_remapped_acts, deflattened_remapped_spans + else: + return dialog_acts, spans + + def check_domain_booked(self, turn): + + booked_domain_current = None + for domain in turn['metadata']: + if turn['metadata'][domain]["book"]["booked"] and domain not in self.booked_domains: + booked_domain_current = domain.capitalize() + self.booked_domains.append(domain) + return booked_domain_current + + +def get_keyword_domains(turn): + keyword_domains = [] + text = turn['text'] + for d in ["Hotel", "Restaurant", "Train"]: + if d.lower() in text.lower(): + keyword_domains.append(d) + return keyword_domains + + +def get_current_domains_from_act(dialog_acts): + + current_domains_temp = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + if domain in ["general", "Booking"]: + continue + if domain not in current_domains_temp: + current_domains_temp.append(domain) + + return current_domains_temp + + +def get_next_user_act_domains(ori_dialog, turn_id): + domains = [] + try: + next_user_act = ori_dialog[turn_id + 1]['dialog_act'] + domains = get_current_domains_from_act(next_user_act) + except: + # will fail if system act is the last act of the dialogue + pass + return domains + + +def flatten_acts(dialog_acts): + flattened_acts = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + for slot_value in dialog_acts[dom_int]: + slot = slot_value[0] + value = slot_value[1] + flattened_acts.append((domain, intent, slot, value)) + + return flattened_acts + + +def flatten_span_acts(span_acts): + + flattened_acts = [] + for span_act in span_acts: + domain, intent = span_act[0].split("-") + flattened_acts.append((domain, intent, span_act[1], span_act[2:])) + return flattened_acts + + +def deflat_acts(flattened_acts): + + dialog_acts = dict() + + for act in flattened_acts: + domain, intent, slot, value = act + if f"{domain}-{intent}" not in dialog_acts.keys(): + dialog_acts[f"{domain}-{intent}"] = [[slot, value]] + else: + dialog_acts[f"{domain}-{intent}"].append([slot, value]) + + return dialog_acts + + +def deflat_span_acts(flattened_acts): + + dialog_span_acts = [] + for act in flattened_acts: + domain, intent, slot, value = act + if value == 'none': + continue + new_act = [f"{domain}-{intent}", slot] + new_act.extend(value) + dialog_span_acts.append(new_act) + + return dialog_span_acts + + +def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None, + keyword_domains_system=None, current_domain_system=None, next_user_domain=None, ontology=None): + + # We now look for all cases that can happen: Booking domain, Booking within a domain or taxi-inform-car for booking + error = 0 + remapped_acts = [] + + # if there is more than one current domain or none at all, we try to get booked domain differently + if len(current_domains) != 1 and booked_domain: + current_domains = [booked_domain] + elif len(current_domains) != 1 and len(keyword_domains_user) == 1: + current_domains = keyword_domains_user + elif len(current_domains) != 1 and len(keyword_domains_system) == 1: + current_domains = keyword_domains_system + elif len(current_domains) != 1 and len(current_domain_system) == 1: + current_domains = current_domain_system + elif len(current_domains) != 1 and len(next_user_domain) == 1: + current_domains = next_user_domain + + for act in flattened_acts: + try: + domain, intent, slot, value = act + if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref": + # We need to remap that booking act now + potential_domain = current_domains[0] + remapped_acts.append((potential_domain, "Book", "none", "none")) + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append((potential_domain, "Inform", "Ref", value)) + elif domain == "Booking" and intent == "Book" and slot != "Ref": + # the book intent is here actually an inform intent according to the data + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append((potential_domain, "Inform", slot, value)) + elif domain == "Booking" and intent == "Inform": + # the inform intent is here actually a request intent according to the data + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append((potential_domain, "OfferBook", slot, value)) + elif domain == "Booking" and intent in ["NoBook", "Request"]: + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append((potential_domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" == "Taxi-Inform-Car": + # taxi-inform-car actually triggers the booking and informs on a car + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]: + # train-inform/offerbooked-ref actually triggers the booking and informs on the reference number + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, "Inform", slot, value)) + elif domain == "Train" and intent == "OfferBooked" and slot != "Ref": + # this is actually an inform act + remapped_acts.append((domain, "Inform", slot, value)) + else: + remapped_acts.append(act) + except Exception as e: + print("Error detected:", e) + error += 1 + + return remapped_acts, error + + +def ontology_check(domain_, slot_, init_ontology): + + domain = domain_.lower() + slot = slot_.lower() + if slot not in init_ontology['domains'][domain]['slots']: + if slot in slot_name_map: + slot = slot_name_map[slot] + elif slot in slot_name_map[domain]: + slot = slot_name_map[domain][slot] + return slot in init_ontology['domains'][domain]['slots'] diff --git a/data/unified_datasets/multiwoz21/data.zip b/data/unified_datasets/multiwoz21/data.zip index 991588a59cacf5ab0181822d55b31d12aa3d52b2..95c1ebbf0794eaad22b945c5ad4d55e7916f41f7 100644 Binary files a/data/unified_datasets/multiwoz21/data.zip and b/data/unified_datasets/multiwoz21/data.zip differ diff --git a/data/unified_datasets/multiwoz21/dummy_data.json b/data/unified_datasets/multiwoz21/dummy_data.json index 9009018f77d201baf284e4f80459a4f99814fa7a..a012b907e3634e25eefbdbf8a2eb44ee0236b5ba 100644 --- a/data/unified_datasets/multiwoz21/dummy_data.json +++ b/data/unified_datasets/multiwoz21/dummy_data.json @@ -106,8 +106,7 @@ { "intent": "request", "domain": "hotel", - "slot": "area", - "value": "" + "slot": "area" } ] }, @@ -206,14 +205,12 @@ { "intent": "inform", "domain": "booking", - "slot": "", - "value": "" + "slot": "" }, { "intent": "inform", "domain": "hotel", - "slot": "parking", - "value": "" + "slot": "parking" } ] }, @@ -320,14 +317,12 @@ { "intent": "request", "domain": "booking", - "slot": "book stay", - "value": "" + "slot": "book stay" }, { "intent": "request", "domain": "booking", - "slot": "day", - "value": "" + "slot": "day" } ] }, @@ -421,8 +416,7 @@ { "intent": "reqmore", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -451,8 +445,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -510,8 +503,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -565,8 +557,7 @@ { "intent": "inform", "domain": "police", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -667,8 +658,7 @@ { "intent": "request", "domain": "police", - "slot": "postcode", - "value": "" + "slot": "postcode" } ] }, @@ -761,8 +751,7 @@ { "intent": "request", "domain": "police", - "slot": "address", - "value": "" + "slot": "address" } ] }, @@ -838,8 +827,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -897,8 +885,7 @@ { "intent": "welcome", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -974,14 +961,12 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" }, { "intent": "welcome", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -1128,8 +1113,7 @@ { "intent": "request", "domain": "train", - "slot": "departure", - "value": "" + "slot": "departure" } ] }, @@ -1214,8 +1198,7 @@ { "intent": "request", "domain": "train", - "slot": "day", - "value": "" + "slot": "day" } ] }, @@ -1590,8 +1573,7 @@ { "intent": "request", "domain": "booking", - "slot": "day", - "value": "" + "slot": "day" } ] }, @@ -1733,8 +1715,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -1792,8 +1773,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -1931,20 +1911,17 @@ { "intent": "request", "domain": "police", - "slot": "address", - "value": "" + "slot": "address" }, { "intent": "request", "domain": "police", - "slot": "postcode", - "value": "" + "slot": "postcode" }, { "intent": "request", "domain": "police", - "slot": "phone", - "value": "" + "slot": "phone" } ] }, @@ -2045,8 +2022,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -2104,8 +2080,7 @@ { "intent": "greet", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -2181,8 +2156,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -2312,14 +2286,12 @@ { "intent": "request", "domain": "hotel", - "slot": "price range", - "value": "" + "slot": "price range" }, { "intent": "inform", "domain": "hotel", - "slot": "internet", - "value": "" + "slot": "internet" } ] }, @@ -2417,14 +2389,12 @@ { "intent": "inform", "domain": "hotel", - "slot": "parking", - "value": "" + "slot": "parking" }, { "intent": "reqmore", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -2534,8 +2504,7 @@ { "intent": "request", "domain": "hotel", - "slot": "price range", - "value": "" + "slot": "price range" } ] }, @@ -2628,14 +2597,12 @@ { "intent": "inform", "domain": "booking", - "slot": "", - "value": "" + "slot": "" }, { "intent": "inform", "domain": "hotel", - "slot": "parking", - "value": "" + "slot": "parking" } ] }, @@ -2666,8 +2633,7 @@ { "intent": "request", "domain": "hotel", - "slot": "address", - "value": "" + "slot": "address" } ] }, @@ -2844,26 +2810,22 @@ { "intent": "request", "domain": "train", - "slot": "day", - "value": "" + "slot": "day" }, { "intent": "request", "domain": "train", - "slot": "destination", - "value": "" + "slot": "destination" }, { "intent": "request", "domain": "train", - "slot": "leave at", - "value": "" + "slot": "leave at" }, { "intent": "request", "domain": "train", - "slot": "arrive by", - "value": "" + "slot": "arrive by" } ] }, @@ -2965,8 +2927,7 @@ { "intent": "request", "domain": "train", - "slot": "day", - "value": "" + "slot": "day" } ] }, @@ -3106,8 +3067,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -3165,8 +3125,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -3306,8 +3265,7 @@ { "intent": "request", "domain": "train", - "slot": "destination", - "value": "" + "slot": "destination" } ] }, @@ -3409,8 +3367,7 @@ { "intent": "request", "domain": "train", - "slot": "departure", - "value": "" + "slot": "departure" } ] }, @@ -3510,8 +3467,7 @@ { "intent": "offerbook", "domain": "train", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -3560,20 +3516,17 @@ { "intent": "request", "domain": "train", - "slot": "arrive by", - "value": "" + "slot": "arrive by" }, { "intent": "request", "domain": "train", - "slot": "duration", - "value": "" + "slot": "duration" }, { "intent": "request", "domain": "train", - "slot": "price", - "value": "" + "slot": "price" } ] }, @@ -3764,8 +3717,7 @@ { "intent": "request", "domain": "hotel", - "slot": "price range", - "value": "" + "slot": "price range" } ] }, @@ -3873,14 +3825,12 @@ { "intent": "request", "domain": "hotel", - "slot": "internet", - "value": "" + "slot": "internet" }, { "intent": "request", "domain": "hotel", - "slot": "parking", - "value": "" + "slot": "parking" } ] }, @@ -3956,8 +3906,7 @@ { "intent": "inform", "domain": "booking", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -3981,8 +3930,7 @@ { "intent": "request", "domain": "hotel", - "slot": "ref", - "value": "" + "slot": "ref" } ] }, @@ -4040,8 +3988,7 @@ { "intent": "request", "domain": "booking", - "slot": "book stay", - "value": "" + "slot": "book stay" } ] }, @@ -4134,8 +4081,7 @@ { "intent": "request", "domain": "booking", - "slot": "day", - "value": "" + "slot": "day" } ] }, @@ -4235,8 +4181,7 @@ { "intent": "reqmore", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -4265,8 +4210,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -4324,8 +4268,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -4392,8 +4335,7 @@ { "intent": "inform", "domain": "attraction", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -4460,14 +4402,12 @@ { "intent": "request", "domain": "attraction", - "slot": "area", - "value": "" + "slot": "area" }, { "intent": "request", "domain": "attraction", - "slot": "type", - "value": "" + "slot": "type" } ] }, @@ -4592,8 +4532,7 @@ { "intent": "request", "domain": "attraction", - "slot": "address", - "value": "" + "slot": "address" } ] }, @@ -4660,8 +4599,7 @@ { "intent": "request", "domain": "restaurant", - "slot": "food", - "value": "" + "slot": "food" } ] }, @@ -4762,14 +4700,12 @@ { "intent": "request", "domain": "restaurant", - "slot": "food", - "value": "" + "slot": "food" }, { "intent": "request", "domain": "restaurant", - "slot": "area", - "value": "" + "slot": "area" } ] }, @@ -4817,8 +4753,7 @@ { "intent": "request", "domain": "attraction", - "slot": "type", - "value": "" + "slot": "type" } ] }, @@ -4938,20 +4873,17 @@ { "intent": "request", "domain": "restaurant", - "slot": "phone", - "value": "" + "slot": "phone" }, { "intent": "request", "domain": "restaurant", - "slot": "postcode", - "value": "" + "slot": "postcode" }, { "intent": "request", "domain": "restaurant", - "slot": "address", - "value": "" + "slot": "address" } ] }, @@ -5026,8 +4958,7 @@ { "intent": "reqmore", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5051,8 +4982,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5110,14 +5040,12 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" }, { "intent": "welcome", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5165,8 +5093,7 @@ { "intent": "inform", "domain": "hospital", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5224,8 +5151,7 @@ { "intent": "request", "domain": "hospital", - "slot": "department", - "value": "" + "slot": "department" } ] }, @@ -5249,8 +5175,7 @@ { "intent": "request", "domain": "hospital", - "slot": "phone", - "value": "" + "slot": "phone" } ] }, @@ -5335,8 +5260,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5394,8 +5318,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5551,8 +5474,7 @@ { "intent": "request", "domain": "train", - "slot": "leave at", - "value": "" + "slot": "leave at" } ] }, @@ -5687,14 +5609,12 @@ { "intent": "request", "domain": "train", - "slot": "arrive by", - "value": "" + "slot": "arrive by" }, { "intent": "request", "domain": "train", - "slot": "price", - "value": "" + "slot": "price" } ] }, @@ -5777,8 +5697,7 @@ { "intent": "offerbook", "domain": "train", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5919,8 +5838,7 @@ { "intent": "reqmore", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -5962,20 +5880,17 @@ { "intent": "request", "domain": "attraction", - "slot": "address", - "value": "" + "slot": "address" }, { "intent": "request", "domain": "attraction", - "slot": "postcode", - "value": "" + "slot": "postcode" }, { "intent": "request", "domain": "attraction", - "slot": "phone", - "value": "" + "slot": "phone" } ] }, @@ -6066,8 +5981,7 @@ { "intent": "reqmore", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -6096,8 +6010,7 @@ { "intent": "thank", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -6155,14 +6068,12 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" }, { "intent": "welcome", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, @@ -6306,20 +6217,17 @@ { "intent": "request", "domain": "hospital", - "slot": "address", - "value": "" + "slot": "address" }, { "intent": "request", "domain": "hospital", - "slot": "postcode", - "value": "" + "slot": "postcode" }, { "intent": "request", "domain": "hospital", - "slot": "phone", - "value": "" + "slot": "phone" } ] }, @@ -6377,8 +6285,7 @@ { "intent": "bye", "domain": "general", - "slot": "", - "value": "" + "slot": "" } ] }, diff --git a/data/unified_datasets/multiwoz21/preprocess.py b/data/unified_datasets/multiwoz21/preprocess.py index 416a6b98d44d30ca7608abde5f29c07448a111de..f48f06e70a8c53f4cac4b5ae8ed2e0077f24fbda 100644 --- a/data/unified_datasets/multiwoz21/preprocess.py +++ b/data/unified_datasets/multiwoz21/preprocess.py @@ -8,8 +8,9 @@ from tqdm import tqdm from collections import Counter from pprint import pprint from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +from data.unified_datasets.multiwoz21.booking_remapper import BookingActRemapper -init_ontology = { +ontology = { "domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different "attraction": { "description": "find an attraction", @@ -566,9 +567,9 @@ init_ontology = { } }, "dialogue_acts": { - "categorical": set(), - "non-categorical": set(), - "binary": set() + "categorical": {}, + "non-categorical": {}, + "binary": {} } } @@ -618,7 +619,7 @@ digit2word = { cnt_domain_slot = Counter() def normalize_domain_slot_value(domain, slot, value): - global init_ontology, slot_name_map + global ontology, slot_name_map domain = domain.lower() slot = slot.lower() value = value.strip() @@ -626,16 +627,16 @@ def normalize_domain_slot_value(domain, slot, value): value = 'dontcare' if value in ['?', 'none', 'not mentioned']: value = "" - if domain not in init_ontology['domains']: + if domain not in ontology['domains']: raise Exception(f'{domain} not in ontology') - if slot not in init_ontology['domains'][domain]['slots']: + if slot not in ontology['domains'][domain]['slots']: if slot in slot_name_map: slot = slot_name_map[slot] elif slot in slot_name_map[domain]: slot = slot_name_map[domain][slot] else: raise Exception(f'{domain}-{slot} not in ontology') - assert slot=='' or slot in init_ontology['domains'][domain]['slots'], f'{(domain, slot, value)} not in ontology' + assert slot=='' or slot in ontology['domains'][domain]['slots'], f'{(domain, slot, value)} not in ontology' return domain, slot, value def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): @@ -644,7 +645,7 @@ def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): :param da_dict: dict[(intent, domain, slot, value)] = [word_start, word_end] :param utt: user or system utt ''' - global init_ontology, digit2word, cnt_domain_slot + global ontology, digit2word, cnt_domain_slot converted_da = { 'categorical': [], @@ -663,13 +664,13 @@ def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): for (intent, domain, slot, value), span in da_dict.items(): if intent == 'request' or slot == '' or value == '': # binary dialog acts + assert value == '' converted_da['binary'].append({ 'intent': intent, 'domain': domain, - 'slot': slot, - 'value': value + 'slot': slot }) - elif init_ontology['domains'][domain]['slots'][slot]['is_categorical']: + elif ontology['domains'][domain]['slots'][slot]['is_categorical']: # categorical dialog acts converted_da['categorical'].append({ 'intent': intent, @@ -759,7 +760,7 @@ def preprocess(): copy2(f'{original_data_dir}/{filename}', new_data_dir) original_data = json.load(open(f'{original_data_dir}/data.json')) - global init_ontology, cnt_domain_slot + global ontology, cnt_domain_slot val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split()) test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split()) @@ -768,6 +769,7 @@ def preprocess(): dialogues_by_split = {split:[] for split in splits} sent_tokenizer = PunktSentenceTokenizer() word_tokenizer = TreebankWordTokenizer() + booking_remapper = BookingActRemapper(init_ontology) for ori_dialog_id, ori_dialog in tqdm(original_data.items()): if ori_dialog_id in val_list: split = 'validation' @@ -785,7 +787,7 @@ def preprocess(): 'request': {} } for k, v in ori_dialog['goal'].items(): - if len(v) != 0 and k in init_ontology['domains']: + if len(v) != 0 and k in ontology['domains']: cur_domains.append(k) goal['inform'][k] = {} goal['request'][k] = {} @@ -814,6 +816,7 @@ def preprocess(): 'turns': [] } + booking_remapper.reset() for turn_id, turn in enumerate(ori_dialog['log']): # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1 text = turn['text'] @@ -828,13 +831,17 @@ def preprocess(): utt = text speaker = 'user' if turn_id % 2 == 0 else 'system' - das = turn.get('dialog_act', []) + das = turn.get('dialog_act', []) spans = turn.get('span_info', []) + + if speaker == 'system': + das, spans = booking_remapper.remap(turn_id, ori_dialog['log']) + da_dict = {} # transform DA for Domain_Intent in das: domain, intent = Domain_Intent.lower().split('-') - assert intent in init_ontology['intents'], f'{ori_dialog_id}:{turn_id}:da\t{intent} not in ontology' + assert intent in ontology['intents'], f'{ori_dialog_id}:{turn_id}:da\t{intent} not in ontology' for Slot, value in das[Domain_Intent]: domain, slot, value = normalize_domain_slot_value(domain, Slot, value) if domain not in cur_domains: @@ -862,17 +869,14 @@ def preprocess(): for da_type in dialogue_acts: das = dialogue_acts[da_type] for da in das: - intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'] - if da_type == 'binary': - init_ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot, value)) - else: - init_ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot)) + ontology["dialogue_acts"][da_type].setdefault((da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][(da['intent'], da['domain'], da['slot'])][speaker] = True if speaker == 'system': # add state to last user turn # add empty db_results turn_state = turn['metadata'] - cur_state = copy.deepcopy(init_ontology['state']) + cur_state = copy.deepcopy(ontology['state']) booked = {} for domain in turn_state: if domain not in cur_state: @@ -882,7 +886,7 @@ def preprocess(): if slot == 'ticket': continue elif slot == 'booked': - assert domain in init_ontology['domains'] + assert domain in ontology['domains'] booked[domain] = value continue _, slot, value = normalize_domain_slot_value(domain, slot, value) @@ -895,20 +899,17 @@ def preprocess(): dialogues = [] for split in splits: dialogues += dialogues_by_split[split] - for da_type in init_ontology['dialogue_acts']: - if da_type == 'binary': - init_ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3],'value':da[4]}) for da in sorted(init_ontology["dialogue_acts"][da_type])] - else: - init_ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3]}) for da in sorted(init_ontology["dialogue_acts"][da_type])] + for da_type in ontology['dialogue_acts']: + ontology["dialogue_acts"][da_type] = sorted([str({'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent':da[0],'domain':da[1], 'slot':da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) - json.dump(init_ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: for filename in os.listdir(new_data_dir): zf.write(f'{new_data_dir}/{filename}') rmtree(original_data_dir) rmtree(new_data_dir) - return dialogues, init_ontology + return dialogues, ontology if __name__ == '__main__': preprocess() \ No newline at end of file diff --git a/data/unified_datasets/sgd/data.zip b/data/unified_datasets/sgd/data.zip index fc2398bc94a0f0c39484df06714c13fd0cc4f6be..e3b9142c51145230a2a5dd8de2b7d8e89694843f 100644 Binary files a/data/unified_datasets/sgd/data.zip and b/data/unified_datasets/sgd/data.zip differ diff --git a/data/unified_datasets/sgd/dummy_data.json b/data/unified_datasets/sgd/dummy_data.json index cbb161ea5e2034ee41a80153f0fd485b57a0959b..920dff19aaab96a2098eb40d7b6417cae265eb43 100644 --- a/data/unified_datasets/sgd/dummy_data.json +++ b/data/unified_datasets/sgd/dummy_data.json @@ -22,8 +22,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -60,8 +59,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -328,8 +326,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" } ], "categorical": [], @@ -389,8 +386,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "phone_number", - "value": "" + "slot": "phone_number" } ], "categorical": [], @@ -450,8 +446,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -517,8 +512,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -690,14 +684,12 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" }, { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -734,8 +726,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "time", - "value": "" + "slot": "time" } ], "categorical": [], @@ -845,14 +836,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -891,8 +880,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -944,14 +932,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "serves_alcohol", - "value": "" + "slot": "serves_alcohol" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" } ], "categorical": [], @@ -1019,14 +1005,12 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1063,8 +1047,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1098,8 +1081,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -1136,14 +1118,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "cuisine", - "value": "" + "slot": "cuisine" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -1289,14 +1269,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "price_range", - "value": "" + "slot": "price_range" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" } ], "categorical": [], @@ -1364,8 +1342,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1431,8 +1408,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1546,8 +1522,7 @@ { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1584,8 +1559,7 @@ { "intent": "offer_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" } ], "categorical": [], @@ -1603,8 +1577,7 @@ { "intent": "affirm_intent", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1650,8 +1623,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "time", - "value": "" + "slot": "time" } ], "categorical": [], @@ -1761,8 +1733,7 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -1843,8 +1814,7 @@ { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1881,8 +1851,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1927,14 +1896,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" } ], "categorical": [], @@ -2002,8 +1969,7 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2040,8 +2006,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2059,14 +2024,12 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2103,8 +2066,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2138,8 +2100,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -2176,14 +2137,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "cuisine", - "value": "" + "slot": "cuisine" } ], "categorical": [], @@ -2337,8 +2296,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" } ], "categorical": [], @@ -2396,14 +2354,12 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" }, { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2449,8 +2405,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "time", - "value": "" + "slot": "time" } ], "categorical": [], @@ -2568,20 +2523,17 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "serves_alcohol", - "value": "" + "slot": "serves_alcohol" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "price_range", - "value": "" + "slot": "price_range" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2621,14 +2573,12 @@ { "intent": "notify_failure", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" }, { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2673,8 +2623,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" } ], "categorical": [], @@ -2772,8 +2721,7 @@ { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2810,8 +2758,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2856,14 +2803,12 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2900,8 +2845,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -2935,8 +2879,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -2973,14 +2916,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "cuisine", - "value": "" + "slot": "cuisine" } ], "categorical": [], @@ -3134,8 +3075,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3201,8 +3141,7 @@ { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3239,8 +3178,7 @@ { "intent": "offer_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" } ], "categorical": [], @@ -3258,8 +3196,7 @@ { "intent": "affirm_intent", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -3364,20 +3301,17 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "price_range", - "value": "" + "slot": "price_range" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "phone_number", - "value": "" + "slot": "phone_number" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3417,8 +3351,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -3479,8 +3412,7 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3517,8 +3449,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3536,14 +3467,12 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3580,8 +3509,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -3615,8 +3543,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -3653,8 +3580,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -3875,8 +3801,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "phone_number", - "value": "" + "slot": "phone_number" } ], "categorical": [], @@ -3936,8 +3861,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4003,8 +3927,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" } ], "categorical": [], @@ -4062,8 +3985,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4129,14 +4051,12 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" }, { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4242,20 +4162,17 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "serves_alcohol", - "value": "" + "slot": "serves_alcohol" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4295,8 +4212,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -4354,8 +4270,7 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4392,8 +4307,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4411,14 +4325,12 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4455,8 +4367,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4490,8 +4401,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -4537,8 +4447,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -4674,8 +4583,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4741,8 +4649,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" } ], "categorical": [ @@ -4817,14 +4724,12 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" }, { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4861,8 +4766,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "time", - "value": "" + "slot": "time" } ], "categorical": [], @@ -4980,20 +4884,17 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "serves_alcohol", - "value": "" + "slot": "serves_alcohol" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "phone_number", - "value": "" + "slot": "phone_number" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5033,8 +4934,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -5095,14 +4995,12 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5139,8 +5037,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5174,8 +5071,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -5281,8 +5177,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -5480,8 +5375,7 @@ { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5518,8 +5412,7 @@ { "intent": "offer_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" } ], "categorical": [], @@ -5537,8 +5430,7 @@ { "intent": "affirm_intent", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5584,8 +5476,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "time", - "value": "" + "slot": "time" } ], "categorical": [], @@ -5695,20 +5586,17 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5748,8 +5636,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -5810,8 +5697,7 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5848,8 +5734,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5867,14 +5752,12 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5911,8 +5794,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5946,8 +5828,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -5984,8 +5865,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -6215,8 +6095,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6282,14 +6161,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "phone_number", - "value": "" + "slot": "phone_number" } ], "categorical": [], @@ -6358,8 +6235,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6425,14 +6301,12 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" }, { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6530,8 +6404,7 @@ { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6568,8 +6441,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6614,8 +6486,7 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6652,8 +6523,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6671,14 +6541,12 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6715,8 +6583,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -6750,8 +6617,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [ @@ -6972,14 +6838,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" } ], "categorical": [], @@ -7047,8 +6911,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7114,8 +6977,7 @@ { "intent": "request_alts", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7181,8 +7043,7 @@ { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7219,8 +7080,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7238,8 +7098,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" } ], "categorical": [], @@ -7276,8 +7135,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "time", - "value": "" + "slot": "time" } ], "categorical": [], @@ -7387,8 +7245,7 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -7469,8 +7326,7 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7544,8 +7400,7 @@ { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7582,8 +7437,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7628,14 +7482,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "street_address", - "value": "" + "slot": "street_address" } ], "categorical": [], @@ -7703,14 +7555,12 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7747,8 +7597,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -7782,8 +7631,7 @@ { "intent": "inform_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "FindRestaurants" + "slot": "FindRestaurants" } ], "categorical": [], @@ -7820,8 +7668,7 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "city", - "value": "" + "slot": "city" } ], "categorical": [], @@ -7954,14 +7801,12 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "serves_alcohol", - "value": "" + "slot": "serves_alcohol" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "has_live_music", - "value": "" + "slot": "has_live_music" } ], "categorical": [], @@ -8026,8 +7871,7 @@ { "intent": "select", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8064,8 +7908,7 @@ { "intent": "offer_intent", "domain": "Restaurants_1", - "slot": "intent", - "value": "ReserveRestaurant" + "slot": "ReserveRestaurant" } ], "categorical": [], @@ -8083,8 +7926,7 @@ { "intent": "affirm_intent", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -8189,20 +8031,17 @@ { "intent": "request", "domain": "Restaurants_1", - "slot": "phone_number", - "value": "" + "slot": "phone_number" }, { "intent": "request", "domain": "Restaurants_1", - "slot": "price_range", - "value": "" + "slot": "price_range" }, { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8242,8 +8081,7 @@ { "intent": "notify_failure", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [ @@ -8334,8 +8172,7 @@ { "intent": "affirm", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8372,8 +8209,7 @@ { "intent": "notify_success", "domain": "Restaurants_1", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8418,8 +8254,7 @@ { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8456,8 +8291,7 @@ { "intent": "req_more", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8475,14 +8309,12 @@ { "intent": "negate", "domain": "", - "slot": "", - "value": "" + "slot": "" }, { "intent": "thank_you", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -8519,8 +8351,7 @@ { "intent": "goodbye", "domain": "", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], diff --git a/data/unified_datasets/sgd/preprocess.py b/data/unified_datasets/sgd/preprocess.py index 9c120e7e250bd009fa05435e0295e814e31f9a41..4f334e3abbebcbd81fdb206726271669832f04ad 100644 --- a/data/unified_datasets/sgd/preprocess.py +++ b/data/unified_datasets/sgd/preprocess.py @@ -134,9 +134,9 @@ def preprocess(): 'intents': get_intent(), 'state': {}, 'dialogue_acts': { - "categorical": set(), - "non-categorical": set(), - "binary": set() + "categorical": {}, + "non-categorical": {}, + "binary": {} }} splits = ['train', 'validation', 'test'] dataset_name = 'sgd' @@ -216,8 +216,7 @@ def preprocess(): turn['dialogue_acts']['binary'].append({ "intent": intent, "domain": '', - "slot": '', - "value": '', + "slot": '' }) elif action['act'] in ['NOTIFY_SUCCESS', 'NOTIFY_FAILURE', 'REQUEST_ALTS', 'AFFIRM_INTENT', 'NEGATE_INTENT']: # Slot and values are always empty @@ -225,17 +224,15 @@ def preprocess(): turn['dialogue_acts']['binary'].append({ "intent": intent, "domain": domain, - "slot": '', - "value": '', + "slot": '' }) elif action['act'] in ['OFFER_INTENT', 'INFORM_INTENT']: - # always has "intent" as the slot, and a single value containing the intent being offered. + # slot containing the intent being offered. assert slot == 'intent' and len(value_list) == 1 turn['dialogue_acts']['binary'].append({ "intent": intent, "domain": domain, - "slot": slot, - "value": value_list[0], + "slot": value_list[0] }) elif action['act'] in ['REQUEST'] and len(value_list) == 0: # always contains a slot, but values are optional. @@ -243,8 +240,7 @@ def preprocess(): turn['dialogue_acts']['binary'].append({ "intent": intent, "domain": domain, - "slot": slot, - "value": '', + "slot": slot }) elif action['act'] in ['SELECT'] and len(value_list) == 0: # (slot=='' and len(value_list) == 0) or (slot!='' and len(value_list) > 0) @@ -252,8 +248,7 @@ def preprocess(): turn['dialogue_acts']['binary'].append({ "intent": intent, "domain": domain, - "slot": slot, - "value": '', + "slot": slot }) elif action['act'] in ['INFORM_COUNT']: # always has "count" as the slot, and a single element in values for the number of results obtained by the system. @@ -340,19 +335,13 @@ def preprocess(): for da_type in turn['dialogue_acts']: das = turn['dialogue_acts'][da_type] for da in das: - intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'] - if da_type == 'binary': - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot, value)) - else: - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot)) + ontology["dialogue_acts"][da_type].setdefault((da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][(da['intent'], da['domain'], da['slot'])][speaker] = True dialogue['turns'].append(turn) dialogues.append(dialogue) for da_type in ontology['dialogue_acts']: - if da_type == 'binary': - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3],'value':da[4]}) for da in sorted(ontology["dialogue_acts"][da_type])] - else: - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3]}) for da in sorted(ontology["dialogue_acts"][da_type])] + ontology["dialogue_acts"][da_type] = sorted([str({'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent':da[0],'domain':da[1], 'slot':da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) diff --git a/data/unified_datasets/tm1/data.zip b/data/unified_datasets/tm1/data.zip index d1f6a9c80a770c417892d9b5f188e5d9d46eb589..aa7a1e2b8bc7da0e4a2c76386950b8865e42d28e 100644 Binary files a/data/unified_datasets/tm1/data.zip and b/data/unified_datasets/tm1/data.zip differ diff --git a/data/unified_datasets/tm1/dummy_data.json b/data/unified_datasets/tm1/dummy_data.json index 982375d753b698cd6b427cf6c9e3e8bbfb2ab926..8e1797347f4eea0655c963d6fa720d39e04ec222 100644 --- a/data/unified_datasets/tm1/dummy_data.json +++ b/data/unified_datasets/tm1/dummy_data.json @@ -1490,8 +1490,7 @@ { "intent": "accept", "domain": "auto_repair", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -1914,8 +1913,7 @@ { "intent": "accept", "domain": "coffee_ordering", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -4677,8 +4675,7 @@ { "intent": "accept", "domain": "restaurant_reservation", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5081,8 +5078,7 @@ { "intent": "inform", "domain": "movie_ticket", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], @@ -5156,8 +5152,7 @@ { "intent": "accept", "domain": "movie_ticket", - "slot": "", - "value": "" + "slot": "" } ], "categorical": [], diff --git a/data/unified_datasets/tm1/preprocess.py b/data/unified_datasets/tm1/preprocess.py index 5bb2956b186a92f40c65423886647770f070e9c7..22df49aac5685a8e2c1e52eedafe781a385dd209 100644 --- a/data/unified_datasets/tm1/preprocess.py +++ b/data/unified_datasets/tm1/preprocess.py @@ -151,9 +151,9 @@ def preprocess(): }, 'state': {}, 'dialogue_acts': { - "categorical": set(), - "non-categorical": set(), - "binary": set() + "categorical": {}, + "non-categorical": {}, + "binary": {} }} global descriptions ori_ontology = {} @@ -260,7 +260,6 @@ def preprocess(): 'intent': intent, 'domain': domain, 'slot': '', - 'value': '' }) else: assert turn['utterance'][segment['start_index']:segment['end_index']] == segment['text'] @@ -282,18 +281,15 @@ def preprocess(): bdas = set() for da in turn['dialogue_acts']['binary']: - da_tuple = (da['intent'], da['domain'], da['slot'], da['value'],) + da_tuple = (da['intent'], da['domain'], da['slot'],) bdas.add(da_tuple) - turn['dialogue_acts']['binary'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in sorted(bdas)] + turn['dialogue_acts']['binary'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2]} for bda in sorted(bdas)] # add to dialogue_acts dictionary in the ontology for da_type in turn['dialogue_acts']: das = turn['dialogue_acts'][da_type] for da in das: - intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'] - if da_type == 'binary': - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot, value)) - else: - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot)) + ontology["dialogue_acts"][da_type].setdefault((da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][(da['intent'], da['domain'], da['slot'])][speaker] = True for da in turn['dialogue_acts']['non-categorical']: slot, value = da['slot'], da['value'] @@ -311,14 +307,11 @@ def preprocess(): dialogues_by_split[data_split].append(dialogue) for da_type in ontology['dialogue_acts']: - if da_type == 'binary': - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3],'value':da[4]}) for da in sorted(ontology["dialogue_acts"][da_type])] - else: - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3]}) for da in sorted(ontology["dialogue_acts"][da_type])] + ontology["dialogue_acts"][da_type] = sorted([str({'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent':da[0],'domain':da[1], 'slot':da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) dialogues = dialogues_by_split['train']+dialogues_by_split['validation']+dialogues_by_split['test'] json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) - json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: for filename in os.listdir(new_data_dir): zf.write(f'{new_data_dir}/{filename}') diff --git a/data/unified_datasets/tm2/data.zip b/data/unified_datasets/tm2/data.zip index e6b3a3ad77fa3ce405e3185ddce48299620624d7..4eae9b371e5eb90dfaf83b1fcd8d9ae7b83221be 100644 Binary files a/data/unified_datasets/tm2/data.zip and b/data/unified_datasets/tm2/data.zip differ diff --git a/data/unified_datasets/tm2/dummy_data.json b/data/unified_datasets/tm2/dummy_data.json index 6cb30f4be2581202be683487db67cdd37917ca3c..044c749923a1f4bb979611124f6eed0f090fe838 100644 --- a/data/unified_datasets/tm2/dummy_data.json +++ b/data/unified_datasets/tm2/dummy_data.json @@ -1045,8 +1045,7 @@ { "intent": "inform", "domain": "flights", - "slot": "flight_booked", - "value": "" + "slot": "flight_booked" } ], "categorical": [], @@ -2162,8 +2161,7 @@ { "intent": "inform", "domain": "flights", - "slot": "flight_booked", - "value": "" + "slot": "flight_booked" } ], "categorical": [], @@ -2809,8 +2807,7 @@ { "intent": "inform", "domain": "flights", - "slot": "flight_booked", - "value": "" + "slot": "flight_booked" } ], "categorical": [], @@ -4708,8 +4705,7 @@ { "intent": "inform", "domain": "flights", - "slot": "flight_booked", - "value": "" + "slot": "flight_booked" } ], "categorical": [], diff --git a/data/unified_datasets/tm2/preprocess.py b/data/unified_datasets/tm2/preprocess.py index fa2c9c376a082ad4043618e5e77334547be48a81..b07597d4eb5e79a97242f153948dbf1c2802f44e 100644 --- a/data/unified_datasets/tm2/preprocess.py +++ b/data/unified_datasets/tm2/preprocess.py @@ -245,9 +245,9 @@ def preprocess(): }, 'state': {}, 'dialogue_acts': { - "categorical": set(), - "non-categorical": set(), - "binary": set() + "categorical": {}, + "non-categorical": {}, + "binary": {} }} global descriptions global anno2slot @@ -357,7 +357,6 @@ def preprocess(): 'intent': intent, 'domain': domain, 'slot': slot, - 'value': '' }) continue else: @@ -382,18 +381,15 @@ def preprocess(): bdas = set() for da in turn['dialogue_acts']['binary']: - da_tuple = (da['intent'], da['domain'], da['slot'], da['value'],) + da_tuple = (da['intent'], da['domain'], da['slot'],) bdas.add(da_tuple) - turn['dialogue_acts']['binary'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in sorted(bdas)] + turn['dialogue_acts']['binary'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2]} for bda in sorted(bdas)] # add to dialogue_acts dictionary in the ontology for da_type in turn['dialogue_acts']: das = turn['dialogue_acts'][da_type] for da in das: - intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'] - if da_type == 'binary': - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot, value)) - else: - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot)) + ontology["dialogue_acts"][da_type].setdefault((da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][(da['intent'], da['domain'], da['slot'])][speaker] = True for da in turn['dialogue_acts']['non-categorical']: slot, value = da['slot'], da['value'] @@ -409,14 +405,11 @@ def preprocess(): dialogues_by_split[data_split].append(dialogue) for da_type in ontology['dialogue_acts']: - if da_type == 'binary': - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3],'value':da[4]}) for da in sorted(ontology["dialogue_acts"][da_type])] - else: - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3]}) for da in sorted(ontology["dialogue_acts"][da_type])] + ontology["dialogue_acts"][da_type] = sorted([str({'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent':da[0],'domain':da[1], 'slot':da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) dialogues = dialogues_by_split['train']+dialogues_by_split['validation']+dialogues_by_split['test'] json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) - json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: for filename in os.listdir(new_data_dir): zf.write(f'{new_data_dir}/{filename}') diff --git a/data/unified_datasets/tm3/data.zip b/data/unified_datasets/tm3/data.zip index 07f325ff146e69e0fcfaf4bc23570585b716760a..b53ea9fa00f2250f9dd974f254dc7e9342bbd566 100644 Binary files a/data/unified_datasets/tm3/data.zip and b/data/unified_datasets/tm3/data.zip differ diff --git a/data/unified_datasets/tm3/preprocess.py b/data/unified_datasets/tm3/preprocess.py index 78a3efb5d3f3e23385962f0e6c4a3d4dc8c62ac8..a16a8c2eb2ade1e02ffa2c932413b8ca5e94e88f 100644 --- a/data/unified_datasets/tm3/preprocess.py +++ b/data/unified_datasets/tm3/preprocess.py @@ -102,9 +102,9 @@ def preprocess(): }, 'state': {}, 'dialogue_acts': { - "categorical": set(), - "non-categorical": set(), - "binary": set() + "categorical": {}, + "non-categorical": {}, + "binary": {} }} global descriptions global anno2slot @@ -197,7 +197,6 @@ def preprocess(): 'intent': intent, 'domain': domain, 'slot': slot, - 'value': '' }) continue assert turn['utterance'][segment['start_index']:segment['end_index']] == segment['text'] @@ -219,18 +218,15 @@ def preprocess(): bdas = set() for da in turn['dialogue_acts']['binary']: - da_tuple = (da['intent'], da['domain'], da['slot'], da['value'],) + da_tuple = (da['intent'], da['domain'], da['slot'],) bdas.add(da_tuple) - turn['dialogue_acts']['binary'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in sorted(bdas)] + turn['dialogue_acts']['binary'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2]} for bda in sorted(bdas)] # add to dialogue_acts dictionary in the ontology for da_type in turn['dialogue_acts']: das = turn['dialogue_acts'][da_type] for da in das: - intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'] - if da_type == 'binary': - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot, value)) - else: - ontology["dialogue_acts"][da_type].add((speaker, intent, domain, slot)) + ontology["dialogue_acts"][da_type].setdefault((da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][(da['intent'], da['domain'], da['slot'])][speaker] = True for da in turn['dialogue_acts']['non-categorical']: slot, value = da['slot'], da['value'] @@ -250,14 +246,11 @@ def preprocess(): dialogues_by_split[data_split].append(dialogue) for da_type in ontology['dialogue_acts']: - if da_type == 'binary': - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3],'value':da[4]}) for da in sorted(ontology["dialogue_acts"][da_type])] - else: - ontology["dialogue_acts"][da_type] = [str({'speaker': da[0], 'intent':da[1],'domain':da[2],'slot':da[3]}) for da in sorted(ontology["dialogue_acts"][da_type])] + ontology["dialogue_acts"][da_type] = sorted([str({'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent':da[0],'domain':da[1], 'slot':da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) dialogues = dialogues_by_split['train']+dialogues_by_split['validation']+dialogues_by_split['test'] json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) - json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: for filename in os.listdir(new_data_dir): zf.write(f'{new_data_dir}/{filename}')