diff --git a/convlab2/nlu/jointBERT/README.md b/convlab2/nlu/jointBERT/README.md new file mode 100755 index 0000000000000000000000000000000000000000..c9756d3c1ebdd42e975bb86d32a532b066a29048 --- /dev/null +++ b/convlab2/nlu/jointBERT/README.md @@ -0,0 +1,57 @@ +# BERTNLU + +On top of the pre-trained BERT, BERTNLU use an MLP for slot tagging and another MLP for intent classification. All parameters are fine-tuned to learn these two tasks jointly. + +Dialog acts are split into two groups, depending on whether the values are in the utterances: + +- For dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values. For example, `"Find me a cheap hotel"`, its dialog act is `{intent=Inform, domain=hotel, slot=price, value=cheap}`, and the corresponding BIO tag sequence is `["O", "O", "O", "B-inform-hotel-price", "O"]`. An MLP classifier takes a token's representation from BERT and outputs its tag. +- For dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances. Another MLP takes embeddings of `[CLS]` of a utterance as input and does the binary classification for each intent independently. Since some intents are rare, we set the weight of positive samples as $\lg(\frac{\# \ negative\ samples}{\# \ positive\ samples})$ empirically for each intent. + +The model can also incorporate context information by setting the `context=true` in the config file. The context utterances will be concatenated (separated by `[SEP]`) and fed into BERT. Then the `[CLS]` embedding serves as context representaion and is concatenated to all token representations in the target utterance right before the slot and intent classifiers. + + +## Usage + +Follow the instruction under each dataset's directory to prepare data and model config file for training and evaluation. + +#### Train a model + +```sh +$ python train.py --config_path path_to_a_config_file +``` + +The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. + +#### Test a model + +```sh +$ python test.py --config_path path_to_a_config_file +``` + +The result (`output.json`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file. + + +## References + +``` +@inproceedings{devlin2019bert, + title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding}, + author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina}, + booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)}, + pages={4171--4186}, + year={2019} +} + +@inproceedings{zhu-etal-2020-convlab, + title = "{C}onv{L}ab-2: An Open-Source Toolkit for Building, Evaluating, and Diagnosing Dialogue Systems", + author = "Zhu, Qi and Zhang, Zheng and Fang, Yan and Li, Xiang and Takanobu, Ryuichi and Li, Jinchao and Peng, Baolin and Gao, Jianfeng and Zhu, Xiaoyan and Huang, Minlie", + booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations", + month = jul, + year = "2020", + address = "Online", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2020.acl-demos.19", + doi = "10.18653/v1/2020.acl-demos.19", + pages = "142--149" +} +``` \ No newline at end of file diff --git a/convlab2/nlu/jointBERT/dataloader.py b/convlab2/nlu/jointBERT/dataloader.py index 38fc24ea0fdc288410a65146716996432ebd896c..d1fcbc7a4864211a9956cedacc7c3479c195733a 100755 --- a/convlab2/nlu/jointBERT/dataloader.py +++ b/convlab2/nlu/jointBERT/dataloader.py @@ -39,13 +39,13 @@ class Dataloader: for d in self.data[data_key]: max_sen_len = max(max_sen_len, len(d[0])) sen_len.append(len(d[0])) - # d = (tokens, tags, intents, da2triples(turn["dialog_act"], context(list of str)) + # d = (tokens, tags, intents, original dialog acts, context(list of str)) if cut_sen_len > 0: d[0] = d[0][:cut_sen_len] d[1] = d[1][:cut_sen_len] d[4] = [' '.join(s.split()[:cut_sen_len]) for s in d[4]] - d[4] = self.tokenizer.encode('[CLS] ' + ' [SEP] '.join(d[4])) + d[4] = self.tokenizer.encode(' [SEP] '.join(d[4])) max_context_len = max(max_context_len, len(d[4])) context_len.append(len(d[4])) diff --git a/convlab2/nlu/jointBERT/test.py b/convlab2/nlu/jointBERT/test.py index 7856e5ecc0c1be7471f9497339a7f9fbc2f3f9ec..2e1e1b51940c5d899a833fc9fba3a7f6aa257e7b 100755 --- a/convlab2/nlu/jointBERT/test.py +++ b/convlab2/nlu/jointBERT/test.py @@ -29,7 +29,11 @@ if __name__ == '__main__': set_seed(config['seed']) - if 'multiwoz' in data_dir: + if 'unified_datasets' in data_dir: + dataset_name = config['dataset_name'] + print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20) + from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent + elif 'multiwoz' in data_dir: print('-'*20 + 'dataset:multiwoz' + '-'*20) from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent elif 'camrest' in data_dir: @@ -90,14 +94,25 @@ if __name__ == '__main__': 'predict': predicts, 'golden': labels }) - predict_golden['slot'].append({ - 'predict': [x for x in predicts if is_slot_da(x)], - 'golden': [x for x in labels if is_slot_da(x)] - }) - predict_golden['intent'].append({ - 'predict': [x for x in predicts if not is_slot_da(x)], - 'golden': [x for x in labels if not is_slot_da(x)] - }) + if isinstance(predicts, dict): + predict_golden['slot'].append({ + 'predict': {k:v for k, v in predicts.items() if is_slot_da(k)}, + 'golden': {k:v for k, v in labels.items() if is_slot_da(k)} + }) + predict_golden['intent'].append({ + 'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)}, + 'golden': {k:v for k, v in labels.items() if not is_slot_da(k)} + }) + else: + assert isinstance(predicts, list) + predict_golden['slot'].append({ + 'predict': [x for x in predicts if is_slot_da(x)], + 'golden': [x for x in labels if is_slot_da(x)] + }) + predict_golden['intent'].append({ + 'predict': [x for x in predicts if not is_slot_da(x)], + 'golden': [x for x in labels if not is_slot_da(x)] + }) print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key]))) total = len(dataloader.data[data_key]) diff --git a/convlab2/nlu/jointBERT/train.py b/convlab2/nlu/jointBERT/train.py index a6267b9403dd805ea7537763f1063cbcc03965d1..fad50eda9c3b6676d9ce5b9e00dcc961e14ae4e7 100755 --- a/convlab2/nlu/jointBERT/train.py +++ b/convlab2/nlu/jointBERT/train.py @@ -32,7 +32,11 @@ if __name__ == '__main__': set_seed(config['seed']) - if 'multiwoz' in data_dir: + if 'unified_datasets' in data_dir: + dataset_name = config['dataset_name'] + print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20) + from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent + elif 'multiwoz' in data_dir: print('-'*20 + 'dataset:multiwoz' + '-'*20) from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent elif 'camrest' in data_dir: @@ -149,14 +153,25 @@ if __name__ == '__main__': 'predict': predicts, 'golden': labels }) - predict_golden['slot'].append({ - 'predict': [x for x in predicts if is_slot_da(x)], - 'golden': [x for x in labels if is_slot_da(x)] - }) - predict_golden['intent'].append({ - 'predict': [x for x in predicts if not is_slot_da(x)], - 'golden': [x for x in labels if not is_slot_da(x)] - }) + if isinstance(predicts, dict): + predict_golden['slot'].append({ + 'predict': {k:v for k, v in predicts.items() if is_slot_da(k)}, + 'golden': {k:v for k, v in labels.items() if is_slot_da(k)} + }) + predict_golden['intent'].append({ + 'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)}, + 'golden': {k:v for k, v in labels.items() if not is_slot_da(k)} + }) + else: + assert isinstance(predicts, list) + predict_golden['slot'].append({ + 'predict': [x for x in predicts if is_slot_da(x)], + 'golden': [x for x in labels if is_slot_da(x)] + }) + predict_golden['intent'].append({ + 'predict': [x for x in predicts if not is_slot_da(x)], + 'golden': [x for x in labels if not is_slot_da(x)] + }) for j in range(10): writer.add_text('val_sample_{}'.format(j), diff --git a/convlab2/nlu/jointBERT/unified_datasets/README.md b/convlab2/nlu/jointBERT/unified_datasets/README.md new file mode 100755 index 0000000000000000000000000000000000000000..6ced41a7f793da71dcf07edfdc58c29cde4996d5 --- /dev/null +++ b/convlab2/nlu/jointBERT/unified_datasets/README.md @@ -0,0 +1,41 @@ +# BERTNLU on datasets in unified format + +We support training BERTNLU on datasets that are in our unified format. + +- For **non-categorical** dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values. +- For **categorical** and **binary** dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances. + +## Usage + +#### Preprocess data + +```sh +$ python preprocess.py --dataset dataset_name --speaker {user,system,all} --context_window_size CONTEXT_WINDOW_SIZE --save_dir save_directory +``` + +Note that the dataset will be loaded by `convlab2.util.load_dataset(dataset_name)`. If you want to use custom datasets, make sure they follow the unified format and can be loaded using this function. +output processed data on `${save_dir}/${dataset_name}/${speaker}/context_window_size_${context_window_size}` dir. + +#### Train a model + +Prepare a config file and run the training script in the parent directory: + +```sh +$ python train.py --config_path path_to_a_config_file +``` + +The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file. + +#### Test a model + +Run the inference script in the parent directory: + +```sh +$ python test.py --config_path path_to_a_config_file +``` + +The result (`output.json`) will be saved under the `output_dir` of the config file. + +#### Predict + +See `nlu.py` for usage. diff --git a/convlab2/nlu/jointBERT/unified_datasets/__init__.py b/convlab2/nlu/jointBERT/unified_datasets/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..74b08065ca97132195ab56c65cb7fe87ec7b4fca --- /dev/null +++ b/convlab2/nlu/jointBERT/unified_datasets/__init__.py @@ -0,0 +1 @@ +from convlab2.nlu.jointBERT.unified_datasets.nlu import BERTNLU \ No newline at end of file diff --git a/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py new file mode 100755 index 0000000000000000000000000000000000000000..6de31fbea9825b54b2e29bcf51e035a571de1c6b --- /dev/null +++ b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py @@ -0,0 +1,33 @@ +import json +import os +from convlab2.util import load_dataset, load_nlu_data + + +def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): + assert os.path.exists(predict_result) + dataset = load_dataset(dataset_name) + data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] + + if save_dir is None: + save_dir = os.path.dirname(predict_result) + else: + os.makedirs(save_dir, exist_ok=True) + predict_result = json.load(open(predict_result)) + + for sample, prediction in zip(data, predict_result): + sample['dialogue_acts_prediction'] = prediction['predict'] + + json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation") + parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset') + parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') + parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') + parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated by ../test.py') + args = parser.parse_args() + print(args) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) diff --git a/convlab2/nlu/jointBERT/unified_datasets/nlu.py b/convlab2/nlu/jointBERT/unified_datasets/nlu.py new file mode 100755 index 0000000000000000000000000000000000000000..063ea036e4999626ca02452b1c1dc9f38ddb913f --- /dev/null +++ b/convlab2/nlu/jointBERT/unified_datasets/nlu.py @@ -0,0 +1,106 @@ +import logging +import os +import json +import torch +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +import transformers +from convlab2.nlu.nlu import NLU +from convlab2.nlu.jointBERT.dataloader import Dataloader +from convlab2.nlu.jointBERT.jointBERT import JointBERT +from convlab2.nlu.jointBERT.unified_datasets.postprocess import recover_intent +from convlab2.util.custom_util import model_downloader + + +class BERTNLU(NLU): + def __init__(self, mode, config_file, model_file=None): + assert mode == 'user' or mode == 'sys' or mode == 'all' + self.mode = mode + config_file = os.path.join(os.path.dirname( + os.path.abspath(__file__)), 'configs/{}'.format(config_file)) + config = json.load(open(config_file)) + # print(config['DEVICE']) + # DEVICE = config['DEVICE'] + DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE'] + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + data_dir = os.path.join(root_dir, config['data_dir']) + output_dir = os.path.join(root_dir, config['output_dir']) + + assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first') + + intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) + tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'))) + dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab, + pretrained_weights=config['model']['pretrained_weights']) + + logging.info('intent num:' + str(len(intent_vocab))) + logging.info('tag num:' + str(len(tag_vocab))) + + if not os.path.exists(output_dir): + model_downloader(root_dir, model_file) + model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) + + state_dict = torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE) + if int(transformers.__version__.split('.')[0]) >= 3 and 'bert.embeddings.position_ids' not in state_dict: + state_dict['bert.embeddings.position_ids'] = torch.tensor(range(512)).reshape(1, -1).to(DEVICE) + + model.load_state_dict(state_dict) + model.to(DEVICE) + model.eval() + + self.model = model + self.use_context = config['model']['context'] + self.context_window_size = config['context_window_size'] + self.dataloader = dataloader + self.sent_tokenizer = PunktSentenceTokenizer() + self.word_tokenizer = TreebankWordTokenizer() + logging.info("BERTNLU loaded") + + def predict(self, utterance, context=list()): + sentences = self.sent_tokenizer.tokenize(utterance) + ori_word_seq = [token for sent in sentences for token in self.word_tokenizer.tokenize(sent)] + ori_tag_seq = [str(('O',))] * len(ori_word_seq) + if self.use_context: + if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: + context = [item[1] for item in context] + context_seq = self.dataloader.tokenizer.encode(' [SEP] '.join(context[-self.context_window_size:])) + context_seq = context_seq[:510] + else: + context_seq = self.dataloader.tokenizer.encode('') + intents = [] + da = {} + + word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq) + word_seq = word_seq[:510] + tag_seq = tag_seq[:510] + batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq, + new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]] + + pad_batch = self.dataloader.pad_batch(batch_data) + pad_batch = tuple(t.to(self.model.device) for t in pad_batch) + word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch + slot_logits, intent_logits = self.model.forward(word_seq_tensor, word_mask_tensor, + context_seq_tensor=context_seq_tensor, + context_mask_tensor=context_mask_tensor) + das = recover_intent(self.dataloader, intent_logits[0], slot_logits[0], tag_mask_tensor[0], + batch_data[0][0], batch_data[0][-4]) + dialog_act = [] + for da_type in das: + for da in das[da_type]: + dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')]) + return dialog_act + + +if __name__ == '__main__': + texts = [ + "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", + "I want to leave after 17:15.", + "Thank you for all the help! I appreciate it.", + "Please find a restaurant called Nusha.", + "What is the train id, please? ", + "I don't care about the price and it doesn't need to have free parking." + ] + nlu = BERTNLU(mode='user', config_file='multiwoz21_user.json') + for text in texts: + print(text) + print(nlu.predict(text)) + print() diff --git a/convlab2/nlu/jointBERT/unified_datasets/postprocess.py b/convlab2/nlu/jointBERT/unified_datasets/postprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..982b4a92a6df10832ac313acd5984af63567dd1d --- /dev/null +++ b/convlab2/nlu/jointBERT/unified_datasets/postprocess.py @@ -0,0 +1,111 @@ +import re +import torch + + +def is_slot_da(da_type): + return da_type == 'non-categorical' + + +def calculateF1(predict_golden): + # F1 of all three types of dialogue acts + TP, FP, FN = 0, 0, 0 + for item in predict_golden: + for da_type in ['non-categorical', 'categorical', 'binary']: + if da_type not in item['predict']: + assert da_type not in item['golden'] + continue + if da_type == 'binary': + predicts = [(x['intent'], x['domain'], x['slot']) for x in item['predict'][da_type]] + labels = [(x['intent'], x['domain'], x['slot']) for x in item['golden'][da_type]] + else: + predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['predict'][da_type]] + labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['golden'][da_type]] + + for ele in predicts: + if ele in labels: + TP += 1 + else: + FP += 1 + for ele in labels: + if ele not in predicts: + FN += 1 + # print(TP, FP, FN) + precision = 1.0 * TP / (TP + FP) if TP + FP else 0. + recall = 1.0 * TP / (TP + FN) if TP + FN else 0. + F1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0. + return precision, recall, F1 + + +def tag2triples(word_seq, tag_seq): + word_seq = word_seq[:len(tag_seq)] + assert len(word_seq)==len(tag_seq) + triples = [] + i = 0 + while i < len(tag_seq): + tag = eval(tag_seq[i]) + if tag[-1] == 'B': + intent, domain, slot = tag[0], tag[1], tag[2] + value = word_seq[i] + j = i + 1 + while j < len(tag_seq): + next_tag = eval(tag_seq[j]) + if next_tag[-1] == 'I' and next_tag[:-1] == tag[:-1]: + value += ' ' + word_seq[j] + i += 1 + j += 1 + else: + break + triples.append([intent, domain, slot, value]) + i += 1 + return triples + + +def recover_intent(dataloader, intent_logits, tag_logits, tag_mask_tensor, ori_word_seq, new2ori): + # tag_logits = [sequence_length, tag_dim] + # intent_logits = [intent_dim] + # tag_mask_tensor = [sequence_length] + # new2ori = {(new_idx:old_idx),...} (after removing [CLS] and [SEP] + max_seq_len = tag_logits.size(0) + dialogue_acts = { + "categorical": [], + "non-categorical": [], + "binary": [] + } + # for categorical & binary dialogue acts + for j in range(dataloader.intent_dim): + if intent_logits[j] > 0: + intent = eval(dataloader.id2intent[j]) + if len(intent) == 3: + dialogue_acts['binary'].append({ + 'intent': intent[0], + 'domain': intent[1], + 'slot': intent[2] + }) + else: + assert len(intent) == 4 + dialogue_acts['categorical'].append({ + 'intent': intent[0], + 'domain': intent[1], + 'slot': intent[2], + 'value': intent[3] + }) + # for non-categorical dialogues acts + tags = [] + for j in range(1, max_seq_len-1): + if tag_mask_tensor[j] == 1: + value, tag_id = torch.max(tag_logits[j], dim=-1) + tags.append(dataloader.id2tag[tag_id.item()]) + recover_tags = [] + for i, tag in enumerate(tags): + if new2ori[i] >= len(recover_tags): + recover_tags.append(tag) + ori_word_seq = ori_word_seq[:len(recover_tags)] + tag_intent = tag2triples(ori_word_seq, recover_tags) + for intent in tag_intent: + dialogue_acts['non-categorical'].append({ + 'intent': intent[0], + 'domain': intent[1], + 'slot': intent[2], + 'value': intent[3] + }) + return dialogue_acts diff --git a/convlab2/nlu/jointBERT/unified_datasets/preprocess.py b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..055e0724ab4c8d5d8f467a6d11354426974da3d8 --- /dev/null +++ b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py @@ -0,0 +1,93 @@ +import json +import os +from collections import Counter +from convlab2.util import load_dataset, load_ontology, load_nlu_data +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +from tqdm import tqdm + + +def preprocess(dataset_name, speaker, save_dir, context_window_size): + dataset = load_dataset(dataset_name) + data_by_split = load_nlu_data(dataset, speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size) + data_dir = os.path.join(save_dir, dataset_name, speaker, f'context_window_size_{context_window_size}') + os.makedirs(data_dir, exist_ok=True) + + sent_tokenizer = PunktSentenceTokenizer() + word_tokenizer = TreebankWordTokenizer() + + processed_data = {} + all_tags = set([str(('O',))]) + all_intents = Counter() + for data_split, data in data_by_split.items(): + if data_split == 'validation': + data_split = 'val' + processed_data[data_split] = [] + for sample in tqdm(data, desc=f'{data_split} samples'): + + utterance = sample['utterance'] + + sentences = sent_tokenizer.tokenize(utterance) + sent_spans = sent_tokenizer.span_tokenize(utterance) + tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)] + token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)] + tags = [str(('O',))] * len(tokens) + for da in sample['dialogue_acts']['non-categorical']: + if 'start' not in da: + # skip da that doesn't have span annotation + continue + char_start = da['start'] + char_end = da['end'] + word_start, word_end = -1, -1 + for i, token_span in enumerate(token_spans): + if char_start == token_span[0]: + word_start = i + if char_end == token_span[1]: + word_end = i + 1 + if word_start == -1 and word_end == -1: + # char span does not match word, maybe there is an error in the annotation, skip + print('char span does not match word, skipping') + print('\t', 'utteance:', utterance) + print('\t', 'value:', utterance[char_start: char_end]) + print('\t', 'da:', da, '\n') + continue + intent, domain, slot = da['intent'], da['domain'], da['slot'] + all_tags.add(str((intent, domain, slot, 'B'))) + all_tags.add(str((intent, domain, slot, 'I'))) + tags[word_start] = str((intent, domain, slot, 'B')) + for i in range(word_start+1, word_end): + tags[i] = str((intent, domain, slot, 'I')) + + intents = [] + for da in sample['dialogue_acts']['categorical']: + intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'].strip().lower() + intent = str((intent, domain, slot, value)) + intents.append(intent) + all_intents[intent] += 1 + for da in sample['dialogue_acts']['binary']: + intent, domain, slot = da['intent'], da['domain'], da['slot'] + intent = str((intent, domain, slot)) + intents.append(intent) + all_intents[intent] += 1 + context = [] + if context_window_size > 0: + context = [s['utterance'] for s in sample['context']] + processed_data[data_split].append([tokens, tags, intents, sample['dialogue_acts'], context]) + json.dump(processed_data[data_split], open(os.path.join(data_dir, '{}_data.json'.format(data_split)), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + + # filter out intents that occur only once to get intent vocabulary. however, these intents are still in the data + all_intents = {x: count for x, count in all_intents.items() if count > 1} + print('sentence label num:', len(all_intents)) + print('tag num:', len(all_tags)) + json.dump(sorted(all_intents), open(os.path.join(data_dir, 'intent_vocab.json'), 'w'), indent=2) + json.dump(sorted(all_tags), open(os.path.join(data_dir, 'tag_vocab.json'), 'w'), indent=2) + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="create nlu data for bertnlu training") + parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset') + parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') + parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, save_dir/$dataset_name/$speaker') + parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') + args = parser.parse_args() + print(args) + preprocess(args.dataset, args.speaker, args.save_dir, args.context_window_size) diff --git a/setup.py b/setup.py index 2e426a70fc9a9e6c74bbd1f5a5ae0e4f7cffc8d4..62cbf1bd880e55e5dc13b26322dcf67ce6b1b73b 100755 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ setup( 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], install_requires=[ + 'matplotlib', 'tabulate', 'python-Levenshtein', 'requests',