Skip to content
Snippets Groups Projects
Commit ca6cf298 authored by zqwerty's avatar zqwerty
Browse files

multitask pre-training

parent e4872bae
No related branches found
No related tags found
No related merge requests found
......@@ -14,7 +14,6 @@ def main(args):
data_split = filename.split('/')[-1].split('_')[-1].split('.')[0]
output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}")
print(f'processing {dataset_name}: {filename} => {output_file}')
cnt = 0
with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout:
for dial in tqdm(json_lines.reader(fin)):
context = []
......@@ -57,7 +56,10 @@ def main(args):
continue
if args.mode == 'key2gen_noisy':
if random.random() < 0.8:
possible_keywords_sents = turn['keywords'][:]
else:
possible_keywords_sents = []
num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
possible_keywords_sents.extend(turn_keywords)
......
......@@ -19,3 +19,22 @@ do
fi
done
done
# merge key2gen+key2gen_noisy data
task_name="key2gen+key2gen_noisy"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
names=$(echo ${task_name} | tr "+" "\n")
model_type="gpt"
data_dir=data/${task_name}/${model_type}/${dataset_name}
mkdir -p ${data_dir}
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
rm ${train_file} ${validation_file}
for name in ${names}
do
echo "preprocessing ${name}"
if [ "${name}" != "${task_name}" ]; then
cat "data/${name}/${model_type}/${dataset_name}/train.json" >> ${train_file}
cat "data/${name}/${model_type}/${dataset_name}/validation.json" >> ${validation_file}
fi
done
\ No newline at end of file
set -e
n_gpus=4
master_port=23457
task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
n_gpus=8
master_port=23456
task_name="key2gen+key2gen_noisy"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="t5-small"
model_name_or_path="output/rg/${model_type}/${dataset_name}"
per_device_train_batch_size=64
per_device_eval_batch_size=128
gradient_accumulation_steps=2
gradient_accumulation_steps=1
num_workers=16
lr=1e-3
num_train_epochs=1
......@@ -27,7 +26,6 @@ python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
......@@ -35,15 +33,11 @@ python -m torch.distributed.launch --master_port ${master_port} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--prediction_loss_only \
--save_steps 5000 \
--save_total_limit 3 \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
......@@ -52,4 +46,5 @@ python -m torch.distributed.launch --master_port ${master_port} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment