diff --git a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
index ce8cd9f074ec3a60fb084c16eed6d1b41e53027c..3c04ea8a29aebecf06dffac16fdf79fef41c763a 100644
--- a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
+++ b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
@@ -19,16 +19,27 @@ def main(predict_result):
             if item["keywords+context"].startswith("keywords"):
                 data["keywords"]["predictions"].append(item['predictions'].strip())
                 data["keywords"]["references"].append(item['response'].strip())
-                positive_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split(' | ') if len(k) > 0]
+                positive_keywords = [k.strip() for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split('|')[1].split(' : ') if len(k) > 0]
                 data["keywords"]["positive_keywords"].append(positive_keywords)
             elif item["keywords+context"].startswith("possible keywords"):
                 data["possible keywords"]["predictions"].append(item['predictions'].strip())
                 data["possible keywords"]["references"].append(item['response'].strip())
-                possible_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split(' | ') if len(k) > 0]
+                possible_keywords = [k.strip() for ks in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split('|') for k in ks.split(' : ') if len(k) > 0]
+                has_positive = True
                 for keyword in positive_keywords:
-                    possible_keywords.remove(keyword)
-                data["possible keywords"]["positive_keywords"].append(positive_keywords)
+                    if keyword in possible_keywords:
+                        possible_keywords.remove(keyword)
+                    else:
+                        has_positive = False
+                        break
+                if has_positive:
+                    data["possible keywords"]["positive_keywords"].append(positive_keywords)
+                else:
+                    data["possible keywords"]["positive_keywords"].append([])
                 data["possible keywords"]["negative_keywords"].append(possible_keywords)
+            # print(data)
+            # if len(data["possible keywords"]["positive_keywords"])>0:
+            #     break
     metric = datasets.load_metric('./key2gen_metric.py')
     table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
     if len(data["possible keywords"]["predictions"]) > 0:
diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
index 3a64b766256df0b733450a1f0f52df241ca3c6c6..0a1d63457a19c3ff42cd77a04c3d4505ad60a9bf 100644
--- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -4,36 +4,46 @@ import random
 from tqdm import tqdm
 
 def main(args):
-    random.seed(45)
+    random.seed(42)
     os.makedirs(args.output_dir, exist_ok=True)
     filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
     for filename in filenames:
         data = json.load(open(os.path.join(args.input_dir, filename)))
         fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')
-        turn_keywords = [turn['keywords'] for dial in data for turn in dial]
-        random.shuffle(turn_keywords)
-        cnt = 0
-        # keywords_set = {keyword for keywords in turn_keywords_set for keyword in keywords}
         for dial in tqdm(data):
             context = []
+            turn_keywords = [turn['keywords'] for turn in dial]
             for i, turn in enumerate(dial):
-                speaker = 'user' if i%2 == 0 else 'system'
-                random.shuffle(turn['keywords'])
-                keywords = ' | '.join(turn['keywords'])
+                speaker = 'user' if i % 2 == 0 else 'system'
                 utt = turn['utterance']
                 context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
-                input_seq = f'keywords: {keywords}\n\ncontext: {context_seq}'
-                context.append({'speaker': speaker, 'utt':utt})
-                fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
+                context.append({'speaker': speaker, 'utt': utt})
+                if i == 0:
+                    continue
+                
+                input_seq = f'generate a response: context:\n\n{context_seq}'
+                fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
+                if args.mode == 'rg':
+                    continue
+
+                random.shuffle(turn['keywords'])
+                keywords = ' : '.join(turn['keywords'])
+                input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
+                fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
+                if args.mode == 'key2gen':
+                    continue
 
-                negative_keywords = turn_keywords[cnt]
-                cnt += 1
-                possible_keywords = turn['keywords'] + list(negative_keywords)
-                random.shuffle(possible_keywords)
-                possible_keywords = ' | '.join(possible_keywords)
-                input_seq = f'possible keywords: {possible_keywords}\n\ncontext: {context_seq}'
-                if args.noisy:
-                    fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
+                possible_keywords_turns = [turn['keywords']]
+                num_possible_keywords_turns = min(random.randint(1, 5), len(turn_keywords) - 1)
+                possible_keywords_turns += random.sample(turn_keywords[:i] + turn_keywords[i+1:], num_possible_keywords_turns)
+                random.shuffle(possible_keywords_turns)
+                for possible_keywords_turn in possible_keywords_turns:
+                    random.shuffle(possible_keywords_turn)
+                possible_keywords = ' | '.join([' : '.join(possible_keywords_turn) for possible_keywords_turn in possible_keywords_turns])
+                input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
+                fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
+                if args.mode == 'key2gen_noisy':
+                    continue
     
 
 if __name__ == '__main__':
@@ -41,7 +51,7 @@ if __name__ == '__main__':
     parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
     parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
     parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
-    parser.add_argument('--noisy', action='store_true', help='whether add noisy keywords samples')
+    parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy'], help='which task to perform')
     args = parser.parse_args()
     print(args)
     main(args)
diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
index 00b8223fece4adf5a4b350771945631ad56d0f9e..ca3017eaac6864a13c0ef055b589b3d177417518 100644
--- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -1,4 +1,4 @@
-task_name="key2gen_shuffle_noisy"
+task_name="key2gen_noisy"
 dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
 names=$(echo ${dataset_name} | tr "+" "\n")
 model_type="gpt"
@@ -11,11 +11,11 @@ test_file="${data_dir}/test.json"
 for name in ${names}
 do
     echo "preprocessing ${name}"
-    python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} --noisy
+    python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name}
     if [ "${name}" != "${dataset_name}" ]; then
         cat "data/${task_name}/gpt/${name}/train.json" >> ${train_file}
         cat "data/${task_name}/gpt/${name}/validation.json" >> ${validation_file}
         cat "data/${task_name}/gpt/${name}/test.json" >> ${test_file}
     fi
 done
-python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 --noisy
\ No newline at end of file
+python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 -m ${task_name}
\ No newline at end of file
diff --git a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
index 57eaa502ef11e35a316728646f00b1e671c7d89a..418f8d78af372b893e82f67273bc46010260cfc2 100644
--- a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
+++ b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
@@ -48,14 +48,6 @@ Returns:
     bleu: corpus-bleu score
     positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged
     negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged
-Examples:
-
-    >>> key2gen_metric = datasets.load_metric("key2gen_metric.py")
-    >>> predictions = ["hello there general kenobi", "foo bar foobar"]
-    >>> references = ["hello there kenobi", "foo bar foobar"]
-    >>> results = nlg_metric.compute(predictions=predictions, references=references)
-    >>> print(results)
-    {'bleu': 35.35533905932737}
 """
 
 
@@ -77,6 +69,7 @@ class Key2GenMetrics(datasets.Metric):
 
     def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
         """Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
+        # rouge-1/2/L bleu-1/2 distinct-1/2
         if not negative_keywords:
             negative_keywords = [[]] * len(positive_keywords)
         bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
index 04d743db32ad9ee4f6e88747fa8a02bd989fa35d..256c15952b2365c793205ae96783c3dc232ddb72 100644
--- a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -41,7 +41,7 @@ def merge_tokens(tokens, losses):
                 res[-1][0].append(token)
                 res[-1][1].append(loss)
             else:
-                res.append([token, loss])
+                res.append([[token], [loss]])
         i += 1
     return res
 
diff --git a/convlab2/base_models/gpt/keyword_extraction/merge_keywords_res.py b/convlab2/base_models/gpt/keyword_extraction/merge_keywords_res.py
index a8ebd5ba0b623cae37521bfe93fe046d8cd0c53e..94af288a38845f1cf72470da4af916be5a6f0dda 100644
--- a/convlab2/base_models/gpt/keyword_extraction/merge_keywords_res.py
+++ b/convlab2/base_models/gpt/keyword_extraction/merge_keywords_res.py
@@ -6,9 +6,9 @@ def main(args):
     dialogs = []
     for i in range(len(filename2data[first_filename])):
         turns = []
-        for j in range(len(filename2data[first_filename][i])):
+        for j in range(min([len(filename2data[filename][i]) for filename in filename2data])):
             utt = filename2data[first_filename][i][j]['utterance']
-            keywords = {filename.split('_')[2]+'_nonstopword'+filename.split('_')[-1]: ' | '.join([x[0] for x in filename2data[filename][i][j]['keywords']]) for filename in filename2data}
+            keywords = {filename.split('_')[3]+'_nonstopword'+filename.split('_')[-1]: ' | '.join(filename2data[filename][i][j]['keywords']) for filename in filename2data}
             turns.append({
                 "utterance": utt,
                 **keywords
diff --git a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
index ac204b5d564fe0acb5fb2ac1b49d4d1d6bcad17d..69d4c9980addceab7d01b4ba13e8069107a9555e 100644
--- a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
@@ -1,29 +1,29 @@
 set -e
-n_gpus=2
-task_name="key2gen_shuffle_noisy"
+n_gpus=4
+master_port=23457
+task_name="key2gen_noisy"
 dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
-speaker="all"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
 output_dir="output/${task_name}/${model_type}/${dataset_name}"
-cache_dir="../../t5/cache"
+cache_dir="../cache"
 logging_dir="${output_dir}/runs"
 train_file="${data_dir}/train.json"
 validation_file="${data_dir}/validation.json"
 test_file="${data_dir}/test.json"
-source_column="keywords+context"
-target_column="response"
+source_column="source"
+target_column="target"
 truncation_side="left"
 max_source_length=512
 max_target_length=128
-model_name_or_path="output/${task_name}/${model_type}/${dataset_name}"
+model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+sgd+tm1+tm2+tm3"
 per_device_train_batch_size=128
 per_device_eval_batch_size=128
-gradient_accumulation_steps=4
+gradient_accumulation_steps=2
 lr=1e-3
-num_train_epochs=1
+num_train_epochs=3
 
-python -m torch.distributed.launch \
+python -m torch.distributed.launch --master_port ${master_port} \
     --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
     --task_name ${task_name} \
     --test_file ${test_file} \
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
index 2c795ecf58e331e2acbe8ada66b4cf057ed83037..8d9a4490b6e85039041e2ec05fef22128e99a18e 100644
--- a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
@@ -1,29 +1,29 @@
 set -e
-n_gpus=2
-task_name="key2gen_shuffle_noisy"
+n_gpus=4
+master_port=23457
+task_name="key2gen_noisy"
 dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
-speaker="all"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
 output_dir="output/${task_name}/${model_type}/${dataset_name}"
-cache_dir="../../t5/cache"
+cache_dir="../cache"
 logging_dir="${output_dir}/runs"
 train_file="${data_dir}/train.json"
 validation_file="${data_dir}/validation.json"
 test_file="${data_dir}/test.json"
-source_column="keywords+context"
-target_column="response"
+source_column="source"
+target_column="target"
 truncation_side="left"
 max_source_length=512
 max_target_length=128
 model_name_or_path="t5-small"
 per_device_train_batch_size=128
 per_device_eval_batch_size=128
-gradient_accumulation_steps=4
+gradient_accumulation_steps=2
 lr=1e-3
-num_train_epochs=1
+num_train_epochs=3
 
-python -m torch.distributed.launch \
+python -m torch.distributed.launch --master_port ${master_port} \
     --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_rg.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_rg.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6e697e84545c30657f3b32a8a9d44fa94231c01d
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_rg.sh
@@ -0,0 +1,56 @@
+set -e
+n_gpus=2
+master_port=23456
+task_name="rg"
+dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
+model_type="gpt"
+data_dir="data/${task_name}/${model_type}/${dataset_name}"
+output_dir="output/${task_name}/${model_type}/${dataset_name}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+source_column="source"
+target_column="target"
+truncation_side="left"
+max_source_length=512
+max_target_length=128
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=128
+gradient_accumulation_steps=4
+lr=1e-3
+num_train_epochs=3
+
+python -m torch.distributed.launch --master_port ${master_port} \
+    --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --do_predict \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --load_best_model_at_end \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py
index c702897d5c2d19d164ae00ee058718ff0dc0be96..8ce0f8d7f305b13b10ff0f5b899094fc2a4c96df 100644
--- a/convlab2/base_models/t5/run_seq2seq.py
+++ b/convlab2/base_models/t5/run_seq2seq.py
@@ -445,7 +445,7 @@ def main():
                 inputs.append(examples[source_column][i])
                 targets.append(examples[target_column][i])
 
-        inputs = [prefix + '\n\n' + inp for inp in inputs]
+        inputs = [prefix + inp for inp in inputs]
         if padding:
             model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
         else:
@@ -566,7 +566,8 @@ def main():
         compute_metrics=compute_metrics if training_args.predict_with_generate else None,
     )
     if training_args.load_best_model_at_end:
-        trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
+        if data_args.early_stopping_patience > 0:
+            trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
 
     # Training
     if training_args.do_train: