diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
index 28109eae096e8266a44f463488b904dc9a4d52fb..b6ef65db298378b744a45130fd71c072243bcfca 100644
--- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
+++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -14,7 +14,6 @@ def main(args):
         data_split = filename.split('/')[-1].split('_')[-1].split('.')[0]
         output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}")
         print(f'processing {dataset_name}: {filename} => {output_file}')
-        cnt = 0
         with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout:
             for dial in tqdm(json_lines.reader(fin)):
                 context = []
@@ -57,7 +56,10 @@ def main(args):
                         continue
 
                     if args.mode == 'key2gen_noisy':
-                        possible_keywords_sents = turn['keywords'][:]
+                        if random.random() < 0.8:
+                            possible_keywords_sents = turn['keywords'][:]
+                        else:
+                            possible_keywords_sents = []
                         num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
                         for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
                             possible_keywords_sents.extend(turn_keywords)
diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
index ea1d4d51c7cb0dc56c74b47d15b4c1292a08f847..eb67a18b22e0480323f132abbf42a2d1508755b9 100644
--- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
+++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -18,4 +18,23 @@ do
             cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file}
         fi
     done
+done
+
+# merge key2gen+key2gen_noisy data
+task_name="key2gen+key2gen_noisy"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
+names=$(echo ${task_name} | tr "+" "\n")
+model_type="gpt"
+data_dir=data/${task_name}/${model_type}/${dataset_name}
+mkdir -p ${data_dir}
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+rm ${train_file} ${validation_file}
+for name in ${names}
+do
+    echo "preprocessing ${name}"
+    if [ "${name}" != "${task_name}" ]; then
+        cat "data/${name}/${model_type}/${dataset_name}/train.json" >> ${train_file}
+        cat "data/${name}/${model_type}/${dataset_name}/validation.json" >> ${validation_file}
+    fi
 done
\ No newline at end of file
diff --git a/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh b/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh
index e63d2e894c7fc0cabea6d994df41245547943327..dd8d5a460478ae068143b5d18914500e31bae439 100644
--- a/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh
+++ b/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh
@@ -1,24 +1,23 @@
 set -e
-n_gpus=4
-master_port=23457
-task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
-dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+n_gpus=8
+master_port=23456
+task_name="key2gen+key2gen_noisy"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
 output_dir="output/${task_name}/${model_type}/${dataset_name}"
 cache_dir="../cache"
 logging_dir="${output_dir}/runs"
 train_file="${data_dir}/train.json"
-validation_file="${data_dir}/validation.json"
 source_column="source"
 target_column="target"
 truncation_side="left"
 max_source_length=512
 max_target_length=128
-model_name_or_path="t5-small"
+model_name_or_path="output/rg/${model_type}/${dataset_name}"
 per_device_train_batch_size=64
 per_device_eval_batch_size=128
-gradient_accumulation_steps=2
+gradient_accumulation_steps=1
 num_workers=16
 lr=1e-3
 num_train_epochs=1
@@ -27,7 +26,6 @@ python -m torch.distributed.launch --master_port ${master_port} \
     --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \
-    --validation_file ${validation_file} \
     --source_column ${source_column} \
     --target_column ${target_column} \
     --max_source_length ${max_source_length} \
@@ -35,15 +33,11 @@ python -m torch.distributed.launch --master_port ${master_port} \
     --truncation_side ${truncation_side} \
     --model_name_or_path ${model_name_or_path} \
     --do_train \
-    --do_eval \
-    --save_strategy epoch \
-    --evaluation_strategy epoch \
-    --load_best_model_at_end \
-    --prediction_loss_only \
+    --save_steps 5000 \
+    --save_total_limit 3 \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
     --logging_dir ${logging_dir} \
-    --overwrite_output_dir \
     --preprocessing_num_workers ${num_workers} \
     --dataloader_num_workers ${num_workers} \
     --per_device_train_batch_size ${per_device_train_batch_size} \
@@ -52,4 +46,5 @@ python -m torch.distributed.launch --master_port ${master_port} \
     --learning_rate ${lr} \
     --num_train_epochs ${num_train_epochs} \
     --optim adafactor \
+    --lr_scheduler_type constant \
     --gradient_checkpointing