diff --git a/convlab/policy/emoTUS/train_model.py b/convlab/policy/emoTUS/train_model.py index 8c79c8a6cd8b34f374aa334aadd35989e997bdae..93bc762c19600d97246ed15ef5d72f44ca52d3e6 100644 --- a/convlab/policy/emoTUS/train_model.py +++ b/convlab/policy/emoTUS/train_model.py @@ -37,6 +37,7 @@ def arg_parser(): parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--model-checkpoint", type=str, default="facebook/bart-base") + parser.add_argument("--fine-tune", action="store_true") return parser.parse_args() @@ -66,6 +67,25 @@ def gentus_compute_metrics(eval_preds): return result +def basic_metric(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = TOKENIZER.batch_decode( + preds, skip_special_tokens=True, max_length=MAX_OUT_LEN) + + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id) + decoded_labels = TOKENIZER.batch_decode( + labels, skip_special_tokens=True, max_length=MAX_OUT_LEN) + labels = [[x] for x in decoded_labels] + + result = METRIC.compute( + predictions=decoded_preds, references=labels) + result = {"bleu": result["score"]} + return result + + def postprocess_text(preds, labels): act = {"preds": [], "labels": []} text = {"preds": [], "labels": []} @@ -164,6 +184,27 @@ class TrainerHelper: return tokenized_datasets + def remove_dialmage_action(self): + self.dir_name = "fine_tune" + folder = "convlab/policy/emoTUS/unify/data" + data_name = { + "emowoz": "EmoUS_emowoz_0_1", + "dialmage": "EmoUS_dialmage_0_1_emotion_only"} + data = {} + for d, d_n in data_name.items(): + data[d] = {} + for d_type in ["train", "validation", "test"]: + f_name = os.path.join(folder, d_n, f"{d_type}.json") + data[d][d_type] = json.load(open(f_name)) + + tokenized_datasets = {} + for d_n, d in data.items(): + tokenized_datasets[d_n] = {} + for s_d_n, s_d in d.items(): + tokenized_datasets[d_n][s_d_n] = Dataset.from_dict( + self._preprocess(s_d["dialog"])) + return tokenized_datasets + def _preprocess(self, examples): model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} if isinstance(examples, dict): @@ -240,17 +281,117 @@ def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max trainer.save_model() +def fine_tune_on_dialmage(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"): + tokenizer = TOKENIZER + + train_helper = TrainerHelper( + tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length) + data = train_helper.remove_dialmage_action() + + model = BartForConditionalGeneration.from_pretrained(model_checkpoint) + model.resize_token_embeddings(len(tokenizer)) + fp16 = False + if torch.cuda.is_available(): + print("use cuda") + fp16 = True + model.to("cuda") + + model_dir = os.path.join( + train_helper.get_model_folder(model_type), + f"{datetime.now().strftime('%y-%m-%d-%H-%M')}") + + # Emowoz + + args = Seq2SeqTrainingArguments( + model_dir, + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + weight_decay=0.01, + save_total_limit=2, + num_train_epochs=4, + predict_with_generate=True, + fp16=fp16, + push_to_hub=False, + generation_max_length=max_target_length, + logging_dir=os.path.join(model_dir, 'log') + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, padding=True) + + trainer = Seq2SeqTrainer( + model=model, + args=args, + train_dataset=data["emowoz"]["train"], + eval_dataset=data["emowoz"]["test"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=gentus_compute_metrics) + print("start training...") + trainer.train() + print("saving model...") + trainer.save_model() + + # dialmage + args = Seq2SeqTrainingArguments( + model_dir+"_dialmage_fine_tune", + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + weight_decay=0.01, + save_total_limit=2, + num_train_epochs=1, + predict_with_generate=True, + fp16=fp16, + push_to_hub=False, + generation_max_length=max_target_length, + logging_dir=os.path.join(model_dir, 'log') + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, padding=True) + + trainer = Seq2SeqTrainer( + model=model, + args=args, + train_dataset=data["dialmage"]["train"], + eval_dataset=data["dialmage"]["test"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=basic_metric) + print("start training...") + trainer.train() + print("saving model...") + trainer.save_model() + + def main(): args = arg_parser() print("---> data_name", args.data_name) - train(model_type=args.model_type, - data_name=args.data_name, - dial_ids_order=args.dial_ids_order, - split2ratio=args.split2ratio, - batch_size=args.batch_size, - max_input_length=MAX_IN_LEN, - max_target_length=MAX_OUT_LEN, - model_checkpoint=args.model_checkpoint) + if args.fine_tune: + fine_tune_on_dialmage( + model_type=args.model_type, + data_name=args.data_name, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio, + batch_size=args.batch_size, + max_input_length=MAX_IN_LEN, + max_target_length=MAX_OUT_LEN, + model_checkpoint=args.model_checkpoint + ) + else: + + train( + model_type=args.model_type, + data_name=args.data_name, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio, + batch_size=args.batch_size, + max_input_length=MAX_IN_LEN, + max_target_length=MAX_OUT_LEN, + model_checkpoint=args.model_checkpoint + ) if __name__ == "__main__": diff --git a/convlab/policy/emoTUS/unify/build_data.py b/convlab/policy/emoTUS/unify/build_data.py index 86071b8075da16c89a11489c27e7faf3f9520caa..1b5d0bdc3dd2072e22d25ec4e7787e1facd71bab 100644 --- a/convlab/policy/emoTUS/unify/build_data.py +++ b/convlab/policy/emoTUS/unify/build_data.py @@ -23,6 +23,7 @@ def arg_parser(): parser.add_argument("--use-sentiment", action="store_true") parser.add_argument("--add-persona", action="store_true") parser.add_argument("--emotion-mid", action="store_true") + parser.add_argument("--emotion-only", action="store_true") return parser.parse_args() @@ -33,6 +34,7 @@ class DataBuilder(GenTUSDataBuilder): self.use_sentiment = kwargs.get("use_sentiment", False) self.emotion_mid = kwargs.get("emotion_mid", False) self.add_persona = kwargs.get("add_persona", False) + self.emotion_only = kwargs.get("emotion_only", False) self.emotion = {} for emotion, index in json.load(open("convlab/policy/emoTUS/emotion.json")).items(): @@ -128,9 +130,12 @@ class DataBuilder(GenTUSDataBuilder): "action": usr_act, "text": text} elif not self.use_sentiment and not self.emotion_mid: - out_str = {"emotion": usr_emotion, - "action": usr_act, - "text": text} + if self.emotion_only: + out_str = {"emotion": usr_emotion} + else: + out_str = {"emotion": usr_emotion, + "action": usr_act, + "text": text} else: out_str = {"action": usr_act, "emotion": usr_emotion, @@ -183,6 +188,9 @@ if __name__ == "__main__": dir_name = f"SentEmoUS_noPersona_{dir_name}" else: print("NOT DEFINED", use_sentiment, add_persona, emotion_mid) + + if args.emotion_only: + dir_name = dir_name + '_emotion_only' print("dir_name", dir_name) folder_name = os.path.join(base_name, dir_name) @@ -197,7 +205,8 @@ if __name__ == "__main__": dataset=args.dataset, use_sentiment=use_sentiment, add_persona=add_persona, - emotion_mid=emotion_mid) + emotion_mid=emotion_mid, + emotion_only=args.emotion_only) data = data_builder.setup_data( raw_data=dataset, random_order=False, diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py index 1e2bf12a8c024faa6299e810152c42330f5055cc..b3e24a57028933ee15776192c83d650a45cbbf53 100644 --- a/convlab/policy/tus/unify/util.py +++ b/convlab/policy/tus/unify/util.py @@ -8,6 +8,7 @@ NOT_MENTIONED = "not mentioned" def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1): ratio = {'train': split2ratio, 'validation': split2ratio} + print("data_name", data_name) if data_name == "all" or data_name == "sgd+tm" or data_name == "tm": print("merge all datasets...") if data_name == "all": @@ -31,7 +32,7 @@ def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2rati datasets[name] = load_dataset( name, dial_ids_order=None) raw_data = merge_dataset(datasets, all_dataset[0]) - elif data_name == "dialmage": + elif data_name in ["dialmage", "emowoz"]: raw_data = load_dataset(data_name, dial_ids_order=None) else: