Skip to content
Snippets Groups Projects
Commit 3a38a9b5 authored by zqwerty's avatar zqwerty
Browse files

add gpt dialogLM

parent 824788e1
No related branches found
No related tags found
No related merge requests found
import os
import json
from tqdm import tqdm
import re
from convlab2.util import load_dataset
def create_lm_data(dataset, data_dir, args):
data_by_split = dataset
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
if args.model_type == 'dialogpt':
dialogue = ' <|endoftext|> '.join([turn['utterance'] for turn in sample['turns']])
else:
dialogue = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']])
data.append(json.dumps({'dialogue': dialogue}, ensure_ascii=False)+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['lm'], help='names of tasks')
parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--model_type', '-m', metavar='model_type', help='type of the language model: gpt, dialogpt, ..')
args = parser.parse_args()
print(args)
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_dir = os.path.join('data', task_name, dataset_name)
eval(f"create_{task_name}_data")(dataset, data_dir, args)
set -e
n_gpus=1
task_name="lm"
dataset_name="multiwoz21"
data_dir="data/${task_name}/${dataset_name}"
output_dir="output/${task_name}/${dataset_name}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
source_column="dialogue"
max_length=512
model_name_or_path="microsoft/DialoGPT-large"
per_device_train_batch_size=16
per_device_eval_batch_size=16
gradient_accumulation_steps=4
lr=5e-5
num_train_epochs=3
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
python ../run_clm.py \
--model_name_or_path ${model_name_or_path} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--max_length ${max_length} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--gradient_checkpointing
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment