diff --git a/convlab/policy/emoTUS/train_model.py b/convlab/policy/emoTUS/train_model.py
index 8c79c8a6cd8b34f374aa334aadd35989e997bdae..93bc762c19600d97246ed15ef5d72f44ca52d3e6 100644
--- a/convlab/policy/emoTUS/train_model.py
+++ b/convlab/policy/emoTUS/train_model.py
@@ -37,6 +37,7 @@ def arg_parser():
     parser.add_argument("--batch-size", type=int, default=16)
     parser.add_argument("--model-checkpoint", type=str,
                         default="facebook/bart-base")
+    parser.add_argument("--fine-tune", action="store_true")
     return parser.parse_args()
 
 
@@ -66,6 +67,25 @@ def gentus_compute_metrics(eval_preds):
     return result
 
 
+def basic_metric(eval_preds):
+    preds, labels = eval_preds
+    if isinstance(preds, tuple):
+        preds = preds[0]
+    decoded_preds = TOKENIZER.batch_decode(
+        preds, skip_special_tokens=True, max_length=MAX_OUT_LEN)
+
+    # Replace -100 in the labels as we can't decode them.
+    labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id)
+    decoded_labels = TOKENIZER.batch_decode(
+        labels, skip_special_tokens=True, max_length=MAX_OUT_LEN)
+    labels = [[x] for x in decoded_labels]
+
+    result = METRIC.compute(
+        predictions=decoded_preds, references=labels)
+    result = {"bleu": result["score"]}
+    return result
+
+
 def postprocess_text(preds, labels):
     act = {"preds": [], "labels": []}
     text = {"preds": [], "labels": []}
@@ -164,6 +184,27 @@ class TrainerHelper:
 
         return tokenized_datasets
 
+    def remove_dialmage_action(self):
+        self.dir_name = "fine_tune"
+        folder = "convlab/policy/emoTUS/unify/data"
+        data_name = {
+            "emowoz": "EmoUS_emowoz_0_1",
+            "dialmage": "EmoUS_dialmage_0_1_emotion_only"}
+        data = {}
+        for d, d_n in data_name.items():
+            data[d] = {}
+            for d_type in ["train", "validation", "test"]:
+                f_name = os.path.join(folder, d_n, f"{d_type}.json")
+                data[d][d_type] = json.load(open(f_name))
+
+        tokenized_datasets = {}
+        for d_n, d in data.items():
+            tokenized_datasets[d_n] = {}
+            for s_d_n, s_d in d.items():
+                tokenized_datasets[d_n][s_d_n] = Dataset.from_dict(
+                    self._preprocess(s_d["dialog"]))
+        return tokenized_datasets
+
     def _preprocess(self, examples):
         model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
         if isinstance(examples, dict):
@@ -240,17 +281,117 @@ def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max
     trainer.save_model()
 
 
+def fine_tune_on_dialmage(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"):
+    tokenizer = TOKENIZER
+
+    train_helper = TrainerHelper(
+        tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length)
+    data = train_helper.remove_dialmage_action()
+
+    model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
+    model.resize_token_embeddings(len(tokenizer))
+    fp16 = False
+    if torch.cuda.is_available():
+        print("use cuda")
+        fp16 = True
+        model.to("cuda")
+
+    model_dir = os.path.join(
+        train_helper.get_model_folder(model_type),
+        f"{datetime.now().strftime('%y-%m-%d-%H-%M')}")
+
+    # Emowoz
+
+    args = Seq2SeqTrainingArguments(
+        model_dir,
+        evaluation_strategy="epoch",
+        learning_rate=2e-5,
+        per_device_train_batch_size=batch_size,
+        per_device_eval_batch_size=batch_size,
+        weight_decay=0.01,
+        save_total_limit=2,
+        num_train_epochs=4,
+        predict_with_generate=True,
+        fp16=fp16,
+        push_to_hub=False,
+        generation_max_length=max_target_length,
+        logging_dir=os.path.join(model_dir, 'log')
+    )
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer, model=model, padding=True)
+
+    trainer = Seq2SeqTrainer(
+        model=model,
+        args=args,
+        train_dataset=data["emowoz"]["train"],
+        eval_dataset=data["emowoz"]["test"],
+        data_collator=data_collator,
+        tokenizer=tokenizer,
+        compute_metrics=gentus_compute_metrics)
+    print("start training...")
+    trainer.train()
+    print("saving model...")
+    trainer.save_model()
+
+    # dialmage
+    args = Seq2SeqTrainingArguments(
+        model_dir+"_dialmage_fine_tune",
+        evaluation_strategy="epoch",
+        learning_rate=2e-5,
+        per_device_train_batch_size=batch_size,
+        per_device_eval_batch_size=batch_size,
+        weight_decay=0.01,
+        save_total_limit=2,
+        num_train_epochs=1,
+        predict_with_generate=True,
+        fp16=fp16,
+        push_to_hub=False,
+        generation_max_length=max_target_length,
+        logging_dir=os.path.join(model_dir, 'log')
+    )
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer, model=model, padding=True)
+
+    trainer = Seq2SeqTrainer(
+        model=model,
+        args=args,
+        train_dataset=data["dialmage"]["train"],
+        eval_dataset=data["dialmage"]["test"],
+        data_collator=data_collator,
+        tokenizer=tokenizer,
+        compute_metrics=basic_metric)
+    print("start training...")
+    trainer.train()
+    print("saving model...")
+    trainer.save_model()
+
+
 def main():
     args = arg_parser()
     print("---> data_name", args.data_name)
-    train(model_type=args.model_type,
-          data_name=args.data_name,
-          dial_ids_order=args.dial_ids_order,
-          split2ratio=args.split2ratio,
-          batch_size=args.batch_size,
-          max_input_length=MAX_IN_LEN,
-          max_target_length=MAX_OUT_LEN,
-          model_checkpoint=args.model_checkpoint)
+    if args.fine_tune:
+        fine_tune_on_dialmage(
+            model_type=args.model_type,
+            data_name=args.data_name,
+            dial_ids_order=args.dial_ids_order,
+            split2ratio=args.split2ratio,
+            batch_size=args.batch_size,
+            max_input_length=MAX_IN_LEN,
+            max_target_length=MAX_OUT_LEN,
+            model_checkpoint=args.model_checkpoint
+        )
+    else:
+
+        train(
+            model_type=args.model_type,
+            data_name=args.data_name,
+            dial_ids_order=args.dial_ids_order,
+            split2ratio=args.split2ratio,
+            batch_size=args.batch_size,
+            max_input_length=MAX_IN_LEN,
+            max_target_length=MAX_OUT_LEN,
+            model_checkpoint=args.model_checkpoint
+        )
 
 
 if __name__ == "__main__":
diff --git a/convlab/policy/emoTUS/unify/build_data.py b/convlab/policy/emoTUS/unify/build_data.py
index 86071b8075da16c89a11489c27e7faf3f9520caa..1b5d0bdc3dd2072e22d25ec4e7787e1facd71bab 100644
--- a/convlab/policy/emoTUS/unify/build_data.py
+++ b/convlab/policy/emoTUS/unify/build_data.py
@@ -23,6 +23,7 @@ def arg_parser():
     parser.add_argument("--use-sentiment", action="store_true")
     parser.add_argument("--add-persona", action="store_true")
     parser.add_argument("--emotion-mid", action="store_true")
+    parser.add_argument("--emotion-only", action="store_true")
 
     return parser.parse_args()
 
@@ -33,6 +34,7 @@ class DataBuilder(GenTUSDataBuilder):
         self.use_sentiment = kwargs.get("use_sentiment", False)
         self.emotion_mid = kwargs.get("emotion_mid", False)
         self.add_persona = kwargs.get("add_persona", False)
+        self.emotion_only = kwargs.get("emotion_only", False)
 
         self.emotion = {}
         for emotion, index in json.load(open("convlab/policy/emoTUS/emotion.json")).items():
@@ -128,9 +130,12 @@ class DataBuilder(GenTUSDataBuilder):
                        "action": usr_act,
                        "text": text}
         elif not self.use_sentiment and not self.emotion_mid:
-            out_str = {"emotion": usr_emotion,
-                       "action": usr_act,
-                       "text": text}
+            if self.emotion_only:
+                out_str = {"emotion": usr_emotion}
+            else:
+                out_str = {"emotion": usr_emotion,
+                           "action": usr_act,
+                           "text": text}
         else:
             out_str = {"action": usr_act,
                        "emotion": usr_emotion,
@@ -183,6 +188,9 @@ if __name__ == "__main__":
         dir_name = f"SentEmoUS_noPersona_{dir_name}"
     else:
         print("NOT DEFINED", use_sentiment, add_persona, emotion_mid)
+
+    if args.emotion_only:
+        dir_name = dir_name + '_emotion_only'
     print("dir_name", dir_name)
 
     folder_name = os.path.join(base_name, dir_name)
@@ -197,7 +205,8 @@ if __name__ == "__main__":
         dataset=args.dataset,
         use_sentiment=use_sentiment,
         add_persona=add_persona,
-        emotion_mid=emotion_mid)
+        emotion_mid=emotion_mid,
+        emotion_only=args.emotion_only)
     data = data_builder.setup_data(
         raw_data=dataset,
         random_order=False,
diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py
index 1e2bf12a8c024faa6299e810152c42330f5055cc..b3e24a57028933ee15776192c83d650a45cbbf53 100644
--- a/convlab/policy/tus/unify/util.py
+++ b/convlab/policy/tus/unify/util.py
@@ -8,6 +8,7 @@ NOT_MENTIONED = "not mentioned"
 
 def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1):
     ratio = {'train': split2ratio, 'validation': split2ratio}
+    print("data_name", data_name)
     if data_name == "all" or data_name == "sgd+tm" or data_name == "tm":
         print("merge all datasets...")
         if data_name == "all":
@@ -31,7 +32,7 @@ def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2rati
             datasets[name] = load_dataset(
                 name, dial_ids_order=None)
         raw_data = merge_dataset(datasets, all_dataset[0])
-    elif data_name == "dialmage":
+    elif data_name in ["dialmage", "emowoz"]:
         raw_data = load_dataset(data_name, dial_ids_order=None)
 
     else: