Skip to content
Snippets Groups Projects
Select Git revision
  • c2c24b4268a26d69da682a9fe6b63d75cbc140a3
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

env.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    training.py 16.54 KiB
    # -*- coding: utf-8 -*-
    # Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
    # Authors: Carel van Niekerk (niekerk@hhu.de)
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    """Training utils"""
    
    import random
    import os
    import logging
    
    import torch
    from torch.distributions import Categorical
    import numpy as np
    from transformers import AdamW, get_linear_schedule_with_warmup
    from tqdm import tqdm, trange
    
    from utils import clear_checkpoints, upload_local_directory_to_gcs
    
    
    # Load logger and tensorboard summary writer
    def set_logger(logger_, tb_writer_):
        global logger, tb_writer
        logger = logger_
        tb_writer = tb_writer_
    
    
    # Set seeds
    def set_seed(args):
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)
        logger.info('Seed set to %d.' % args.seed)
    
    
    # Linear learning rate scheduler
    def warmup_linear(progress, warmup=0.1):
        if progress < warmup:
            return progress/warmup
        else:
            return 1.0 - progress
    
    
    def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev):
        """Train model!"""
        
        # Calculate the total number of training steps to be performed
        if args.max_training_steps > 0:
            t_total = args.max_training_steps
            args.num_train_epochs = args.max_training_steps // ((len(train_dataloader) // args.gradient_accumulation_steps) + 1)
        else:
            t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs
        
        if args.save_steps <= 0:
            args.save_steps = len(train_dataloader) // args.gradient_accumulation_steps
        
        # Group weight decay and no decay parameters in the model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay,
                "lr": args.learning_rate
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr":args.learning_rate
            },
        ]
    
        # Initialise the optimizer
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
        
        # Load optimizer checkpoint if available
        if (os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt'))):
            logger.info("Optimizer loaded from previous run.")
            optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
        
        # Log training set up
        logger.info("***** Running training *****")
        logger.info("  Num Batches = %d", len(train_dataloader))
        logger.info("  Num Epochs = %d", args.num_train_epochs)
        logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
    
        # Initialise training parameters
        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        best_model = {'joint goal accuracy': 0.0,
                        'train loss': np.inf,
                        'state_dict': None}
        
        # Check if continuing training from a checkpoint
        if os.path.exists(args.model_name_or_path):
            try:
                # set global_step to gobal_step of last saved checkpoint from model path
                checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
                global_step = int(checkpoint_suffix)
                epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
    
                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                logger.info("  Starting fine-tuning.")
        
        # Prepare model for training
        tr_loss, logging_loss = 0.0, 0.0
        model.train()
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(args.num_train_epochs), desc="Epoch"
        )
    
        steps_since_last_update = 0
        # Perform training
        for e in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            # Iterate over all batches
            for step, batch in enumerate(epoch_iterator):
                # Skip batches already trained on
                if step < steps_trained_in_current_epoch:
                    continue
    
                # Get the dialogue information from the batch
                input_ids = batch['input_ids'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                # Get labels
                labels = {slot: batch['labels-' + slot].to(device) for slot in slots}
    
                # Perform forward pass
                loss, _ = model(input_ids = input_ids,
                                    token_type_ids = token_type_ids,
                                    attention_mask = attention_mask,
                                    labels = labels)
                
                if args.n_gpu > 1:
                    loss = loss.mean()
            
                # Update step
                if step % args.gradient_accumulation_steps == 0:
                    loss = loss / args.gradient_accumulation_steps
                    tb_writer.add_scalar('Loss/train', loss, global_step)
                    # Backpropogate accumulated loss
                    if args.fp16:
                        with scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                            tb_writer.add_scalar('Scaled_Loss/train', scaled_loss, global_step)
                    else:
                        loss.backward()
    
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
    
                    # Get learning rate
                    lr = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
                    tb_writer.add_scalar('LearningRate', lr, global_step)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    
                    optimizer.step()
                    model.zero_grad()
    
                    tr_loss += loss.float().item()
                    loss = 0.0
                    global_step += 1
    
                # Save model checkpoint
                if global_step % args.save_steps == 0:
                    logging_loss = tr_loss - logging_loss
    
                    # Evaluate model
                    if args.do_eval:
                        # Set up model for evaluation
                        model.eval()
                        values = {slot: slots_dev[slot][1] for slot in slots_dev}
                        for slot in slots:
                            if slot in values:
                                model.add_value_candidates(slot, values[slot], replace=True)
    
                        jg_acc, sl_acc = train_eval(args, model, device, dev_dataloader, slots_dev)
                        logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' \
                                    % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc))
                        tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step)
                        tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step)
    
                        # Set model back to training mode
                        model.train()
                        model.zero_grad()
                        values = {slot: slots[slot][1] for slot in slots}
                        for slot in values:
                            model.add_value_candidates(slot, values[slot], replace=True)
                    else:
                        jg_acc = 0.0
                        logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps))
                    
                    logging_loss = tr_loss
                    if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0:
                        update = True
                    elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0:
                        update = True
                    else:
                        update = False
                    
                    if update:
                        steps_since_last_update = 0
                        logger.info('Model saved.')
                        best_model['joint goal accuracy'] = jg_acc
                        best_model['train loss'] = tr_loss / global_step
    
                        output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        
                        model.save_pretrained(output_dir)
                        best_model['state_dict'] = model.state_dict()
    
                        torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
    
                        clear_checkpoints(args.output_dir)
                        if args.gcs_bucket_name:
                            remote = os.path.join(os.path.basename(args.output_dir), "checkpoint-{}".format(global_step))
                            upload_local_directory_to_gcs(output_dir, args.gcs_bucket_name, remote)
                    else:
                        steps_since_last_update += 1
                        logger.info('Model not saved.')
    
                # Stop training after max training steps or if the model has not updated for too long
                if args.max_training_steps > 0 and global_step > args.max_training_steps:
                    epoch_iterator.close()
                    break
                if args.patience > 0 and steps_since_last_update >= args.patience:
                    epoch_iterator.close()
                    break
    
            logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step))
            if args.max_training_steps > 0 and global_step > args.max_training_steps:
                train_iterator.close()
                break
            if args.patience > 0 and steps_since_last_update >= args.patience:
                train_iterator.close()
                break
                logger.info('Model has not improved for atleast %i steps. Training stopped!' % args.patience)
        
        # Evaluate final model
        if args.do_eval:
            model.eval()
            values = {slot: slots_dev[slot][1] for slot in slots_dev}
            for slot in slots:
                if slot in values:
                    model.add_value_candidates(slot, values[slot], replace=True)
    
            jg_acc, sl_acc = train_eval(args, model, device, dev_dataloader, slots_dev)
            logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' \
                        % (tr_loss / global_step, jg_acc, sl_acc))
        else:
            jg_acc = 0.0
            logger.info('Training complete!')
    
        # Store final model
        if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0:
            update = True
        elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0:
            update = True
        else:
            update = False
        
        if update:
            logger.info('Final model saved.')
            output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            
            model.save_pretrained(output_dir)
            best_model['state_dict'] = model.state_dict()
    
            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            clear_checkpoints(args.output_dir)
        else:
            logger.info('Final model not saved, since it is not the best performing model.')
        
        # At end of training roll back to best performing model
        model.load_state_dict(best_model['state_dict'])
    
    
    # Function for validation
    def train_eval(args, model, device, dev_dataloader, slots):
        """Evaluate Model during training!"""
        accuracy_jg = []
        accuracy_sl = []
        turns = []
        for step, batch in enumerate(dev_dataloader):
            # Perform with no gradients stored
            with torch.no_grad():
                input_ids = batch['input_ids'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
    
                jg_acc = 0.0
                p = model(input_ids, token_type_ids, attention_mask)
                for slot in slots:
                    labels = batch['labels-' + slot].to(device)
                    p_ = p[slot]
    
                    acc = (p_.argmax(-1) == labels).reshape(-1).float()
                    jg_acc += acc
                sl_acc = sum(jg_acc / len(slots)).float()
                jg_acc = sum((jg_acc / len(slots)).int()).float()
                n_turns = (labels >= 0).reshape(-1).sum().float().item()
    
                accuracy_jg.append(jg_acc.item())
                accuracy_sl.append(sl_acc.item())
                turns.append(n_turns)
        
        # Global accuracy reduction across batches
        turns = sum(turns)
        jg_acc = sum(accuracy_jg) / turns
        sl_acc = sum(accuracy_sl) / turns
    
        return jg_acc, sl_acc
    
    
    def evaluate(args, model, device, dataloader, slots):
        """Evaluate Model!"""
        # Evaluate!
        logger.info("***** Running evaluation *****")
        logger.info("  Num Batches = %d", len(dataloader))
        
        tr_loss = 0.0
        model.eval()
        
        logits = {slot: [] for slot in slots}
        accuracy_jg = []
        accuracy_sl = []
        turns = []
        epoch_iterator = tqdm(dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            with torch.no_grad():
                input_ids = batch['input_ids'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
    
                labels = {slot: batch['labels-' + slot].to(device) for slot in slots}
    
                loss, p = model(input_ids = input_ids,
                                    token_type_ids = token_type_ids,
                                    attention_mask = attention_mask,
                                    labels = labels)
                
                jg_acc = 0.0
                for slot in slots:
                    p_ = p[slot]
                    labels = batch['labels-' + slot].to(device)
    
                    if args.temp_scaling > 0.0:
                        p_ = torch.log(p_ + 1e-10) / args.temp_scaling
                        p_ = torch.softmax(p_, -1)
                    else:
                        p_ = torch.log(p_ + 1e-10) / 1.0
                        p_ = torch.softmax(p_, -1)
                    
                    logits[slot].append(p_)
    
                    if args.accuracy_samples > 0:
                        dist = Categorical(probs=p_.reshape(-1, p_.size(-1)))
                        lab_sample = dist.sample((args.accuracy_samples,))
                        lab_sample = lab_sample.transpose(0, 1)
                        acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)]
                        acc = torch.tensor(acc).float()
                    elif args.accuracy_topn > 0:
                        labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
                        labs = labs[:, :args.accuracy_topn]
                        acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)]
                        acc = torch.tensor(acc).float()
                    else:
                        acc = (p_.argmax(-1) == labels).reshape(-1).float()
                    
                    jg_acc += acc
                sl_acc = sum(jg_acc / len(slots)).float()
                jg_acc = sum((jg_acc / len(slots)).int()).float()
                n_turns = (labels >= 0).reshape(-1).sum().float().item()
    
                accuracy_jg.append(jg_acc.item())
                accuracy_sl.append(sl_acc.item())
                turns.append(n_turns)
                tr_loss += loss.item()
        
        for slot in logits:
            logits[slot] = torch.cat(logits[slot], 0)
        
        # Global accuracy reduction across batches
        turns = sum(turns)
        jg_acc = sum(accuracy_jg) / turns
        sl_acc = sum(accuracy_sl) / turns
    
        return jg_acc, sl_acc, tr_loss / (step + 1), logits