Select Git revision
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
training.py 16.54 KiB
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training utils"""
import random
import os
import logging
import torch
from torch.distributions import Categorical
import numpy as np
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
from utils import clear_checkpoints, upload_local_directory_to_gcs
# Load logger and tensorboard summary writer
def set_logger(logger_, tb_writer_):
global logger, tb_writer
logger = logger_
tb_writer = tb_writer_
# Set seeds
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
logger.info('Seed set to %d.' % args.seed)
# Linear learning rate scheduler
def warmup_linear(progress, warmup=0.1):
if progress < warmup:
return progress/warmup
else:
return 1.0 - progress
def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev):
"""Train model!"""
# Calculate the total number of training steps to be performed
if args.max_training_steps > 0:
t_total = args.max_training_steps
args.num_train_epochs = args.max_training_steps // ((len(train_dataloader) // args.gradient_accumulation_steps) + 1)
else:
t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs
if args.save_steps <= 0:
args.save_steps = len(train_dataloader) // args.gradient_accumulation_steps
# Group weight decay and no decay parameters in the model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
"lr": args.learning_rate
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
"lr":args.learning_rate
},
]
# Initialise the optimizer
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
# Load optimizer checkpoint if available
if (os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt'))):
logger.info("Optimizer loaded from previous run.")
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
# Log training set up
logger.info("***** Running training *****")
logger.info(" Num Batches = %d", len(train_dataloader))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
# Initialise training parameters
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
best_model = {'joint goal accuracy': 0.0,
'train loss': np.inf,
'state_dict': None}
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
try:
# set global_step to gobal_step of last saved checkpoint from model path
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
global_step = int(checkpoint_suffix)
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
logger.info(" Starting fine-tuning.")
# Prepare model for training
tr_loss, logging_loss = 0.0, 0.0
model.train()
model.zero_grad()
train_iterator = trange(
epochs_trained, int(args.num_train_epochs), desc="Epoch"
)
steps_since_last_update = 0
# Perform training
for e in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration")
# Iterate over all batches
for step, batch in enumerate(epoch_iterator):
# Skip batches already trained on
if step < steps_trained_in_current_epoch:
continue
# Get the dialogue information from the batch
input_ids = batch['input_ids'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# Get labels
labels = {slot: batch['labels-' + slot].to(device) for slot in slots}
# Perform forward pass
loss, _ = model(input_ids = input_ids,
token_type_ids = token_type_ids,
attention_mask = attention_mask,
labels = labels)
if args.n_gpu > 1:
loss = loss.mean()
# Update step
if step % args.gradient_accumulation_steps == 0:
loss = loss / args.gradient_accumulation_steps
tb_writer.add_scalar('Loss/train', loss, global_step)
# Backpropogate accumulated loss
if args.fp16:
with scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
tb_writer.add_scalar('Scaled_Loss/train', scaled_loss, global_step)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
# Get learning rate
lr = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
tb_writer.add_scalar('LearningRate', lr, global_step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
model.zero_grad()
tr_loss += loss.float().item()
loss = 0.0
global_step += 1
# Save model checkpoint
if global_step % args.save_steps == 0:
logging_loss = tr_loss - logging_loss
# Evaluate model
if args.do_eval:
# Set up model for evaluation
model.eval()
values = {slot: slots_dev[slot][1] for slot in slots_dev}
for slot in slots:
if slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
jg_acc, sl_acc = train_eval(args, model, device, dev_dataloader, slots_dev)
logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' \
% (global_step, logging_loss / args.save_steps, jg_acc, sl_acc))
tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
# Set model back to training mode
model.train()
model.zero_grad()
values = {slot: slots[slot][1] for slot in slots}
for slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
else:
jg_acc = 0.0
logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps))
logging_loss = tr_loss
if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0:
update = True
elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0:
update = True
else:
update = False
if update:
steps_since_last_update = 0
logger.info('Model saved.')
best_model['joint goal accuracy'] = jg_acc
best_model['train loss'] = tr_loss / global_step
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
best_model['state_dict'] = model.state_dict()
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
clear_checkpoints(args.output_dir)
if args.gcs_bucket_name:
remote = os.path.join(os.path.basename(args.output_dir), "checkpoint-{}".format(global_step))
upload_local_directory_to_gcs(output_dir, args.gcs_bucket_name, remote)
else:
steps_since_last_update += 1
logger.info('Model not saved.')
# Stop training after max training steps or if the model has not updated for too long
if args.max_training_steps > 0 and global_step > args.max_training_steps:
epoch_iterator.close()
break
if args.patience > 0 and steps_since_last_update >= args.patience:
epoch_iterator.close()
break
logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step))
if args.max_training_steps > 0 and global_step > args.max_training_steps:
train_iterator.close()
break
if args.patience > 0 and steps_since_last_update >= args.patience:
train_iterator.close()
break
logger.info('Model has not improved for atleast %i steps. Training stopped!' % args.patience)
# Evaluate final model
if args.do_eval:
model.eval()
values = {slot: slots_dev[slot][1] for slot in slots_dev}
for slot in slots:
if slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
jg_acc, sl_acc = train_eval(args, model, device, dev_dataloader, slots_dev)
logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' \
% (tr_loss / global_step, jg_acc, sl_acc))
else:
jg_acc = 0.0
logger.info('Training complete!')
# Store final model
if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0:
update = True
elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0:
update = True
else:
update = False
if update:
logger.info('Final model saved.')
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
best_model['state_dict'] = model.state_dict()
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
clear_checkpoints(args.output_dir)
else:
logger.info('Final model not saved, since it is not the best performing model.')
# At end of training roll back to best performing model
model.load_state_dict(best_model['state_dict'])
# Function for validation
def train_eval(args, model, device, dev_dataloader, slots):
"""Evaluate Model during training!"""
accuracy_jg = []
accuracy_sl = []
turns = []
for step, batch in enumerate(dev_dataloader):
# Perform with no gradients stored
with torch.no_grad():
input_ids = batch['input_ids'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
jg_acc = 0.0
p = model(input_ids, token_type_ids, attention_mask)
for slot in slots:
labels = batch['labels-' + slot].to(device)
p_ = p[slot]
acc = (p_.argmax(-1) == labels).reshape(-1).float()
jg_acc += acc
sl_acc = sum(jg_acc / len(slots)).float()
jg_acc = sum((jg_acc / len(slots)).int()).float()
n_turns = (labels >= 0).reshape(-1).sum().float().item()
accuracy_jg.append(jg_acc.item())
accuracy_sl.append(sl_acc.item())
turns.append(n_turns)
# Global accuracy reduction across batches
turns = sum(turns)
jg_acc = sum(accuracy_jg) / turns
sl_acc = sum(accuracy_sl) / turns
return jg_acc, sl_acc
def evaluate(args, model, device, dataloader, slots):
"""Evaluate Model!"""
# Evaluate!
logger.info("***** Running evaluation *****")
logger.info(" Num Batches = %d", len(dataloader))
tr_loss = 0.0
model.eval()
logits = {slot: [] for slot in slots}
accuracy_jg = []
accuracy_sl = []
turns = []
epoch_iterator = tqdm(dataloader, desc="Iteration")
for step, batch in enumerate(epoch_iterator):
with torch.no_grad():
input_ids = batch['input_ids'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = {slot: batch['labels-' + slot].to(device) for slot in slots}
loss, p = model(input_ids = input_ids,
token_type_ids = token_type_ids,
attention_mask = attention_mask,
labels = labels)
jg_acc = 0.0
for slot in slots:
p_ = p[slot]
labels = batch['labels-' + slot].to(device)
if args.temp_scaling > 0.0:
p_ = torch.log(p_ + 1e-10) / args.temp_scaling
p_ = torch.softmax(p_, -1)
else:
p_ = torch.log(p_ + 1e-10) / 1.0
p_ = torch.softmax(p_, -1)
logits[slot].append(p_)
if args.accuracy_samples > 0:
dist = Categorical(probs=p_.reshape(-1, p_.size(-1)))
lab_sample = dist.sample((args.accuracy_samples,))
lab_sample = lab_sample.transpose(0, 1)
acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)]
acc = torch.tensor(acc).float()
elif args.accuracy_topn > 0:
labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
labs = labs[:, :args.accuracy_topn]
acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)]
acc = torch.tensor(acc).float()
else:
acc = (p_.argmax(-1) == labels).reshape(-1).float()
jg_acc += acc
sl_acc = sum(jg_acc / len(slots)).float()
jg_acc = sum((jg_acc / len(slots)).int()).float()
n_turns = (labels >= 0).reshape(-1).sum().float().item()
accuracy_jg.append(jg_acc.item())
accuracy_sl.append(sl_acc.item())
turns.append(n_turns)
tr_loss += loss.item()
for slot in logits:
logits[slot] = torch.cat(logits[slot], 0)
# Global accuracy reduction across batches
turns = sum(turns)
jg_acc = sum(accuracy_jg) / turns
sl_acc = sum(accuracy_sl) / turns
return jg_acc, sl_acc, tr_loss / (step + 1), logits