diff --git a/convlab/base_models/t5/create_data.py b/convlab/base_models/t5/create_data.py
index 75c282848fa30deca8ca66ecd176df35d8588ed0..a1b10adb2ef0440868f4575e2ce2ce76508f13c4 100644
--- a/convlab/base_models/t5/create_data.py
+++ b/convlab/base_models/t5/create_data.py
@@ -3,7 +3,7 @@ import json
 from tqdm import tqdm
 import re
 from transformers import AutoTokenizer
-from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
+from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, retrieve_utterances
 from convlab.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq
 from convlab.base_models.t5.dst.serialization import serialize_dialogue_state, deserialize_dialogue_state, equal_state_seq
 
@@ -120,6 +120,59 @@ def create_goal2dialogue_data(dataset, data_dir, args):
         data_by_split[data_split] = data
     return data_by_split
 
+def create_retnlu_data(dataset, data_dir, args):
+    dataset_name = dataset[list(dataset.keys())[0]][0]['dataset']
+    data_by_split = load_nlu_data(dataset, speaker=args.speaker, use_context=args.context_window_size>0, context_window_size=args.context_window_size)
+    data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}', \
+        f'in_context_{args.retrieval_in_context}', f'topk_{args.retrieval_topk}')
+    os.makedirs(data_dir, exist_ok=True)
+
+    turn_pool = []
+    for d in args.retrieval_datasets:
+        pool_dataset = load_dataset(d)
+        for turn in load_nlu_data(pool_dataset, data_split='train', speaker=args.speaker)['train']:
+            if any([len(das) > 0 for da_type, das in turn['dialogue_acts'].items()]):
+                turn_pool.append({'dataset': d, **turn})
+
+    data_splits = data_by_split.keys()
+    query_turns = []
+    for data_split in data_splits:
+        query_turns.extend(data_by_split[data_split])
+    augmented_dataset = retrieve_utterances(query_turns, turn_pool, args.retrieval_topk, 'all-MiniLM-L6-v2')
+
+    i = 0
+    for data_split in data_splits:
+        data = []
+        for j in tqdm(range(len(data_by_split[data_split])), desc=f'{data_split} sample', leave=False):
+            sample = augmented_dataset[i+j]
+            response = f"{sample['speaker']}: {sample['utterance']}"
+            if args.context_window_size>0:
+                context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response])
+            else:
+                context = response
+            context = ' '.join([dataset_name, context])
+            dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
+            assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq))
+
+            retrieved_turns = sample['retrieved_turns']
+            for t in retrieved_turns:
+                # in-context learning
+                retrieved_utterance = f"{t['dataset']} {t['speaker']}: {t['utterance']}"
+                retrieved_dialogue_acts_seq = serialize_dialogue_acts(t['dialogue_acts'])
+                if args.retrieval_in_context:
+                    context = f"{retrieved_utterance} => {retrieved_dialogue_acts_seq}\n\n" + context
+                elif data_split != 'test':
+                    data.append(json.dumps({'context': retrieved_utterance, 'dialogue_acts_seq': retrieved_dialogue_acts_seq}, ensure_ascii=False)+'\n')        
+
+            data.append(json.dumps({'context': context, 'dialogue_acts_seq': dialogue_acts_seq}, ensure_ascii=False)+'\n')
+        i += len(data_by_split[data_split])
+
+        file_name = os.path.join(data_dir, f"{data_split}.json")
+        with open(file_name, "w", encoding='utf-8') as f:
+            f.writelines(data)
+        data_by_split[data_split] = data
+    return data_by_split
+
 def get_max_len(data_by_split, tokenizer):
     for data_split in data_by_split.keys():
         seq_len = {}
@@ -136,13 +189,16 @@ def get_max_len(data_by_split, tokenizer):
 if __name__ == '__main__':
     from argparse import ArgumentParser
     parser = ArgumentParser(description="create data for seq2seq training")
-    parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue'], help='names of tasks')
+    parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue', 'retnlu', 'retnlg'], help='names of tasks')
     parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
     parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s)')
     parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
     parser.add_argument('--len_tokenizer', '-l', type=str, default=None, help='name or path of tokenizer that used to get seq len')
     parser.add_argument('--ratio', '-r', type=float, default=None, help='how many data is used for training and evaluation')
     parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
