import os import json from tqdm import tqdm from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer from collections import Counter def create_bio_data(dataset, data_dir): data_by_split = load_nlu_data(dataset, speaker='all') os.makedirs(data_dir, exist_ok=True) sent_tokenizer = PunktSentenceTokenizer() word_tokenizer = TreebankWordTokenizer() data_splits = data_by_split.keys() cnt = Counter() for data_split in data_splits: data = [] for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): utterance = sample['utterance'] dialogue_acts = [da for da in sample['dialogue_acts']['non-categorical'] if 'start' in da] cnt[len(dialogue_acts)] += 1 sentences = sent_tokenizer.tokenize(utterance) sent_spans = sent_tokenizer.span_tokenize(utterance) tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)] token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)] labels = ['O'] * len(tokens) for da in dialogue_acts: char_start = da['start'] char_end = da['end'] word_start, word_end = -1, -1 for i, token_span in enumerate(token_spans): if char_start == token_span[0]: word_start = i if char_end == token_span[1]: word_end = i + 1 if word_start == -1 and word_end == -1: # char span does not match word, skip continue labels[word_start] = 'B' for i in range(word_start+1, word_end): labels[i] = "I" data.append(json.dumps({'tokens': tokens, 'labels': labels}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") with open(file_name, "w", encoding='utf-8') as f: f.writelines(data) print('num of spans in utterances', cnt) def create_dialogBIO_data(dataset, data_dir): data_by_split = load_nlu_data(dataset, split_to_turn=False) os.makedirs(data_dir, exist_ok=True) sent_tokenizer = PunktSentenceTokenizer() word_tokenizer = TreebankWordTokenizer() data_splits = data_by_split.keys() cnt = Counter() for data_split in data_splits: data = [] for dialog in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): all_tokens, all_labels = [], [] for sample in dialog['turns']: speaker = sample['speaker'] utterance = sample['utterance'] dialogue_acts = [da for da in sample['dialogue_acts']['non-categorical'] if 'start' in da] cnt[len(dialogue_acts)] += 1 sentences = sent_tokenizer.tokenize(utterance) sent_spans = sent_tokenizer.span_tokenize(utterance) tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)] token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)] labels = ['O'] * len(tokens) for da in dialogue_acts: char_start = da['start'] char_end = da['end'] word_start, word_end = -1, -1 for i, token_span in enumerate(token_spans): if char_start == token_span[0]: word_start = i if char_end == token_span[1]: word_end = i + 1 if word_start == -1 and word_end == -1: # char span does not match word, skip continue labels[word_start] = 'B' for i in range(word_start+1, word_end): labels[i] = "I" all_tokens.extend([speaker, ':']+tokens) all_labels.extend(['O', 'O']+labels) data.append(json.dumps({'tokens': all_tokens, 'labels': all_labels}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") with open(file_name, "w", encoding='utf-8') as f: f.writelines(data) print('num of spans in utterances', cnt) if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser(description="create data for seq2seq training") parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['bio', 'dialogBIO'], help='names of tasks') parser.add_argument('--datasets', metavar='dataset_name', nargs='*', help='names of unified datasets') parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, default: data/$task_name/$dataset_name') args = parser.parse_args() print(args) for dataset_name in tqdm(args.datasets, desc='datasets'): dataset = load_dataset(dataset_name) for task_name in tqdm(args.tasks, desc='tasks', leave=False): data_dir = os.path.join(args.save_dir, task_name, dataset_name) eval(f"create_{task_name}_data")(dataset, data_dir)