Skip to content
Snippets Groups Projects
Commit 20648d63 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 15b002b0
Branches
No related tags found
No related merge requests found
......@@ -37,6 +37,7 @@ def arg_parser():
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--model-checkpoint", type=str,
default="facebook/bart-base")
parser.add_argument("--fine-tune", action="store_true")
return parser.parse_args()
......@@ -66,6 +67,25 @@ def gentus_compute_metrics(eval_preds):
return result
def basic_metric(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = TOKENIZER.batch_decode(
preds, skip_special_tokens=True, max_length=MAX_OUT_LEN)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id)
decoded_labels = TOKENIZER.batch_decode(
labels, skip_special_tokens=True, max_length=MAX_OUT_LEN)
labels = [[x] for x in decoded_labels]
result = METRIC.compute(
predictions=decoded_preds, references=labels)
result = {"bleu": result["score"]}
return result
def postprocess_text(preds, labels):
act = {"preds": [], "labels": []}
text = {"preds": [], "labels": []}
......@@ -164,6 +184,27 @@ class TrainerHelper:
return tokenized_datasets
def remove_dialmage_action(self):
self.dir_name = "fine_tune"
folder = "convlab/policy/emoTUS/unify/data"
data_name = {
"emowoz": "EmoUS_emowoz_0_1",
"dialmage": "EmoUS_dialmage_0_1_emotion_only"}
data = {}
for d, d_n in data_name.items():
data[d] = {}
for d_type in ["train", "validation", "test"]:
f_name = os.path.join(folder, d_n, f"{d_type}.json")
data[d][d_type] = json.load(open(f_name))
tokenized_datasets = {}
for d_n, d in data.items():
tokenized_datasets[d_n] = {}
for s_d_n, s_d in d.items():
tokenized_datasets[d_n][s_d_n] = Dataset.from_dict(
self._preprocess(s_d["dialog"]))
return tokenized_datasets
def _preprocess(self, examples):
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if isinstance(examples, dict):
......@@ -240,17 +281,117 @@ def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max
trainer.save_model()
def fine_tune_on_dialmage(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"):
tokenizer = TOKENIZER
train_helper = TrainerHelper(
tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length)
data = train_helper.remove_dialmage_action()
model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
model.resize_token_embeddings(len(tokenizer))
fp16 = False
if torch.cuda.is_available():
print("use cuda")
fp16 = True
model.to("cuda")
model_dir = os.path.join(
train_helper.get_model_folder(model_type),
f"{datetime.now().strftime('%y-%m-%d-%H-%M')}")
# Emowoz
args = Seq2SeqTrainingArguments(
model_dir,
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=2,
num_train_epochs=4,
predict_with_generate=True,
fp16=fp16,
push_to_hub=False,
generation_max_length=max_target_length,
logging_dir=os.path.join(model_dir, 'log')
)
data_collator = DataCollatorForSeq2Seq(
tokenizer, model=model, padding=True)
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=data["emowoz"]["train"],
eval_dataset=data["emowoz"]["test"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=gentus_compute_metrics)
print("start training...")
trainer.train()
print("saving model...")
trainer.save_model()
# dialmage
args = Seq2SeqTrainingArguments(
model_dir+"_dialmage_fine_tune",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=2,
num_train_epochs=1,
predict_with_generate=True,
fp16=fp16,
push_to_hub=False,
generation_max_length=max_target_length,
logging_dir=os.path.join(model_dir, 'log')
)
data_collator = DataCollatorForSeq2Seq(
tokenizer, model=model, padding=True)
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=data["dialmage"]["train"],
eval_dataset=data["dialmage"]["test"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=basic_metric)
print("start training...")
trainer.train()
print("saving model...")
trainer.save_model()
def main():
args = arg_parser()
print("---> data_name", args.data_name)
train(model_type=args.model_type,
if args.fine_tune:
fine_tune_on_dialmage(
model_type=args.model_type,
data_name=args.data_name,
dial_ids_order=args.dial_ids_order,
split2ratio=args.split2ratio,
batch_size=args.batch_size,
max_input_length=MAX_IN_LEN,
max_target_length=MAX_OUT_LEN,
model_checkpoint=args.model_checkpoint)
model_checkpoint=args.model_checkpoint
)
else:
train(
model_type=args.model_type,
data_name=args.data_name,
dial_ids_order=args.dial_ids_order,
split2ratio=args.split2ratio,
batch_size=args.batch_size,
max_input_length=MAX_IN_LEN,
max_target_length=MAX_OUT_LEN,
model_checkpoint=args.model_checkpoint
)
if __name__ == "__main__":
......
......@@ -23,6 +23,7 @@ def arg_parser():
parser.add_argument("--use-sentiment", action="store_true")
parser.add_argument("--add-persona", action="store_true")
parser.add_argument("--emotion-mid", action="store_true")
parser.add_argument("--emotion-only", action="store_true")
return parser.parse_args()
......@@ -33,6 +34,7 @@ class DataBuilder(GenTUSDataBuilder):
self.use_sentiment = kwargs.get("use_sentiment", False)
self.emotion_mid = kwargs.get("emotion_mid", False)
self.add_persona = kwargs.get("add_persona", False)
self.emotion_only = kwargs.get("emotion_only", False)
self.emotion = {}
for emotion, index in json.load(open("convlab/policy/emoTUS/emotion.json")).items():
......@@ -128,6 +130,9 @@ class DataBuilder(GenTUSDataBuilder):
"action": usr_act,
"text": text}
elif not self.use_sentiment and not self.emotion_mid:
if self.emotion_only:
out_str = {"emotion": usr_emotion}
else:
out_str = {"emotion": usr_emotion,
"action": usr_act,
"text": text}
......@@ -183,6 +188,9 @@ if __name__ == "__main__":
dir_name = f"SentEmoUS_noPersona_{dir_name}"
else:
print("NOT DEFINED", use_sentiment, add_persona, emotion_mid)
if args.emotion_only:
dir_name = dir_name + '_emotion_only'
print("dir_name", dir_name)
folder_name = os.path.join(base_name, dir_name)
......@@ -197,7 +205,8 @@ if __name__ == "__main__":
dataset=args.dataset,
use_sentiment=use_sentiment,
add_persona=add_persona,
emotion_mid=emotion_mid)
emotion_mid=emotion_mid,
emotion_only=args.emotion_only)
data = data_builder.setup_data(
raw_data=dataset,
random_order=False,
......
......@@ -8,6 +8,7 @@ NOT_MENTIONED = "not mentioned"
def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1):
ratio = {'train': split2ratio, 'validation': split2ratio}
print("data_name", data_name)
if data_name == "all" or data_name == "sgd+tm" or data_name == "tm":
print("merge all datasets...")
if data_name == "all":
......@@ -31,7 +32,7 @@ def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2rati
datasets[name] = load_dataset(
name, dial_ids_order=None)
raw_data = merge_dataset(datasets, all_dataset[0])
elif data_name == "dialmage":
elif data_name in ["dialmage", "emowoz"]:
raw_data = load_dataset(data_name, dial_ids_order=None)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment