diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index e4344bd838785dda7a3736c37d7577a2887fd9d7..ed9c211f9ae1276df257376056df0677057eaca0 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -3,6 +3,7 @@ from typing import Dict, List, Tuple from zipfile import ZipFile import json import os +import re import importlib from abc import ABC, abstractmethod from pprint import pprint @@ -180,6 +181,81 @@ def load_rg_data(dataset, data_split='all', speaker='system', context_window_siz kwargs.setdefault('utterance', True) return load_unified_data(dataset, **kwargs) + +def create_delex_data(dataset, delex_format='[({domain})-({slot})]', ignore_values=['yes', 'no']): + # add delex_utterance to the dataset according to dialogue acts and belief_state + + def delex_inplace(texts_placeholders, value_pattern): + res = [] + for substring, is_placeholder in texts_placeholders: + if not is_placeholder: + matches = value_pattern.findall(substring) + res.append(len(matches) == 1) + else: + res.append(False) + if sum(res) == 1: + # only one piece matches + idx = res.index(True) + substring = texts_placeholders[idx][0] + searchObj = re.search(value_pattern, substring) + assert searchObj + start, end = searchObj.span(2) + texts_placeholders[idx:idx+1] = [(substring[0:start], False), (placeholder, True), (substring[end:], False)] + return True + return False + + delex_vocab = set() + for data_split in dataset: + for dialog in dataset[data_split]: + state = {} + for turn in dialog['turns']: + utt = turn['utterance'] + delex_utt = [] + last_end = 0 + # ignore the non-categorical das that do not have span annotation + spans = [x for x in turn['dialogue_acts']['non-categorical'] if 'start' in x] + for da in sorted(spans, key=lambda x: x['start']): + # from left to right + start, end = da['start'], da['end'] + domain, slot, value = da['domain'], da['slot'], da['value'] + assert utt[start:end] == value + # make sure there are no words/number prepend & append and no overlap with other spans + if start >= last_end and (start == 0 or re.match('\W', utt[start-1])) and (end == len(utt) or re.match('\W', utt[end])): + placeholder = delex_format.format(domain=domain, slot=slot, value=value) + delex_vocab.add(placeholder) + delex_utt.append((utt[last_end:start], False)) + delex_utt.append((placeholder, True)) + last_end = end + delex_utt.append((utt[last_end:], False)) + + # search for value in categorical dialogue acts and belief state + for da in sorted(turn['dialogue_acts']['categorical'], key=lambda x: len(x['value'])): + domain, slot, value = da['domain'], da['slot'], da['value'] + if value.lower() not in ignore_values: + placeholder = delex_format.format(domain=domain, slot=slot, value=value) + pattern = re.compile(r'(\W|^)'+f'({value})'+r'(\W|$)', flags=re.I) + if delex_inplace(delex_utt, pattern): + delex_vocab.add(placeholder) + + # for domain in turn['state'] + if 'state' in turn: + state = turn['state'] + for domain in state: + for slot, values in state[domain].items(): + if len(values) > 0: + # has value + for value in values.split('|'): + if value.lower() not in ignore_values: + placeholder = delex_format.format(domain=domain, slot=slot, value=value) + pattern = re.compile(r'(\W|^)'+f'({value})'+r'(\W|$)', flags=re.I) + if delex_inplace(delex_utt, pattern): + delex_vocab.add(placeholder) + + turn['delex_utterance'] = ''.join([x[0] for x in delex_utt]) + + return dataset, sorted(list(delex_vocab)) + + if __name__ == "__main__": dataset = load_dataset('multiwoz21') print(dataset.keys()) @@ -192,3 +268,13 @@ if __name__ == "__main__": data_by_split = load_nlu_data(dataset, data_split='test', speaker='user') pprint(data_by_split['test'][0]) + + dataset, delex_vocab = create_delex_data(dataset) + json.dump(dataset['test'], open('delex_multiwoz21_test.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(delex_vocab, open('delex_vocab.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + with open('delex_cmp.txt', 'w') as f: + for dialog in dataset['test']: + for turn in dialog['turns']: + f.write(turn['utterance']+'\n') + f.write(turn['delex_utterance']+'\n') + f.write('\n')