diff --git a/data/unified_datasets/multiwoz21/preprocess.py b/data/unified_datasets/multiwoz21/preprocess.py index 07fcc261e557855c34e77cca393345b3898e44f0..fa88ecf256c1c242d9a77b6536874199699c4576 100644 --- a/data/unified_datasets/multiwoz21/preprocess.py +++ b/data/unified_datasets/multiwoz21/preprocess.py @@ -567,8 +567,29 @@ slot_name_map = { 'train': { 'day': 'day', 'time': "duration" }, - 'police': {}, - 'booking': {} + 'police': {} +} + +reverse_da_slot_name_map = { + 'address': 'Addr', + 'postcode': 'Post', + 'price range': 'Price', + 'arrive by': 'Arrive', + 'leave at': 'Leave', + 'departure': 'Depart', + 'destination': 'Dest', + 'entrance fee': 'Fee', + 'open hours': 'Open', + 'price': 'Ticket', + 'train id': 'Id', + 'book people': 'People', + 'book stay': 'Stay', + 'book day': 'Day', + 'book time': 'Time', + 'duration': 'Time', + 'taxi': { + 'type': 'Car' + } } digit2word = { @@ -578,6 +599,40 @@ digit2word = { cnt_domain_slot = Counter() +def reverse_da(dialogue_acts): + global reverse_da_slot_name_map + das = {} + for da_type in dialogue_acts: + for da in dialogue_acts[da_type]: + intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da.get('value', '') + if domain == 'general': + Domain_Intent = '-'.join([domain, intent]) + elif intent == 'nooffer': + Domain_Intent = '-'.join([domain.capitalize(), 'NoOffer']) + elif intent == 'nobook': + Domain_Intent = '-'.join([domain.capitalize(), 'NoBook']) + elif intent == 'offerbook': + Domain_Intent = '-'.join([domain.capitalize(), 'OfferBook']) + else: + Domain_Intent = '-'.join([domain.capitalize(), intent.capitalize()]) + das.setdefault(Domain_Intent, []) + if slot in reverse_da_slot_name_map: + Slot = reverse_da_slot_name_map[slot] + elif domain in reverse_da_slot_name_map and slot in reverse_da_slot_name_map[domain]: + Slot = reverse_da_slot_name_map[domain][slot] + else: + Slot = slot.capitalize() + if value == '': + if intent == 'request': + value = '?' + else: + value = 'none' + if Slot == '': + Slot = 'none' + das[Domain_Intent].append([Slot, value]) + return das + + def normalize_domain_slot_value(domain, slot, value): global ontology, slot_name_map domain = domain.lower() @@ -818,6 +873,15 @@ def preprocess(): dialogue_acts = convert_da(da_dict, utt, sent_tokenizer, word_tokenizer) + # reverse_das = reverse_da(dialogue_acts) + # das_list = sorted([(Domain_Intent, Slot, ''.join(value.split()).lower()) for Domain_Intent in das for Slot, value in das[Domain_Intent]]) + # reverse_das_list = sorted([(Domain_Intent, Slot, ''.join(value.split()).lower()) for Domain_Intent in reverse_das for Slot, value in reverse_das[Domain_Intent]]) + # if das_list != reverse_das_list: + # print(das_list) + # print(reverse_das_list) + # print() + # print() + dialogue['turns'].append({ 'speaker': speaker, 'utterance': utt,