diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
index e7e40112c9ccaff28b98e220ae5e1e24bba4ebf6..d6c2ec7a4e6f97973a7a0da4811e72ab36734801 100644
--- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
+++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -2,49 +2,139 @@ import json
 import os
 import random
 from tqdm import tqdm
+from nltk import sent_tokenize
 
 def main(args):
     random.seed(42)
     os.makedirs(args.output_dir, exist_ok=True)
+    if args.mode == 'multitask':
+        dataset_name = args.output_dir.split('/')[-1]
+        for data_split in ['validation', 'train']:
+            with open(os.path.join(args.output_dir, f"{data_split}.json"), 'w', encoding='utf-8') as fout:
+                for task_name in ['rg', 'key2gen', 'key2gen_noisy']:
+                    with open(os.path.join(args.input_dir, task_name, 'gpt', dataset_name, f"{data_split}.json")) as fin:
+                        for line in fin:
+                            item = json.loads(line)
+                            fout.write(json.dumps({'source': item['source'], 'target': item['target']}, ensure_ascii=False)+'\n')
+        return
+
+    if args.mode == 'sen2gen':
+        generated_filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if f.startswith('gen_')]
+        original_filenames = [f[4:] for f in generated_filenames]
+        for ori_f, gen_f in zip(original_filenames, generated_filenames):
+            fori = open(os.path.join(args.input_dir, ori_f))
+            fgen = open(os.path.join(args.input_dir, gen_f))
+            fout = open(os.path.join(args.output_dir, f"{ori_f.split('_')[0]}.json"), 'w', encoding='utf-8')
+            for ori_line, gen_line in zip(fori, fgen):
+                ori_item = json.loads(ori_line)
+                gen_item = json.loads(gen_line)
+                context = ori_item['source'][ori_item['source'].index('context:\n\n'):]
+                gen_sen = gen_item['predictions']
+                ori_item['source'] = f'generate a response: grounded knowledge: | {gen_sen} | {context}'
+                ori_item['gen_sen'] = gen_sen
+                fout.write(json.dumps(ori_item, ensure_ascii=False)+'\n')
+        return
+    if args.mode == 'sen2gen_noisy':
+        def gen_samples(dialog_samples):
+            turn_gen_sens = [sent_tokenize(item['gen_sen']) for item in dialog_samples]
+            for i, sample in enumerate(dialog_samples):
+                possible_sens_turns = turn_gen_sens[i][:]
+                num_possible_sens_turns = min(random.randint(1, 5), len(turn_gen_sens) - 1)
+                for turn_sens in random.sample(turn_gen_sens[:i] + turn_gen_sens[i+1:], num_possible_sens_turns):
+                    possible_sens_turns.extend(turn_sens)
+                random.shuffle(possible_sens_turns)
+                possible_sens = ' | '.join(possible_sens_turns)
+                context = sample['source'][sample['source'].index('context:\n\n'):]
+                sample['source'] = f'generate a response: all knowledge: | {possible_sens} | {context}'
+                yield sample
+
+        for ori_f in [f for (_, _, fs) in os.walk(args.input_dir) for f in fs]:
+            fori = open(os.path.join(args.input_dir, ori_f))
+            fout = open(os.path.join(args.output_dir, ori_f), 'w', encoding='utf-8')
+            dialog = []
+            prev_num_turns = 0
+            for line in fori:
+                item = json.loads(line)
+                num_turns = item['source'].count('\n')
+                if len(dialog) == 0 or num_turns < prev_num_turns:
+                    # process a dialog with augmented responses
+                    for sample in gen_samples(dialog):
+                        fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
+                    # next dialog
+                    dialog = [item]
+                else:
+                    # next turn
+                    dialog.append(item)
+                prev_num_turns = num_turns
+            for sample in gen_samples(dialog):
+                fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
+        return
     filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
     for filename in filenames:
         data = json.load(open(os.path.join(args.input_dir, filename)))
-        fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')
-        for dial in tqdm(data):
-            context = []
-            turns_keywords = [turn['keywords'] for turn in dial]
-            for i, turn in enumerate(dial):
-                speaker = 'user' if i % 2 == 0 else 'system'
-                utt = turn['utterance']
-                context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
-                context.append({'speaker': speaker, 'utt': utt})
-                if i == 0:
-                    continue
-                
-                input_seq = f'generate a response: context:\n\n{context_seq}'
-                fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
-                if args.mode == 'rg':
-                    continue
+        # Splitting the data into multiple pieces.
+        if args.n_splits > 1:
+            len_data_pieces = len(data)//args.n_splits
+            fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}_split_{i}-of-{args.n_splits}.json"), 'w', encoding='utf-8') for i in range(args.n_splits)]
+            random.shuffle(data)
+        else:
+            len_data_pieces = len(data)
+            fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')]
+        for i, fout in enumerate(fouts):
+            for dial in tqdm(data[i*len_data_pieces:(i+1)*len_data_pieces]):
+                context = []
+                turns_keywords = [turn['keywords'] for turn in dial]
+                for i, turn in enumerate(dial):
+                    if 'wikidialog' in filename:
+                        # skip user turns that generated by T5 in wikidialog 
+                        speaker = 'user' if i % 2 == 1 else 'system'
+                    else:
+                        speaker = 'user' if i % 2 == 0 else 'system'
+                    utt = turn['utterance']
+                    context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
+                    context.append({'speaker': speaker, 'utt': utt})
+                    if i == 0 or ('wikidialog' in filename and speaker == 'user'):
+                        continue
+                    
+                    if args.mode == 'rg':
+                        input_seq = f'generate a response: context:\n\n{context_seq}'
+                        fout.write(json.dumps({
+                            'dataset': filename.split('/')[-1].split('_')[0], 
+                            'source': input_seq, 
+                            'target': utt}, ensure_ascii=False)+'\n')
+                        continue
+
+                    if len(turn['keywords']) == 0 or max([len(k) for k in turn['keywords']])>10:
+                        continue
 
-                random.shuffle(turn['keywords'])
-                for j in range(len(turn['keywords'])):
-                    random.shuffle(turn['keywords'][j])
-                keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
-                input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
-                fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords']}, ensure_ascii=False)+'\n')
-                if args.mode == 'key2gen':
-                    continue
+                    if args.mode == 'key2gen':
+                        random.shuffle(turn['keywords'])
+                        for j in range(len(turn['keywords'])):
+                            random.shuffle(turn['keywords'][j])
+                        keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
+                        input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
+                        fout.write(json.dumps({
+                            'dataset': filename.split('/')[-1].split('_')[0], 
+                            'source': input_seq, 
+                            'target': utt, 
+                            'keywords': turn['keywords']}, ensure_ascii=False)+'\n')
+                        continue
 
-                possible_keywords_sents = turn['keywords'][:]
-                num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
-                for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
-                    possible_keywords_sents.extend(turn_keywords)
-                random.shuffle(possible_keywords_sents)
-                possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
-                input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
-                fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n')
-                if args.mode == 'key2gen_noisy':
-                    continue
+                    if args.mode == 'key2gen_noisy':
+                        possible_keywords_sents = turn['keywords'][:]
+                        num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
+                        for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
+                            possible_keywords_sents.extend(turn_keywords)
+                        random.shuffle(possible_keywords_sents)
+                        possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
+                        input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
+                        fout.write(json.dumps({
+                            'dataset': filename.split('/')[-1].split('_')[0], 
+                            'source': input_seq, 
+                            'target': utt, 
+                            'keywords': turn['keywords'], 
+                            'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n')
+                        continue
     
 
 if __name__ == '__main__':
@@ -52,7 +142,8 @@ if __name__ == '__main__':
     parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
     parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
     parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
-    parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy'], help='which task to perform')
+    parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy', 'sen2gen', 'sen2gen_noisy', 'multitask'], help='which task to perform')
+    parser.add_argument('--n_splits', '-n', type=int, default=1, help='split the data into multiple pieces')
     args = parser.parse_args()
     print(args)
     main(args)
diff --git a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
index c98060c6044359fabbb4ee9295f2e3e70df86eec..7abc21eeb4244de235c9870a8103d5556dc8eec8 100644
--- a/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
+++ b/convlab/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -1,21 +1,76 @@
-task_name="key2gen_noisy"
-dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
-names=$(echo ${dataset_name} | tr "+" "\n")
-model_type="gpt"
-data_dir=data/${task_name}/${model_type}/${name}/${dataset_name}
-rm -r ${data_dir}
-mkdir -p ${data_dir}
-train_file="${data_dir}/train.json"
-validation_file="${data_dir}/validation.json"
-test_file="${data_dir}/test.json"
-for name in ${names}
+# generate data for response generation, key2gen, key2gen_noisy
+for task_name in rg
 do
-    echo "preprocessing ${name}"
-    python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name}
-    if [ "${name}" != "${dataset_name}" ]; then
-        cat "data/${task_name}/gpt/${name}/train.json" >> ${train_file}
-        cat "data/${task_name}/gpt/${name}/validation.json" >> ${validation_file}
-        cat "data/${task_name}/gpt/${name}/test.json" >> ${test_file}
-    fi
+    dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
+    names=$(echo ${dataset_name} | tr "+" "\n")
+    model_type="gpt"
+    data_dir=data/${task_name}/${model_type}/${dataset_name}
+    mkdir -p ${data_dir}
+    train_file="${data_dir}/train.json"
+    validation_file="${data_dir}/validation.json"
+    test_file="${data_dir}/test.json"
+    rm ${train_file} ${validation_file} ${test_file}
+    for name in ${names}
+    do
+        echo "preprocessing ${name}"
+        python gen_pretraining_data.py -i data/lm/${model_type}/${name} -o data/${task_name}/${model_type}/${name} -m ${task_name}
+        if [ "${name}" != "${dataset_name}" ]; then
+            cat "data/${task_name}/${model_type}/${name}/train.json" >> ${train_file}
+            cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file}
+            cat "data/${task_name}/${model_type}/${name}/test.json" >> ${test_file}
+        fi
+    done
 done
-python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 -m ${task_name}
\ No newline at end of file
+
+
+# # generate data for sentence grounded generation
+# task_name="key2gen"
+# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+# names=$(echo ${dataset_name} | tr "+" "\n")
+# model_type="gpt"
+# data_dir=data/${task_name}/${model_type}/${dataset_name}
+# mkdir -p ${data_dir}
+# n_splits=2
+# for ((n=0;n<${n_splits};n++))
+# do
+#     rm ${data_dir}/train_split_${n}-of-${n_splits}.json ${data_dir}/validation_split_${n}-of-${n_splits}.json ${data_dir}/test_split_${n}-of-${n_splits}.json
+# done
+# for name in ${names}
+# do
+#     echo "preprocessing ${name}"
+#     python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name} -n ${n_splits}
+#     if [ "${name}" != "${dataset_name}" ]; then
+#         for ((n=0;n<${n_splits};n++))
+#         do
+#             cat "data/${task_name}/gpt/${name}/train_split_${n}-of-${n_splits}.json" >> "${data_dir}/train_split_${n}-of-${n_splits}.json"
+#             cat "data/${task_name}/gpt/${name}/validation_split_${n}-of-${n_splits}.json" >> "${data_dir}/validation_split_${n}-of-${n_splits}.json"
+#             cat "data/${task_name}/gpt/${name}/test_split_${n}-of-${n_splits}.json" >> "${data_dir}/test_split_${n}-of-${n_splits}.json"
+#         done
+#     fi
+# done
+
+# # merge generated data with original data
+# task_name="sen2gen"
+# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+# names=$(echo ${dataset_name} | tr "+" "\n")
+# model_type="gpt"
+# data_dir=data/${task_name}/${model_type}/${dataset_name}
+# mkdir -p ${data_dir}
+# python gen_pretraining_data.py -i data/key2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
+
+# # generate sen2gen_noisy data with original data
+# task_name="sen2gen_noisy"
+# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+# names=$(echo ${dataset_name} | tr "+" "\n")
+# model_type="gpt"
+# data_dir=data/${task_name}/${model_type}/${dataset_name}
+# mkdir -p ${data_dir}
+# python gen_pretraining_data.py -i data/sen2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
+
+# merge data for multitask training
+# task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
+# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+# model_type="gpt"
+# data_dir=data/${task_name}/${model_type}/${dataset_name}
+# mkdir -p ${data_dir}
+# python gen_pretraining_data.py -i data/ -o data/${task_name}/${model_type}/${dataset_name} -m multitask
diff --git a/convlab/base_models/gpt/keyword_extraction/infer_t5_key2gen_half.sh b/convlab/base_models/gpt/keyword_extraction/infer_t5_key2gen_half.sh
new file mode 100644
index 0000000000000000000000000000000000000000..536e7588f3e6bfaa0e4bc083962a7aade6a66d12
--- /dev/null
+++ b/convlab/base_models/gpt/keyword_extraction/infer_t5_key2gen_half.sh
@@ -0,0 +1,53 @@
+n_gpus=2
+master_port=$1
+task_name="key2gen"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+model_type="gpt"
+split_id=$2
+n_splits=$3
+data_dir="data/${task_name}/${model_type}/${dataset_name}"
+output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}/gen"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+# train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
+# validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
+let infer_split_id=($split_id+1)%$n_splits
+test_file="${data_dir}/validation_split_${infer_split_id}-of-${n_splits}.json"
+source_column="source"
+target_column="target"
+truncation_side="left"
+max_source_length=512
+max_target_length=128
+model_name_or_path="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
+per_device_train_batch_size=128
+per_device_eval_batch_size=128
+gradient_accumulation_steps=4
+lr=1e-3
+num_train_epochs=10
+
+python -m torch.distributed.launch --master_port ${master_port} \
+    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_predict \
+    --predict_with_generate \
+    --do_sample \
+    --top_p 0.9 \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 16 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
diff --git a/convlab/base_models/gpt/keyword_extraction/train_t5_key2gen_half.sh b/convlab/base_models/gpt/keyword_extraction/train_t5_key2gen_half.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0c52477f539ab13211f46ad17f1ee3e10e93164f
--- /dev/null
+++ b/convlab/base_models/gpt/keyword_extraction/train_t5_key2gen_half.sh
@@ -0,0 +1,61 @@
+set -e
+n_gpus=2
+master_port=$1
+task_name="key2gen"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+model_type="gpt"
+split_id=$2
+n_splits=$3
+data_dir="data/${task_name}/${model_type}/${dataset_name}"
+output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
+validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
+test_file="${data_dir}/test_split_${split_id}-of-${n_splits}.json"
+source_column="source"
+target_column="target"
+truncation_side="left"
+max_source_length=512
+max_target_length=128
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=128
+gradient_accumulation_steps=4
+num_workers=16
+lr=1e-3
+num_train_epochs=1
+
+python -m torch.distributed.launch --master_port ${master_port} \
+    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --do_predict \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --load_best_model_at_end \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers ${num_workers} \
+    --dataloader_num_workers ${num_workers} \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
diff --git a/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh b/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e63d2e894c7fc0cabea6d994df41245547943327
--- /dev/null
+++ b/convlab/base_models/gpt/keyword_extraction/train_t5_multitask.sh
@@ -0,0 +1,55 @@
+set -e
+n_gpus=4
+master_port=23457
+task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
+model_type="gpt"
+data_dir="data/${task_name}/${model_type}/${dataset_name}"
+output_dir="output/${task_name}/${model_type}/${dataset_name}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+source_column="source"
+target_column="target"
+truncation_side="left"
+max_source_length=512
+max_target_length=128
+model_name_or_path="t5-small"
+per_device_train_batch_size=64
+per_device_eval_batch_size=128
+gradient_accumulation_steps=2
+num_workers=16
+lr=1e-3
+num_train_epochs=1
+
+python -m torch.distributed.launch --master_port ${master_port} \
+    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --load_best_model_at_end \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers ${num_workers} \
+    --dataloader_num_workers ${num_workers} \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --optim adafactor \
+    --gradient_checkpointing