diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8b5af5c41d5d40db1ec16ac63427920c3cd8b5e
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -0,0 +1,29 @@
+import json
+import os
+
+def main(args):
+    os.makedirs(args.output_dir, exist_ok=True)
+    filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
+    for filename in filenames:
+        data = json.load(open(os.path.join(args.input_dir, filename)))
+        fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')
+        for dial in data:
+            context = []
+            for i, turn in enumerate(dial):
+                speaker = 'user' if i%2 == 0 else 'system'
+                keywords = ', '.join(turn['keywords'])
+                utt = turn['utterance']
+                input_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
+                input_seq = f'{keywords}\n{input_seq}'
+                context.append({'speaker': speaker, 'utt':utt})
+                fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
+    
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
+    parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
+    parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
+    args = parser.parse_args()
+    print(args)
+    main(args)
diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7fd24bdfee565e73b7e326f71514d565573e5e7a
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -0,0 +1,18 @@
+dataset_name="sgd+metalwoz+tm1+tm2+tm3"
+names=$(echo ${dataset_name} | tr "+" "\n")
+model_type="gpt"
+data_dir=data/key2gen/${model_type}/${name}/${dataset_name}
+mkdir -p ${data_dir}
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+for name in ${names}
+do
+    echo "preprocessing ${name}"
+    python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/keygen/${model_type}/${name}
+    if [ "${name}" != "${dataset_name}" ]; then
+        cat "data/keygen/gpt/${name}/train.json" >> ${train_file}
+        cat "data/keygen/gpt/${name}/validation.json" >> ${validation_file}
+        cat "data/keygen/gpt/${name}/test.json" >> ${test_file}
+    fi
+done
diff --git a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
index 6dd2680bc3c4390cf2d85cff46d7000c5293ef70..cffa944b0374cff67abe223b4c2ea252ebd889f4 100644
--- a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
@@ -1,20 +1,23 @@
-model_type=dialogpt
-dataset_name=multiwoz21
-model_name=dialogpt-large
-data_dir="data/lm/${dataset_name}/${model_type}"
-word_loss_file="${data_dir}/${model_name}_${dataset_name}_word_loss.json"
-keywords_num=5
-keywords_ratio=1
-keywords_th=0
+task_name="lm"
+dataset_name=$1
+model_type="gpt"
+data_dir="data/${task_name}/${dataset_name}/${model_type}"
+model_name_or_path="gpt2-large"
+keywords_num=100
+keywords_ratio=0.3
+keywords_th_ratio=0
 stopwords=True
-output_file="${data_dir}/${dataset_name}_keywords_${model_name}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th}_stopwords_${stopwords}.json"
+for data_split in validation test train
+do
+    word_loss_file="${data_dir}/${model_name_or_path}_${dataset_name}_${data_split}_word_loss.json"
+    output_file="${data_dir}/${dataset_name}_${data_split}_keywords_${model_name_or_path}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th_ratio}_stopwords_${stopwords}.json"
 
-python lmloss2keywords.py \
-    --model_type ${model_type} \
-    --word_loss_file ${word_loss_file} \
-    --keywords_num ${keywords_num} \
-    --keywords_ratio ${keywords_ratio} \
-    --keywords_th ${keywords_th} \
-    --stopwords ${stopwords} \
-    --output_file ${output_file}
-    
\ No newline at end of file
+    python lmloss2keywords.py \
+        --model_type ${model_type} \
+        --word_loss_file ${word_loss_file} \
+        --keywords_num ${keywords_num} \
+        --keywords_ratio ${keywords_ratio} \
+        --keywords_th_ratio ${keywords_th_ratio} \
+        --stopwords ${stopwords} \
+        --output_file ${output_file}
+done
\ No newline at end of file
diff --git a/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh b/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh
index 2aad467cf181c08532505a1523af746e52aacb4a..e0b8c1499ade1faa90ea26cde1aa988b06ed84d6 100644
--- a/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh
@@ -1,65 +1,33 @@
 set -e
 n_gpus=1
 task_name="lm"
-dataset_name="multiwoz21"
-model_type="dialogpt"
+dataset_name=$1
+model_type="gpt"
 data_dir="data/${task_name}/${dataset_name}/${model_type}"
 output_dir="output/${task_name}/${dataset_name}/${model_type}"
 cache_dir="../cache"
 validation_file="${data_dir}/validation.json"
 source_column="dialogue"
 max_length=512
