Skip to content
Snippets Groups Projects
Commit 8e1da6bc authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Initial commit

parents
No related branches found
No related tags found
No related merge requests found
Showing
with 1591 additions and 0 deletions
# Neural Belief Tracking
Implementation of different uncertainty in Deep Learning methods for improved Dialogue ***Belief*** Tracking.
## NBT Model
Slot Utterance Matching Belief Tracker transformer. Using BERT language model fine tuning, a slot utterance matching Attention mechanism and a RNN latent state tracker to generate a latent belief state. Further embedding distance comparison is used for generating model logits.
## Uncertainty
### Loss
* Standard Cross Entropy Loss
* Label Smoothing KL divergence Loss (improved uncertainty)
* Belief Matching (Dirichlet activation function with ELBO Loss)
### Strategies
* Temperature Scaling
* Dropout Ensembling
* Model Ensembling
transformers==2.8.0
torch==1.4.0
urllib3
tqdm
tensorboardX
numpy
google-cloud-storage
\ No newline at end of file
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calibration Plot plotting script"""
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import torch
from matplotlib import pyplot as plt
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_dir', help='Location of the belief states', required=True)
parser.add_argument('--output', help='Output image path', default='calibration_plot.png')
parser.add_argument('--n_bins', help='Number of bins', default=10, type=int)
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
path = args.data_dir
models = os.listdir(path)
models = [os.path.join(path, model, 'test.belief') for model in models]
fig = plt.figure(figsize=(14,8))
font=20
plt.tick_params(labelsize=font-2)
linestyle = ['-', ':', (0, (3, 5, 1, 5)), '-.', (0, (5, 10))]
for i, model in enumerate(models):
conf, acc = get_calibration(model, device, n_bins=args.n_bins)
name = model.split('/')[-2].strip()
print(name, conf, acc)
plt.plot(conf, acc, label=name, linestyle=linestyle[i], linewidth=3)
plt.plot(torch.tensor([0,1]), torch.tensor([0,1]), linestyle='--', color='black', linewidth=3)
plt.xlabel('Confidence', fontsize=font)
plt.ylabel('Joint Goal Accuracy', fontsize=font)
plt.legend(fontsize=font)
plt.savefig(args.output)
def get_calibration(path, device, n_bins=10, temperature=1.00):
logits = torch.load(path, map_location=device)
y_true = logits['labels']
logits = logits['logits']
y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
goal_acc = sum([goal_acc[slot] for slot in goal_acc])
goal_acc = (goal_acc == len(y_true)).int()
scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
scores = torch.cat(scores, 0).min(0)[0]
step = 1.0 / float(n_bins)
bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
bins = []
for b in range(n_bins):
lower, upper = bin_ranges[b], bin_ranges[b + 1]
if b == 0:
ids = torch.where((scores >= lower) * (scores <= upper))[0]
else:
ids = torch.where((scores > lower) * (scores <= upper))[0]
bins.append(ids)
conf = [0.0]
for b in bins:
if b.size(0) > 0:
l = scores[b]
conf.append(l.mean())
else:
conf.append(-1)
conf = torch.tensor(conf)
slot = [s for s in y_true][0]
acc = [0.0]
for b in bins:
if b.size(0) > 0:
acc_ = goal_acc[b]
acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.float().mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
conf = conf[acc != -1]
acc = acc[acc != -1]
return conf, acc
if __name__ == '__main__':
main()
it's it is
don't do not
doesn't does not
didn't did not
you'd you would
you're you are
you'll you will
i'm i am
they're they are
that's that is
what's what is
couldn't could not
i've i have
we've we have
can't cannot
i'd i would
i'd i would
aren't are not
isn't is not
wasn't was not
weren't were not
won't will not
there's there is
there're there are
. . .
restaurants restaurant -s
hotels hotel -s
laptops laptop -s
cheaper cheap -er
dinners dinner -s
lunches lunch -s
breakfasts breakfast -s
expensively expensive -ly
moderately moderate -ly
cheaply cheap -ly
prices price -s
places place -s
venues venue -s
ranges range -s
meals meal -s
locations location -s
areas area -s
policies policy -s
children child -s
kids kid -s
kidfriendly kid friendly
cards card -s
upmarket expensive
inpricey cheap
inches inch -s
uses use -s
dimensions dimension -s
driverange drive range
includes include -s
computers computer -s
machines machine -s
families family -s
ratings rating -s
constraints constraint -s
pricerange price range
batteryrating battery rating
requirements requirement -s
drives drive -s
specifications specification -s
weightrange weight range
harddrive hard drive
batterylife battery life
businesses business -s
hours hour -s
one 1
two 2
three 3
four 4
five 5
six 6
seven 7
eight 8
nine 9
ten 10
eleven 11
twelve 12
anywhere any where
good bye goodbye
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MultiWOZ 2.1 Dialogue Dataset"""
import os
import json
import requests
import zipfile
import io
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from tqdm import tqdm
from utils import clean_text, ACTIVE_DOMAINS, IGNORE_GOALS, get_domains, fix_delexicalisation, extract_dialogue, get_acts
# Set up global data_directory
def set_datadir(dir):
global DATA_DIR
DATA_DIR = dir
# MultiWOZ2.1 download link
URL = 'https://www.repository.cam.ac.uk/bitstream/handle/1810/294507/MULTIWOZ2.1.zip?sequence=1&isAllowed=y'
def set_url(url):
global URL
URL = url
# Create Dialogue examples from the dataset
def create_examples(max_utt_len, force_processing=False):
# Load or download Raw Data
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')):
# Download data archive and extract
archive = _download()
data = _extract(archive)
writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w')
json.dump(data, writer, indent = 2)
del archive, writer
else:
reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r')
data = json.load(reader)
if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')):
# Preprocess all dialogues
data_processed = _process(data['data'], data['dialogue_acts'])
# Format data and split train, test and devlopment sets
train, dev, test = _split_data(data_processed, data['testListFile'],
data['valListFile'], max_utt_len)
# Write data
writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w')
json.dump(train, writer, indent = 2)
writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w')
json.dump(test, writer, indent = 2)
writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w')
json.dump(dev, writer, indent = 2)
writer.flush()
writer.close()
del writer
# Extract slots and slot value candidates from the dataset
for set_type in ['train', 'dev', 'test']:
_get_ontology(set_type)
# Extract slots and slot value candidates from the dataset
PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive']
BOOLEAN = ['do not care', 'yes', 'no']
DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saterday', 'sunday']
QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l]
def _get_ontology(set_type):
datasets = ['train']
if set_type in ['test', 'dev']:
datasets.append('dev')
datasets.append('test')
# Load examples
data = []
for dataset in datasets:
reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r')
data += json.load(reader)
ontology = dict()
for dial in data:
for turn in dial['dialogue']:
for state in turn['dialogue_state']:
slot, value = state
if slot not in ontology:
ontology[slot] = [value]
else:
ontology[slot].append(value)
for slot in ontology:
if 'price' in slot:
ontology[slot] = PRICERANGE
if 'parking' in slot or 'internet' in slot:
ontology[slot] = BOOLEAN
if 'day' in slot:
ontology[slot] = DAYS
if 'people' in slot or 'duration' in slot or 'stay' in slot:
ontology[slot] = QUANTITIES
if 'time' in slot or 'leave' in slot or 'arrive' in slot:
ontology[slot] = TIME
if 'stars' in slot:
ontology[slot] += [str(i) for i in range(5)]
# Sort slot values and add none and dontcare values
for slot in ontology:
ontology[slot] = list(set(ontology[slot]))
ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']])
writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w')
json.dump(ontology, writer, indent=2)
# Convert dialogue examples to model input features and labels
def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64):
features = dict()
# Load examples
reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r')
data = json.load(reader)
# Get encoder input for system, user utterance pairs
input_feats = []
for dial in data:
dial_feats = []
for turn in dial['dialogue']:
if len(turn['system_transcript']) == 0:
dial_feats.append(tokenizer.encode_plus(turn['transcript'], add_special_tokens = True,
max_length = max_seq_len, pad_to_max_length = 'right',
truncation_strategy = 'longest_first'))
else:
dial_feats.append(tokenizer.encode_plus(turn['transcript'], turn['system_transcript'], add_special_tokens = True,
max_length = max_seq_len, pad_to_max_length = 'right',
truncation_strategy = 'longest_first'))
if len(dial_feats) >= max_turns:
break
input_feats.append(dial_feats)
del dial_feats
# Perform turn level padding
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]:
token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
else:
token_type_ids = None
if 'attention_mask' in input_feats[0][0]:
attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
else:
attention_mask = None
del input_feats
# Create torch data tensors
features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
del input_ids, token_type_ids, attention_mask
# Load ontology
reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
ontology = json.load(reader)
# Create slot labels
for slot in ontology:
labels = []
for dial in data:
labs = []
for turn in dial['dialogue']:
slots_active = [s for s, v in turn['dialogue_state']]
if slot in slots_active:
value = [v for s, v in turn['dialogue_state'] if s == slot][0]
else:
value = 'none'
value = ontology[slot].index(value)
labs.append(value)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['labels-' + slot] = labels
del labels
return features
# MultiWOZ2.1 Dataset object
class MultiWoz21(Dataset):
def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64):
self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len)
def __getitem__(self, index):
return {label: self.features[label][index] if self.features[label] is not None else None
for label in self.features}
def __len__(self):
return self.features['input_ids'].size(0)
def resample(self, size=None):
n_dialogues = self.__len__()
if not size:
size = n_dialogues
dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
self.features = {label: self.features[label][dialogues] for label in self.features}
def to(self, device):
self.device = device
self.features = {label: self.features[label].to(device) for label in self.features}
# Module to create torch dataloaders
def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None):
data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len)
if device is not None:
data.to(device)
if resampled_size:
data.resample(resampled_size)
sampler = RandomSampler(data)
loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return loader
def _download(chunk_size=1048576):
"""Download data archive.
Parameters:
chunk_size (int): Download chunk size. (default=1048576)
Returns:
archive: ZipFile archive object.
"""
# Download the archive byte string
req = requests.get(URL, stream=True)
archive = b''
for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'):
if chunk:
archive += chunk
# Convert the bytestring into a zipfile object
archive = io.BytesIO(archive)
archive = zipfile.ZipFile(archive)
return archive
def _extract(archive):
"""Extract the json dictionaries from the archive.
Parameters:
archive: ZipFile archive object.
Returns:
data: Data dictionary.
"""
files = [file for file in archive.filelist if '.json' in file.filename and 'MACOSX' not in file.filename]
objects = []
for file in tqdm(files, desc='File'):
data = archive.open(file).read()
# Get data objects from the files
try:
data = json.loads(data)
except json.decoder.JSONDecodeError:
data = data.decode().split('\n')
objects.append(data)
files = [file.filename.split('/')[-1].split('.')[0] for file in files]
data = {file: data for file, data in zip(files, objects)}
return data
# Process files
def _process(dialogue_data, acts_data):
print('Processing Dialogues')
out = {}
for dial_name in tqdm(dialogue_data):
dialogue = dialogue_data[dial_name]
act_id, prev_dom = 1, ''
for turn_id, turn in enumerate(dialogue['log']):
dialogue['log'][turn_id]['text'] = clean_text(turn['text'])
if len(turn['metadata']) != 0:
crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom)
last_dom = [crnt_dom]
dialogue['log'][turn_id - 1]['domain'] = crnt_dom
dialogue['log'][turn_id]['dialogue_acts'] = get_acts(dial_name, acts_data, act_id)
act_id += 1
dialogue = fix_delexicalisation(dial_name, dialogue, acts_data, turn_id, act_id)
out[dial_name] = dialogue
return out
# Split data (train, dev, test)
def _split_data(dial_data, test, dev, max_utt_len):
train_dials, test_dials, dev_dials = [], [], []
print('Formatting and Splitting Data')
for name in tqdm(dial_data):
dialogue = dial_data[name]
domains = []
dial = extract_dialogue(dialogue, max_utt_len)
if dial:
dialogue = dict()
dialogue['dialogue_idx'] = name
dialogue['domains'] = []
dialogue['dialogue'] = []
for turn_id, turn in enumerate(dial):
turn_dialog = dict()
turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else ''
turn_dialog['turn_idx'] = turn_id
turn_dialog['dialogue_state'] = turn['ds']
turn_dialog['transcript'] = turn['usr']
turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else []
turn_dialog['domain'] = turn['domain']
dialogue['domains'].append(turn['domain'])
dialogue['dialogue'].append(turn_dialog)
# Add last turn
# turn_id += 1
# turn_dialog = dict()
# turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else ''
# turn_dialog['turn_idx'] = turn_id
# turn_dialog['dialogue_state'] = turn['ds']
# turn_dialog['transcript'] = ''
# turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else []
# turn_dialog['domain'] = turn['domain']
# dialogue['dialogue'].append(turn_dialog)
dialogue['domains'] = [dom for dom in list(set(dialogue['domains']))
if dom in ACTIVE_DOMAINS]
if dialogue['domains']:
if name in test:
test_dials.append(dialogue)
elif name in dev:
dev_dials.append(dialogue)
else:
train_dials.append(dialogue)
print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials)))
return train_dials, dev_dials, test_dials
This diff is collapsed.
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create Ontology Embeddings"""
import json
import os
import torch
# Set up global data directory
def set_datadir(dir):
global DATA_DIR
DATA_DIR = dir
# Get embeddings for slots and candidates
def get_slot_candidate_embeddings(set_type, max_slot_len, max_candidate_len, tokenizer, embedding_model):
# Get set alots and candidates
reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
ontology = json.load(reader)
slots = dict()
for slot in ontology:
# Tokenize slot and get embeddings
feats = tokenizer.encode_plus(slot, add_special_tokens = True, max_length = max_slot_len, pad_to_max_length = 'right')
input_ids = torch.tensor([feats['input_ids']])
if 'token_type_ids' in feats:
token_type_ids = torch.tensor([feats['token_type_ids']])
if 'attention_mask' in feats:
attention_mask = torch.tensor([feats['attention_mask']])
feats, _ = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask)
else:
feats, _ = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids)
else:
if 'attention_mask' in feats:
attention_mask = torch.tensor([feats['attention_mask']])
feats, _ = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
else:
feats, _ = embedding_model(input_ids=input_ids)
# Remove special tokens, sum up word embeddings and normalise
feats = feats[:, 1:-1, :]
feats = feats.sum(1)
feats = torch.nn.functional.layer_norm(feats, feats.size())
slot_emb = feats.detach()
# Tokenize value candidates and get embeddings
values = ontology[slot]
feats = [tokenizer.encode_plus(val, add_special_tokens = True, max_length = max_candidate_len, pad_to_max_length = 'right')
for val in values]
input_ids = torch.tensor([f['input_ids'] for f in feats])
if 'token_type_ids' in feats[0]:
token_type_ids = torch.tensor([f['token_type_ids'] for f in feats])
if 'attention_mask' in feats[0]:
attention_mask = torch.tensor([f['attention_mask'] for f in feats])
feats, _ = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask)
else:
feats, _ = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids)
else:
if 'attention_mask' in feats[0]:
attention_mask = torch.tensor([f['attention_mask'] for f in feats])
feats, _ = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
else:
feats, _ = embedding_model(input_ids=input_ids)
# Remove special tokens, sum up word embeddings and normalise
feats = feats[:, 1:-1, :]
feats = feats.sum(1)
feats = torch.nn.functional.layer_norm(feats, feats.size())
slots[slot] = (slot_emb, feats.detach())
# Dump tensors for use in training
writer = os.path.join(DATA_DIR, '%s.slots' % set_type)
torch.save(slots, writer)
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run NBT Calibration"""
import logging
import random
import os
import torch
from torch.nn import DataParallel
from torch.distributions import Categorical
from transformers import (BertModel, BertConfig, BertTokenizer,
AdamW, get_linear_schedule_with_warmup)
from tqdm import tqdm, trange
import numpy as np
from tensorboardX import SummaryWriter
from modeling.nbt import SumbtModel
from dataset import multiwoz21, simr, simm, woz2
from dataset import ontology as embeddings
from utils import get_args
from modeling import calibration
from modeling import ensemble
from loss.ece import ece, jg_ece
# Datasets
DATASETS = {
'multiwoz21': multiwoz21
}
def main(args=None, config=None):
# Get arguments
if args is None:
args, config = get_args(BertConfig)
ROOT = args.root
SCRATCH = args.scratch
# Select Dataset object
if args.dataset in DATASETS:
Dataset = DATASETS[args.dataset]
else:
raise NameError('NotImplemented')
# Set up data directory
if 'root' in args.data_dir:
DATA_DIR = os.path.join(ROOT, args.data_dir.split('-', 1)[-1])
elif 'scratch' in args.data_dir:
DATA_DIR = os.path.join(SCRATCH, args.data_dir.split('-', 1)[-1])
else:
DATA_DIR = args.data_dir.split('-', 1)[-1]
Dataset.set_datadir(DATA_DIR)
embeddings.set_datadir(DATA_DIR)
# Download and preprocess
Dataset.create_examples(args.max_turn_len, args.force_processing)
# Set up output directory
if 'root' in args.output_dir:
OUTPUT_DIR = os.path.join(ROOT, args.output_dir.split('-', 1)[-1])
elif 'scratch' in args.output_dir:
OUTPUT_DIR = os.path.join(SCRATCH, args.output_dir.split('-', 1)[-1])
else:
OUTPUT_DIR = args.output_dir.split('-', 1)[-1]
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
args.output_dir = OUTPUT_DIR
# Create TensorboardX writer
if 'root' in args.tensorboard_path:
tb_writer = SummaryWriter(logdir=os.path.join(ROOT, args.tensorboard_path.split('-', 1)[-1]))
elif 'scratch' in args.tensorboard_path:
tb_writer = SummaryWriter(logdir=os.path.join(SCRATCH, args.tensorboard_path.split('-', 1)[-1]))
else:
tb_writer = SummaryWriter(logdir=args.tensorboard_path.split('-', 1)[-1])
# Create logger
global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if 'stream' not in args.logging_path:
if 'root' in args.logging_path:
fh = logging.FileHandler(os.path.join(ROOT, args.logging_path.split('-', 1)[-1]))
elif 'scratch' in args.logging_path:
fh = logging.FileHandler(os.path.join(SCRATCH, args.logging_path.split('-', 1)[-1]))
else:
fh = logging.FileHandler(args.logging_path.split('-', 1)[-1])
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
else:
ch = logging.StreamHandler()
ch.setLevel(level=logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)
if torch.cuda.is_available() and args.n_gpu > 0:
device = torch.device('cuda')
args.n_gpu = 1
else:
device = torch.device('cpu')
args.n_gpu = 0
# Set up model training/evaluation
calibration.set_logger(logger, tb_writer)
calibration.set_seed(args)
if args.ensemble_size > 0:
ensemble.set_logger(logger, tb_writer)
ensemble.set_seed(args)
# Perform tasks
# TRAINING
if args.do_train:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(DATA_DIR, 'dev.slots')):
train_slots = torch.load(os.path.join(DATA_DIR, 'dev.slots'))
else:
# Create Tokenizer and embedding model for Data Loaders and ontology
encoder = BertModel.from_pretrained(args.candidate_embedding_model_name)
tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name, config=config)
embeddings.get_slot_candidate_embeddings('dev', args.max_slot_len, args.max_candidate_len, tokenizer, encoder)
train_slots = torch.load(os.path.join(DATA_DIR, 'dev.slots'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dev.dataloader')):
train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dev.dataloader'))
if train_dataloader.batch_size == args.dev_batch_size:
exists = True
if not exists:
tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name, config=config)
train_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dev.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'dev.belief')):
logits = torch.load(os.path.join(OUTPUT_DIR, 'dev.belief'))
labels = logits['labels']
logits = logits['logits']
elif args.ensemble_size > 0:
config, models = ensemble.get_models(args.model_name_or_path, device)
logits, labels = ensemble.get_logits(args, models, device, train_dataloader, train_slots)
torch.save({'logits': logits, 'labels': labels}, os.path.join(OUTPUT_DIR, 'dev.belief'))
else:
# Initialise Model
model = SumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = model.to(device)
# Get slot and value embeddings
slots = {slot: train_slots[slot][0] for slot in train_slots}
values = {slot: train_slots[slot][1] for slot in train_slots}
# Load model ontology
model.add_slot_candidates(slots)
for slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
logits, labels = calibration.get_logits(args, model, device, train_dataloader, train_slots)
torch.save({'logits': logits, 'labels': labels}, os.path.join(OUTPUT_DIR, 'dev.belief'))
best = calibration.train(logits, labels, args.patience)
logger.info('Best Temperature: %f' % best['temperature'])
args.temp_scaling = best['temperature']
# Testing
if args.do_test or args.do_eval:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(DATA_DIR, 'test.slots')):
test_slots = torch.load(os.path.join(DATA_DIR, 'test.slots'))
else:
# Create Tokenizer and embedding model for Data Loaders and ontology
encoder = BertModel.from_pretrained(args.candidate_embedding_model_name)
tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name, config=config)
embeddings.get_slot_candidate_embeddings('test', args.max_slot_len, args.max_candidate_len, tokenizer, encoder)
test_slots = torch.load(os.path.join(DATA_DIR, 'test.slots'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'test.dataloader')):
test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'test.dataloader'))
if test_dataloader.batch_size == args.test_batch_size:
exists = True
if not exists:
tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name, config=config)
test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'test.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'test.belief')):
logits = torch.load(os.path.join(OUTPUT_DIR, 'test.belief'))
labels = logits['labels']
logits = logits['logits']
elif args.ensemble_size > 0:
config, models = ensemble.get_models(args.model_name_or_path, device)
logits, labels = ensemble.get_logits(args, models, device, test_dataloader, test_slots)
torch.save({'logits': logits, 'labels': labels}, os.path.join(OUTPUT_DIR, 'test.belief'))
else:
# Initialise Model
model = SumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = model.to(device)
# Get slot and value embeddings
slots = {slot: test_slots[slot][0] for slot in test_slots}
values = {slot: test_slots[slot][1] for slot in test_slots}
# Load model ontology
model.add_slot_candidates(slots)
for slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
logits, labels = calibration.get_logits(args, model, device, test_dataloader, test_slots)
torch.save({'logits': logits, 'labels': labels}, os.path.join(OUTPUT_DIR, 'test.belief'))
if args.temp_scaling != 1.0:
logits = {slot: calibration.calibrate(args.temp_scaling, logits[slot]) for slot in logits}
else:
logits = {slot: logits[slot] + 1e-8 for slot in logits}
err = [ece(logits[slot].reshape(-1, logits[slot].size(-1)), labels[slot].reshape(-1), 10)
for slot in logits]
err = max(err)
logger.info('Temperature: %f' % args.temp_scaling)
logger.info('ECE: %f' % err)
jg = jg_ece(logits, labels, 10)
logger.info('Joint Goal ECE: %f' % jg)
jg_acc = 0.0
for slot in logits:
topn = args.accuracy_topn
p_ = logits[slot]
gold = labels[slot]
if p_.size(-1) <= topn:
topn = p_.size(-1) - 1
if topn <= 0:
topn = 1
labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True)
labs = labs[:, :topn]
acc = [lab in s for lab, s in zip(gold.reshape(-1), labs)]
acc = torch.tensor(acc).float()
jg_acc += acc
n_turns = (gold >= 0).reshape(-1).sum().float().item()
sl_acc = sum(jg_acc / len(logits)).float()
jg_acc = sum((jg_acc / len(logits)).int()).float()
sl_acc /= n_turns
jg_acc /= n_turns
logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc))
entropy = [logits[slot] for slot in logits]
entropy = [Categorical(p + 1e-8) for p in entropy]
entropy = [p.entropy().mean() for p in entropy]
entropy = min(entropy)
logger.info('Entropy: %f' % entropy)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run Neural Belief Tracker"""
import logging
import random
import os
import torch
from transformers import (BertModel, BertConfig, BertTokenizer,
AdamW, get_linear_schedule_with_warmup)
from tqdm import tqdm, trange
import numpy as np
from tensorboardX import SummaryWriter
from modeling.nbt import SumbtModel
from dataset import multiwoz21, simr, simm, woz2
from modeling import training
from dataset import ontology as embeddings
from utils import get_args, upload_local_directory_to_gcs
from modeling import ensemble
# Datasets
DATASETS = {
'multiwoz21': multiwoz21
}
def main(args=None, config=None):
# Get arguments
if args is None:
args, config = get_args(BertConfig)
ROOT = args.root
SCRATCH = args.scratch
# Select Dataset object
if args.dataset in DATASETS:
Dataset = DATASETS[args.dataset]
else:
raise NameError('NotImplemented')
# Set up data directory
if 'root' in args.data_dir:
DATA_DIR = os.path.join(ROOT, args.data_dir.split('-', 1)[-1])
elif 'scratch' in args.data_dir:
DATA_DIR = os.path.join(SCRATCH, args.data_dir.split('-', 1)[-1])
else:
DATA_DIR = args.data_dir.split('-', 1)[-1]
Dataset.set_datadir(DATA_DIR)
embeddings.set_datadir(DATA_DIR)
# Download and preprocess
Dataset.create_examples(args.max_turn_len, args.force_processing)
# Set up output directory
if 'root' in args.output_dir:
OUTPUT_DIR = os.path.join(ROOT, args.output_dir.split('-', 1)[-1])
elif 'scratch' in args.output_dir:
OUTPUT_DIR = os.path.join(SCRATCH, args.output_dir.split('-', 1)[-1])
else:
OUTPUT_DIR = args.output_dir.split('-', 1)[-1]
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
args.output_dir = OUTPUT_DIR
# Create TensorboardX writer
if 'root' in args.tensorboard_path:
tb_writer = SummaryWriter(logdir=os.path.join(ROOT, args.tensorboard_path.split('-', 1)[-1]))
elif 'scratch' in args.tensorboard_path:
tb_writer = SummaryWriter(logdir=os.path.join(SCRATCH, args.tensorboard_path.split('-', 1)[-1]))
else:
tb_writer = SummaryWriter(logdir=args.tensorboard_path.split('-', 1)[-1])
# Create logger
global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if 'stream' not in args.logging_path:
if 'root' in args.logging_path:
fh = logging.FileHandler(os.path.join(ROOT, args.logging_path.split('-', 1)[-1]))
elif 'scratch' in args.logging_path:
fh = logging.FileHandler(os.path.join(SCRATCH, args.logging_path.split('-', 1)[-1]))
else:
fh = logging.FileHandler(args.logging_path.split('-', 1)[-1])
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
else:
ch = logging.StreamHandler()
ch.setLevel(level=logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)
if torch.cuda.is_available() and args.n_gpu > 0:
device = torch.device('cuda')
args.n_gpu = 1
else:
device = torch.device('cpu')
args.n_gpu = 0
# Initialise Model
model = SumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = model.to(device)
# Create Tokenizer and embedding model for Data Loaders and ontology
encoder = BertModel.from_pretrained(args.candidate_embedding_model_name)
tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name, config=config)
# Set up model training/evaluation
training.set_logger(logger, tb_writer)
training.set_seed(args)
if args.ensemble_size > 1:
ensemble.set_logger(logger, tb_writer)
ensemble.set_seed(args)
logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
args.data_sampling_size))
dataloaders = ensemble.build_train_loaders(args, config, Dataset)
logger.info('Dataloaders built.')
for i, loader in enumerate(dataloaders):
path = os.path.join(OUTPUT_DIR, 'ensemble_%i' % i)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, 'train.dataloader')
torch.save(loader, path)
logger.info('Dataloaders saved.')
# Perform tasks
# TRAINING
if args.do_train:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(DATA_DIR, 'train.slots')):
train_slots = torch.load(os.path.join(DATA_DIR, 'train.slots'))
else:
embeddings.get_slot_candidate_embeddings('train', args.max_slot_len, args.max_candidate_len, tokenizer, encoder)
train_slots = torch.load(os.path.join(DATA_DIR, 'train.slots'))
if os.path.exists(os.path.join(DATA_DIR, 'dev.slots')):
dev_slots = torch.load(os.path.join(DATA_DIR, 'dev.slots'))
else:
embeddings.get_slot_candidate_embeddings('dev', args.max_slot_len, args.max_candidate_len, tokenizer, encoder)
dev_slots = torch.load(os.path.join(DATA_DIR, 'dev.slots'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'train.dataloader')):
train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'train.dataloader'))
if train_dataloader.batch_size == args.train_batch_size:
exists = True
if not exists:
if args.data_sampling_size <= 0:
args.data_sampling_size = None
train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len, resampled_size=args.data_sampling_size)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'train.dataloader'))
# Get development set batch loaders= and ontology embeddings
if args.do_eval:
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dev.dataloader'))
if dev_dataloader.batch_size == args.dev_batch_size:
exists = True
if not exists:
dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dev.dataloader'))
else:
dev_dataloader = None,
dev_slots = None
# Get slot and value embeddings
slots = {slot: train_slots[slot][0] for slot in train_slots}
values = {slot: train_slots[slot][1] for slot in train_slots}
# Load model ontology
model.add_slot_candidates(slots)
for slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
# TRAINING !!!!!!!!!!!!!!!!!!
training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots)
# Evaluation on the development set
if args.do_eval:
# Get development set batch loaders= and ontology embeddings
if os.path.exists(os.path.join(DATA_DIR, 'dev.slots')):
dev_slots = torch.load(os.path.join(DATA_DIR, 'dev.slots'))
else:
embeddings.get_slot_candidate_embeddings('dev', args.max_slot_len, args.max_candidate_len, tokenizer, encoder)
dev_slots = torch.load(os.path.join(DATA_DIR, 'dev.slots'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dev.dataloader'))
if dev_dataloader.batch_size == args.dev_batch_size:
exists = True
if not exists:
dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dev.dataloader'))
# Get slot and value embeddings
slots = {slot: dev_slots[slot][0] for slot in dev_slots}
values = {slot: dev_slots[slot][1] for slot in dev_slots}
# Load model ontology
model.add_slot_candidates(slots)
for slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
# EVALUATION
jg_acc, sl_acc, loss, logits = training.evaluate(args, model, device, dev_dataloader, dev_slots)
logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f' % (loss, jg_acc, sl_acc))
# Evaluation on the test set
if args.do_test:
# Get test set batch loaders= and ontology embeddings
if os.path.exists(os.path.join(DATA_DIR, 'test.slots')):
test_slots = torch.load(os.path.join(DATA_DIR, 'test.slots'))
else:
embeddings.get_slot_candidate_embeddings('test', args.max_slot_len, args.max_candidate_len, tokenizer, encoder)
test_slots = torch.load(os.path.join(DATA_DIR, 'test.slots'))
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'test.dataloader')):
test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'test.dataloader'))
if test_dataloader.batch_size == args.test_batch_size:
exists = True
if not exists:
test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'test.dataloader'))
# Get slot and value embeddings
slots = {slot: test_slots[slot][0] for slot in test_slots}
values = {slot: test_slots[slot][1] for slot in test_slots}
# Load model ontology
model.add_slot_candidates(slots)
for slot in values:
model.add_value_candidates(slot, values[slot], replace=True)
# TESTING
jg_acc, sl_acc, loss, logits = training.evaluate(args, model, device, test_dataloader, test_slots)
logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f' % (loss, jg_acc, sl_acc))
tb_writer.close()
if args.gcs_bucket_name:
remote = os.path.join(os.path.basename(args.output_dir))
upload_local_directory_to_gcs(args.output_dir, args.gcs_bucket_name, remote)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracting the Turn Encoder from the model checkpoint"""
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import json
import torch
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_dir', help='Location of the belief states', required=True)
parser.add_argument('--output_dir', help='Output image path', default='calibration_plot.png')
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
reader = open(os.path.join(args.model_dir, 'config.json'), 'r')
config = json.load(reader)
reader.close()
writer = open(os.path.join(args.output_dir, 'config.json'), 'w')
json.dump(config, writer)
writer.close()
state_dict = torch.load(os.path.join(args.model_dir, 'pytorch_model.bin'), map_location='cpu')
state_dict = {key: item for key, item in state_dict.items() if 'turn_encoder' in key}
torch.save(state_dict, os.path.join(args.output_dir, 'pytorch_model.bin'))
if __name__ == '__main__':
main()
\ No newline at end of file
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculating the l2 norm accuracies"""
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import json
import torch
import numpy as np
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--belief_path', help='Path to the belief states and labels')
args = parser.parse_args()
belief = torch.load(args.belief_path)
labels = belief['labels']
belief = belief['logits']
l2 = l2_acc(belief, labels)
print(f'Model L2 Norm Goal Accuracy: {l2}')
def l2_acc(belief_state, labels):
l2 = []
state = []
labs = []
for slot, bs in belief_state.items():
# Predictive Distribution
bs = bs.reshape(-1, bs.size(-1))
lab = labels[slot].reshape(-1)
bs = bs[lab >= 0]
lab = lab[lab >= 0]
y = torch.zeros(bs.shape).cuda()
# Target predictions
y[range(y.size(0)), lab] = 1.0
state.append(bs)
labs.append(y)
# L2 Norm of the difference between the predicted and target
# err = torch.sqrt(((y - bs) ** 2).sum(-1)).reshape(-1, 1)
# l2.append(err)
state = torch.cat(state, -1)
labs = torch.cat(labs, -1)
err = torch.sqrt(((labs - state) ** 2).sum(-1))
# max -> Worse slot in each turn (matches strict joint goal rule)
# Average across all turns
# l2 = torch.cat(l2, 1).max(-1)[0].mean()
l2 = err.mean()
return l2
if __name__ == '__main__':
main()
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bayesian Matching Activation and Loss Functions"""
import torch
from torch import digamma
from torch.distributions import Dirichlet
from torch.nn import Module
from torch.nn.functional import kl_div
# Inverse Linear activation function
def invlinear(x):
z = (1.0 / (1.0 - x)) * (x < 0)
z += (1.0 + x) * (x >= 0)
return z
# Dirichlet activation function for the model
def dirichlet(a):
p = Dirichlet(invlinear(a))
return p.mean
# Pytorch BayesianMatchingLoss nn.Module
class BayesianMatchingLoss(Module):
def __init__(self, lamb=0.01, ignore_index=-1):
super(BayesianMatchingLoss, self).__init__()
self.lamb = lamb
self.ignore_index = ignore_index
def forward(self, alpha, labels, prior=None):
# Assert input sizes
assert alpha.dim() == 2 # Observations, predictive distribution
assert labels.dim() == 1 # Label for each observation
assert labels.size(0) == alpha.size(0) # Equal number of observation
# Confirm predictive distribution dimension
if labels.max() <= alpha.size(-1):
dimension = alpha.size(-1)
else:
raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1)))
# Remove observations with no labels
if prior is not None:
prior = prior[labels != self.ignore_index]
alpha = invlinear(alpha[labels != self.ignore_index])
labels = labels[labels != self.ignore_index]
# Initialise and reshape prior parameters
if prior is None:
prior = torch.ones(dimension)
prior = prior.to(alpha.device)
prior = Dirichlet(prior)
prior = prior.mean
# KL divergence term
predicted = Dirichlet(alpha)
predicted = torch.log(predicted.mean)
kl = kl_div(predicted, prior, reduction='none').sum(-1).mean()
kl *= self.lamb
# Expected log likelihood
alpha_target = alpha[range(labels.size(0)), labels]
alpha_0 = alpha.sum(1)
expected_likelihood = digamma(alpha_target) - digamma(alpha_0)
# Apply ELBO loss and mean reduction
elbo = expected_likelihood.mean() - kl
loss = -1.0 * elbo
return loss
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Expected calibration error"""
import torch
def fill_bins(n_bins, logits):
assert logits.dim() == 2
logits = logits.max(-1)[0]
step = 1.0 / n_bins
bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
bins = []
for b in range(n_bins):
lower, upper = bin_ranges[b], bin_ranges[b + 1]
if b == 0:
ids = torch.where((logits >= lower) * (logits <= upper))[0]
else:
ids = torch.where((logits > lower) * (logits <= upper))[0]
bins.append(ids)
return bins
def bin_confidence(bins, logits):
logits = logits.max(-1)[0]
scores = []
for b in bins:
if b is not None:
l = logits[b]
scores.append(l.mean())
else:
scores.append(-1)
scores = torch.tensor(scores)
return scores
def bin_accuracy(bins, logits, y_true):
y_pred = logits.argmax(-1)
acc = []
for b in bins:
if b is not None:
p = y_pred[b]
acc_ = (p == y_true[b]).float()
acc_ = acc_[y_true[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
return acc
def ece(logits, y_true, n_bins):
bins = fill_bins(n_bins, logits)
scores = bin_confidence(bins, logits)
acc = bin_accuracy(bins, logits, y_true)
n = logits.size(0)
bk = torch.tensor([b.size(0) for b in bins])
ece = torch.abs(scores - acc) * bk / n
ece = ece[acc >= 0.0]
ece = ece.sum().item()
return ece
def jg_ece(logits, y_true, n_bins):
y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
goal_acc = sum([goal_acc[slot] for slot in goal_acc])
goal_acc = (goal_acc == len(y_true)).int()
scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
scores = torch.cat(scores, 0).min(0)[0]
step = 1.0 / n_bins
bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
bins = []
for b in range(n_bins):
lower, upper = bin_ranges[b], bin_ranges[b + 1]
if b == 0:
ids = torch.where((scores >= lower) * (scores <= upper))[0]
else:
ids = torch.where((scores > lower) * (scores <= upper))[0]
bins.append(ids)
conf = []
for b in bins:
if b is not None:
l = scores[b]
conf.append(l.mean())
else:
conf.append(-1)
conf = torch.tensor(conf)
slot = [s for s in y_true][0]
acc = []
for b in bins:
if b is not None:
acc_ = goal_acc[b]
acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.float().mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
n = logits[slot].size(0)
bk = torch.tensor([b.size(0) for b in bins])
ece = torch.abs(conf - acc) * bk / n
ece = ece[acc >= 0.0]
ece = ece.sum().item()
return ece
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment