Skip to content
Snippets Groups Projects
Commit c8c19090 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Refactor and rename calibration to evaluate

parent 06491f2b
Branches
No related tags found
No related merge requests found
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,33 +16,22 @@ ...@@ -16,33 +16,22 @@
"""Run SetSUMBT Calibration""" """Run SetSUMBT Calibration"""
import logging import logging
import random
import os import os
from shutil import copy2 as copy
import torch import torch
from transformers import (BertModel, BertConfig, BertTokenizer, from transformers import (BertModel, BertConfig, BertTokenizer,
RobertaModel, RobertaConfig, RobertaTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer)
AdamW, get_linear_schedule_with_warmup)
from tqdm import tqdm, trange from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
from tensorboardX import SummaryWriter from convlab.dst.setsumbt.dataset import unified_format
from torch.distributions import Categorical from convlab.dst.setsumbt.dataset import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, update_args
from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from convlab.dst.setsumbt.modeling import evaluation_utils
from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc
from convlab.dst.setsumbt.multiwoz import multiwoz21 from convlab.dst.setsumbt.modeling import training
from convlab.dst.setsumbt.multiwoz import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, upload_local_directory_to_gcs, update_args
from convlab.dst.setsumbt.modeling import calibration_utils
from convlab.dst.setsumbt.modeling import ensemble_utils
from convlab.dst.setsumbt.loss.ece import ece, jg_ece, l2_acc
# Datasets
DATASETS = {
'multiwoz21': multiwoz21
}
# Available model
MODELS = { MODELS = {
'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
...@@ -54,12 +43,6 @@ def main(args=None, config=None): ...@@ -54,12 +43,6 @@ def main(args=None, config=None):
if args is None: if args is None:
args, config = get_args(MODELS) args, config = get_args(MODELS)
# Select Dataset object
if args.dataset in DATASETS:
Dataset = DATASETS[args.dataset]
else:
raise NameError('NotImplemented')
if args.model_type in MODELS: if args.model_type in MODELS:
SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
else: else:
...@@ -67,69 +50,35 @@ def main(args=None, config=None): ...@@ -67,69 +50,35 @@ def main(args=None, config=None):
# Set up output directory # Set up output directory
OUTPUT_DIR = args.output_dir OUTPUT_DIR = args.output_dir
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
args.output_dir = OUTPUT_DIR args.output_dir = OUTPUT_DIR
if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
paths = os.listdir(args.output_dir) if os.path.exists( # Set pretrained model path to the trained checkpoint
args.output_dir) else [] paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths: if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path) config = ConfigClass.from_pretrained(args.model_name_or_path)
else: else:
paths = os.listdir(args.output_dir) if os.path.exists( paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
args.output_dir) else []
paths = [os.path.join(args.output_dir, p)
for p in paths if 'checkpoint-' in p]
if paths: if paths:
paths = paths[0] paths = paths[0]
args.model_name_or_path = paths args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path) config = ConfigClass.from_pretrained(args.model_name_or_path)
if args.ensemble_size > 0:
paths = os.listdir(args.output_dir) if os.path.exists(
args.output_dir) else []
paths = [os.path.join(args.output_dir, p)
for p in paths if 'ensemble_' in p]
if paths:
args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path)
args = update_args(args, config) args = update_args(args, config)
# Set up data directory
DATA_DIR = args.data_dir
Dataset.set_datadir(DATA_DIR)
embeddings.set_datadir(DATA_DIR)
if args.shrink_active_domains and args.dataset == 'multiwoz21':
Dataset.set_active_domains(
['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
# Download and preprocess
Dataset.create_examples(
args.max_turn_len, args.predict_intents, args.force_processing)
# Create logger # Create logger
global logger global logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
formatter = logging.Formatter( formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if 'stream' not in args.logging_path:
fh = logging.FileHandler(args.logging_path) fh = logging.FileHandler(args.logging_path)
fh.setLevel(logging.INFO) fh.setLevel(logging.INFO)
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
else:
ch = logging.StreamHandler()
ch.setLevel(level=logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)
# Get device # Get device
if torch.cuda.is_available() and args.n_gpu > 0: if torch.cuda.is_available() and args.n_gpu > 0:
...@@ -142,18 +91,12 @@ def main(args=None, config=None): ...@@ -142,18 +91,12 @@ def main(args=None, config=None):
args.fp16 = False args.fp16 = False
# Set up model training/evaluation # Set up model training/evaluation
calibration.set_logger(logger, None) evaluation_utils.set_logger(logger, None)
calibration.set_seed(args) evaluation_utils.set_seed(args)
if args.ensemble_size > 0:
ensemble.set_logger(logger, tb_writer)
ensemble_utils.set_seed(args)
# Perform tasks # Perform tasks
if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')): if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
pred = torch.load(os.path.join( pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
OUTPUT_DIR, 'predictions', 'test.predictions'))
labels = pred['labels'] labels = pred['labels']
belief_states = pred['belief_states'] belief_states = pred['belief_states']
if 'request_labels' in pred: if 'request_labels' in pred:
...@@ -166,100 +109,41 @@ def main(args=None, config=None): ...@@ -166,100 +109,41 @@ def main(args=None, config=None):
else: else:
request_belief = None request_belief = None
del pred del pred
elif args.ensemble_size > 0:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
test_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'test.db'))
else: else:
# Create Tokenizer and embedding model for Data Loaders and ontology # Get training batch loaders and ontology embeddings
encoder = CandidateEncoderModel.from_pretrained(
config.candidate_embedding_model_name)
tokenizer = Tokenizer(config.candidate_embedding_model_name)
embeddings.get_slot_candidate_embeddings(
'test', args, tokenizer, encoder)
test_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'test.db'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
test_dataloader = torch.load(os.path.join( test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
OUTPUT_DIR, 'dataloaders', 'test.dataloader')) if test_dataloader.batch_size != args.test_batch_size:
if test_dataloader.batch_size == args.test_batch_size: test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
exists = True else:
if not exists:
tokenizer = Tokenizer(config.candidate_embedding_model_name) tokenizer = Tokenizer(config.candidate_embedding_model_name)
test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len) config.max_turn_len)
torch.save(test_dataloader, os.path.join( torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
config, models = ensemble.get_models(
args.model_name_or_path, device, ConfigClass, SetSumbtModel)
belief_states, labels = ensemble_utils.get_predictions(
args, models, device, test_dataloader, test_slots)
torch.save({'belief_states': belief_states, 'labels': labels},
os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
else:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
test_slots = torch.load(os.path.join( test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
OUTPUT_DIR, 'database', 'test.db'))
else: else:
# Create Tokenizer and embedding model for Data Loaders and ontology encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
encoder = CandidateEncoderModel.from_pretrained( test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
config.candidate_embedding_model_name)
tokenizer = Tokenizer(config.candidate_embedding_model_name)
embeddings.get_slot_candidate_embeddings(
'test', args, tokenizer, encoder) 'test', args, tokenizer, encoder)
test_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'test.db'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
test_dataloader = torch.load(os.path.join(
OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if test_dataloader.batch_size == args.test_batch_size:
exists = True
if not exists:
tokenizer = Tokenizer(config.candidate_embedding_model_name)
test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(test_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
# Initialise Model # Initialise Model
model = SetSumbtModel.from_pretrained( model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
args.model_name_or_path, config=config)
model = model.to(device) model = model.to(device)
# Get slot and value embeddings training.set_ontology_embeddings(model, test_slots)
slots = {slot: test_slots[slot] for slot in test_slots}
values = {slot: test_slots[slot][1] for slot in test_slots}
# Load model ontology
model.add_slot_candidates(slots)
for slot in model.informable_slot_ids:
model.add_value_candidates(slot, values[slot], replace=True)
belief_states = calibration.get_predictions( belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader)
args, model, device, test_dataloader)
belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states
out = {'belief_states': belief_states, 'labels': labels, out = {'belief_states': belief_states, 'labels': labels,
'request_belief': request_belief, 'request_labels': request_labels, 'request_belief': request_belief, 'request_labels': request_labels,
'domain_belief': domain_belief, 'domain_labels': domain_labels, 'domain_belief': domain_belief, 'domain_labels': domain_labels,
'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels} 'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels}
torch.save(out, os.path.join( torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
OUTPUT_DIR, 'predictions', 'test.predictions'))
# err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10)
# for slot in belief_states]
# err = max(err)
# logger.info('ECE: %f' % err)
# Calculate calibration metrics # Calculate calibration metrics
jg = jg_ece(belief_states, labels, 10) jg = jg_ece(belief_states, labels, 10)
logger.info('Joint Goal ECE: %f' % jg) logger.info('Joint Goal ECE: %f' % jg)
...@@ -298,11 +182,11 @@ def main(args=None, config=None): ...@@ -298,11 +182,11 @@ def main(args=None, config=None):
logger.info('Slot presence Binary ECE: %f' % jg) logger.info('Slot presence Binary ECE: %f' % jg)
jg_acc = 0.0 jg_acc = 0.0
padding = torch.cat([item.unsqueeze(-1) padding = torch.cat([item.unsqueeze(-1) for _, item in labels.items()], -1).sum(-1) * -1.0
for _, item in labels.items()], -1).sum(-1) * -1.0
padding = (padding == len(labels)) padding = (padding == len(labels))
padding = padding.reshape(-1) padding = padding.reshape(-1)
for slot in belief_states: for slot in belief_states:
args.accuracy_topn = 1
topn = args.accuracy_topn topn = args.accuracy_topn
p_ = belief_states[slot] p_ = belief_states[slot]
gold = labels[slot] gold = labels[slot]
...@@ -317,8 +201,7 @@ def main(args=None, config=None): ...@@ -317,8 +201,7 @@ def main(args=None, config=None):
labs = labs[:, :topn] labs = labs[:, :topn]
else: else:
labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1) labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
acc = [lab in s for lab, s, pad in zip( acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), labs, padding) if not pad]
gold.reshape(-1), labs, padding) if not pad]
acc = torch.tensor(acc).float() acc = torch.tensor(acc).float()
jg_acc += acc jg_acc += acc
...@@ -337,6 +220,34 @@ def main(args=None, config=None): ...@@ -337,6 +220,34 @@ def main(args=None, config=None):
l2 = l2_acc(belief_states, labels, remove_belief=True) l2 = l2_acc(belief_states, labels, remove_belief=True)
logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}') logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
padding = torch.cat([item.unsqueeze(-1) for _, item in labels.items()], -1).sum(-1) * -1.0
padding = (padding == len(labels))
padding = padding.reshape(-1)
tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0
for slot in belief_states:
p_ = belief_states[slot]
gold = labels[slot].reshape(-1)
p_ = p_.reshape(-1, p_.size(-1))
p_ = p_[~padding].argmax(-1)
gold = gold[~padding]
tp += (p_ == gold)[gold != 0].int().sum().item()
fp += (p_ != 0)[gold == 0].int().sum().item()
fp += (p_ != gold)[gold != 0].int().sum().item()
fp -= (p_ == 0)[gold != 0].int().sum().item()
fn += (p_ == 0)[gold != 0].int().sum().item()
tn += (p_ == 0)[gold == 0].int().sum().item()
n += p_.size(0)
acc = (tp + tn) / n
prec = tp / (tp + fp)
rec = tp / (tp + fn)
f1 = 2 * (prec * rec) / (prec + rec)
logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}")
for slot in belief_states: for slot in belief_states:
p = belief_states[slot] p = belief_states[slot]
p = p.reshape(-1, p.size(-1)) p = p.reshape(-1, p.size(-1))
...@@ -347,7 +258,6 @@ def main(args=None, config=None): ...@@ -347,7 +258,6 @@ def main(args=None, config=None):
l = labels[slot].reshape(-1) l = labels[slot].reshape(-1)
l[l > 0] = 1 l[l > 0] = 1
labels[slot] = l labels[slot] = l
f1 = 0.0 f1 = 0.0
for slot in belief_states: for slot in belief_states:
prd = belief_states[slot].argmax(-1) prd = belief_states[slot].argmax(-1)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Discriminative models calibration""" """Evaluation Utilities"""
import random import random
...@@ -119,7 +119,9 @@ def get_predictions(args, model, device, dataloader): ...@@ -119,7 +119,9 @@ def get_predictions(args, model, device, dataloader):
else: else:
request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6 request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6
return belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels out = (belief_states, labels, request_belief, request_labels)
out += (domain_belief, domain_labels, greeting_belief, greeting_labels)
return out
def normalise(p): def normalise(p):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment