diff --git a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
index 3c04ea8a29aebecf06dffac16fdf79fef41c763a..6b1068cef045550f57621fe0ab4aad8a4047cfbb 100644
--- a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
+++ b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
@@ -4,46 +4,51 @@ from tabulate import tabulate
 
 def main(predict_result):
     data = {
-        "keywords": {
+        "grounded keywords": {
             "positive_keywords": [], "negative_keywords": None,
             "predictions": [], "references": []
         },
-        "possible keywords": {
+        "all keywords": {
             "positive_keywords": [], "negative_keywords": [],
             "predictions": [], "references": []
+        },
+        "no keywords": {
+            "positive_keywords": None, "negative_keywords": None,
+            "predictions": [], "references": []
         }
     }
     with open(predict_result) as f:
         for line in f:
             item = json.loads(line)
-            if item["keywords+context"].startswith("keywords"):
-                data["keywords"]["predictions"].append(item['predictions'].strip())
-                data["keywords"]["references"].append(item['response'].strip())
-                positive_keywords = [k.strip() for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split('|')[1].split(' : ') if len(k) > 0]
-                data["keywords"]["positive_keywords"].append(positive_keywords)
-            elif item["keywords+context"].startswith("possible keywords"):
-                data["possible keywords"]["predictions"].append(item['predictions'].strip())
-                data["possible keywords"]["references"].append(item['response'].strip())
-                possible_keywords = [k.strip() for ks in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split('|') for k in ks.split(' : ') if len(k) > 0]
-                has_positive = True
+            prediction = item['predictions'].strip()
+            reference = item['target'].strip()
+            if 'all_keywords' in item and item['all_keywords']:
+                sample_type = 'all keywords'
+
+                positive_keywords = [k for g in item['keywords'] for k in g]
+                data[sample_type]["positive_keywords"].append(positive_keywords)
+
+                all_keywords = [k for g in item['all_keywords'] for k in g]
                 for keyword in positive_keywords:
-                    if keyword in possible_keywords:
-                        possible_keywords.remove(keyword)
-                    else:
-                        has_positive = False
-                        break
-                if has_positive:
-                    data["possible keywords"]["positive_keywords"].append(positive_keywords)
-                else:
-                    data["possible keywords"]["positive_keywords"].append([])
-                data["possible keywords"]["negative_keywords"].append(possible_keywords)
-            # print(data)
-            # if len(data["possible keywords"]["positive_keywords"])>0:
-            #     break
+                    all_keywords.remove(keyword)
+                data[sample_type]["negative_keywords"].append(all_keywords)
+
+            elif 'keywords' in item and item['keywords']:
+                sample_type = 'grounded keywords'
+
+                positive_keywords = [k for g in item['keywords'] for k in g]
+                data[sample_type]["positive_keywords"].append(positive_keywords)
+            
+            else:
+                sample_type = 'no keywords'
+
+            data[sample_type]["predictions"].append(prediction)
+            data[sample_type]["references"].append(reference)
+
     metric = datasets.load_metric('./key2gen_metric.py')
-    table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
-    if len(data["possible keywords"]["predictions"]) > 0:
-        table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])})
+    table = []
+    for sample_type in data:
+        table.append({'sample_type': sample_type, **metric.compute(**data[sample_type])})
     print(tabulate(table, headers='keys', tablefmt='github'))
 
 
diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
index 0a1d63457a19c3ff42cd77a04c3d4505ad60a9bf..e7e40112c9ccaff28b98e220ae5e1e24bba4ebf6 100644
--- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -12,7 +12,7 @@ def main(args):
         fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')
         for dial in tqdm(data):
             context = []
-            turn_keywords = [turn['keywords'] for turn in dial]
+            turns_keywords = [turn['keywords'] for turn in dial]
             for i, turn in enumerate(dial):
                 speaker = 'user' if i % 2 == 0 else 'system'
                 utt = turn['utterance']
@@ -27,21 +27,22 @@ def main(args):
                     continue
 
                 random.shuffle(turn['keywords'])
-                keywords = ' : '.join(turn['keywords'])
+                for j in range(len(turn['keywords'])):
+                    random.shuffle(turn['keywords'][j])
+                keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
                 input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
-                fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
+                fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords']}, ensure_ascii=False)+'\n')
                 if args.mode == 'key2gen':
                     continue
 
-                possible_keywords_turns = [turn['keywords']]
-                num_possible_keywords_turns = min(random.randint(1, 5), len(turn_keywords) - 1)
-                possible_keywords_turns += random.sample(turn_keywords[:i] + turn_keywords[i+1:], num_possible_keywords_turns)
-                random.shuffle(possible_keywords_turns)
-                for possible_keywords_turn in possible_keywords_turns:
-                    random.shuffle(possible_keywords_turn)
-                possible_keywords = ' | '.join([' : '.join(possible_keywords_turn) for possible_keywords_turn in possible_keywords_turns])
+                possible_keywords_sents = turn['keywords'][:]
+                num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
+                for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
+                    possible_keywords_sents.extend(turn_keywords)
+                random.shuffle(possible_keywords_sents)
+                possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
                 input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
-                fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
+                fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n')
                 if args.mode == 'key2gen_noisy':
                     continue
     
diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
index ca3017eaac6864a13c0ef055b589b3d177417518..f24058ecfa63c40c9100f03f061d64b58946796f 100644
--- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -1,5 +1,5 @@
 task_name="key2gen_noisy"
-dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3"
 names=$(echo ${dataset_name} | tr "+" "\n")
 model_type="gpt"
 data_dir=data/${task_name}/${model_type}/${name}/${dataset_name}
diff --git a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
index cffa944b0374cff67abe223b4c2ea252ebd889f4..c060c9a03f7799ae59407c86c7c356788b3f529d 100644
--- a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
@@ -4,7 +4,7 @@ model_type="gpt"
 data_dir="data/${task_name}/${dataset_name}/${model_type}"
 model_name_or_path="gpt2-large"
 keywords_num=100
-keywords_ratio=0.3
+keywords_ratio=0.4
 keywords_th_ratio=0
 stopwords=True
 for data_split in validation test train
diff --git a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
index 418f8d78af372b893e82f67273bc46010260cfc2..d9722d96ca71a961dc7ad837191fa202848111f3 100644
--- a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
+++ b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
@@ -69,25 +69,25 @@ class Key2GenMetrics(datasets.Metric):
 
     def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
         """Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
-        # rouge-1/2/L bleu-1/2 distinct-1/2
-        if not negative_keywords:
-            negative_keywords = [[]] * len(positive_keywords)
         bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
         cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0}
-        for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions):
-            cnt['pos'] += len(poskeys)
-            cnt['neg'] += len(negkeys)
+        if positive_keywords:
+            if not negative_keywords:
+                negative_keywords = [[]] * len(positive_keywords)
+            for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions):
+                cnt['pos'] += len(poskeys)
+                cnt['neg'] += len(negkeys)
 
-            prediction = prediction.lower()
-            for key in poskeys:
-                key = key.lower()
-                if key in prediction:
-                    cnt['pos_recall'] += 1
-            
-            for key in negkeys:
-                key = key.lower()
-                if key in prediction:
-                    cnt['neg_recall'] += 1
+                prediction = prediction.lower()
+                for key in poskeys:
+                    key = key.lower()
+                    if key in prediction:
+                        cnt['pos_recall'] += 1
+                
+                for key in negkeys:
+                    key = key.lower()
+                    if key in prediction:
+                        cnt['neg_recall'] += 1
             
         return {
             "bleu": bleu,
diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
index 256c15952b2365c793205ae96783c3dc232ddb72..01581d2d34dee0b1f64b8ee65ab68375e6c7d5db 100644
--- a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -5,7 +5,7 @@ import os
 from tqdm import tqdm
 import numpy as np
 from nltk.corpus import stopwords
-from nltk.tokenize import word_tokenize
+from nltk.tokenize import word_tokenize, PunktSentenceTokenizer
 from transformers import GPT2Tokenizer
 from string import punctuation
 
@@ -34,7 +34,7 @@ def merge_tokens(tokens, losses):
             # token = token.replace("Ġ", "")
             res.append([[token], [loss]])
         elif token == '<|endoftext|>':
-            res.append([[token], [loss]])
+            res.append([[token], [0.]])
         else:
             assert 'Ġ' not in token
             if len(res) > 0:
@@ -80,6 +80,7 @@ def main(args):
 
     stop_words = set(stopwords.words('english'))
     tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
+    sent_tokenizer = PunktSentenceTokenizer()
 
     if args.keywords_th_ratio > 0:
         losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
@@ -88,9 +89,19 @@ def main(args):
     else:
         loss_th = 0
 
-    def keywords_filter(word_loss_pairs):
+    def keywords_filter(words, losses):
+        word_loss_pairs = list(zip(words, losses))
         index2keyword = {}
+        index2turn_sent = {}
+        num_turns = 0
+        turns_sent_spans = [list(sent_tokenizer.span_tokenize(utt)) for utt in ''.join(words).strip().split('<|endoftext|>')]
+        utt = ''
         for i, word_loss_pair in enumerate(word_loss_pairs):
+            if word_loss_pair[0].startswith('<|endoftext|>'):
+                num_turns += 1
+                utt = ''
+                continue
+            utt += word_loss_pair[0]
             words = word_tokenize(word_loss_pair[0])
             if args.stopwords and any([w.lower() in stop_words for w in words]):
                 # skip stopwords
@@ -104,42 +115,54 @@ def main(args):
                 # skip punctuation
                 continue
             index2keyword[i] = strip_punctuation
+            for sent_idx, (sent_start, sent_end) in enumerate(turns_sent_spans[num_turns]):
+                if len(utt.strip()) <= sent_end:
+                    index2turn_sent[i] = (num_turns, sent_idx)
+                    break
         candidate_indexes = list(index2keyword.keys())
-        topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num)
+        topk = min(round(args.keywords_ratio*(len(word_loss_pairs)-num_turns)), args.keywords_num)
         topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk]
         topk_indexes = sorted(topk_indexes)
         keywords = []
+        keywords_turn_sent2idx = {}
         for i, index in enumerate(topk_indexes):
             if i > 0 and index == topk_indexes[i-1] + 1 and \
                 word_loss_pairs[index][0].strip().startswith(index2keyword[index]) and \
                 word_loss_pairs[topk_indexes[i-1]][0].strip().endswith(index2keyword[topk_indexes[i-1]]):
                 keywords[-1]+= ' '+index2keyword[index]
             else:
+                keywords_turn_sent2idx.setdefault(index2turn_sent[index][0], {})
+                keywords_turn_sent2idx[index2turn_sent[index][0]].setdefault(index2turn_sent[index][1], [])
+                keywords_turn_sent2idx[index2turn_sent[index][0]][index2turn_sent[index][1]].append(len(keywords))
                 keywords.append(index2keyword[index])
 
-        return keywords
+        return keywords, keywords_turn_sent2idx
 
     dialogs = []
     for item in tqdm(word_loss_list):
-        words = item['words']
-        losses = item['losses']
+        words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
+        losses = [np.mean(loss) for loss in item['losses']]
+        dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
+        # print(keywords_turn_sent2idx)
         turns = []
         turn = {'words': [], 'losses': []}
-        for word, loss in zip(words, losses):
-            if word == ['<|endoftext|>']:
+        for i, (word, loss) in enumerate(zip(words, losses)):
+            if word != '<|endoftext|>':
+                turn['words'].append(word)
+                turn['losses'].append(loss)
+            if word == '<|endoftext|>' or i == len(words) - 1:
                 # switch turn
-                turn['words'] = [tokenizer.convert_tokens_to_string(tokens) for tokens in turn['words']]
-                turn['losses'] = [np.mean(losses) for losses in turn['losses']]
                 turn['utterance'] = ''.join(turn['words']).strip()
-                keywords = keywords_filter(list(zip(turn['words'], turn['losses'])))
-                turn['keywords'] = keywords
+                # 1) extract keywords according to LM loss within the turn
+                # keywords, _ = keywords_filter(turn['words'], turn['losses'])
+                # turn['turn-level_keywords'] = keywords
+                # 1) extract keywords according to LM loss over the dialog, and group them by sentence
+                turn['keywords'] = [[dialog_keywords[idx] for idx in k_idxes] for sent_idx, k_idxes in keywords_turn_sent2idx.get(len(turns), {}).items()]
                 turn.pop('words')
                 turn.pop('losses')
                 turns.append(turn)
                 turn = {'words': [], 'losses': []}
-            else:
-                turn['words'].append(word)
-                turn['losses'].append(loss)
+                
         dialogs.append(turns)
     json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
 
diff --git a/convlab2/base_models/gpt/keyword_extraction/run.sh b/convlab2/base_models/gpt/keyword_extraction/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f35c2403ce21f9450d3d7a84dc8e7076ee6f5f89
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/run.sh
@@ -0,0 +1,5 @@
+set -e
+for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd multiwoz21
+do
+    bash get_keywords.sh ${dataset_name}
+done
\ No newline at end of file
diff --git a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
index 69d4c9980addceab7d01b4ba13e8069107a9555e..faaef560c20bd1a928f9c99503277780c4e8c26d 100644
--- a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
@@ -1,8 +1,8 @@
 set -e
-n_gpus=4
+n_gpus=2
 master_port=23457
 task_name="key2gen_noisy"
-dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
+dataset_name=$1
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
 output_dir="output/${task_name}/${model_type}/${dataset_name}"
@@ -16,7 +16,7 @@ target_column="target"
 truncation_side="left"
 max_source_length=512
 max_target_length=128
-model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+sgd+tm1+tm2+tm3"
+model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+tm1+tm2+tm3"
 per_device_train_batch_size=128
 per_device_eval_batch_size=128
 gradient_accumulation_steps=2
@@ -39,7 +39,7 @@ python -m torch.distributed.launch --master_port ${master_port} \
     --output_dir ${output_dir} \
     --logging_dir ${logging_dir} \
     --overwrite_output_dir \
-    --preprocessing_num_workers 4 \
+    --preprocessing_num_workers 16 \
     --per_device_train_batch_size ${per_device_train_batch_size} \
     --per_device_eval_batch_size ${per_device_eval_batch_size} \
     --gradient_accumulation_steps ${gradient_accumulation_steps} \
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
index 8d9a4490b6e85039041e2ec05fef22128e99a18e..9a413f2dd91381cd0a05b13fdc79e5be588f8cc2 100644
--- a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
@@ -1,8 +1,8 @@
 set -e
-n_gpus=4
+n_gpus=2
 master_port=23457
 task_name="key2gen_noisy"
-dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
+dataset_name="dailydialog+metalwoz+tm1+tm2+tm3"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
 output_dir="output/${task_name}/${model_type}/${dataset_name}"
@@ -19,7 +19,7 @@ max_target_length=128
 model_name_or_path="t5-small"
 per_device_train_batch_size=128
 per_device_eval_batch_size=128
-gradient_accumulation_steps=2
+gradient_accumulation_steps=4
 lr=1e-3
 num_train_epochs=3
 
@@ -46,7 +46,7 @@ python -m torch.distributed.launch --master_port ${master_port} \
     --output_dir ${output_dir} \
     --logging_dir ${logging_dir} \
     --overwrite_output_dir \
-    --preprocessing_num_workers 4 \
+    --preprocessing_num_workers 16 \
     --per_device_train_batch_size ${per_device_train_batch_size} \
     --per_device_eval_batch_size ${per_device_eval_batch_size} \
     --gradient_accumulation_steps ${gradient_accumulation_steps} \