-model_name_or_path="microsoft/DialoGPT-large"
-per_device_eval_batch_size=4
-
-dump_eval_loss_to="${data_dir}/dialogpt-large_${dataset_name}_token_loss.json"
-python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
-python ../run_clm.py \
-    --dump_eval_loss_to ${dump_eval_loss_to}\
-    --model_name_or_path ${model_name_or_path} \
-    --output_dir ${data_dir} \
-    --validation_file ${validation_file} \
-    --source_column ${source_column} \
-    --max_length ${max_length} \
-    --do_eval \
-    --prediction_loss_only \
-    --cache_dir ${cache_dir} \
-    --preprocessing_num_workers 4 \
-    --per_device_eval_batch_size ${per_device_eval_batch_size}
-python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
-
-dump_eval_loss_to="${data_dir}/dialogpt-large-mwoz_${dataset_name}_token_loss.json"
-python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
-python ../run_clm.py \
-    --dump_eval_loss_to ${dump_eval_loss_to}\
-    --model_name_or_path ${output_dir} \
-    --output_dir ${data_dir} \
-    --validation_file ${validation_file} \
-    --source_column ${source_column} \
-    --max_length ${max_length} \
-    --do_eval \
-    --prediction_loss_only \
-    --cache_dir ${cache_dir} \
-    --preprocessing_num_workers 4 \
-    --per_device_eval_batch_size ${per_device_eval_batch_size}
-python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
-
-model_type="gpt"
-data_dir="data/${task_name}/${dataset_name}/${model_type}"
-validation_file="${data_dir}/validation.json"
 model_name_or_path="gpt2-large"
-dump_eval_loss_to="${data_dir}/gpt2-large_${dataset_name}_token_loss.json"
-python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type gpt
-python ../run_clm.py \
-    --dump_eval_loss_to ${dump_eval_loss_to}\
-    --model_name_or_path ${model_name_or_path} \
-    --output_dir ${data_dir} \
-    --validation_file ${validation_file} \
-    --source_column ${source_column} \
-    --max_length ${max_length} \
-    --do_eval \
-    --prediction_loss_only \
-    --cache_dir ${cache_dir} \
-    --preprocessing_num_workers 4 \
-    --per_device_eval_batch_size ${per_device_eval_batch_size}
-python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
+per_device_eval_batch_size=16
+
+python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type ${model_type}
+for data_split in validation test train
+do
+    validation_file="${data_dir}/${data_split}.json"
+    dump_eval_loss_to="${data_dir}/${model_name_or_path}_${dataset_name}_${data_split}_token_loss.json"
+    python ../run_clm.py \
+        --dump_eval_loss_to ${dump_eval_loss_to}\
+        --model_name_or_path ${model_name_or_path} \
+        --output_dir ${data_dir} \
+        --validation_file ${validation_file} \
+        --source_column ${source_column} \
+        --max_length ${max_length} \
+        --do_eval \
+        --prediction_loss_only \
+        --cache_dir ${cache_dir} \
+        --preprocessing_num_workers 4 \
+        --per_device_eval_batch_size ${per_device_eval_batch_size}
+    python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
+done
diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
index 307d57edf4d09c8a72968f35051d451afe21bc64..b0e14c86f58baaca8af6e246cef4c58eddde6447 100644
--- a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -80,8 +80,36 @@ def main(args):
 
     stop_words = set(stopwords.words('english'))
 
+    if args.keywords_th_ratio > 0:
+        losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
+        loss_th = sorted(losses, reverse=True)[round(args.keywords_th_ratio*len(losses))]
+        print(f'loss th for top {args.keywords_th_ratio*100}%: {loss_th}')
+    else:
+        loss_th = 0
+
+    def keywords_filter(word_loss_pairs):
+        candidate_indexes = []
+        for i, word_loss_pair in enumerate(word_loss_pairs):
+            if args.stopwords and any([w.lower() in stop_words for w in word_tokenize(word_loss_pair[0])]):
+                continue
+            if word_loss_pair[1] <= loss_th:
+                continue
+            candidate_indexes.append(i)
+
+        topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num)
+        topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk]
+        topk_indexes = sorted(topk_indexes)
+        keywords = []
+        for i, index in enumerate(topk_indexes):
+            if i > 0 and index == topk_indexes[i-1] + 1:
+                keywords[-1]+= ' '+word_loss_pairs[index][0]
+            else:
+                keywords.append(word_loss_pairs[index][0])
+
+        return keywords
+
     dialogs = []
-    for item in word_loss_list:
+    for item in tqdm(word_loss_list):
         words = item['words']
         losses = item['losses']
         turns = []
@@ -90,11 +118,9 @@ def main(args):
             if word == '<|endoftext|>':
                 # switch turn
                 turn['utterance'] = ' '.join(turn['words'])
-                turn['keywords'] = list(zip(turn['words'], turn['losses']))
-                if args.stopwords:
-                    turn['keywords'] = [x for x in turn['keywords'] if not any([w.lower() in stop_words for w in word_tokenize(x[0])])]
-                turn['keywords'] = sorted(turn['keywords'], key=lambda x: x[1], reverse=True)
-                turn['keywords'] = [x for x in turn['keywords'] if x[1] > args.keywords_th][:min(round(args.keywords_ratio*len(turn['keywords'])), args.keywords_num)]
+                keywords = keywords_filter(list(zip(turn['words'], turn['losses'])))
+                turn['keywords'] = keywords
+                # turn['keywords'] = ' | '.join([x[0] for x in keywords])
                 turn.pop('words')
                 turn.pop('losses')
                 turns.append(turn)
@@ -116,7 +142,7 @@ if __name__ == '__main__':
     parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
     parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords')
     parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords')
-    parser.add_argument('--keywords_th', '-th', type=float, default=0., help='loss threshold for the keywords')
+    parser.add_argument('--keywords_th_ratio', '-th', type=float, default=0., help='loss threshold for the keywords, ratio of all word losses')
     parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords')
     args = parser.parse_args()
     print(args)
diff --git a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9a274f8fe344efe3125920903bc688e8aeb7c38e
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
@@ -0,0 +1,43 @@
+set -e
+n_gpus=1
+task_name="key2gen"
+dataset_name="multiwoz21"
+speaker="all"
+model_type="gpt"
+data_dir="data/${task_name}/${model_type}/${dataset_name}"
+output_dir="output/${task_name}/${model_type}/${dataset_name}"
+cache_dir="../../t5/cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+source_column="keywords+context"
+target_column="response"
+truncation_side="left"
+max_source_length=512
+max_target_length=128
+model_name_or_path="output/key2gen/gpt/metalwoz+sgd+tm1+tm2+tm3"
+per_device_train_batch_size=128
+per_device_eval_batch_size=128
+gradient_accumulation_steps=4
+lr=1e-3
+num_train_epochs=1
+
+python -m torch.distributed.launch \
+    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_predict \
+    --predict_with_generate \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_eval_batch_size ${per_device_eval_batch_size}
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_lm.sh b/convlab2/base_models/gpt/keyword_extraction/train_lm_dialogpt.sh
similarity index 97%
rename from convlab2/base_models/gpt/keyword_extraction/train_lm.sh
rename to convlab2/base_models/gpt/keyword_extraction/train_lm_dialogpt.sh
index 4ae47c3296e5ca7150cbbffcb7a7d247973613de..303ecb3e0c660a13e190b193c5b1769fbe70812d 100644
--- a/convlab2/base_models/gpt/keyword_extraction/train_lm.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/train_lm_dialogpt.sh
@@ -19,7 +19,7 @@ gradient_accumulation_steps=4
 lr=5e-5
 num_train_epochs=3
 
-python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
+python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type ${model_type}
 
 python ../run_clm.py \
     --model_name_or_path ${model_name_or_path} \
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_lm_gpt.sh b/convlab2/base_models/gpt/keyword_extraction/train_lm_gpt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fb510c880b25505e83780eeab76760e30dbccf9d
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/train_lm_gpt.sh
@@ -0,0 +1,47 @@
+set -e
+n_gpus=1
+task_name="lm"
+dataset_name="multiwoz21"
+model_type="gpt"
+data_dir="data/${task_name}/${dataset_name}/${model_type}"
+output_dir="output/${task_name}/${dataset_name}/${model_type}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+source_column="dialogue"
+max_length=512
+model_name_or_path="gpt2-large"
+per_device_train_batch_size=16
+per_device_eval_batch_size=16
+gradient_accumulation_steps=4
+lr=5e-5
+num_train_epochs=3
+
+python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type ${model_type}
+
+python ../run_clm.py \
+    --model_name_or_path ${model_name_or_path} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --max_length ${max_length} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --load_best_model_at_end \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --debug underflow_overflow \
+    --gradient_checkpointing
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8bd5de0590addaed4c6f32feb7f7d7bd18acd23f
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
@@ -0,0 +1,57 @@
+set -e
+n_gpus=2
+task_name="key2gen"
+dataset_name="metalwoz+sgd+tm1+tm2+tm3"
+speaker="all"
+model_type="gpt"
+data_dir="data/${task_name}/${model_type}/${dataset_name}"
+output_dir="output/${task_name}/${model_type}/${dataset_name}"
+cache_dir="../../t5/cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+source_column="keywords+context"
+target_column="response"
+truncation_side="left"
+max_source_length=512
+max_target_length=128
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=128
+gradient_accumulation_steps=4
+lr=1e-3
+num_train_epochs=1
+
+python -m torch.distributed.launch \
+    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --do_predict \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --load_best_model_at_end \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --debug underflow_overflow \
+    --adafactor \
+    --gradient_checkpointing