+    parser.add_argument('--retrieval_datasets', metavar='dataset_name for retrieval augmentation', nargs='*', help='names of unified datasets for retrieval')
+    parser.add_argument('--retrieval_topk', type=int, default=3, help='how many utterances to be retrieved')
+    parser.add_argument('--retrieval_in_context', action='store_true', default=False, help='whether use the retrieved utterance by in-context learning')
     args = parser.parse_args()
     print(args)
     if args.len_tokenizer:
diff --git a/convlab/base_models/t5/dst/run_dst_fewshot.sh b/convlab/base_models/t5/dst/run_dst_fewshot.sh
index d45719112e50dd44672ab52b28c04014cb5d6e5c..f548c053b544b51101f0cfbcc0b1a7b3a09c8088 100644
--- a/convlab/base_models/t5/dst/run_dst_fewshot.sh
+++ b/convlab/base_models/t5/dst/run_dst_fewshot.sh
@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --early_stopping_patience 10 \
     --prediction_loss_only \
     --load_best_model_at_end \
diff --git a/convlab/base_models/t5/nlu/run_nlu.sh b/convlab/base_models/t5/nlu/run_nlu.sh
index fb9be0227b3cced261ed6ccbffa9857e477012a2..8cba74aca0510464d176aa44ef0388c914796f5f 100644
--- a/convlab/base_models/t5/nlu/run_nlu.sh
+++ b/convlab/base_models/t5/nlu/run_nlu.sh
@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
diff --git a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh
index 568c271323cf2472f7989e0cb68e9af051bcc89b..8da69801df77d8f72d23204c8cf008ea7512d10c 100644
--- a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh
+++ b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh
@@ -42,7 +42,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --load_best_model_at_end \
     --cache_dir ${cache_dir} \
diff --git a/convlab/base_models/t5/nlu/run_retnlu.sh b/convlab/base_models/t5/nlu/run_retnlu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b45a0e45643fd5a2247633305df7e0c1f11ce848
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu.sh
@@ -0,0 +1,86 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+retrieval_topk=1
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=10
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk}
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d165859b01485b7885f88e5b1ae3a279e41f4caf
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh
@@ -0,0 +1,89 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+ratio=$1
+dial_ids_order=$2
+retrieval_topk=$3
+data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=100
+
+# python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} -r ${ratio} -o ${dial_ids_order}
+
+# python ../run_seq2seq.py \
+#     --task_name ${task_name} \
+#     --train_file ${train_file} \
+#     --validation_file ${validation_file} \
+#     --source_column ${source_column} \
+#     --target_column ${target_column} \
+#     --max_source_length ${max_source_length} \
+#     --max_target_length ${max_target_length} \
+#     --truncation_side ${truncation_side} \
+#     --model_name_or_path ${model_name_or_path} \
+#     --do_train \
+#     --do_eval \
+#     --save_strategy epoch \
+#     --evaluation_strategy epoch \
+#     --save_total_limit 1 \
+#     --prediction_loss_only \
+#     --load_best_model_at_end \
+#     --cache_dir ${cache_dir} \
+#     --output_dir ${output_dir} \
+#     --logging_dir ${logging_dir} \
+#     --overwrite_output_dir \
+#     --preprocessing_num_workers 4 \
+#     --per_device_train_batch_size ${per_device_train_batch_size} \
+#     --per_device_eval_batch_size ${per_device_eval_batch_size} \
+#     --gradient_accumulation_steps ${gradient_accumulation_steps} \
+#     --learning_rate ${lr} \
+#     --num_train_epochs ${num_train_epochs} \
+#     --adafactor \
+#     --gradient_checkpointing
+
+# python ../run_seq2seq.py \
+#     --task_name ${task_name} \
+#     --test_file ${test_file} \
+#     --source_column ${source_column} \
+#     --target_column ${target_column} \
+#     --max_source_length ${max_source_length} \
+#     --max_target_length ${max_target_length} \
+#     --truncation_side ${truncation_side} \
+#     --model_name_or_path ${output_dir} \
+#     --do_predict \
+#     --predict_with_generate \
+#     --metric_name_or_path ${metric_name_or_path} \
+#     --cache_dir ${cache_dir} \
+#     --output_dir ${output_dir} \
+#     --logging_dir ${logging_dir} \
+#     --overwrite_output_dir \
+#     --preprocessing_num_workers 4 \
+#     --per_device_train_batch_size ${per_device_train_batch_size} \
+#     --per_device_eval_batch_size ${per_device_eval_batch_size} \
+#     --gradient_accumulation_steps ${gradient_accumulation_steps} \
+#     --learning_rate ${lr} \
+#     --num_train_epochs ${num_train_epochs} \
+#     --adafactor \
+#     --gradient_checkpointing
+
+# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82dae873ebb419d1d311f347a813f1da6071dccb
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh
@@ -0,0 +1,86 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+retrieval_topk=$1
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=10
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh
new file mode 100644
index 0000000000000000000000000000000000000000..836152f80e9d21695aaadde0016aa7399eedbdf2
--- /dev/null
+++ b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh
@@ -0,0 +1,89 @@
+n_gpus=1
+task_name="retnlu"
+dataset_name="multiwoz21"
+speaker="user"
+context_window_size=0
+ratio=$1
+dial_ids_order=$2
+retrieval_topk=$3
+data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlu_metric.py"
+metric_for_best_model="overall_f1"
+source_column="context"
+target_column="dialogue_acts_seq"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=100
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context -r ${ratio} -o ${dial_ids_order}
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --load_best_model_at_end \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
+
+python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index e24658410738b290da97149382c8c89030936679..a35199bf344efa75e1736c2a46f20a33f5d8d28c 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -9,7 +9,9 @@ from abc import ABC, abstractmethod
 from pprint import pprint
 from convlab.util.file_util import cached_path
 import shutil
-import importlib
+from sentence_transformers import SentenceTransformer, util
+import torch
+from tqdm import tqdm
 
 
 class BaseDatabase(ABC):
@@ -433,6 +435,36 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
     return dataset, sorted(list(delex_vocab))
 
 
+def retrieve_utterances(query_turns, turn_pool, top_k, model_name):
+    """
+    It takes a list of query turns, a list of turn pool, and a top_k value, and returns a list of query
+    turns with a new key called 'retrieve_utterances' that contains a list of top_k retrieved utterances
+    from the turn pool
+    
+    :param query_turns: a list of turns that you want to retrieve utterances for
+    :param turn_pool: the pool of turns to retrieve from
+    :param top_k: the number of utterances to retrieve for each query turn
+    :param model_name: the name of the model you want to use
+    :return: A list of dictionaries, with a new key 'retrieve_utterances' that is a list of retrieved turns and similarity scores.
+    """
+    embedder = SentenceTransformer(model_name)
+    corpus = [turn['utterance'] for turn in turn_pool]
+    corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
+    corpus_embeddings = corpus_embeddings.to('cuda')
+    corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
+
+    queries = [turn['utterance'] for turn in query_turns]
+    query_embeddings = embedder.encode(queries, convert_to_tensor=True)
+    query_embeddings = query_embeddings.to('cuda')
+    query_embeddings = util.normalize_embeddings(query_embeddings)
+
+    hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.dot_score, top_k=top_k)
+
+    for i, turn in enumerate(query_turns):
+        turn['retrieved_turns'] = [{'score': hit['score'], **turn_pool[hit['corpus_id']]} for hit in hits[i]]
+    return query_turns
+
+
 if __name__ == "__main__":
     dataset = load_dataset('multiwoz21', dial_ids_order=0)
     train_ratio = 0.1
@@ -447,7 +479,11 @@ if __name__ == "__main__":
     print(res[0], len(res))
     
     data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
-    pprint(data_by_split['test'][0])
+    query_turns = data_by_split['test'][:10]
+    pool_dataset = load_dataset('camrest')
+    turn_pool = load_nlu_data(pool_dataset, data_split='train', speaker='user')['train']
+    augmented_dataset = retrieve_utterances(query_turns, turn_pool, 3, 'all-MiniLM-L6-v2')
+    pprint(augmented_dataset[0])
 
     def delex_slot(domain, slot, value):
         # only use slot name for delexicalization
diff --git a/requirements.txt b/requirements.txt
index eba77a0fab5eac6f59af19f4456a9e35c2ca5834..380a19d536a5770658d9c6da82e883f65e9a9961 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -79,6 +79,7 @@ s3transfer==0.6.0
 sacrebleu==2.1.0
 scikit-learn==1.1.1
 scipy==1.8.1
+sentence-transformers=2.2.2
 seqeval==1.2.2
 simplejson==3.17.6
 six==1.16.0
diff --git a/setup.py b/setup.py
index 953e5486fbc68a02c2db65a4f4760076681f78f6..07d966b7a758e700fba48e16ec5d9e32cf62991a 100755
--- a/setup.py
+++ b/setup.py
@@ -39,6 +39,7 @@ setup(
         'tensorboard',
         'torch>=1.6',
         'transformers>=4.0',
+        'sentence-transformers',
         'datasets>=1.8',
         'seqeval',
         'spacy',