Skip to content
Snippets Groups Projects
Select Git revision
  • 9f4866f432f68527b1b7f3876d2770dc3780846f
  • master default protected
  • release/1.1.4
  • release/1.1.3
  • release/1.1.1
  • 1.4.2
  • 1.4.1
  • 1.4.0
  • 1.3.0
  • 1.2.1
  • 1.2.0
  • 1.1.5
  • 1.1.4
  • 1.1.3
  • 1.1.1
  • 1.1.0
  • 1.0.9
  • 1.0.8
  • 1.0.7
  • v1.0.5
  • 1.0.5
21 results

ConfigfileEvaluator.java

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    train.py 10.65 KiB
    import argparse
    import os
    import json
    from torch.utils.tensorboard import SummaryWriter
    import random
    import numpy as np
    import zipfile
    import torch
    from transformers import AdamW, get_linear_schedule_with_warmup
    from convlab2.nlu.jointBERT.dataloader import Dataloader
    from convlab2.nlu.jointBERT.jointBERT import JointBERT
    
    
    def set_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
    
    
    parser = argparse.ArgumentParser(description="Train a model.")
    parser.add_argument('--config_path',
                        help='path to config file')
    
    
    if __name__ == '__main__':
        args = parser.parse_args()
        config = json.load(open(args.config_path))
        data_dir = config['data_dir']
        output_dir = config['output_dir']
        log_dir = config['log_dir']
        DEVICE = config['DEVICE']
    
        set_seed(config['seed'])
    
        if 'unified_datasets' in data_dir:
            dataset_name = config['dataset_name']
            print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
            from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
        elif 'multiwoz' in data_dir:
            print('-'*20 + 'dataset:multiwoz' + '-'*20)
            from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
        elif 'camrest' in data_dir:
            print('-' * 20 + 'dataset:camrest' + '-' * 20)
            from convlab2.nlu.jointBERT.camrest.postprocess import is_slot_da, calculateF1, recover_intent
        elif 'crosswoz' in data_dir:
            print('-' * 20 + 'dataset:crosswoz' + '-' * 20)
            from convlab2.nlu.jointBERT.crosswoz.postprocess import is_slot_da, calculateF1, recover_intent
    
        intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
        tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
        dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
                                pretrained_weights=config['model']['pretrained_weights'])
        print('intent num:', len(intent_vocab))
        print('tag num:', len(tag_vocab))
        for data_key in ['train', 'val', 'test']:
            dataloader.load_data(json.load(open(os.path.join(data_dir, '{}_data.json'.format(data_key)))), data_key,
                                 cut_sen_len=config['cut_sen_len'], use_bert_tokenizer=config['use_bert_tokenizer'])
            print('{} set size: {}'.format(data_key, len(dataloader.data[data_key])))
    
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    
        writer = SummaryWriter(log_dir)
    
        model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim, dataloader.intent_weight)
        model.to(DEVICE)
    
        if config['model']['finetune']:
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in model.named_parameters() if
                            not any(nd in n for nd in no_decay) and p.requires_grad],
                 'weight_decay': config['model']['weight_decay']},
                {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
                 'weight_decay': 0.0}
            ]
            optimizer = AdamW(optimizer_grouped_parameters, lr=config['model']['learning_rate'],
                              eps=config['model']['adam_epsilon'])
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config['model']['warmup_steps'],
                                                        num_training_steps=config['model']['max_step'])
        else:
            for n, p in model.named_parameters():
                if 'bert' in n:
                    p.requires_grad = False
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                         lr=config['model']['learning_rate'])
    
        for name, param in model.named_parameters():
            print(name, param.shape, param.device, param.requires_grad)
    
        max_step = config['model']['max_step']
        check_step = config['model']['check_step']
        batch_size = config['model']['batch_size']
        model.zero_grad()
        train_slot_loss, train_intent_loss = 0, 0
        best_val_f1 = 0.
    
        writer.add_text('config', json.dumps(config))
    
        for step in range(1, max_step + 1):
            model.train()
            batched_data = dataloader.get_train_batch(batch_size)
            batched_data = tuple(t.to(DEVICE) for t in batched_data)
            word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = batched_data
            if not config['model']['context']:
                context_seq_tensor, context_mask_tensor = None, None
            _, _, slot_loss, intent_loss = model.forward(word_seq_tensor, word_mask_tensor, tag_seq_tensor, tag_mask_tensor,
                                                         intent_tensor, context_seq_tensor, context_mask_tensor)
            train_slot_loss += slot_loss.item()
            train_intent_loss += intent_loss.item()
            loss = slot_loss + intent_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if config['model']['finetune']:
                scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            if step % check_step == 0:
                train_slot_loss = train_slot_loss / check_step
                train_intent_loss = train_intent_loss / check_step
                print('[%d|%d] step' % (step, max_step))
                print('\t slot loss:', train_slot_loss)
                print('\t intent loss:', train_intent_loss)
    
                predict_golden = {'intent': [], 'slot': [], 'overall': []}
    
                val_slot_loss, val_intent_loss = 0, 0
                model.eval()
                for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(batch_size, data_key='val'):
                    pad_batch = tuple(t.to(DEVICE) for t in pad_batch)
                    word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
                    if not config['model']['context']:
                        context_seq_tensor, context_mask_tensor = None, None
    
                    with torch.no_grad():
                        slot_logits, intent_logits, slot_loss, intent_loss = model.forward(word_seq_tensor,
                                                                                           word_mask_tensor,
                                                                                           tag_seq_tensor,
                                                                                           tag_mask_tensor,
                                                                                           intent_tensor,
                                                                                           context_seq_tensor,
                                                                                           context_mask_tensor)
                    val_slot_loss += slot_loss.item() * real_batch_size
                    val_intent_loss += intent_loss.item() * real_batch_size
                    for j in range(real_batch_size):
                        predicts = recover_intent(dataloader, intent_logits[j], slot_logits[j], tag_mask_tensor[j],
                                                  ori_batch[j][0], ori_batch[j][-4])
                        labels = ori_batch[j][3]
    
                        predict_golden['overall'].append({
                            'predict': predicts,
                            'golden': labels
                        })
                        if isinstance(predicts, dict):
                            predict_golden['slot'].append({
                                'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
                                'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
                            })
                            predict_golden['intent'].append({
                                'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
                                'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
                            })
                        else:
                            assert isinstance(predicts, list)
                            predict_golden['slot'].append({
                                'predict': [x for x in predicts if is_slot_da(x)],
                                'golden': [x for x in labels if is_slot_da(x)]
                            })
                            predict_golden['intent'].append({
                                'predict': [x for x in predicts if not is_slot_da(x)],
                                'golden': [x for x in labels if not is_slot_da(x)]
                            })
    
                for j in range(10):
                    writer.add_text('val_sample_{}'.format(j),
                                    json.dumps(predict_golden['overall'][j], indent=2, ensure_ascii=False),
                                    global_step=step)
    
                total = len(dataloader.data['val'])
                val_slot_loss /= total
                val_intent_loss /= total
                print('%d samples val' % total)
                print('\t slot loss:', val_slot_loss)
                print('\t intent loss:', val_intent_loss)
    
                writer.add_scalar('intent_loss/train', train_intent_loss, global_step=step)
                writer.add_scalar('intent_loss/val', val_intent_loss, global_step=step)
    
                writer.add_scalar('slot_loss/train', train_slot_loss, global_step=step)
                writer.add_scalar('slot_loss/val', val_slot_loss, global_step=step)
    
                for x in ['intent', 'slot', 'overall']:
                    precision, recall, F1 = calculateF1(predict_golden[x])
                    print('-' * 20 + x + '-' * 20)
                    print('\t Precision: %.2f' % (100 * precision))
                    print('\t Recall: %.2f' % (100 * recall))
                    print('\t F1: %.2f' % (100 * F1))
    
                    writer.add_scalar('val_{}/precision'.format(x), precision, global_step=step)
                    writer.add_scalar('val_{}/recall'.format(x), recall, global_step=step)
                    writer.add_scalar('val_{}/F1'.format(x), F1, global_step=step)
    
                if F1 > best_val_f1:
                    best_val_f1 = F1
                    torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
                    print('best val F1 %.4f' % best_val_f1)
                    print('save on', output_dir)
    
                train_slot_loss, train_intent_loss = 0, 0
    
        writer.add_text('val overall F1', '%.2f' % (100 * best_val_f1))
        writer.close()
    
        model_path = os.path.join(output_dir, 'pytorch_model.bin')
        zip_path = config['zipped_model_path']
        print('zip model to', zip_path)
    
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            zf.write(model_path)