diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
index 0f9c841257387a293866b6d0727900d626c8047f..3a64b766256df0b733450a1f0f52df241ca3c6c6 100644
--- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py
@@ -26,17 +26,14 @@ def main(args):
                 context.append({'speaker': speaker, 'utt':utt})
                 fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
 
-                # min_neg = len(turn['keywords'])
-                # max_neg = 4 * min_neg
-                # negative_keywords = random.sample(keywords_set, random.randint(min_neg, max_neg))
-                # negative_keywords = random.sample(turn_keywords_set, 1)[0]
                 negative_keywords = turn_keywords[cnt]
                 cnt += 1
                 possible_keywords = turn['keywords'] + list(negative_keywords)
                 random.shuffle(possible_keywords)
                 possible_keywords = ' | '.join(possible_keywords)
                 input_seq = f'possible keywords: {possible_keywords}\n\ncontext: {context_seq}'
-                fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
+                if args.noisy:
+                    fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
     
 
 if __name__ == '__main__':
@@ -44,6 +41,7 @@ if __name__ == '__main__':
     parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
     parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
     parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
+    parser.add_argument('--noisy', action='store_true', help='whether add noisy keywords samples')
     args = parser.parse_args()
     print(args)
     main(args)
diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
index c48f49b52947ff56b68710265d7aae79856c005a..00b8223fece4adf5a4b350771945631ad56d0f9e 100644
--- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh
@@ -1,7 +1,8 @@
-dataset_name="metalwoz+sgd+tm1+tm2+tm3"
+task_name="key2gen_shuffle_noisy"
+dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
 names=$(echo ${dataset_name} | tr "+" "\n")
 model_type="gpt"
-data_dir=data/key2gen_shuffle_noisy/${model_type}/${name}/${dataset_name}
+data_dir=data/${task_name}/${model_type}/${name}/${dataset_name}
 rm -r ${data_dir}
 mkdir -p ${data_dir}
 train_file="${data_dir}/train.json"
@@ -10,11 +11,11 @@ test_file="${data_dir}/test.json"
 for name in ${names}
 do
     echo "preprocessing ${name}"
-    python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/key2gen_shuffle_noisy/${model_type}/${name}
+    python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} --noisy
     if [ "${name}" != "${dataset_name}" ]; then
-        cat "data/key2gen_shuffle_noisy/gpt/${name}/train.json" >> ${train_file}
-        cat "data/key2gen_shuffle_noisy/gpt/${name}/validation.json" >> ${validation_file}
-        cat "data/key2gen_shuffle_noisy/gpt/${name}/test.json" >> ${test_file}
+        cat "data/${task_name}/gpt/${name}/train.json" >> ${train_file}
+        cat "data/${task_name}/gpt/${name}/validation.json" >> ${validation_file}
+        cat "data/${task_name}/gpt/${name}/test.json" >> ${test_file}
     fi
 done
-python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/key2gen_shuffle_noisy/${model_type}/multiwoz21
\ No newline at end of file
+python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 --noisy
\ No newline at end of file
diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
index b0e14c86f58baaca8af6e246cef4c58eddde6447..04d743db32ad9ee4f6e88747fa8a02bd989fa35d 100644
--- a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -6,17 +6,20 @@ from tqdm import tqdm
 import numpy as np
 from nltk.corpus import stopwords
 from nltk.tokenize import word_tokenize
+from transformers import GPT2Tokenizer
+from string import punctuation
 
 
-
-def merge_tokens(tokens, losses, loss_merge_func=np.mean):
+def merge_tokens(tokens, losses):
+    """Merge tokens into words"""
     res = []
     i = 0
     while i < len(tokens):
         token = tokens[i]
         loss = losses[i]
         if token in ['Ġ', 'Ċ']:
-            if token == 'Ċ' and i < len(tokens) - 1:
+            # "Ġ" means " ", "Ċ" means "\n"
+            if token == 'Ċ' and i < len(tokens) - 1 and not tokens[i+1].startswith('Ġ'):
                 tokens[i+1] = 'Ġ'+tokens[i+1]
             i += 1
             continue
@@ -28,26 +31,23 @@ def merge_tokens(tokens, losses, loss_merge_func=np.mean):
                 i += 2
             continue
         if token.startswith('Ġ'):
-            # Ġ means space
-            token = token.replace("Ġ", "")
-            res.append([token, loss])
+            # token = token.replace("Ġ", "")
+            res.append([[token], [loss]])
         elif token == '<|endoftext|>':
-            res.append([token, loss])
+            res.append([[token], [loss]])
         else:
             assert 'Ġ' not in token
             if len(res) > 0:
-                res[-1][0] += token
-                res[-1].append(loss)
+                res[-1][0].append(token)
+                res[-1][1].append(loss)
             else:
                 res.append([token, loss])
         i += 1
-    if loss_merge_func:
-        for i in range(len(res)):
-            res[i] = [res[i][0], loss_merge_func(res[i][1:])]
     return res
 
 
-def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean):
+def convert_token_loss2word_loss(token_loss_file):
+    """generate a word loss file according to the token loss file"""
     word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word'))
     fin = open(token_loss_file, 'rb')
     fout = open(word_loss_file, 'w', encoding='utf-8')
@@ -56,7 +56,7 @@ def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean):
     for item in tqdm(json_lines.reader(fin)):
         tokens, losses = item['tokens'], item['losses']
         assert len(tokens) == len(losses)
-        word2losses = merge_tokens(tokens, losses, loss_merge_func)
+        word2losses = merge_tokens(tokens, losses)
         lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]})
         fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n')
 
@@ -79,6 +79,7 @@ def main(args):
         return
 
     stop_words = set(stopwords.words('english'))
+    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
 
     if args.keywords_th_ratio > 0:
         losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
@@ -88,23 +89,33 @@ def main(args):
         loss_th = 0
 
     def keywords_filter(word_loss_pairs):
-        candidate_indexes = []
+        index2keyword = {}
         for i, word_loss_pair in enumerate(word_loss_pairs):
-            if args.stopwords and any([w.lower() in stop_words for w in word_tokenize(word_loss_pair[0])]):
+            words = word_tokenize(word_loss_pair[0])
+            if args.stopwords and any([w.lower() in stop_words for w in words]):
+                # skip stopwords
                 continue
             if word_loss_pair[1] <= loss_th:
+                # skip if loss is too small
                 continue
-            candidate_indexes.append(i)
-
+            # strip punctuation
+            strip_punctuation = word_loss_pair[0].strip(punctuation).strip()
+            if len(strip_punctuation) == 0:
+                # skip punctuation
+                continue
+            index2keyword[i] = strip_punctuation
+        candidate_indexes = list(index2keyword.keys())
         topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num)
         topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk]
         topk_indexes = sorted(topk_indexes)
         keywords = []
         for i, index in enumerate(topk_indexes):
-            if i > 0 and index == topk_indexes[i-1] + 1:
-                keywords[-1]+= ' '+word_loss_pairs[index][0]
+            if i > 0 and index == topk_indexes[i-1] + 1 and \
+                word_loss_pairs[index][0].strip().startswith(index2keyword[index]) and \
+                word_loss_pairs[topk_indexes[i-1]][0].strip().endswith(index2keyword[topk_indexes[i-1]]):
+                keywords[-1]+= ' '+index2keyword[index]
             else:
-                keywords.append(word_loss_pairs[index][0])
+                keywords.append(index2keyword[index])
 
         return keywords
 
@@ -115,12 +126,13 @@ def main(args):
         turns = []
         turn = {'words': [], 'losses': []}
         for word, loss in zip(words, losses):
-            if word == '<|endoftext|>':
+            if word == ['<|endoftext|>']:
                 # switch turn
-                turn['utterance'] = ' '.join(turn['words'])
+                turn['words'] = [tokenizer.convert_tokens_to_string(tokens) for tokens in turn['words']]
+                turn['losses'] = [np.mean(losses) for losses in turn['losses']]
+                turn['utterance'] = ''.join(turn['words']).strip()
                 keywords = keywords_filter(list(zip(turn['words'], turn['losses'])))
                 turn['keywords'] = keywords
-                # turn['keywords'] = ' | '.join([x[0] for x in keywords])
                 turn.pop('words')
                 turn.pop('losses')
                 turns.append(turn)
diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
index c04d68fc374c38eb27b78f0ac288d04470e98d05..aa648116737f10909d1d83a0e9ec1ac0a5a682d2 100644
--- a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh
@@ -1,7 +1,7 @@
 set -e
-n_gpus=1
+n_gpus=2
 task_name="key2gen_shuffle_noisy"
-dataset_name="metalwoz+sgd+tm1+tm2+tm3"
+dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
 speaker="all"
 model_type="gpt"
 data_dir="data/${task_name}/${model_type}/${dataset_name}"
@@ -19,11 +19,11 @@ max_target_length=128
 model_name_or_path="t5-small"
 per_device_train_batch_size=128
 per_device_eval_batch_size=128
-gradient_accumulation_steps=8
+gradient_accumulation_steps=4
 lr=1e-3
 num_train_epochs=1
 
-python -m torch.distributed.launch \
+python -m torch.distributed.launch --master_port 23456\
     --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \