diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py index 28109eae096e8266a44f463488b904dc9a4d52fb..b6ef65db298378b744a45130fd71c072243bcfca 100644 --- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py +++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py @@ -14,7 +14,6 @@ def main(args): data_split = filename.split('/')[-1].split('_')[-1].split('.')[0] output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}") print(f'processing {dataset_name}: {filename} => {output_file}') - cnt = 0 with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout: for dial in tqdm(json_lines.reader(fin)): context = [] @@ -57,7 +56,10 @@ def main(args): continue if args.mode == 'key2gen_noisy': - possible_keywords_sents = turn['keywords'][:] + if random.random() < 0.8: + possible_keywords_sents = turn['keywords'][:] + else: + possible_keywords_sents = [] num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1) for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns): possible_keywords_sents.extend(turn_keywords) diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh index ea1d4d51c7cb0dc56c74b47d15b4c1292a08f847..eb67a18b22e0480323f132abbf42a2d1508755b9 100644 --- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh +++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh @@ -18,4 +18,23 @@ do cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file} fi done +done + +# merge key2gen+key2gen_noisy data +task_name="key2gen+key2gen_noisy" +dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog" +names=$(echo ${task_name} | tr "+" "\n") +model_type="gpt" +data_dir=data/${task_name}/${model_type}/${dataset_name} +mkdir -p ${data_dir} +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +rm ${train_file} ${validation_file} +for name in ${names} +do + echo "preprocessing ${name}" + if [ "${name}" != "${task_name}" ]; then + cat "data/${name}/${model_type}/${dataset_name}/train.json" >> ${train_file} + cat "data/${name}/${model_type}/${dataset_name}/validation.json" >> ${validation_file} + fi done \ No newline at end of file diff --git a/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh b/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh index e63d2e894c7fc0cabea6d994df41245547943327..dd8d5a460478ae068143b5d18914500e31bae439 100644 --- a/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh +++ b/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh @@ -1,24 +1,23 @@ set -e -n_gpus=4 -master_port=23457 -task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy" -dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog" +n_gpus=8 +master_port=23456 +task_name="key2gen+key2gen_noisy" +dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog" model_type="gpt" data_dir="data/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}" cache_dir="../cache" logging_dir="${output_dir}/runs" train_file="${data_dir}/train.json" -validation_file="${data_dir}/validation.json" source_column="source" target_column="target" truncation_side="left" max_source_length=512 max_target_length=128 -model_name_or_path="t5-small" +model_name_or_path="output/rg/${model_type}/${dataset_name}" per_device_train_batch_size=64 per_device_eval_batch_size=128 -gradient_accumulation_steps=2 +gradient_accumulation_steps=1 num_workers=16 lr=1e-3 num_train_epochs=1 @@ -27,7 +26,6 @@ python -m torch.distributed.launch --master_port ${master_port} \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --task_name ${task_name} \ --train_file ${train_file} \ - --validation_file ${validation_file} \ --source_column ${source_column} \ --target_column ${target_column} \ --max_source_length ${max_source_length} \ @@ -35,15 +33,11 @@ python -m torch.distributed.launch --master_port ${master_port} \ --truncation_side ${truncation_side} \ --model_name_or_path ${model_name_or_path} \ --do_train \ - --do_eval \ - --save_strategy epoch \ - --evaluation_strategy epoch \ - --load_best_model_at_end \ - --prediction_loss_only \ + --save_steps 5000 \ + --save_total_limit 3 \ --cache_dir ${cache_dir} \ --output_dir ${output_dir} \ --logging_dir ${logging_dir} \ - --overwrite_output_dir \ --preprocessing_num_workers ${num_workers} \ --dataloader_num_workers ${num_workers} \ --per_device_train_batch_size ${per_device_train_batch_size} \ @@ -52,4 +46,5 @@ python -m torch.distributed.launch --master_port ${master_port} \ --learning_rate ${lr} \ --num_train_epochs ${num_train_epochs} \ --optim adafactor \ + --lr_scheduler_type constant \ --gradient_checkpointing