Skip to content
Snippets Groups Projects
Commit ca6cf298 authored by zqwerty's avatar zqwerty
Browse files

multitask pre-training

parent e4872bae
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,6 @@ def main(args): ...@@ -14,7 +14,6 @@ def main(args):
data_split = filename.split('/')[-1].split('_')[-1].split('.')[0] data_split = filename.split('/')[-1].split('_')[-1].split('.')[0]
output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}") output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}")
print(f'processing {dataset_name}: {filename} => {output_file}') print(f'processing {dataset_name}: {filename} => {output_file}')
cnt = 0
with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout: with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout:
for dial in tqdm(json_lines.reader(fin)): for dial in tqdm(json_lines.reader(fin)):
context = [] context = []
...@@ -57,7 +56,10 @@ def main(args): ...@@ -57,7 +56,10 @@ def main(args):
continue continue
if args.mode == 'key2gen_noisy': if args.mode == 'key2gen_noisy':
if random.random() < 0.8:
possible_keywords_sents = turn['keywords'][:] possible_keywords_sents = turn['keywords'][:]
else:
possible_keywords_sents = []
num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1) num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns): for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
possible_keywords_sents.extend(turn_keywords) possible_keywords_sents.extend(turn_keywords)
......
...@@ -19,3 +19,22 @@ do ...@@ -19,3 +19,22 @@ do
fi fi
done done
done done
# merge key2gen+key2gen_noisy data
task_name="key2gen+key2gen_noisy"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
names=$(echo ${task_name} | tr "+" "\n")
model_type="gpt"
data_dir=data/${task_name}/${model_type}/${dataset_name}
mkdir -p ${data_dir}
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
rm ${train_file} ${validation_file}
for name in ${names}
do
echo "preprocessing ${name}"
if [ "${name}" != "${task_name}" ]; then
cat "data/${name}/${model_type}/${dataset_name}/train.json" >> ${train_file}
cat "data/${name}/${model_type}/${dataset_name}/validation.json" >> ${validation_file}
fi
done
\ No newline at end of file
set -e set -e
n_gpus=4 n_gpus=8
master_port=23457 master_port=23456
task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy" task_name="key2gen+key2gen_noisy"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog" dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
model_type="gpt" model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}" data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../cache" cache_dir="../cache"
logging_dir="${output_dir}/runs" logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
source_column="source" source_column="source"
target_column="target" target_column="target"
truncation_side="left" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=128 max_target_length=128
model_name_or_path="t5-small" model_name_or_path="output/rg/${model_type}/${dataset_name}"
per_device_train_batch_size=64 per_device_train_batch_size=64
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=2 gradient_accumulation_steps=1
num_workers=16 num_workers=16
lr=1e-3 lr=1e-3
num_train_epochs=1 num_train_epochs=1
...@@ -27,7 +26,6 @@ python -m torch.distributed.launch --master_port ${master_port} \ ...@@ -27,7 +26,6 @@ python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--train_file ${train_file} \ --train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \ --source_column ${source_column} \
--target_column ${target_column} \ --target_column ${target_column} \
--max_source_length ${max_source_length} \ --max_source_length ${max_source_length} \
...@@ -35,15 +33,11 @@ python -m torch.distributed.launch --master_port ${master_port} \ ...@@ -35,15 +33,11 @@ python -m torch.distributed.launch --master_port ${master_port} \
--truncation_side ${truncation_side} \ --truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \ --model_name_or_path ${model_name_or_path} \
--do_train \ --do_train \
--do_eval \ --save_steps 5000 \
--save_strategy epoch \ --save_total_limit 3 \
--evaluation_strategy epoch \
--load_best_model_at_end \
--prediction_loss_only \
--cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
--output_dir ${output_dir} \ --output_dir ${output_dir} \
--logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers ${num_workers} \ --preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \ --dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_train_batch_size ${per_device_train_batch_size} \
...@@ -52,4 +46,5 @@ python -m torch.distributed.launch --master_port ${master_port} \ ...@@ -52,4 +46,5 @@ python -m torch.distributed.launch --master_port ${master_port} \
--learning_rate ${lr} \ --learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \ --num_train_epochs ${num_train_epochs} \
--optim adafactor \ --optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing --gradient_checkpointing
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment