Skip to content
Snippets Groups Projects
Unverified Commit 88f2d263 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #30 from ConvLab/nlu

bertnlu can use unified datasets now
parents f2ec9960 95eb2201
No related branches found
No related tags found
No related merge requests found
Showing
with 547 additions and 20 deletions
# BERTNLU
On top of the pre-trained BERT, BERTNLU use an MLP for slot tagging and another MLP for intent classification. All parameters are fine-tuned to learn these two tasks jointly.
Dialog acts are split into two groups, depending on whether the values are in the utterances:
- For dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values. For example, `"Find me a cheap hotel"`, its dialog act is `{intent=Inform, domain=hotel, slot=price, value=cheap}`, and the corresponding BIO tag sequence is `["O", "O", "O", "B-inform-hotel-price", "O"]`. An MLP classifier takes a token's representation from BERT and outputs its tag.
- For dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances. Another MLP takes embeddings of `[CLS]` of a utterance as input and does the binary classification for each intent independently. Since some intents are rare, we set the weight of positive samples as $\lg(\frac{\# \ negative\ samples}{\# \ positive\ samples})$ empirically for each intent.
The model can also incorporate context information by setting the `context=true` in the config file. The context utterances will be concatenated (separated by `[SEP]`) and fed into BERT. Then the `[CLS]` embedding serves as context representaion and is concatenated to all token representations in the target utterance right before the slot and intent classifiers.
## Usage
Follow the instruction under each dataset's directory to prepare data and model config file for training and evaluation.
#### Train a model
```sh
$ python train.py --config_path path_to_a_config_file
```
The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file.
#### Test a model
```sh
$ python test.py --config_path path_to_a_config_file
```
The result (`output.json`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file.
## References
```
@inproceedings{devlin2019bert,
title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)},
pages={4171--4186},
year={2019}
}
@inproceedings{zhu-etal-2020-convlab,
title = "{C}onv{L}ab-2: An Open-Source Toolkit for Building, Evaluating, and Diagnosing Dialogue Systems",
author = "Zhu, Qi and Zhang, Zheng and Fang, Yan and Li, Xiang and Takanobu, Ryuichi and Li, Jinchao and Peng, Baolin and Gao, Jianfeng and Zhu, Xiaoyan and Huang, Minlie",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2020.acl-demos.19",
doi = "10.18653/v1/2020.acl-demos.19",
pages = "142--149"
}
```
\ No newline at end of file
...@@ -39,13 +39,13 @@ class Dataloader: ...@@ -39,13 +39,13 @@ class Dataloader:
for d in self.data[data_key]: for d in self.data[data_key]:
max_sen_len = max(max_sen_len, len(d[0])) max_sen_len = max(max_sen_len, len(d[0]))
sen_len.append(len(d[0])) sen_len.append(len(d[0]))
# d = (tokens, tags, intents, da2triples(turn["dialog_act"], context(list of str)) # d = (tokens, tags, intents, original dialog acts, context(list of str))
if cut_sen_len > 0: if cut_sen_len > 0:
d[0] = d[0][:cut_sen_len] d[0] = d[0][:cut_sen_len]
d[1] = d[1][:cut_sen_len] d[1] = d[1][:cut_sen_len]
d[4] = [' '.join(s.split()[:cut_sen_len]) for s in d[4]] d[4] = [' '.join(s.split()[:cut_sen_len]) for s in d[4]]
d[4] = self.tokenizer.encode('[CLS] ' + ' [SEP] '.join(d[4])) d[4] = self.tokenizer.encode(' [SEP] '.join(d[4]))
max_context_len = max(max_context_len, len(d[4])) max_context_len = max(max_context_len, len(d[4]))
context_len.append(len(d[4])) context_len.append(len(d[4]))
......
...@@ -29,7 +29,11 @@ if __name__ == '__main__': ...@@ -29,7 +29,11 @@ if __name__ == '__main__':
set_seed(config['seed']) set_seed(config['seed'])
if 'multiwoz' in data_dir: if 'unified_datasets' in data_dir:
dataset_name = config['dataset_name']
print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
elif 'multiwoz' in data_dir:
print('-'*20 + 'dataset:multiwoz' + '-'*20) print('-'*20 + 'dataset:multiwoz' + '-'*20)
from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
elif 'camrest' in data_dir: elif 'camrest' in data_dir:
...@@ -90,6 +94,17 @@ if __name__ == '__main__': ...@@ -90,6 +94,17 @@ if __name__ == '__main__':
'predict': predicts, 'predict': predicts,
'golden': labels 'golden': labels
}) })
if isinstance(predicts, dict):
predict_golden['slot'].append({
'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
})
predict_golden['intent'].append({
'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
})
else:
assert isinstance(predicts, list)
predict_golden['slot'].append({ predict_golden['slot'].append({
'predict': [x for x in predicts if is_slot_da(x)], 'predict': [x for x in predicts if is_slot_da(x)],
'golden': [x for x in labels if is_slot_da(x)] 'golden': [x for x in labels if is_slot_da(x)]
......
...@@ -32,7 +32,11 @@ if __name__ == '__main__': ...@@ -32,7 +32,11 @@ if __name__ == '__main__':
set_seed(config['seed']) set_seed(config['seed'])
if 'multiwoz' in data_dir: if 'unified_datasets' in data_dir:
dataset_name = config['dataset_name']
print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
elif 'multiwoz' in data_dir:
print('-'*20 + 'dataset:multiwoz' + '-'*20) print('-'*20 + 'dataset:multiwoz' + '-'*20)
from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
elif 'camrest' in data_dir: elif 'camrest' in data_dir:
...@@ -149,6 +153,17 @@ if __name__ == '__main__': ...@@ -149,6 +153,17 @@ if __name__ == '__main__':
'predict': predicts, 'predict': predicts,
'golden': labels 'golden': labels
}) })
if isinstance(predicts, dict):
predict_golden['slot'].append({
'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
})
predict_golden['intent'].append({
'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
})
else:
assert isinstance(predicts, list)
predict_golden['slot'].append({ predict_golden['slot'].append({
'predict': [x for x in predicts if is_slot_da(x)], 'predict': [x for x in predicts if is_slot_da(x)],
'golden': [x for x in labels if is_slot_da(x)] 'golden': [x for x in labels if is_slot_da(x)]
......
# BERTNLU on datasets in unified format
We support training BERTNLU on datasets that are in our unified format.
- For **non-categorical** dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values.
- For **categorical** and **binary** dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances.
## Usage
#### Preprocess data
```sh
$ python preprocess.py --dataset dataset_name --speaker {user,system,all} --context_window_size CONTEXT_WINDOW_SIZE --save_dir save_directory
```
Note that the dataset will be loaded by `convlab2.util.load_dataset(dataset_name)`. If you want to use custom datasets, make sure they follow the unified format and can be loaded using this function.
output processed data on `${save_dir}/${dataset_name}/${speaker}/context_window_size_${context_window_size}` dir.
#### Train a model
Prepare a config file and run the training script in the parent directory:
```sh
$ python train.py --config_path path_to_a_config_file
```
The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file.
#### Test a model
Run the inference script in the parent directory:
```sh
$ python test.py --config_path path_to_a_config_file
```
The result (`output.json`) will be saved under the `output_dir` of the config file.
#### Predict
See `nlu.py` for usage.
from convlab2.nlu.jointBERT.unified_datasets.nlu import BERTNLU
\ No newline at end of file
{
"dataset_name": "multiwoz21",
"data_dir": "unified_datasets/data/multiwoz21/user/context_window_size_0",
"output_dir": "unified_datasets/output/multiwoz21/user/context_window_size_0",
"zipped_model_path": "unified_datasets/output/multiwoz21/user/context_window_size_0/bertnlu_unified_multiwoz_user_context0.zip",
"log_dir": "unified_datasets/output/multiwoz21/user/context_window_size_0/log",
"DEVICE": "cuda:0",
"seed": 2019,
"cut_sen_len": 40,
"use_bert_tokenizer": true,
"context_window_size": 0,
"model": {
"finetune": true,
"context": false,
"context_grad": false,
"pretrained_weights": "bert-base-uncased",
"check_step": 1000,
"max_step": 10000,
"batch_size": 128,
"learning_rate": 1e-4,
"adam_epsilon": 1e-8,
"warmup_steps": 0,
"weight_decay": 0.0,
"dropout": 0.1,
"hidden_units": 768
}
}
\ No newline at end of file
{
"dataset_name": "multiwoz21",
"data_dir": "unified_datasets/data/multiwoz21/user/context_window_size_3",
"output_dir": "unified_datasets/output/multiwoz21/user/context_window_size_3",
"zipped_model_path": "unified_datasets/output/multiwoz21/user/context_window_size_3/bertnlu_unified_multiwoz_user_context3.zip",
"log_dir": "unified_datasets/output/multiwoz21/user/context_window_size_3/log",
"DEVICE": "cuda:0",
"seed": 2019,
"cut_sen_len": 40,
"use_bert_tokenizer": true,
"context_window_size": 3,
"model": {
"finetune": true,
"context": true,
"context_grad": true,
"pretrained_weights": "bert-base-uncased",
"check_step": 1000,
"max_step": 10000,
"batch_size": 128,
"learning_rate": 1e-4,
"adam_epsilon": 1e-8,
"warmup_steps": 0,
"weight_decay": 0.0,
"dropout": 0.1,
"hidden_units": 1536
}
}
\ No newline at end of file
import json
import os
from convlab2.util import load_dataset, load_nlu_data
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name)
data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None:
save_dir = os.path.dirname(predict_result)
else:
os.makedirs(save_dir, exist_ok=True)
predict_result = json.load(open(predict_result))
for sample, prediction in zip(data, predict_result):
sample['predictions'] = {'dialogue_acts': prediction['predict']}
json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated by ../test.py')
args = parser.parse_args()
print(args)
merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
import logging
import os
import json
import torch
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
import transformers
from convlab2.nlu.nlu import NLU
from convlab2.nlu.jointBERT.dataloader import Dataloader
from convlab2.nlu.jointBERT.jointBERT import JointBERT
from convlab2.nlu.jointBERT.unified_datasets.postprocess import recover_intent
from convlab2.util.custom_util import model_downloader
class BERTNLU(NLU):
def __init__(self, mode, config_file, model_file=None):
assert mode == 'user' or mode == 'sys' or mode == 'all'
self.mode = mode
config_file = os.path.join(os.path.dirname(
os.path.abspath(__file__)), 'configs/{}'.format(config_file))
config = json.load(open(config_file))
# print(config['DEVICE'])
# DEVICE = config['DEVICE']
DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
data_dir = os.path.join(root_dir, config['data_dir'])
output_dir = os.path.join(root_dir, config['output_dir'])
assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first')
intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
pretrained_weights=config['model']['pretrained_weights'])
logging.info('intent num:' + str(len(intent_vocab)))
logging.info('tag num:' + str(len(tag_vocab)))
if not os.path.exists(output_dir):
model_downloader(root_dir, model_file)
model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
state_dict = torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)
if int(transformers.__version__.split('.')[0]) >= 3 and 'bert.embeddings.position_ids' not in state_dict:
state_dict['bert.embeddings.position_ids'] = torch.tensor(range(512)).reshape(1, -1).to(DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
self.model = model
self.use_context = config['model']['context']
self.context_window_size = config['context_window_size']
self.dataloader = dataloader
self.sent_tokenizer = PunktSentenceTokenizer()
self.word_tokenizer = TreebankWordTokenizer()
logging.info("BERTNLU loaded")
def predict(self, utterance, context=list()):
sentences = self.sent_tokenizer.tokenize(utterance)
ori_word_seq = [token for sent in sentences for token in self.word_tokenizer.tokenize(sent)]
ori_tag_seq = [str(('O',))] * len(ori_word_seq)
if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context]
context_seq = self.dataloader.tokenizer.encode(' [SEP] '.join(context[-self.context_window_size:]))
context_seq = context_seq[:510]
else:
context_seq = self.dataloader.tokenizer.encode('')
intents = []
da = {}
word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq)
word_seq = word_seq[:510]
tag_seq = tag_seq[:510]
batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq,
new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]]
pad_batch = self.dataloader.pad_batch(batch_data)
pad_batch = tuple(t.to(self.model.device) for t in pad_batch)
word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
slot_logits, intent_logits = self.model.forward(word_seq_tensor, word_mask_tensor,
context_seq_tensor=context_seq_tensor,
context_mask_tensor=context_mask_tensor)
das = recover_intent(self.dataloader, intent_logits[0], slot_logits[0], tag_mask_tensor[0],
batch_data[0][0], batch_data[0][-4])
dialog_act = []
for da_type in das:
for da in das[da_type]:
dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
return dialog_act
if __name__ == '__main__':
texts = [
"I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"I want to leave after 17:15.",
"Thank you for all the help! I appreciate it.",
"Please find a restaurant called Nusha.",
"What is the train id, please? ",
"I don't care about the price and it doesn't need to have free parking."
]
nlu = BERTNLU(mode='user', config_file='multiwoz21_user.json')
for text in texts:
print(text)
print(nlu.predict(text))
print()
import re
import torch
def is_slot_da(da_type):
return da_type == 'non-categorical'
def calculateF1(predict_golden):
# F1 of all three types of dialogue acts
TP, FP, FN = 0, 0, 0
for item in predict_golden:
for da_type in ['non-categorical', 'categorical', 'binary']:
if da_type not in item['predict']:
assert da_type not in item['golden']
continue
if da_type == 'binary':
predicts = [(x['intent'], x['domain'], x['slot']) for x in item['predict'][da_type]]
labels = [(x['intent'], x['domain'], x['slot']) for x in item['golden'][da_type]]
else:
predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['predict'][da_type]]
labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['golden'][da_type]]
for ele in predicts:
if ele in labels:
TP += 1
else:
FP += 1
for ele in labels:
if ele not in predicts:
FN += 1
# print(TP, FP, FN)
precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
F1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
return precision, recall, F1
def tag2triples(word_seq, tag_seq):
word_seq = word_seq[:len(tag_seq)]
assert len(word_seq)==len(tag_seq)
triples = []
i = 0
while i < len(tag_seq):
tag = eval(tag_seq[i])
if tag[-1] == 'B':
intent, domain, slot = tag[0], tag[1], tag[2]
value = word_seq[i]
j = i + 1
while j < len(tag_seq):
next_tag = eval(tag_seq[j])
if next_tag[-1] == 'I' and next_tag[:-1] == tag[:-1]:
value += ' ' + word_seq[j]
i += 1
j += 1
else:
break
triples.append([intent, domain, slot, value])
i += 1
return triples
def recover_intent(dataloader, intent_logits, tag_logits, tag_mask_tensor, ori_word_seq, new2ori):
# tag_logits = [sequence_length, tag_dim]
# intent_logits = [intent_dim]
# tag_mask_tensor = [sequence_length]
# new2ori = {(new_idx:old_idx),...} (after removing [CLS] and [SEP]
max_seq_len = tag_logits.size(0)
dialogue_acts = {
"categorical": [],
"non-categorical": [],
"binary": []
}
# for categorical & binary dialogue acts
for j in range(dataloader.intent_dim):
if intent_logits[j] > 0:
intent = eval(dataloader.id2intent[j])
if len(intent) == 3:
dialogue_acts['binary'].append({
'intent': intent[0],
'domain': intent[1],
'slot': intent[2]
})
else:
assert len(intent) == 4
dialogue_acts['categorical'].append({
'intent': intent[0],
'domain': intent[1],
'slot': intent[2],
'value': intent[3]
})
# for non-categorical dialogues acts
tags = []
for j in range(1, max_seq_len-1):
if tag_mask_tensor[j] == 1:
value, tag_id = torch.max(tag_logits[j], dim=-1)
tags.append(dataloader.id2tag[tag_id.item()])
recover_tags = []
for i, tag in enumerate(tags):
if new2ori[i] >= len(recover_tags):
recover_tags.append(tag)
ori_word_seq = ori_word_seq[:len(recover_tags)]
tag_intent = tag2triples(ori_word_seq, recover_tags)
for intent in tag_intent:
dialogue_acts['non-categorical'].append({
'intent': intent[0],
'domain': intent[1],
'slot': intent[2],
'value': intent[3]
})
return dialogue_acts
import json
import os
from collections import Counter
from convlab2.util import load_dataset, load_ontology, load_nlu_data
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from tqdm import tqdm
def preprocess(dataset_name, speaker, save_dir, context_window_size):
dataset = load_dataset(dataset_name)
data_by_split = load_nlu_data(dataset, speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)
data_dir = os.path.join(save_dir, dataset_name, speaker, f'context_window_size_{context_window_size}')
os.makedirs(data_dir, exist_ok=True)
sent_tokenizer = PunktSentenceTokenizer()
word_tokenizer = TreebankWordTokenizer()
processed_data = {}
all_tags = set([str(('O',))])
all_intents = Counter()
for data_split, data in data_by_split.items():
if data_split == 'validation':
data_split = 'val'
processed_data[data_split] = []
for sample in tqdm(data, desc=f'{data_split} samples'):
utterance = sample['utterance']
sentences = sent_tokenizer.tokenize(utterance)
sent_spans = sent_tokenizer.span_tokenize(utterance)
tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)]
token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)]
tags = [str(('O',))] * len(tokens)
for da in sample['dialogue_acts']['non-categorical']:
if 'start' not in da:
# skip da that doesn't have span annotation
continue
char_start = da['start']
char_end = da['end']
word_start, word_end = -1, -1
for i, token_span in enumerate(token_spans):
if char_start == token_span[0]:
word_start = i
if char_end == token_span[1]:
word_end = i + 1
if word_start == -1 and word_end == -1:
# char span does not match word, maybe there is an error in the annotation, skip
print('char span does not match word, skipping')
print('\t', 'utteance:', utterance)
print('\t', 'value:', utterance[char_start: char_end])
print('\t', 'da:', da, '\n')
continue
intent, domain, slot = da['intent'], da['domain'], da['slot']
all_tags.add(str((intent, domain, slot, 'B')))
all_tags.add(str((intent, domain, slot, 'I')))
tags[word_start] = str((intent, domain, slot, 'B'))
for i in range(word_start+1, word_end):
tags[i] = str((intent, domain, slot, 'I'))
intents = []
for da in sample['dialogue_acts']['categorical']:
intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'].strip().lower()
intent = str((intent, domain, slot, value))
intents.append(intent)
all_intents[intent] += 1
for da in sample['dialogue_acts']['binary']:
intent, domain, slot = da['intent'], da['domain'], da['slot']
intent = str((intent, domain, slot))
intents.append(intent)
all_intents[intent] += 1
context = []
if context_window_size > 0:
context = [s['utterance'] for s in sample['context']]
processed_data[data_split].append([tokens, tags, intents, sample['dialogue_acts'], context])
json.dump(processed_data[data_split], open(os.path.join(data_dir, '{}_data.json'.format(data_split)), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
# filter out intents that occur only once to get intent vocabulary. however, these intents are still in the data
all_intents = {x: count for x, count in all_intents.items() if count > 1}
print('sentence label num:', len(all_intents))
print('tag num:', len(all_tags))
json.dump(sorted(all_intents), open(os.path.join(data_dir, 'intent_vocab.json'), 'w'), indent=2)
json.dump(sorted(all_tags), open(os.path.join(data_dir, 'tag_vocab.json'), 'w'), indent=2)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create nlu data for bertnlu training")
parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, save_dir/$dataset_name/$speaker')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
args = parser.parse_args()
print(args)
preprocess(args.dataset, args.speaker, args.save_dir, args.context_window_size)
...@@ -34,6 +34,7 @@ setup( ...@@ -34,6 +34,7 @@ setup(
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
], ],
install_requires=[ install_requires=[
'matplotlib',
'tabulate', 'tabulate',
'python-Levenshtein', 'python-Levenshtein',
'requests', 'requests',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment