diff --git a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce8cd9f074ec3a60fb084c16eed6d1b41e53027c
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
@@ -0,0 +1,45 @@
+import json
+import datasets
+from tabulate import tabulate
+
+def main(predict_result):
+    data = {
+        "keywords": {
+            "positive_keywords": [], "negative_keywords": None,
+            "predictions": [], "references": []
+        },
+        "possible keywords": {
+            "positive_keywords": [], "negative_keywords": [],
+            "predictions": [], "references": []
+        }
+    }
+    with open(predict_result) as f:
+        for line in f:
+            item = json.loads(line)
+            if item["keywords+context"].startswith("keywords"):
+                data["keywords"]["predictions"].append(item['predictions'].strip())
+                data["keywords"]["references"].append(item['response'].strip())
+                positive_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split(' | ') if len(k) > 0]
+                data["keywords"]["positive_keywords"].append(positive_keywords)
+            elif item["keywords+context"].startswith("possible keywords"):
+                data["possible keywords"]["predictions"].append(item['predictions'].strip())
+                data["possible keywords"]["references"].append(item['response'].strip())
+                possible_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split(' | ') if len(k) > 0]
+                for keyword in positive_keywords:
+                    possible_keywords.remove(keyword)
+                data["possible keywords"]["positive_keywords"].append(positive_keywords)
+                data["possible keywords"]["negative_keywords"].append(possible_keywords)
+    metric = datasets.load_metric('./key2gen_metric.py')
+    table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
+    if len(data["possible keywords"]["predictions"]) > 0:
+        table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])})
+    print(tabulate(table, headers='keys', tablefmt='github'))
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="evaluate keywords to response generation performance")
+    parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
+    args = parser.parse_args()
+    print(args)
+    main(args.predict_result)
diff --git a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..57eaa502ef11e35a316728646f00b1e671c7d89a
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
@@ -0,0 +1,103 @@
+# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""key2gen Metric"""
+
+import datasets
+import sacrebleu
+
+# TODO: Add BibTeX citation
+_CITATION = """\
+@inproceedings{post-2018-call,
+    title = "A Call for Clarity in Reporting {BLEU} Scores",
+    author = "Post, Matt",
+    booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
+    month = oct,
+    year = "2018",
+    address = "Belgium, Brussels",
+    publisher = "Association for Computational Linguistics",
+    url = "https://www.aclweb.org/anthology/W18-6319",
+    pages = "186--191",
+}
+"""
+
+_DESCRIPTION = """\
+Metric to evaluate text-to-text models on the keywords grounded generation task.
+"""
+
+_KWARGS_DESCRIPTION = """
+Calculates corpus-bleu4, positive keywords recall, negative keywords recall 
+Args:
+    positive_keywords: list of keywords (list of string) in the ground truth references
+    negative_keywords: list of keywords (list of string) in the random sampled references
+    predictions: list of predictions to score. Each predictions
+        should be a string.
+    references: list of reference for each prediction. Each
+        reference should be a string.
+Returns:
+    bleu: corpus-bleu score
+    positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged
+    negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged
+Examples:
+
+    >>> key2gen_metric = datasets.load_metric("key2gen_metric.py")
+    >>> predictions = ["hello there general kenobi", "foo bar foobar"]
+    >>> references = ["hello there kenobi", "foo bar foobar"]
+    >>> results = nlg_metric.compute(predictions=predictions, references=references)
+    >>> print(results)
+    {'bleu': 35.35533905932737}
+"""
+
+
+@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
+class Key2GenMetrics(datasets.Metric):
+    """Metric to evaluate text-to-text models on the keywords grounded generation task."""
+
+    def _info(self):
+        return datasets.MetricInfo(
+            description=_DESCRIPTION,
+            citation=_CITATION,
+            inputs_description=_KWARGS_DESCRIPTION,
+            # This defines the format of each prediction and reference
+            features=datasets.Features({
+                'predictions': datasets.Value('string'),
+                'references': datasets.Value('string'),
+            })
+        )
+
+    def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
+        """Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
+        if not negative_keywords:
+            negative_keywords = [[]] * len(positive_keywords)
+        bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
+        cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0}
+        for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions):
+            cnt['pos'] += len(poskeys)
+            cnt['neg'] += len(negkeys)
+
+            prediction = prediction.lower()
+            for key in poskeys:
+                key = key.lower()
+                if key in prediction:
+                    cnt['pos_recall'] += 1
+            
+            for key in negkeys:
+                key = key.lower()
+                if key in prediction:
+                    cnt['neg_recall'] += 1
+            
+        return {
+            "bleu": bleu,
+            "positive_keywords_recall": cnt['pos_recall']/cnt['pos'] if cnt['pos'] > 0 else 0,
+            "negative_keywords_recall": cnt['neg_recall']/cnt['neg'] if cnt['neg'] > 0 else 0,
+        }
diff --git a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
index 9a274f8fe344efe3125920903bc688e8aeb7c38e..469ec695ba681a835c7d9c51e95803c674c87d11 100644
--- a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
@@ -1,7 +1,7 @@
 set -e
-n_gpus=1
-task_name="key2gen"
-dataset_name="multiwoz21"
+n_gpus=2
+task_name="key2gen_shuffle_noisy"
+dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
 speaker="all"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
@@ -16,7 +16,7 @@ target_column="response"
 truncation_side="left"
 max_source_length=512
 max_target_length=128
-model_name_or_path="output/key2gen/gpt/metalwoz+sgd+tm1+tm2+tm3"
+model_name_or_path="output/${task_name}/${model_type}/${dataset_name}"
 per_device_train_batch_size=128
 per_device_eval_batch_size=128
 gradient_accumulation_steps=4
@@ -40,4 +40,11 @@ python -m torch.distributed.launch \
     --logging_dir ${logging_dir} \
     --overwrite_output_dir \
     --preprocessing_num_workers 4 \
-    --per_device_eval_batch_size ${per_device_eval_batch_size}
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --debug underflow_overflow \
+    --adafactor \
+    --gradient_checkpointing
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
index aa648116737f10909d1d83a0e9ec1ac0a5a682d2..d92365e787e2c58fdd8b6f4a4f870053c7561f2e 100644
--- a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
@@ -23,7 +23,7 @@ gradient_accumulation_steps=4
 lr=1e-3
 num_train_epochs=1
 
-python -m torch.distributed.launch --master_port 23456\
+python -m torch.distributed.launch \
     --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \