diff --git a/convlab/base_models/t5/__init__.py b/convlab/base_models/t5/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/convlab/base_models/t5/create_data.py b/convlab/base_models/t5/create_data.py
index 75c282848fa30deca8ca66ecd176df35d8588ed0..a1b10adb2ef0440868f4575e2ce2ce76508f13c4 100644
--- a/convlab/base_models/t5/create_data.py
+++ b/convlab/base_models/t5/create_data.py
@@ -3,7 +3,7 @@ import json
 from tqdm import tqdm
 import re
 from transformers import AutoTokenizer
-from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
+from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, retrieve_utterances
 from convlab.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq
 from convlab.base_models.t5.dst.serialization import serialize_dialogue_state, deserialize_dialogue_state, equal_state_seq
 
@@ -120,6 +120,59 @@ def create_goal2dialogue_data(dataset, data_dir, args):
         data_by_split[data_split] = data
     return data_by_split
 
+def create_retnlu_data(dataset, data_dir, args):
+    dataset_name = dataset[list(dataset.keys())[0]][0]['dataset']
+    data_by_split = load_nlu_data(dataset, speaker=args.speaker, use_context=args.context_window_size>0, context_window_size=args.context_window_size)
+    data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}', \
+        f'in_context_{args.retrieval_in_context}', f'topk_{args.retrieval_topk}')
+    os.makedirs(data_dir, exist_ok=True)
+
+    turn_pool = []
+    for d in args.retrieval_datasets:
+        pool_dataset = load_dataset(d)
+        for turn in load_nlu_data(pool_dataset, data_split='train', speaker=args.speaker)['train']:
+            if any([len(das) > 0 for da_type, das in turn['dialogue_acts'].items()]):
+                turn_pool.append({'dataset': d, **turn})
+
+    data_splits = data_by_split.keys()
+    query_turns = []
+    for data_split in data_splits:
+        query_turns.extend(data_by_split[data_split])
+    augmented_dataset = retrieve_utterances(query_turns, turn_pool, args.retrieval_topk, 'all-MiniLM-L6-v2')
+
+    i = 0
+    for data_split in data_splits:
+        data = []
+        for j in tqdm(range(len(data_by_split[data_split])), desc=f'{data_split} sample', leave=False):
+            sample = augmented_dataset[i+j]
+            response = f"{sample['speaker']}: {sample['utterance']}"
+            if args.context_window_size>0:
+                context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response])
+            else:
+                context = response
+            context = ' '.join([dataset_name, context])
+            dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
+            assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq))
+
+            retrieved_turns = sample['retrieved_turns']
+            for t in retrieved_turns:
+                # in-context learning
+                retrieved_utterance = f"{t['dataset']} {t['speaker']}: {t['utterance']}"
+                retrieved_dialogue_acts_seq = serialize_dialogue_acts(t['dialogue_acts'])
+                if args.retrieval_in_context:
+                    context = f"{retrieved_utterance} => {retrieved_dialogue_acts_seq}\n\n" + context
+                elif data_split != 'test':
+                    data.append(json.dumps({'context': retrieved_utterance, 'dialogue_acts_seq': retrieved_dialogue_acts_seq}, ensure_ascii=False)+'\n')        
+
+            data.append(json.dumps({'context': context, 'dialogue_acts_seq': dialogue_acts_seq}, ensure_ascii=False)+'\n')
+        i += len(data_by_split[data_split])
+
+        file_name = os.path.join(data_dir, f"{data_split}.json")
+        with open(file_name, "w", encoding='utf-8') as f:
+            f.writelines(data)
+        data_by_split[data_split] = data
+    return data_by_split
+
 def get_max_len(data_by_split, tokenizer):
     for data_split in data_by_split.keys():
         seq_len = {}
@@ -136,13 +189,16 @@ def get_max_len(data_by_split, tokenizer):
 if __name__ == '__main__':
     from argparse import ArgumentParser
     parser = ArgumentParser(description="create data for seq2seq training")
-    parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue'], help='names of tasks')
+    parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue', 'retnlu', 'retnlg'], help='names of tasks')
     parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
     parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s)')
     parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
     parser.add_argument('--len_tokenizer', '-l', type=str, default=None, help='name or path of tokenizer that used to get seq len')
     parser.add_argument('--ratio', '-r', type=float, default=None, help='how many data is used for training and evaluation')
     parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
