Skip to content
Snippets Groups Projects
Select Git revision
  • dd7a9961fe61c399067e16e84bca40df5ef64dfa
  • master default protected
  • cran_version
3 results

Makevars.win

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test.py 6.18 KiB
    import argparse
    import os
    import json
    import random
    import numpy as np
    import torch
    from convlab2.nlu.jointBERT.dataloader import Dataloader
    from convlab2.nlu.jointBERT.jointBERT import JointBERT
    
    
    def set_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
    
    
    parser = argparse.ArgumentParser(description="Test a model.")
    parser.add_argument('--config_path',
                        help='path to config file')
    
    
    if __name__ == '__main__':
        args = parser.parse_args()
        config = json.load(open(args.config_path))
        data_dir = config['data_dir']
        output_dir = config['output_dir']
        log_dir = config['log_dir']
        DEVICE = config['DEVICE']
    
        set_seed(config['seed'])
    
        if 'unified_datasets' in data_dir:
            dataset_name = config['dataset_name']
            print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
            from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
        elif 'multiwoz' in data_dir:
            print('-'*20 + 'dataset:multiwoz' + '-'*20)
            from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
        elif 'camrest' in data_dir:
            print('-' * 20 + 'dataset:camrest' + '-' * 20)
            from convlab2.nlu.jointBERT.camrest.postprocess import is_slot_da, calculateF1, recover_intent
        elif 'crosswoz' in data_dir:
            print('-' * 20 + 'dataset:crosswoz' + '-' * 20)
            from convlab2.nlu.jointBERT.crosswoz.postprocess import is_slot_da, calculateF1, recover_intent
    
        intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
        tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
        dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
                                pretrained_weights=config['model']['pretrained_weights'])
        print('intent num:', len(intent_vocab))
        print('tag num:', len(tag_vocab))
        for data_key in ['val', 'test']:
            dataloader.load_data(json.load(open(os.path.join(data_dir, '{}_data.json'.format(data_key)))), data_key,
                                 cut_sen_len=0, use_bert_tokenizer=config['use_bert_tokenizer'])
            print('{} set size: {}'.format(data_key, len(dataloader.data[data_key])))
    
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    
        model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
        model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE))
        model.to(DEVICE)
        model.eval()
    
        batch_size = config['model']['batch_size']
    
        data_key = 'test'
        predict_golden = {'intent': [], 'slot': [], 'overall': []}
        slot_loss, intent_loss = 0, 0
        for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(batch_size, data_key=data_key):
            pad_batch = tuple(t.to(DEVICE) for t in pad_batch)
            word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
            if not config['model']['context']:
                context_seq_tensor, context_mask_tensor = None, None
    
            with torch.no_grad():
                slot_logits, intent_logits, batch_slot_loss, batch_intent_loss = model.forward(word_seq_tensor,
                                                                                               word_mask_tensor,
                                                                                               tag_seq_tensor,
                                                                                               tag_mask_tensor,
                                                                                               intent_tensor,
                                                                                               context_seq_tensor,
                                                                                               context_mask_tensor)
            slot_loss += batch_slot_loss.item() * real_batch_size
            intent_loss += batch_intent_loss.item() * real_batch_size
            for j in range(real_batch_size):
                predicts = recover_intent(dataloader, intent_logits[j], slot_logits[j], tag_mask_tensor[j],
                                          ori_batch[j][0], ori_batch[j][-4])
                labels = ori_batch[j][3]
    
                predict_golden['overall'].append({
                    'predict': predicts,
                    'golden': labels
                })
                if isinstance(predicts, dict):
                    predict_golden['slot'].append({
                        'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
                        'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
                    })
                    predict_golden['intent'].append({
                        'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
                        'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
                    })
                else:
                    assert isinstance(predicts, list)
                    predict_golden['slot'].append({
                        'predict': [x for x in predicts if is_slot_da(x)],
                        'golden': [x for x in labels if is_slot_da(x)]
                    })
                    predict_golden['intent'].append({
                        'predict': [x for x in predicts if not is_slot_da(x)],
                        'golden': [x for x in labels if not is_slot_da(x)]
                    })
            print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key])))
    
        total = len(dataloader.data[data_key])
        slot_loss /= total
        intent_loss /= total
        print('%d samples %s' % (total, data_key))
        print('\t slot loss:', slot_loss)
        print('\t intent loss:', intent_loss)
    
        for x in ['intent', 'slot', 'overall']:
            precision, recall, F1 = calculateF1(predict_golden[x])
            print('-' * 20 + x + '-' * 20)
            print('\t Precision: %.2f' % (100 * precision))
            print('\t Recall: %.2f' % (100 * recall))
            print('\t F1: %.2f' % (100 * F1))
    
        output_file = os.path.join(output_dir, 'output.json')
        json.dump(predict_golden['overall'], open(output_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)