diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
index d6c2ec7a4e6f97973a7a0da4811e72ab36734801..28109eae096e8266a44f463488b904dc9a4d52fb 100644
--- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
+++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -1,4 +1,5 @@
 import json
+import json_lines
 import os
 import random
 from tqdm import tqdm
@@ -7,85 +8,19 @@ from nltk import sent_tokenize
 def main(args):
     random.seed(42)
     os.makedirs(args.output_dir, exist_ok=True)
-    if args.mode == 'multitask':
-        dataset_name = args.output_dir.split('/')[-1]
-        for data_split in ['validation', 'train']:
-            with open(os.path.join(args.output_dir, f"{data_split}.json"), 'w', encoding='utf-8') as fout:
-                for task_name in ['rg', 'key2gen', 'key2gen_noisy']:
-                    with open(os.path.join(args.input_dir, task_name, 'gpt', dataset_name, f"{data_split}.json")) as fin:
-                        for line in fin:
-                            item = json.loads(line)
-                            fout.write(json.dumps({'source': item['source'], 'target': item['target']}, ensure_ascii=False)+'\n')
-        return
-
-    if args.mode == 'sen2gen':
-        generated_filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if f.startswith('gen_')]
-        original_filenames = [f[4:] for f in generated_filenames]
-        for ori_f, gen_f in zip(original_filenames, generated_filenames):
-            fori = open(os.path.join(args.input_dir, ori_f))
-            fgen = open(os.path.join(args.input_dir, gen_f))
-            fout = open(os.path.join(args.output_dir, f"{ori_f.split('_')[0]}.json"), 'w', encoding='utf-8')
-            for ori_line, gen_line in zip(fori, fgen):
-                ori_item = json.loads(ori_line)
-                gen_item = json.loads(gen_line)
-                context = ori_item['source'][ori_item['source'].index('context:\n\n'):]
-                gen_sen = gen_item['predictions']
-                ori_item['source'] = f'generate a response: grounded knowledge: | {gen_sen} | {context}'
-                ori_item['gen_sen'] = gen_sen
-                fout.write(json.dumps(ori_item, ensure_ascii=False)+'\n')
-        return
-    if args.mode == 'sen2gen_noisy':
-        def gen_samples(dialog_samples):
-            turn_gen_sens = [sent_tokenize(item['gen_sen']) for item in dialog_samples]
-            for i, sample in enumerate(dialog_samples):
-                possible_sens_turns = turn_gen_sens[i][:]
-                num_possible_sens_turns = min(random.randint(1, 5), len(turn_gen_sens) - 1)
-                for turn_sens in random.sample(turn_gen_sens[:i] + turn_gen_sens[i+1:], num_possible_sens_turns):
-                    possible_sens_turns.extend(turn_sens)
-                random.shuffle(possible_sens_turns)
-                possible_sens = ' | '.join(possible_sens_turns)
-                context = sample['source'][sample['source'].index('context:\n\n'):]
-                sample['source'] = f'generate a response: all knowledge: | {possible_sens} | {context}'
-                yield sample
-
-        for ori_f in [f for (_, _, fs) in os.walk(args.input_dir) for f in fs]:
-            fori = open(os.path.join(args.input_dir, ori_f))
-            fout = open(os.path.join(args.output_dir, ori_f), 'w', encoding='utf-8')
-            dialog = []
-            prev_num_turns = 0
-            for line in fori:
-                item = json.loads(line)
-                num_turns = item['source'].count('\n')
-                if len(dialog) == 0 or num_turns < prev_num_turns:
-                    # process a dialog with augmented responses
-                    for sample in gen_samples(dialog):
-                        fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
-                    # next dialog
-                    dialog = [item]
-                else:
-                    # next turn
-                    dialog.append(item)
-                prev_num_turns = num_turns
-            for sample in gen_samples(dialog):
-                fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
-        return
-    filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
+    filenames = [os.path.join(args.input_dir, f) for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
     for filename in filenames:
-        data = json.load(open(os.path.join(args.input_dir, filename)))
-        # Splitting the data into multiple pieces.
-        if args.n_splits > 1:
-            len_data_pieces = len(data)//args.n_splits
-            fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}_split_{i}-of-{args.n_splits}.json"), 'w', encoding='utf-8') for i in range(args.n_splits)]
-            random.shuffle(data)
-        else:
-            len_data_pieces = len(data)
-            fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')]
-        for i, fout in enumerate(fouts):
-            for dial in tqdm(data[i*len_data_pieces:(i+1)*len_data_pieces]):
+        dataset_name = filename.split('/')[-2]
+        data_split = filename.split('/')[-1].split('_')[-1].split('.')[0]
+        output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}")
+        print(f'processing {dataset_name}: {filename} => {output_file}')
+        cnt = 0
+        with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout:
+            for dial in tqdm(json_lines.reader(fin)):
                 context = []
                 turns_keywords = [turn['keywords'] for turn in dial]
                 for i, turn in enumerate(dial):
-                    if 'wikidialog' in filename:
+                    if dataset_name == 'wikidialog':
                         # skip user turns that generated by T5 in wikidialog 
                         speaker = 'user' if i % 2 == 1 else 'system'
                     else:
@@ -93,18 +28,16 @@ def main(args):
                     utt = turn['utterance']
                     context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
                     context.append({'speaker': speaker, 'utt': utt})
-                    if i == 0 or ('wikidialog' in filename and speaker == 'user'):
+                    if i == 0 or (dataset_name == 'wikidialog' and speaker == 'user'):
                         continue
                     
                     if args.mode == 'rg':
-                        input_seq = f'generate a response: context:\n\n{context_seq}'
+                        input_seq = f'generate a response: all knowledge: | | context:\n\n{context_seq}'
                         fout.write(json.dumps({
-                            'dataset': filename.split('/')[-1].split('_')[0], 
+                            'dataset': dataset_name,
                             'source': input_seq, 
-                            'target': utt}, ensure_ascii=False)+'\n')
-                        continue
-
-                    if len(turn['keywords']) == 0 or max([len(k) for k in turn['keywords']])>10:
+                            'target': utt
+                            }, ensure_ascii=False)+'\n')
                         continue
 
                     if args.mode == 'key2gen':
@@ -113,11 +46,14 @@ def main(args):
                             random.shuffle(turn['keywords'][j])
                         keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
                         input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
-                        fout.write(json.dumps({
-                            'dataset': filename.split('/')[-1].split('_')[0], 
+                        json2dump = {
+                            'dataset': dataset_name,
                             'source': input_seq, 
-                            'target': utt, 
-                            'keywords': turn['keywords']}, ensure_ascii=False)+'\n')
+                            'target': utt
+                            }
+                        if data_split == 'validation':
+                            json2dump.update({'keywords': turn['keywords']})
+                        fout.write(json.dumps(json2dump, ensure_ascii=False)+'\n')
                         continue
 
                     if args.mode == 'key2gen_noisy':
@@ -128,12 +64,14 @@ def main(args):
                         random.shuffle(possible_keywords_sents)
                         possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
                         input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
-                        fout.write(json.dumps({
-                            'dataset': filename.split('/')[-1].split('_')[0], 
+                        json2dump = {
+                            'dataset': dataset_name,
                             'source': input_seq, 
-                            'target': utt, 
-                            'keywords': turn['keywords'], 
-                            'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n')
+                            'target': utt
+                            }
+                        if data_split == 'validation':
+                            json2dump.update({'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents})
+                        fout.write(json.dumps(json2dump, ensure_ascii=False)+'\n')
                         continue
     
 
@@ -142,8 +80,7 @@ if __name__ == '__main__':
     parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
     parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
     parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
-    parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy', 'sen2gen', 'sen2gen_noisy', 'multitask'], help='which task to perform')
-    parser.add_argument('--n_splits', '-n', type=int, default=1, help='split the data into multiple pieces')
+    parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy'], help='which task to perform')
     args = parser.parse_args()
     print(args)
     main(args)
diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
index 7abc21eeb4244de235c9870a8103d5556dc8eec8..ea1d4d51c7cb0dc56c74b47d15b4c1292a08f847 100644
--- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
+++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -1,5 +1,5 @@
 # generate data for response generation, key2gen, key2gen_noisy
-for task_name in rg
+for task_name in rg key2gen key2gen_noisy
 do
     dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
     names=$(echo ${dataset_name} | tr "+" "\n")
@@ -8,8 +8,7 @@ do
     mkdir -p ${data_dir}
     train_file="${data_dir}/train.json"
     validation_file="${data_dir}/validation.json"
-    test_file="${data_dir}/test.json"
-    rm ${train_file} ${validation_file} ${test_file}
+    rm ${train_file} ${validation_file}
     for name in ${names}
     do
         echo "preprocessing ${name}"
@@ -17,60 +16,6 @@ do
         if [ "${name}" != "${dataset_name}" ]; then
             cat "data/${task_name}/${model_type}/${name}/train.json" >> ${train_file}
             cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file}
-            cat "data/${task_name}/${model_type}/${name}/test.json" >> ${test_file}
         fi
     done
-done
-
-
-# # generate data for sentence grounded generation
-# task_name="key2gen"
-# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-# names=$(echo ${dataset_name} | tr "+" "\n")
-# model_type="gpt"
-# data_dir=data/${task_name}/${model_type}/${dataset_name}
-# mkdir -p ${data_dir}
-# n_splits=2
-# for ((n=0;n<${n_splits};n++))
-# do
-#     rm ${data_dir}/train_split_${n}-of-${n_splits}.json ${data_dir}/validation_split_${n}-of-${n_splits}.json ${data_dir}/test_split_${n}-of-${n_splits}.json
-# done
-# for name in ${names}
-# do
-#     echo "preprocessing ${name}"
-#     python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name} -n ${n_splits}
-#     if [ "${name}" != "${dataset_name}" ]; then
-#         for ((n=0;n<${n_splits};n++))
-#         do
-#             cat "data/${task_name}/gpt/${name}/train_split_${n}-of-${n_splits}.json" >> "${data_dir}/train_split_${n}-of-${n_splits}.json"
-#             cat "data/${task_name}/gpt/${name}/validation_split_${n}-of-${n_splits}.json" >> "${data_dir}/validation_split_${n}-of-${n_splits}.json"
-#             cat "data/${task_name}/gpt/${name}/test_split_${n}-of-${n_splits}.json" >> "${data_dir}/test_split_${n}-of-${n_splits}.json"
-#         done
-#     fi
-# done
-
-# # merge generated data with original data
-# task_name="sen2gen"
-# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-# names=$(echo ${dataset_name} | tr "+" "\n")
-# model_type="gpt"
-# data_dir=data/${task_name}/${model_type}/${dataset_name}
-# mkdir -p ${data_dir}
-# python gen_pretraining_data.py -i data/key2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
-
-# # generate sen2gen_noisy data with original data
-# task_name="sen2gen_noisy"
-# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-# names=$(echo ${dataset_name} | tr "+" "\n")
-# model_type="gpt"
-# data_dir=data/${task_name}/${model_type}/${dataset_name}
-# mkdir -p ${data_dir}
-# python gen_pretraining_data.py -i data/sen2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
-
-# merge data for multitask training
-# task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
-# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-# model_type="gpt"
-# data_dir=data/${task_name}/${model_type}/${dataset_name}
-# mkdir -p ${data_dir}
-# python gen_pretraining_data.py -i data/ -o data/${task_name}/${model_type}/${dataset_name} -m multitask
+done
\ No newline at end of file
diff --git a/convlab/base_models/gpt/keyword_extraction/infer_t5_key2gen_half.sh b/convlab/base_models/gpt/keyword_extraction/infer_t5_key2gen_half.sh
deleted file mode 100644
index 536e7588f3e6bfaa0e4bc083962a7aade6a66d12..0000000000000000000000000000000000000000
--- a/convlab/base_models/gpt/keyword_extraction/infer_t5_key2gen_half.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-n_gpus=2
-master_port=$1
-task_name="key2gen"
-dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-model_type="gpt"
-split_id=$2
-n_splits=$3
-data_dir="data/${task_name}/${model_type}/${dataset_name}"
-output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}/gen"
-cache_dir="../cache"
-logging_dir="${output_dir}/runs"
-# train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
-# validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
-let infer_split_id=($split_id+1)%$n_splits
-test_file="${data_dir}/validation_split_${infer_split_id}-of-${n_splits}.json"
-source_column="source"
-target_column="target"
-truncation_side="left"
-max_source_length=512
-max_target_length=128
-model_name_or_path="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
-per_device_train_batch_size=128
-per_device_eval_batch_size=128
-gradient_accumulation_steps=4
-lr=1e-3
-num_train_epochs=10
-
-python -m torch.distributed.launch --master_port ${master_port} \
-    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
-    --task_name ${task_name} \
-    --test_file ${test_file} \
-    --source_column ${source_column} \
-    --target_column ${target_column} \
-    --max_source_length ${max_source_length} \
-    --max_target_length ${max_target_length} \
-    --truncation_side ${truncation_side} \
-    --model_name_or_path ${model_name_or_path} \
-    --do_predict \
-    --predict_with_generate \
-    --do_sample \
-    --top_p 0.9 \
-    --cache_dir ${cache_dir} \
-    --output_dir ${output_dir} \
-    --logging_dir ${logging_dir} \
-    --overwrite_output_dir \
-    --preprocessing_num_workers 16 \
-    --per_device_train_batch_size ${per_device_train_batch_size} \
-    --per_device_eval_batch_size ${per_device_eval_batch_size} \
-    --gradient_accumulation_steps ${gradient_accumulation_steps} \
-    --learning_rate ${lr} \
-    --num_train_epochs ${num_train_epochs} \
-    --adafactor \
-    --gradient_checkpointing
diff --git a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
index bdd0f99e8165c74f51fafda67f44967958a55aae..bb221f6d78b026c61f10846c385b5fa903c64e7f 100644
--- a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -127,7 +127,7 @@ def main(args):
     fin = open(word_loss_file, 'rb')
     fout = open(args.output_file, 'w', encoding='utf-8')
 
-    for item in json_lines.reader(fin):
+    for item in tqdm(json_lines.reader(fin)):
         words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
         losses = [np.mean(loss) for loss in item['losses']]
         dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
diff --git a/convlab/base_models/gpt/keyword_extraction/train_t5_key2gen_half.sh b/convlab/base_models/gpt/keyword_extraction/train_t5_key2gen_half.sh
deleted file mode 100644
index 0c52477f539ab13211f46ad17f1ee3e10e93164f..0000000000000000000000000000000000000000
--- a/convlab/base_models/gpt/keyword_extraction/train_t5_key2gen_half.sh
+++ /dev/null
@@ -1,61 +0,0 @@
-set -e
-n_gpus=2
-master_port=$1
-task_name="key2gen"
-dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-model_type="gpt"
-split_id=$2
-n_splits=$3
-data_dir="data/${task_name}/${model_type}/${dataset_name}"
-output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
-cache_dir="../cache"
-logging_dir="${output_dir}/runs"
-train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
-validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
-test_file="${data_dir}/test_split_${split_id}-of-${n_splits}.json"
-source_column="source"
-target_column="target"
-truncation_side="left"
-max_source_length=512
-max_target_length=128
-model_name_or_path="t5-small"
-per_device_train_batch_size=128
-per_device_eval_batch_size=128
-gradient_accumulation_steps=4
-num_workers=16
-lr=1e-3
-num_train_epochs=1
-
-python -m torch.distributed.launch --master_port ${master_port} \
-    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
-    --task_name ${task_name} \
-    --train_file ${train_file} \
-    --validation_file ${validation_file} \
-    --test_file ${test_file} \
-    --source_column ${source_column} \
-    --target_column ${target_column} \
-    --max_source_length ${max_source_length} \
-    --max_target_length ${max_target_length} \
-    --truncation_side ${truncation_side} \
-    --model_name_or_path ${model_name_or_path} \
-    --do_train \
-    --do_eval \
-    --do_predict \
-    --save_strategy epoch \
-    --evaluation_strategy epoch \
-    --save_total_limit 1 \
-    --load_best_model_at_end \
-    --prediction_loss_only \
-    --cache_dir ${cache_dir} \
-    --output_dir ${output_dir} \
-    --logging_dir ${logging_dir} \
-    --overwrite_output_dir \
-    --preprocessing_num_workers ${num_workers} \
-    --dataloader_num_workers ${num_workers} \
-    --per_device_train_batch_size ${per_device_train_batch_size} \
-    --per_device_eval_batch_size ${per_device_eval_batch_size} \
-    --gradient_accumulation_steps ${gradient_accumulation_steps} \
-    --learning_rate ${lr} \
-    --num_train_epochs ${num_train_epochs} \
-    --adafactor \
-    --gradient_checkpointing
diff --git a/convlab/base_models/gpt/keyword_extraction/train_t5_rg.sh b/convlab/base_models/gpt/keyword_extraction/train_t5_rg.sh
index 6e697e84545c30657f3b32a8a9d44fa94231c01d..4d628f7f2c53d766c4a0b92861ac7681b8c80b02 100644
--- a/convlab/base_models/gpt/keyword_extraction/train_t5_rg.sh
+++ b/convlab/base_models/gpt/keyword_extraction/train_t5_rg.sh
@@ -1,34 +1,31 @@
 set -e
-n_gpus=2
+n_gpus=8
 master_port=23456
 task_name="rg"
-dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
 output_dir="output/${task_name}/${model_type}/${dataset_name}"
 cache_dir="../cache"
 logging_dir="${output_dir}/runs"
 train_file="${data_dir}/train.json"
-validation_file="${data_dir}/validation.json"
-test_file="${data_dir}/test.json"
 source_column="source"
 target_column="target"
 truncation_side="left"
 max_source_length=512
 max_target_length=128
 model_name_or_path="t5-small"
-per_device_train_batch_size=128
+per_device_train_batch_size=64
 per_device_eval_batch_size=128
-gradient_accumulation_steps=4
+gradient_accumulation_steps=1
+num_workers=16
 lr=1e-3
-num_train_epochs=3
+num_train_epochs=1
 
 python -m torch.distributed.launch --master_port ${master_port} \
     --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \
-    --validation_file ${validation_file} \
-    --test_file ${test_file} \
     --source_column ${source_column} \
     --target_column ${target_column} \
     --max_source_length ${max_source_length} \
@@ -36,21 +33,18 @@ python -m torch.distributed.launch --master_port ${master_port} \
     --truncation_side ${truncation_side} \
     --model_name_or_path ${model_name_or_path} \
     --do_train \
-    --do_eval \
-    --do_predict \
-    --save_strategy epoch \
-    --evaluation_strategy epoch \
-    --load_best_model_at_end \
-    --prediction_loss_only \
+    --save_steps 5000 \
+    --save_total_limit 3 \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
     --logging_dir ${logging_dir} \
-    --overwrite_output_dir \
-    --preprocessing_num_workers 4 \
+    --preprocessing_num_workers ${num_workers} \
+    --dataloader_num_workers ${num_workers} \
     --per_device_train_batch_size ${per_device_train_batch_size} \
     --per_device_eval_batch_size ${per_device_eval_batch_size} \
     --gradient_accumulation_steps ${gradient_accumulation_steps} \
     --learning_rate ${lr} \
     --num_train_epochs ${num_train_epochs} \
-    --adafactor \
+    --optim adafactor \
+    --lr_scheduler_type constant \
     --gradient_checkpointing