Skip to content
Snippets Groups Projects
Commit 301e106f authored by zqwerty's avatar zqwerty
Browse files

train bert for token classification

parent 582758df
Branches
No related tags found
No related merge requests found
import os
import json
from tqdm import tqdm
from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from collections import Counter
def create_bio_data(dataset, data_dir):
data_by_split = load_nlu_data(dataset, speaker='all')
os.makedirs(data_dir, exist_ok=True)
sent_tokenizer = PunktSentenceTokenizer()
word_tokenizer = TreebankWordTokenizer()
data_splits = data_by_split.keys()
cnt = Counter()
for data_split in data_splits:
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
utterance = sample['utterance']
dialogue_acts = [da for da in sample['dialogue_acts']['non-categorical'] if 'start' in da]
cnt[len(dialogue_acts)] += 1
sentences = sent_tokenizer.tokenize(utterance)
sent_spans = sent_tokenizer.span_tokenize(utterance)
tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)]
token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)]
labels = ['O'] * len(tokens)
for da in dialogue_acts:
char_start = da['start']
char_end = da['end']
word_start, word_end = -1, -1
for i, token_span in enumerate(token_spans):
if char_start == token_span[0]:
word_start = i
if char_end == token_span[1]:
word_end = i + 1
if word_start == -1 and word_end == -1:
# char span does not match word, skip
continue
labels[word_start] = 'B'
for i in range(word_start+1, word_end):
labels[i] = "I"
data.append(json.dumps({'tokens': tokens, 'labels': labels})+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w") as f:
f.writelines(data)
print('num of spans in utterances', cnt)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['bio'], help='names of tasks')
parser.add_argument('--datasets', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, default: data/$task_name/$dataset_name')
args = parser.parse_args()
print(args)
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_dir = os.path.join(args.save_dir, task_name, dataset_name)
eval(f"create_{task_name}_data")(dataset, data_dir)
n_gpus=8
task_name="bio"
dataset_name="sgd"
data_dir="data/${task_name}/${dataset_name}"
output_dir="output/${task_name}/${dataset_name}"
cache_dir="cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
source_column="tokens"
target_column="labels"
model_name_or_path="bert-base-uncased"
per_device_train_batch_size=128
per_device_eval_batch_size=512
gradient_accumulation_steps=1
lr=2e-5
num_train_epochs=1
metric_for_best_model="f1"
python create_data.py --tasks ${task_name} --datasets ${dataset_name} --save_dir "data"
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} run_token_classification.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--metric_for_best_model ${metric_for_best_model} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow
This diff is collapsed.
...@@ -42,6 +42,7 @@ setup( ...@@ -42,6 +42,7 @@ setup(
'torch>=1.6', 'torch>=1.6',
'transformers>=4.0', 'transformers>=4.0',
'datasets>=1.8', 'datasets>=1.8',
'seqeval',
'spacy', 'spacy',
'allennlp', 'allennlp',
'simplejson', 'simplejson',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment