diff --git a/.gitignore b/.gitignore index 6012819da84535f1108916153f3116b3a609ba07..ea07c374a8d5c858256dcd34100b5eee1b52e8ab 100644 --- a/.gitignore +++ b/.gitignore @@ -101,7 +101,6 @@ convlab/dst/trade/multiwoz/model/ convlab/dst/trade/multiwoz_config/ convlab/deploy/bert_multiwoz_all.zip convlab/deploy/templates/dialog_eg.html -test.py *convlab/policy/vector/action_dicts *.egg-info diff --git a/convlab/nlu/jointBERT/test.py b/convlab/nlu/jointBERT/test.py new file mode 100755 index 0000000000000000000000000000000000000000..4979e3fa21e8b33a54950d2dcaa7042c8306f006 --- /dev/null +++ b/convlab/nlu/jointBERT/test.py @@ -0,0 +1,133 @@ +import argparse +import os +import json +import random +import numpy as np +import torch +from convlab.nlu.jointBERT.dataloader import Dataloader +from convlab.nlu.jointBERT.jointBERT import JointBERT + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +parser = argparse.ArgumentParser(description="Test a model.") +parser.add_argument('--config_path', + help='path to config file') + + +if __name__ == '__main__': + args = parser.parse_args() + config = json.load(open(args.config_path)) + data_dir = config['data_dir'] + output_dir = config['output_dir'] + log_dir = config['log_dir'] + DEVICE = config['DEVICE'] + + set_seed(config['seed']) + + if 'unified_datasets' in data_dir: + dataset_name = config['dataset_name'] + print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20) + from convlab.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent + elif 'multiwoz' in data_dir: + print('-'*20 + 'dataset:multiwoz' + '-'*20) + from convlab.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent + elif 'camrest' in data_dir: + print('-' * 20 + 'dataset:camrest' + '-' * 20) + from convlab.nlu.jointBERT.camrest.postprocess import is_slot_da, calculateF1, recover_intent + elif 'crosswoz' in data_dir: + print('-' * 20 + 'dataset:crosswoz' + '-' * 20) + from convlab.nlu.jointBERT.crosswoz.postprocess import is_slot_da, calculateF1, recover_intent + + intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) + tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'))) + dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab, + pretrained_weights=config['model']['pretrained_weights']) + print('intent num:', len(intent_vocab)) + print('tag num:', len(tag_vocab)) + for data_key in ['val', 'test']: + dataloader.load_data(json.load(open(os.path.join(data_dir, '{}_data.json'.format(data_key)))), data_key, + cut_sen_len=0, use_bert_tokenizer=config['use_bert_tokenizer']) + print('{} set size: {}'.format(data_key, len(dataloader.data[data_key]))) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) + model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)) + model.to(DEVICE) + model.eval() + + batch_size = config['model']['batch_size'] + + data_key = 'test' + predict_golden = {'intent': [], 'slot': [], 'overall': []} + slot_loss, intent_loss = 0, 0 + for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(batch_size, data_key=data_key): + pad_batch = tuple(t.to(DEVICE) for t in pad_batch) + word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch + if not config['model']['context']: + context_seq_tensor, context_mask_tensor = None, None + + with torch.no_grad(): + slot_logits, intent_logits, batch_slot_loss, batch_intent_loss = model.forward(word_seq_tensor, + word_mask_tensor, + tag_seq_tensor, + tag_mask_tensor, + intent_tensor, + context_seq_tensor, + context_mask_tensor) + slot_loss += batch_slot_loss.item() * real_batch_size + intent_loss += batch_intent_loss.item() * real_batch_size + for j in range(real_batch_size): + predicts = recover_intent(dataloader, intent_logits[j], slot_logits[j], tag_mask_tensor[j], + ori_batch[j][0], ori_batch[j][-4]) + labels = ori_batch[j][3] + + predict_golden['overall'].append({ + 'predict': predicts, + 'golden': labels + }) + if isinstance(predicts, dict): + predict_golden['slot'].append({ + 'predict': {k:v for k, v in predicts.items() if is_slot_da(k)}, + 'golden': {k:v for k, v in labels.items() if is_slot_da(k)} + }) + predict_golden['intent'].append({ + 'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)}, + 'golden': {k:v for k, v in labels.items() if not is_slot_da(k)} + }) + else: + assert isinstance(predicts, list) + predict_golden['slot'].append({ + 'predict': [x for x in predicts if is_slot_da(x)], + 'golden': [x for x in labels if is_slot_da(x)] + }) + predict_golden['intent'].append({ + 'predict': [x for x in predicts if not is_slot_da(x)], + 'golden': [x for x in labels if not is_slot_da(x)] + }) + print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key]))) + + total = len(dataloader.data[data_key]) + slot_loss /= total + intent_loss /= total + print('%d samples %s' % (total, data_key)) + print('\t slot loss:', slot_loss) + print('\t intent loss:', intent_loss) + + for x in ['intent', 'slot', 'overall']: + precision, recall, F1 = calculateF1(predict_golden[x]) + print('-' * 20 + x + '-' * 20) + print('\t Precision: %.2f' % (100 * precision)) + print('\t Recall: %.2f' % (100 * recall)) + print('\t F1: %.2f' % (100 * F1)) + + output_file = os.path.join(output_dir, 'output.json') + json.dump(predict_golden['overall'], open(output_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)