+    parser.add_argument('--retrieval_datasets', metavar='dataset_name for retrieval augmentation', nargs='*', help='names of unified datasets for retrieval')
+    parser.add_argument('--retrieval_topk', type=int, default=3, help='how many utterances to be retrieved')
+    parser.add_argument('--retrieval_in_context', action='store_true', default=False, help='whether use the retrieved utterance by in-context learning')
     args = parser.parse_args()
     print(args)
     if args.len_tokenizer:
diff --git a/convlab/base_models/t5/dst/merge_data.py b/convlab/base_models/t5/dst/merge_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b76cdcdbf84e794abefc3a67227a6da4f133aee
--- /dev/null
+++ b/convlab/base_models/t5/dst/merge_data.py
@@ -0,0 +1,21 @@
+import json
+import os
+import sys
+
+if __name__ == '__main__':
+    merged_data = {'train': [], 'validation': [], 'test': []}
+    print(sys.argv)
+    for dataset_name in sys.argv[1:]:
+        data_dir = os.path.join('data/dst', dataset_name, 'user/context_100')
+        for data_split in merged_data:
+            with open(os.path.join(data_dir, f'{data_split}.json'), 'r') as f:
+                for line in f:
+                    item = json.loads(line)
+                    item['context'] = f"{dataset_name}: {item['context']}"
+                    merged_data[data_split].append(item)
+    for data_split in merged_data:
+        data_dir = os.path.join('data/dst', '+'.join(sys.argv[1:]), 'user/context_100')
+        os.makedirs(data_dir, exist_ok=True)
+        with open(os.path.join(data_dir, f'{data_split}.json'), 'w') as f:
+            for item in merged_data[data_split]:
+                f.write(json.dumps(item)+'\n')
diff --git a/convlab/base_models/t5/dst/merge_predict_res.py b/convlab/base_models/t5/dst/merge_predict_res.py
index 6d21d07c8091350f19ed45668c5a02b7c485d479..f25279a87a8404faab03523a072782c1a08b738c 100755
--- a/convlab/base_models/t5/dst/merge_predict_res.py
+++ b/convlab/base_models/t5/dst/merge_predict_res.py
@@ -4,10 +4,8 @@ from convlab.util import load_dataset, load_dst_data
 from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state
 
 
-def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
+def merge(dataset_names, speaker, save_dir, context_window_size, predict_result):
     assert os.path.exists(predict_result)
-    dataset = load_dataset(dataset_name, args.dial_ids_order)
-    data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
     
     if save_dir is None:
         save_dir = os.path.dirname(predict_result)
@@ -15,10 +13,19 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
         os.makedirs(save_dir, exist_ok=True)
     predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
 
-    for sample, prediction in zip(data, predict_result):
-        sample['predictions'] = {'state': prediction}
+    merged = []
+    i = 0
+    for dataset_name in dataset_names.split('+'):
+        print(dataset_name)
+        dataset = load_dataset(dataset_name, args.dial_ids_order)
+        data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
+    
+        for sample in data:
+            sample['predictions'] = {'state': predict_result[i]}
+            i += 1
+            merged.append(sample)
 
-    json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
 
 
 if __name__ == '__main__':
diff --git a/convlab/base_models/t5/dst/run_dst.sh b/convlab/base_models/t5/dst/run_dst.sh
index 2dfc622d88a9b1b38e70e15f1f5cefd2d4a78661..0704ebf9257be910c2148d052574b535182be07e 100644
--- a/convlab/base_models/t5/dst/run_dst.sh
+++ b/convlab/base_models/t5/dst/run_dst.sh
@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
diff --git a/convlab/base_models/t5/dst/run_dst_fewshot.sh b/convlab/base_models/t5/dst/run_dst_fewshot.sh
index d45719112e50dd44672ab52b28c04014cb5d6e5c..f548c053b544b51101f0cfbcc0b1a7b3a09c8088 100644
--- a/convlab/base_models/t5/dst/run_dst_fewshot.sh
+++ b/convlab/base_models/t5/dst/run_dst_fewshot.sh
@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --early_stopping_patience 10 \
     --prediction_loss_only \
     --load_best_model_at_end \
diff --git a/convlab/base_models/t5/dst/run_dst_multitask.sh b/convlab/base_models/t5/dst/run_dst_multitask.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0f3b60a63a2f1bf861cb430a247121d966aac822
--- /dev/null
+++ b/convlab/base_models/t5/dst/run_dst_multitask.sh
@@ -0,0 +1,94 @@
+n_gpus=1
+task_name="dst"
+dataset_name="sgd+tm1+tm2+tm3+multiwoz21"
+speaker="user"
+context_window_size=100
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="dst_metric.py"
+metric_for_best_model="accuracy"
+source_column="context"
+target_column="state_seq"
+truncation_side="left"
+max_source_length=1024
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=64
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=10
+
+names=$(echo ${dataset_name} | tr "+" "\n")
+rm -r ${data_dir}
+mkdir -p ${data_dir}
+for name in ${names};
+do
+    echo "preprocessing ${name}"
+    # python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
+done
+
+python merge_data.py $(echo ${dataset_name} | tr "+" " ")
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlg/merge_predict_res.py b/convlab/base_models/t5/nlg/merge_predict_res.py
index d21fd489225aab8d75dde9f6f266b778e956512c..7d2995d84737378958a54765f3efc5f996f112e3 100755
--- a/convlab/base_models/t5/nlg/merge_predict_res.py
+++ b/convlab/base_models/t5/nlg/merge_predict_res.py
@@ -24,6 +24,7 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result)
                 continue
             sample['predictions'] = {'utterance': predict_result[i]}
             i += 1
+            merged.append(sample)
 
     json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
 
diff --git a/convlab/base_models/t5/nlu/merge_data.py b/convlab/base_models/t5/nlu/merge_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb692de53b073e18e4cebcfb4762c0ef473c7b69
--- /dev/null
+++ b/convlab/base_models/t5/nlu/merge_data.py
@@ -0,0 +1,21 @@
+import json
+import os
+import sys
+
+if __name__ == '__main__':
+    merged_data = {'train': [], 'validation': [], 'test': []}
+    print(sys.argv)
+    for dataset_name in sys.argv[1:]:
+        data_dir = os.path.join('data/nlu', dataset_name, 'user/context_0')
+        for data_split in merged_data:
+            with open(os.path.join(data_dir, f'{data_split}.json'), 'r') as f:
+                for line in f:
+                    item = json.loads(line)
+                    item['context'] = f"{dataset_name}: {item['context']}"
+                    merged_data[data_split].append(item)
+    for data_split in merged_data:
+        data_dir = os.path.join('data/nlu', '+'.join(sys.argv[1:]), 'user/context_0')
+        os.makedirs(data_dir, exist_ok=True)
+        with open(os.path.join(data_dir, f'{data_split}.json'), 'w') as f:
+            for item in merged_data[data_split]:
+                f.write(json.dumps(item)+'\n')
diff --git a/convlab/base_models/t5/nlu/merge_predict_res.py b/convlab/base_models/t5/nlu/merge_predict_res.py
index 58cf29d194272accd7578d58ba8bac415c025541..e247160769f7e5b0c9445b38e4dc2a5caa567fd0 100755
--- a/convlab/base_models/t5/nlu/merge_predict_res.py
+++ b/convlab/base_models/t5/nlu/merge_predict_res.py
@@ -4,10 +4,8 @@ from convlab.util import load_dataset, load_nlu_data
 from convlab.base_models.t5.nlu.serialization import deserialize_dialogue_acts
 
 
-def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
+def merge(dataset_names, speaker, save_dir, context_window_size, predict_result):
     assert os.path.exists(predict_result)
-    dataset = load_dataset(dataset_name, args.dial_ids_order)
-    data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
     
     if save_dir is None:
         save_dir = os.path.dirname(predict_result)
@@ -15,10 +13,19 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
         os.makedirs(save_dir, exist_ok=True)
     predict_result = [deserialize_dialogue_acts(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
 
-    for sample, prediction in zip(data, predict_result):
-        sample['predictions'] = {'dialogue_acts': prediction}
+    merged = []
+    i = 0
+    for dataset_name in dataset_names.split('+'):
+        print(dataset_name)
+        dataset = load_dataset(dataset_name, args.dial_ids_order)
+        data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
+        
+        for sample in data:
+            sample['predictions'] = {'dialogue_acts': predict_result[i]}
+            i += 1
+            merged.append(sample)
 
-    json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
 
 
 if __name__ == '__main__':
diff --git a/convlab/base_models/t5/nlu/run_nlu.sh b/convlab/base_models/t5/nlu/run_nlu.sh
index fb9be0227b3cced261ed6ccbffa9857e477012a2..8cba74aca0510464d176aa44ef0388c914796f5f 100644
--- a/convlab/base_models/t5/nlu/run_nlu.sh
+++ b/convlab/base_models/t5/nlu/run_nlu.sh
@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
diff --git a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh
index 568c271323cf2472f7989e0cb68e9af051bcc89b..8da69801df77d8f72d23204c8cf008ea7512d10c 100644
--- a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh
+++ b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh
@@ -42,7 +42,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --load_best_model_at_end \
     --cache_dir ${cache_dir} \
diff --git a/convlab/base_models/t5/nlu/run_nlu_multitask.sh b/convlab/base_models/t5/nlu/run_nlu_multitask.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6380acff2fc5e8a2712e530823c5d0b61af451a2
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_nlu_multitask.sh
@@ -0,0 +1,94 @@
+n_gpus=1
+task_name="nlu"
+dataset_name="tm1+tm2+tm3"
+speaker="user"
+context_window_size=0
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=10
+
+names=$(echo ${dataset_name} | tr "+" "\n")
+rm -r ${data_dir}
+mkdir -p ${data_dir}
+for name in ${names};
+do
+    echo "preprocessing ${name}"
+    python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
+done
+
+python merge_data.py $(echo ${dataset_name} | tr "+" " ")
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu.sh b/convlab/base_models/t5/nlu/run_retnlu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b45a0e45643fd5a2247633305df7e0c1f11ce848
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu.sh
@@ -0,0 +1,86 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+retrieval_topk=1
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=10
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk}
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d165859b01485b7885f88e5b1ae3a279e41f4caf
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh
@@ -0,0 +1,89 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+ratio=$1
+dial_ids_order=$2
+retrieval_topk=$3
+data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=100
+
+# python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} -r ${ratio} -o ${dial_ids_order}
+
+# python ../run_seq2seq.py \
+#     --task_name ${task_name} \
+#     --train_file ${train_file} \
+#     --validation_file ${validation_file} \
+#     --source_column ${source_column} \
+#     --target_column ${target_column} \
+#     --max_source_length ${max_source_length} \
+#     --max_target_length ${max_target_length} \
+#     --truncation_side ${truncation_side} \
+#     --model_name_or_path ${model_name_or_path} \
+#     --do_train \
+#     --do_eval \
+#     --save_strategy epoch \
+#     --evaluation_strategy epoch \
+#     --save_total_limit 1 \
+#     --prediction_loss_only \
+#     --load_best_model_at_end \
+#     --cache_dir ${cache_dir} \
+#     --output_dir ${output_dir} \
+#     --logging_dir ${logging_dir} \
+#     --overwrite_output_dir \
+#     --preprocessing_num_workers 4 \
+#     --per_device_train_batch_size ${per_device_train_batch_size} \
+#     --per_device_eval_batch_size ${per_device_eval_batch_size} \
+#     --gradient_accumulation_steps ${gradient_accumulation_steps} \
+#     --learning_rate ${lr} \
+#     --num_train_epochs ${num_train_epochs} \
+#     --adafactor \
+#     --gradient_checkpointing
+
+# python ../run_seq2seq.py \
+#     --task_name ${task_name} \
+#     --test_file ${test_file} \
+#     --source_column ${source_column} \
+#     --target_column ${target_column} \
+#     --max_source_length ${max_source_length} \
+#     --max_target_length ${max_target_length} \
+#     --truncation_side ${truncation_side} \
+#     --model_name_or_path ${output_dir} \
+#     --do_predict \
+#     --predict_with_generate \
+#     --metric_name_or_path ${metric_name_or_path} \
+#     --cache_dir ${cache_dir} \
+#     --output_dir ${output_dir} \
+#     --logging_dir ${logging_dir} \
+#     --overwrite_output_dir \
+#     --preprocessing_num_workers 4 \
+#     --per_device_train_batch_size ${per_device_train_batch_size} \
+#     --per_device_eval_batch_size ${per_device_eval_batch_size} \
+#     --gradient_accumulation_steps ${gradient_accumulation_steps} \
+#     --learning_rate ${lr} \
+#     --num_train_epochs ${num_train_epochs} \
+#     --adafactor \
+#     --gradient_checkpointing
+
+# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82dae873ebb419d1d311f347a813f1da6071dccb
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh
@@ -0,0 +1,86 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+retrieval_topk=$1
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=10
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh
new file mode 100644
index 0000000000000000000000000000000000000000..836152f80e9d21695aaadde0016aa7399eedbdf2
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh
@@ -0,0 +1,89 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+ratio=$1
+dial_ids_order=$2
+retrieval_topk=$3
+data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=100
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context -r ${ratio} -o ${dial_ids_order}
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --load_best_model_at_end \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/run_seq2seq.py b/convlab/base_models/t5/run_seq2seq.py
index 8ce0f8d7f305b13b10ff0f5b899094fc2a4c96df..7aac3c70746e877469fc34892cd3f93f9fd01f22 100644
--- a/convlab/base_models/t5/run_seq2seq.py
+++ b/convlab/base_models/t5/run_seq2seq.py
@@ -39,14 +39,13 @@ from transformers import (
     AutoTokenizer,
     DataCollatorForSeq2Seq,
     HfArgumentParser,
-    Seq2SeqTrainer,
-    Seq2SeqTrainingArguments,
     EarlyStoppingCallback,
     set_seed,
 )
 from transformers.trainer_utils import EvalPrediction, get_last_checkpoint
 from transformers.utils import check_min_version
 from transformers.utils.versions import require_version
+from convlab.base_models.t5.trainer import ConvLabSeq2SeqTrainer, ConvLabSeq2SeqTrainingArguments
 
 
 # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@@ -249,7 +248,7 @@ def main():
     # or by passing the --help flag to this script.
     # We now keep distinct sets of args, for a cleaner separation of concerns.
 
-    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
+    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConvLabSeq2SeqTrainingArguments))
     if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
         # If we pass only one argument to the script and it's the path to a json file,
         # let's parse it to get our arguments.
@@ -556,7 +555,7 @@ def main():
         training_args.generation_max_length = data_args.val_max_target_length
 
     # Initialize our Trainer
-    trainer = Seq2SeqTrainer(
+    trainer = ConvLabSeq2SeqTrainer(
         model=model,
         args=training_args,
         train_dataset=train_dataset if training_args.do_train else None,
diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f125575e0ebfb44e05ef16e4d8d041e016cc81
--- /dev/null
+++ b/convlab/base_models/t5/trainer.py
@@ -0,0 +1,132 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from email.policy import default
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass, field
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+
+from transformers.deepspeed import is_deepspeed_zero3_enabled
+from transformers.trainer_utils import PredictionOutput
+from transformers.utils import logging, add_start_docstrings
+from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+
+
+logger = logging.get_logger(__name__)
+
+@dataclass
+class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
+    """
+    `ConvLabSeq2SeqTrainingArguments` is a subclass of `Seq2SeqTrainingArguments` that adds the
+    following arguments: `do_sample`, `temperature`, `top_k`, `top_p`, `repetition_penalty`, and
+    `num_return_sequences`
+    """
+    do_sample: bool = field(default=False, metadata={"help": "Whether or not to use sampling ; use greedy decoding otherwise."})
+    temperature: Optional[float] = field(default=1.0, metadata={"help": "The value used to module the next token probabilities."})
+    top_k: Optional[int] = field(default=0, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k-filtering."})
+    top_p: Optional[float] = field(default=1.0, metadata={"help": "If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation."})
+    num_return_sequences: Optional[int] = field(default=1, metadata={"help": "The number of independently computed returned sequences for each element in the batch."})
+
+
+
+class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
+    def prediction_step(
+        self,
+        model: nn.Module,
+        inputs: Dict[str, Union[torch.Tensor, Any]],
+        prediction_loss_only: bool,
+        ignore_keys: Optional[List[str]] = None,
+    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+        """
+        Perform an evaluation step on `model` using `inputs`.
+        Subclass and override to inject custom behavior.
+        Args:
+            model (`nn.Module`):
+                The model to evaluate.
+            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+            prediction_loss_only (`bool`):
+                Whether or not to return the loss only.
+        Return:
+            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
+            labels (each being optional).
+        """
+
+        if not self.args.predict_with_generate or prediction_loss_only:
+            return super().prediction_step(
+                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+            )
+
+        has_labels = "labels" in inputs
+        inputs = self._prepare_inputs(inputs)
+
+        # XXX: adapt synced_gpus for fairscale as well
+        gen_kwargs = {
+            "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
+            "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
+            "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
+            "do_sample": self.args.do_sample,
+            "temperature": self.args.temperature,
+            "top_k": self.args.top_k,
+            "top_p": self.args.top_p,
+            "num_return_sequences": self.args.num_return_sequences
+        }
+
+        if "attention_mask" in inputs:
+            gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
+        if "global_attention_mask" in inputs:
+            gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
+
+        # prepare generation inputs
+        # some encoder-decoder models can have varying encoder's and thus
+        # varying model input names
+        if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
+            generation_inputs = inputs[self.model.encoder.main_input_name]
+        else:
+            generation_inputs = inputs[self.model.main_input_name]
+
+        generated_tokens = self.model.generate(
+            generation_inputs,
+            **gen_kwargs,
+        )
+        # in case the batch is shorter than max length, the output should be padded
+        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
+            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
+
+        with torch.no_grad():
+            with self.autocast_smart_context_manager():
+                outputs = model(**inputs)
+            if has_labels:
+                if self.label_smoother is not None:
+                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
+                else:
+                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
+            else:
+                loss = None
+
+        if self.args.prediction_loss_only:
+            return (loss, None, None)
+
+        if has_labels:
+            labels = inputs["labels"]
+            if labels.shape[-1] < gen_kwargs["max_length"]:
+                labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
+        else:
+            labels = None
+
+        return (loss, generated_tokens, labels)
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index e24658410738b290da97149382c8c89030936679..a35199bf344efa75e1736c2a46f20a33f5d8d28c 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -9,7 +9,9 @@ from abc import ABC, abstractmethod
 from pprint import pprint
 from convlab.util.file_util import cached_path
 import shutil
-import importlib
+from sentence_transformers import SentenceTransformer, util
+import torch
+from tqdm import tqdm
 
 
 class BaseDatabase(ABC):
@@ -433,6 +435,36 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
     return dataset, sorted(list(delex_vocab))
 
 
+def retrieve_utterances(query_turns, turn_pool, top_k, model_name):
+    """
+    It takes a list of query turns, a list of turn pool, and a top_k value, and returns a list of query
+    turns with a new key called 'retrieve_utterances' that contains a list of top_k retrieved utterances
+    from the turn pool
+    
+    :param query_turns: a list of turns that you want to retrieve utterances for
+    :param turn_pool: the pool of turns to retrieve from
+    :param top_k: the number of utterances to retrieve for each query turn
+    :param model_name: the name of the model you want to use
+    :return: A list of dictionaries, with a new key 'retrieve_utterances' that is a list of retrieved turns and similarity scores.
+    """
+    embedder = SentenceTransformer(model_name)
+    corpus = [turn['utterance'] for turn in turn_pool]
+    corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
+    corpus_embeddings = corpus_embeddings.to('cuda')
+    corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
+
+    queries = [turn['utterance'] for turn in query_turns]
+    query_embeddings = embedder.encode(queries, convert_to_tensor=True)
+    query_embeddings = query_embeddings.to('cuda')
+    query_embeddings = util.normalize_embeddings(query_embeddings)
+
+    hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.dot_score, top_k=top_k)
+
+    for i, turn in enumerate(query_turns):
+        turn['retrieved_turns'] = [{'score': hit['score'], **turn_pool[hit['corpus_id']]} for hit in hits[i]]
+    return query_turns
+
+
 if __name__ == "__main__":
     dataset = load_dataset('multiwoz21', dial_ids_order=0)
     train_ratio = 0.1
@@ -447,7 +479,11 @@ if __name__ == "__main__":
     print(res[0], len(res))
     
     data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
-    pprint(data_by_split['test'][0])
+    query_turns = data_by_split['test'][:10]
+    pool_dataset = load_dataset('camrest')
+    turn_pool = load_nlu_data(pool_dataset, data_split='train', speaker='user')['train']
+    augmented_dataset = retrieve_utterances(query_turns, turn_pool, 3, 'all-MiniLM-L6-v2')
+    pprint(augmented_dataset[0])
 
     def delex_slot(domain, slot, value):
         # only use slot name for delexicalization
diff --git a/data/unified_datasets/opendialkg/preprocess.py b/data/unified_datasets/opendialkg/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..a010d4084950c419f71835ee72501519195d964d
--- /dev/null
+++ b/data/unified_datasets/opendialkg/preprocess.py
@@ -0,0 +1,143 @@
+from turtle import st
+from zipfile import ZipFile, ZIP_DEFLATED
+from shutil import rmtree
+import json
+import os
+from tqdm import tqdm
+from collections import Counter
+from pprint import pprint
+import re
+import requests
+from dateutil import parser as date_parser
+from string import punctuation
+from copy import deepcopy
+import csv
+import random
+
+
+def value_in_utt(value, utt):
+    """return character level (start, end) if value in utt"""
+    value = value.strip(punctuation).lower()
+    utt = utt
+    p = '(^|[\s,\.:\?!-])(?P<v>{})([\s,\.:\?!-\']|$)'.format(re.escape(value))
+    p = re.compile(p, re.I)
+    m = re.search(p, utt)
+    if m:
+        # very few value appears more than once, take the first span
+        return True, m.span('v')
+    else:
+        try:
+            # solve date representation, e.g. '3 pm' vs '3pm'
+            date_parser.parse(value)
+            if (value.endswith('pm') or value.endswith('am')) and ''.join(value.split(' ')) in ''.join(utt.split(' ')):
+                return True, None
+            
+        except:
+            if value in utt:
+                # value appears, but may be in the plural, -ing, -ly, etc.
+                return True, None
+
+    return False, None
+
+
+def preprocess():
+    random.seed(42)
+
+    data_file = "opendialkg.csv"
+    if not os.path.exists(data_file):
+        response = requests.get("https://github.com/facebookresearch/opendialkg/raw/main/data/opendialkg.csv")
+        open(data_file, "wb").write(response.content)
+
+    new_data_dir = 'data'
+
+    os.makedirs(new_data_dir, exist_ok=True)
+
+    dataset = 'opendialkg'
+    splits = ['train', 'validation', 'test']
+    dialogues_by_split = {split:[] for split in splits}
+
+    ontology = {'domains': {},
+                'intents': {},
+                'state': {},
+                'dialogue_acts': {
+                    "categorical": {},
+                    "non-categorical": {},
+                    "binary": {}
+                }}
+
+    data = []
+    with open(data_file) as csv_file:
+        csv_reader = csv.reader(csv_file, delimiter=',')
+        header = next(csv_reader)
+        for row in csv_reader:
+            sample = {}
+            for i, col in enumerate(row):
+                sample[header[i]] = col
+            data.append(sample)
+
+    # shuffle for random split to train:validation:test = 70:15:15
+    random.shuffle(data)
+    split2range = {
+        'train': [0, round(len(data)*0.7)],
+        'validation': [round(len(data)*0.7), round(len(data)*0.85)],
+        'test': [round(len(data)*0.85), len(data)],
+    }
+    cnt = 0
+    for data_split in splits:
+        for i in tqdm(range(*split2range[data_split])):
+            item = data[i]
+            dialogue_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
+            dialogue = {
+                'dataset': dataset,
+                'data_split': data_split,
+                'dialogue_id': dialogue_id,
+                'original_id': f'{data_split}-{len(dialogues_by_split[data_split])}',
+                'user_rating': eval(item['User Rating']),
+                'system_rating': eval(item['Assistant Rating']),
+                'turns': [],
+            }
+
+            for turn in eval(item['Messages']):
+                speaker = 'user' if turn['sender'] == 'user' else 'system'
+                turn_type = turn['type']
+                if turn_type == 'chat':
+                    assert len(turn) == 3
+                    if len(dialogue['turns'])>0 and speaker == dialogue['turns'][-1]['speaker']:
+                        dialogue['turns'][-1]['utterance'] += turn['message']
+                    else:
+                        dialogue['turns'].append({
+                            'speaker': speaker,
+                            'utterance': turn['message'],
+                            'utt_idx': len(dialogue['turns']),
+                        })
+                elif turn['action_id'] == "meta_thread/send_meta_message":
+                    # skip annotator communication
+                    pass
+                else:
+                    assert turn_type == 'action' and turn['action_id'] == "kgwalk/choose_path"
+                    assert len(dialogue['turns'])==0 or (speaker != dialogue['turns'][-1]['speaker']), print(turn)
+                    dialogue['turns'].append({
+                        'speaker': speaker,
+                        'utterance': '',
+                        'kg_path': {k: v for k, v in zip(['score', 'triples', 'rendering'], turn['metadata']['path'])},
+                        'utt_idx': len(dialogue['turns']),
+                    })
+            if len(dialogue['turns']) != 0:
+                dialogues_by_split[data_split].append(dialogue)
+                if any(['kg_path' in turn for turn in dialogue['turns']]):
+                    cnt+=1
+    
+    dialogues = dialogues_by_split['train']+dialogues_by_split['validation']+dialogues_by_split['test']
+    print(cnt, len(dialogues), cnt/len(dialogues))
+    json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
+        for filename in os.listdir(new_data_dir):
+            zf.write(f'{new_data_dir}/{filename}')
+    rmtree(new_data_dir)
+    return dialogues, ontology
+
+
+if __name__ == '__main__':
+    preprocess()
diff --git a/data/unified_datasets/wikidialog/preprocess.py b/data/unified_datasets/wikidialog/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc2b0b73bbb52a941f041de4bd7194b1b79d9103
--- /dev/null
+++ b/data/unified_datasets/wikidialog/preprocess.py
@@ -0,0 +1,78 @@
+import gzip
+import json
+from zipfile import ZipFile, ZIP_DEFLATED
+import os
+from shutil import rmtree
+from tqdm import tqdm
+
+def preprocess():
+    original_data_dir = 'WikiDialog-OQ'
+    new_data_dir = 'data'
+    os.makedirs(new_data_dir, exist_ok=True)
+
+    dataset = 'wikidialog'
+    splits = ['train', 'validation']
+    dialogues_by_split = {split:[] for split in splits}
+
+    ontology = {
+        'domains': {},
+        'intents': {},
+        'state': {},
+        "dialogue_acts": {
+            "categorical": {},
+            "non-categorical": {},
+            "binary": {}
+        }
+    }
+
+    def process_dial(line, dial_id, data_split):
+        item = json.loads(line)
+        dialogue = {
+            'dataset': dataset,
+            'data_split': data_split,
+            'dialogue_id': dial_id,
+            'original_id': item['pid'],
+            'topic': item['title'],
+            'turns': []
+        }
+        for speaker, utterance in zip(item['author_num'], item['utterances']):
+            speaker = 'system' if speaker == 0 else 'user'
+            turn = {
+                'speaker': speaker,
+                'utterance': utterance.strip(),
+                'utt_idx': len(dialogue['turns']),
+            }
+            dialogue['turns'].append(turn)
+        return dialogue
+            
+    data_split = 'train'
+    for shard in tqdm(range(1)):
+        with gzip.open(f'{original_data_dir}/data_train.jsonl-000{shard:02}-of-00099.gz','r') as fin:
+            for line in fin:
+                dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
+                dialogue = process_dial(line, dial_id, data_split)
+                dialogues_by_split[data_split].append(dialogue)
+
+    data_split = 'validation'
+    with gzip.open(f'{original_data_dir}/data_validation.jsonl.gz','r') as fin:
+        for line in fin:
+            dialogue = process_dial(line, dial_id, data_split)
+            dialogue['dialogue_id'] = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
+            dialogues_by_split[data_split].append(dialogue)
+            if len(dialogues_by_split[data_split]) >= len(dialogues_by_split['train']) // 10:
+                break
+    
+    dialogues = dialogues_by_split['train']+dialogues_by_split['validation']
+    json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
+        for filename in os.listdir(new_data_dir):
+            zf.write(f'{new_data_dir}/{filename}')
+    # rmtree(original_data_dir)
+    rmtree(new_data_dir)
+    return dialogues, ontology
+
+
+if __name__ == '__main__':
+    preprocess()
diff --git a/requirements.txt b/requirements.txt
index eba77a0fab5eac6f59af19f4456a9e35c2ca5834..328d5738877ff8eb6bc0935bcc42195679d2504a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -79,6 +79,7 @@ s3transfer==0.6.0
 sacrebleu==2.1.0
 scikit-learn==1.1.1
 scipy==1.8.1
+sentence-transformers==2.2.2
 seqeval==1.2.2
 simplejson==3.17.6
 six==1.16.0
diff --git a/setup.py b/setup.py
index 953e5486fbc68a02c2db65a4f4760076681f78f6..07d966b7a758e700fba48e16ec5d9e32cf62991a 100755
--- a/setup.py
+++ b/setup.py
@@ -39,6 +39,7 @@ setup(
         'tensorboard',
         'torch>=1.6',
         'transformers>=4.0',
+        'sentence-transformers',
         'datasets>=1.8',
         'seqeval',
         'spacy',