Skip to content
Snippets Groups Projects
Commit 2d6882db authored by zqwerty's avatar zqwerty
Browse files

add back bertnlu test.py. rm test.py in .gitignore

parent 58f9a0d4
No related branches found
No related tags found
No related merge requests found
......@@ -101,7 +101,6 @@ convlab/dst/trade/multiwoz/model/
convlab/dst/trade/multiwoz_config/
convlab/deploy/bert_multiwoz_all.zip
convlab/deploy/templates/dialog_eg.html
test.py
*convlab/policy/vector/action_dicts
*.egg-info
......
import argparse
import os
import json
import random
import numpy as np
import torch
from convlab.nlu.jointBERT.dataloader import Dataloader
from convlab.nlu.jointBERT.jointBERT import JointBERT
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
parser = argparse.ArgumentParser(description="Test a model.")
parser.add_argument('--config_path',
help='path to config file')
if __name__ == '__main__':
args = parser.parse_args()
config = json.load(open(args.config_path))
data_dir = config['data_dir']
output_dir = config['output_dir']
log_dir = config['log_dir']
DEVICE = config['DEVICE']
set_seed(config['seed'])
if 'unified_datasets' in data_dir:
dataset_name = config['dataset_name']
print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
from convlab.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
elif 'multiwoz' in data_dir:
print('-'*20 + 'dataset:multiwoz' + '-'*20)
from convlab.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
elif 'camrest' in data_dir:
print('-' * 20 + 'dataset:camrest' + '-' * 20)
from convlab.nlu.jointBERT.camrest.postprocess import is_slot_da, calculateF1, recover_intent
elif 'crosswoz' in data_dir:
print('-' * 20 + 'dataset:crosswoz' + '-' * 20)
from convlab.nlu.jointBERT.crosswoz.postprocess import is_slot_da, calculateF1, recover_intent
intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
pretrained_weights=config['model']['pretrained_weights'])
print('intent num:', len(intent_vocab))
print('tag num:', len(tag_vocab))
for data_key in ['val', 'test']:
dataloader.load_data(json.load(open(os.path.join(data_dir, '{}_data.json'.format(data_key)))), data_key,
cut_sen_len=0, use_bert_tokenizer=config['use_bert_tokenizer'])
print('{} set size: {}'.format(data_key, len(dataloader.data[data_key])))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
model.load_state_dict(torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE))
model.to(DEVICE)
model.eval()
batch_size = config['model']['batch_size']
data_key = 'test'
predict_golden = {'intent': [], 'slot': [], 'overall': []}
slot_loss, intent_loss = 0, 0
for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(batch_size, data_key=data_key):
pad_batch = tuple(t.to(DEVICE) for t in pad_batch)
word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
if not config['model']['context']:
context_seq_tensor, context_mask_tensor = None, None
with torch.no_grad():
slot_logits, intent_logits, batch_slot_loss, batch_intent_loss = model.forward(word_seq_tensor,
word_mask_tensor,
tag_seq_tensor,
tag_mask_tensor,
intent_tensor,
context_seq_tensor,
context_mask_tensor)
slot_loss += batch_slot_loss.item() * real_batch_size
intent_loss += batch_intent_loss.item() * real_batch_size
for j in range(real_batch_size):
predicts = recover_intent(dataloader, intent_logits[j], slot_logits[j], tag_mask_tensor[j],
ori_batch[j][0], ori_batch[j][-4])
labels = ori_batch[j][3]
predict_golden['overall'].append({
'predict': predicts,
'golden': labels
})
if isinstance(predicts, dict):
predict_golden['slot'].append({
'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
})
predict_golden['intent'].append({
'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
})
else:
assert isinstance(predicts, list)
predict_golden['slot'].append({
'predict': [x for x in predicts if is_slot_da(x)],
'golden': [x for x in labels if is_slot_da(x)]
})
predict_golden['intent'].append({
'predict': [x for x in predicts if not is_slot_da(x)],
'golden': [x for x in labels if not is_slot_da(x)]
})
print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key])))
total = len(dataloader.data[data_key])
slot_loss /= total
intent_loss /= total
print('%d samples %s' % (total, data_key))
print('\t slot loss:', slot_loss)
print('\t intent loss:', intent_loss)
for x in ['intent', 'slot', 'overall']:
precision, recall, F1 = calculateF1(predict_golden[x])
print('-' * 20 + x + '-' * 20)
print('\t Precision: %.2f' % (100 * precision))
print('\t Recall: %.2f' % (100 * recall))
print('\t F1: %.2f' % (100 * F1))
output_file = os.path.join(output_dir, 'output.json')
json.dump(predict_golden['overall'], open(output_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment