Skip to content
Snippets Groups Projects
Unverified Commit 350f36c4 authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Setsumbt updates

parent ffb3dc42
Branches
No related tags found
No related merge requests found
Showing
with 779 additions and 1732 deletions
from convlab.dst.setsumbt.tracker import SetSUMBTTracker
\ No newline at end of file
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calibration Plot plotting script"""
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import torch
from matplotlib import pyplot as plt
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_dir', help='Location of the belief states', required=True)
parser.add_argument('--output', help='Output image path', default='calibration_plot.png')
parser.add_argument('--n_bins', help='Number of bins', default=10, type=int)
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
path = args.data_dir
models = os.listdir(path)
models = [os.path.join(path, model, 'test.predictions') for model in models]
fig = plt.figure(figsize=(14,8))
font=20
plt.tick_params(labelsize=font-2)
linestyle = ['-', ':', (0, (3, 5, 1, 5)), '-.', (0, (5, 10))]
for i, model in enumerate(models):
conf, acc = get_calibration(model, device, n_bins=args.n_bins)
name = model.split('/')[-2].strip()
print(name, conf, acc)
plt.plot(conf, acc, label=name, linestyle=linestyle[i], linewidth=3)
plt.plot(torch.tensor([0,1]), torch.tensor([0,1]), linestyle='--', color='black', linewidth=3)
plt.xlabel('Confidence', fontsize=font)
plt.ylabel('Joint Goal Accuracy', fontsize=font)
plt.legend(fontsize=font)
plt.savefig(args.output)
def get_calibration(path, device, n_bins=10, temperature=1.00):
probs = torch.load(path, map_location=device)
y_true = probs['state_labels']
probs = probs['belief_states']
y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs}
goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
goal_acc = sum([goal_acc[slot] for slot in goal_acc])
goal_acc = (goal_acc == len(y_true)).int()
scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs]
scores = torch.cat(scores, 0).min(0)[0]
step = 1.0 / float(n_bins)
bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
bins = []
for b in range(n_bins):
lower, upper = bin_ranges[b], bin_ranges[b + 1]
if b == 0:
ids = torch.where((scores >= lower) * (scores <= upper))[0]
else:
ids = torch.where((scores > lower) * (scores <= upper))[0]
bins.append(ids)
conf = [0.0]
for b in bins:
if b.size(0) > 0:
l = scores[b]
conf.append(l.mean())
else:
conf.append(-1)
conf = torch.tensor(conf)
slot = [s for s in y_true][0]
acc = [0.0]
for b in bins:
if b.size(0) > 0:
acc_ = goal_acc[b]
acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.float().mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
conf = conf[acc != -1]
acc = acc[acc != -1]
return conf, acc
if __name__ == '__main__':
main()
{
"model_type": "SetSUMBT",
"dataset": "multiwoz21",
"no_action_prediction": false,
"loss_function": "distribution_distillation",
"model_name_or_path": "roberta-base",
"candidate_embedding_model_name": "roberta-base",
"train_batch_size": 3,
"dev_batch_size": 12,
"test_batch_size": 16
}
\ No newline at end of file
{
"model_type": "Ensemble-SetSUMBT",
"dataset": "multiwoz21",
"ensemble_size": 5,
"data_sampling_size": 7500,
"no_action_prediction": false,
"model_name_or_path": "roberta-base",
"candidate_embedding_model_name": "roberta-base",
"train_batch_size": 3,
"dev_batch_size": 3,
"test_batch_size": 3
}
\ No newline at end of file
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
"no_action_prediction": true, "no_action_prediction": true,
"model_name_or_path": "roberta-base", "model_name_or_path": "roberta-base",
"candidate_embedding_model_name": "roberta-base", "candidate_embedding_model_name": "roberta-base",
"transformers_local_files_only": false,
"train_batch_size": 3, "train_batch_size": 3,
"dev_batch_size": 16, "dev_batch_size": 12,
"test_batch_size": 16, "test_batch_size": 16
"run_nbt": true
} }
\ No newline at end of file
from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size
from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create Ontology Embeddings"""
import json
import os
import random
from copy import deepcopy
import torch
import numpy as np
from tqdm import tqdm
def set_seed(args):
"""
Set random seeds
Args:
args (Arguments class): Arguments class containing seed and number of gpus to use
"""
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor:
"""
Embed candidates
Args:
candidates (list): List of candidate descriptions
args (argument class): Runtime arguments
tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
embedding_model (transformer Model): Transformer model for embedding candidate descriptions
Returns:
feats (torch.tensor): Embeddings of the candidate descriptions
"""
# Tokenize candidate descriptions
feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len,
padding='max_length', truncation='longest_first')
for val in candidates]
# Encode tokenized descriptions
with torch.no_grad():
feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]}
embedded_feats = embedding_model(**feats) # [num_candidates, max_candidate_len, hidden_dim]
# Reduce/pool descriptions embeddings if required
if args.set_similarity:
feats = embedded_feats.last_hidden_state.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim]
elif args.candidate_pooling == 'cls':
feats = embedded_feats.pooler_output.detach().cpu() # [num_candidates, hidden_dim]
elif args.candidate_pooling == "mean":
feats = embedded_feats.last_hidden_state.detach().cpu()
feats = feats.sum(1)
feats = torch.nn.functional.layer_norm(feats, feats.size())
feats = feats.detach().cpu() # [num_candidates, hidden_dim]
return feats
def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True):
"""
Get embeddings for slots and candidates
Args:
ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets
set_type (str): Subset of the dataset being used (train/validation/test)
args (argument class): Runtime arguments
tokenizer (transformers Tokenizer): Tokenizer for the embedding_model
embedding_model (transformer Model): Transormer model for embedding candidate descriptions
save_to_file (bool): Indication of whether to save information to file
Returns:
slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
"""
# Set model to eval mode
embedding_model.eval()
slots = dict()
for domain, subset in tqdm(ontology.items(), desc='Domains'):
for slot, slot_info in tqdm(subset.items(), desc='Slots'):
# Get description or use "domain-slot"
if args.use_descriptions:
desc = slot_info['description']
else:
desc = f"{domain}-{slot}"
# Encode domain-slot pair description
slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0]
# Obtain possible value set and discard requestable value
values = deepcopy(slot_info['possible_values'])
is_requestable = False
if '?' in values:
is_requestable = True
values.remove('?')
# Encode value candidates
if values:
feats = encode_candidates(values, args, tokenizer, embedding_model)
else:
feats = None
# Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot
slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
# Dump tensors and ontology for use in training and evaluation
if save_to_file:
writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
torch.save(slots, writer)
writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w')
json.dump(ontology, writer, indent=2)
writer.close()
return slots
from convlab.dst.setsumbt.datasets.unified_format import get_dataloader, change_batch_size, dataloader_sample_dialogues
from convlab.dst.setsumbt.datasets.metrics import (JointGoalAccuracy, BeliefStateUncertainty,
ActPredictionAccuracy, Metrics)
from convlab.dst.setsumbt.datasets.distillation import get_dataloader as get_distillation_dataloader
# -*- coding: utf-8 -*-
# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Get ensemble predictions and build distillation dataloaders"""
import os
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from convlab.dst.setsumbt.datasets.unified_format import UnifiedFormatDataset
from convlab.dst.setsumbt.datasets.utils import IdTensor
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_dataloader(ensemble_path:str, set_type: str = 'train', batch_size: int = 3,
reduction: str = 'none') -> DataLoader:
"""
Get dataloader for distillation of ensemble.
Args:
ensemble_path: Path to ensemble model and predictive distributions.
set_type: Dataset split to load.
batch_size: Batch size.
reduction: Reduction to apply to ensemble predictive distributions.
Returns:
loader: Dataloader for distillation.
"""
# Load data and predictions from ensemble
path = os.path.join(ensemble_path, 'dataloaders', f"{set_type}.dataloader")
dataset = torch.load(path).dataset
path = os.path.join(ensemble_path, 'predictions', f"{set_type}.data")
data = torch.load(path)
dialogue_ids = data.pop('dialogue_ids')
# Preprocess data
data = reduce_data(data, reduction=reduction)
data = flatten_data(data)
data = do_label_padding(data)
# Build dataset and dataloader
data = UnifiedFormatDataset.from_datadict(set_type=set_type if set_type != 'dev' else 'validation',
data=data,
ontology=dataset.ontology,
ontology_embeddings=dataset.ontology_embeddings)
data.features['dialogue_ids'] = IdTensor(dialogue_ids)
if set_type == 'train':
sampler = RandomSampler(data)
else:
sampler = SequentialSampler(data)
loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return loader
def reduce_data(data: dict, reduction: str = 'none') -> dict:
"""
Reduce ensemble predictive distributions.
Args:
data: Dictionary of ensemble predictive distributions.
reduction: Reduction to apply to ensemble predictive distributions.
Returns:
data: Reduced ensemble predictive distributions.
"""
if reduction == 'mean':
data['belief_state'] = {slot: probs.mean(-2) for slot, probs in data['belief_state'].items()}
if 'request_probabilities' in data:
data['request_probabilities'] = {slot: probs.mean(-1)
for slot, probs in data['request_probabilities'].items()}
data['active_domain_probabilities'] = {domain: probs.mean(-1)
for domain, probs in data['active_domain_probabilities'].items()}
data['general_act_probabilities'] = data['general_act_probabilities'].mean(-2)
return data
def do_label_padding(data: dict) -> dict:
"""
Add padding to the ensemble predictions (used as labels in distillation)
Args:
data: Dictionary of ensemble predictions
Returns:
data: Padded ensemble predictions
"""
if 'attention_mask' in data:
dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
else:
dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0)
for key in data:
if key not in ['input_ids', 'attention_mask', 'token_type_ids']:
data[key][dialogs, turns] = -1
return data
def flatten_data(data: dict) -> dict:
"""
Map data to flattened feature format used in training
Args:
data: Ensemble prediction data
Returns:
data: Flattened ensemble prediction data
"""
data_new = dict()
for label, feats in data.items():
if type(feats) == dict:
for label_, feats_ in feats.items():
data_new[label + '-' + label_] = feats_
else:
data_new[label] = feats
return data_new
This diff is collapsed.
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,260 +14,81 @@ ...@@ -14,260 +14,81 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Convlab3 Unified Format Dialogue Datasets""" """Convlab3 Unified Format Dialogue Datasets"""
import pdb
from copy import deepcopy
import torch import torch
import transformers import transformers
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from tqdm import tqdm
from convlab.util import load_dataset from convlab.util import load_dataset
from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, from convlab.dst.setsumbt.datasets.utils import (get_ontology_slots, ontology_add_values,
get_values_from_data, ontology_add_requestable_slots, get_values_from_data, ontology_add_requestable_slots,
get_requestable_slots, load_dst_data, extract_dialogues, get_requestable_slots, load_dst_data, extract_dialogues,
combine_value_sets, IdTensor) combine_value_sets)
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
def convert_examples_to_features(data: list,
ontology: dict,
tokenizer: PreTrainedTokenizer,
max_turns: int = 12,
max_seq_len: int = 64) -> dict:
"""
Convert dialogue examples to model input features and labels
Args:
data (list): List of all extracted dialogues
ontology (dict): Ontology dictionary containing slots, slot descriptions and
possible value sets including requests
tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used
max_turns (int): Maximum numbers of turns in a dialogue
max_seq_len (int): Maximum number of tokens in a dialogue turn
Returns:
features (dict): All inputs and labels required to train the model
"""
features = dict()
ontology = deepcopy(ontology)
# Get encoder input for system, user utterance pairs
input_feats = []
for dial in tqdm(data):
dial_feats = []
for turn in dial:
if len(turn['system_utterance']) == 0:
usr = turn['user_utterance']
dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True,
max_length=max_seq_len, padding='max_length',
truncation='longest_first'))
else:
usr = turn['user_utterance']
sys = turn['system_utterance']
dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True,
max_length=max_seq_len, padding='max_length',
truncation='longest_first'))
# Truncate
if len(dial_feats) >= max_turns:
break
input_feats.append(dial_feats)
del dial_feats
# Perform turn level padding
dial_ids = list()
for dial in data:
_ids = [turn['dialogue_id'] for turn in dial][:max_turns]
_ids += [''] * (max_turns - len(_ids))
dial_ids.append(_ids)
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]:
token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
else:
token_type_ids = None
if 'attention_mask' in input_feats[0][0]:
attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
else:
attention_mask = None
del input_feats
# Create torch data tensors
features['dialogue_ids'] = IdTensor(dial_ids)
features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
del input_ids, token_type_ids, attention_mask
# Extract all informable and requestable slots from the ontology
informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
if ontology[domain][slot]['possible_values']
and ontology[domain][slot]['possible_values'] != ['?']]
requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain]
if '?' in ontology[domain][slot]['possible_values']]
for slot in requestable_slots:
domain, slot = slot.split('-', 1)
ontology[domain][slot]['possible_values'].remove('?')
# Extract a list of domains from the ontology slots
domains = list(set(informable_slots + requestable_slots))
domains = list(set([slot.split('-', 1)[0] for slot in domains]))
# Create slot labels
for domslot in tqdm(informable_slots):
labels = []
for dial in data:
labs = []
for turn in dial:
value = [v for d, substate in turn['state'].items() for s, v in substate.items()
if f'{d}-{s}' == domslot]
domain, slot = domslot.split('-', 1)
if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
value = value[0] if value else 'none'
else:
value = -1
if value in ontology[domain][slot]['possible_values'] and value != -1:
value = ontology[domain][slot]['possible_values'].index(value)
else:
value = -1 # If value is not in ontology then we do not penalise the model
labs.append(value)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['state_labels-' + domslot] = labels
# Create requestable slot labels
for domslot in tqdm(requestable_slots):
labels = []
for dial in data:
labs = []
for turn in dial:
domain, slot = domslot.split('-', 1)
if turn['dataset_name'] in ontology[domain][slot]['dataset_names']:
acts = [act['intent'] for act in turn['dialogue_acts']
if act['domain'] == domain and act['slot'] == slot]
if acts:
act_ = acts[0]
if act_ == 'request':
labs.append(1)
else:
labs.append(0)
else:
labs.append(0)
else:
labs.append(-1)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['request_labels-' + domslot] = labels
# General act labels (1-goodbye, 2-thank you)
labels = []
for dial in tqdm(data):
labs = []
for turn in dial:
acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
if acts:
if 'bye' in acts:
labs.append(1)
else:
labs.append(2)
else:
labs.append(0)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['general_act_labels'] = labels
# Create active domain labels
for domain in tqdm(domains):
labels = []
for dial in data:
labs = []
for turn in dial:
possible_domains = list()
for dom in ontology:
for slt in ontology[dom]:
if turn['dataset_name'] in ontology[dom][slt]['dataset_names']:
possible_domains.append(dom)
if domain in turn['active_domains']:
labs.append(1)
elif domain in possible_domains:
labs.append(0)
else:
labs.append(-1)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['active_domain_labels-' + domain] = labels
del labels
return features
class UnifiedFormatDataset(Dataset): class UnifiedFormatDataset(Dataset):
""" """
Class for preprocessing, and storing data easily from the Convlab3 unified format. Class for preprocessing, and storing data easily from the Convlab3 unified format.
Attributes: Attributes:
dataset_dict (dict): Dictionary containing all the data in dataset set_type (str): Subset of the dataset to load (train, validation or test)
dataset_dicts (dict): Dictionary containing all the data in dataset
ontology (dict): Set of all domain-slot-value triplets in the ontology of the model ontology (dict): Set of all domain-slot-value triplets in the ontology of the model
ontology_embeddings (dict): Set of all domain-slot-value triplets in the ontology of the model
features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model
""" """
def __init__(self, def __init__(self,
dataset_name: str, dataset_name: str,
set_type: str, set_type: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
ontology_encoder,
max_turns: int = 12, max_turns: int = 12,
max_seq_len: int = 64, max_seq_len: int = 64,
train_ratio: float = 1.0, train_ratio: float = 1.0,
seed: int = 0, seed: int = 0,
data: dict = None, data: dict = None,
ontology: dict = None): ontology: dict = None,
ontology_embeddings: dict = None):
""" """
Args: Args:
dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +) dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +)
set_type (str): Subset of the dataset to load (train, validation or test) set_type (str): Subset of the dataset to load (train, validation or test)
tokenizer (transformers tokenizer): Tokenizer for the encoder model used tokenizer (transformers tokenizer): Tokenizer for the encoder model used
ontology_encoder (transformers model): Ontology encoder model
max_turns (int): Maximum numbers of turns in a dialogue max_turns (int): Maximum numbers of turns in a dialogue
max_seq_len (int): Maximum number of tokens in a dialogue turn max_seq_len (int): Maximum number of tokens in a dialogue turn
train_ratio (float): Fraction of training data to use during training train_ratio (float): Fraction of training data to use during training
seed (int): Seed governing random order of ids for subsampling seed (int): Seed governing random order of ids for subsampling
data (dict): Dataset features for loading from dict data (dict): Dataset features for loading from dict
ontology (dict): Ontology dict for loading from dict ontology (dict): Ontology dict for loading from dict
ontology_embeddings (dict): Ontology embeddings for loading from dict
""" """
# Load data from dict if provided
if data is not None: if data is not None:
self.set_type = set_type
self.ontology = ontology self.ontology = ontology
self.ontology_embeddings = ontology_embeddings
self.features = data self.features = data
# Load data from dataset if data is not provided
else: else:
if '+' in dataset_name: if '+' in dataset_name:
dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')] dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')]
else: else:
dataset_args = [{"dataset_name": dataset_name}] dataset_args = [{"dataset_name": dataset_name}]
self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args] self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
self.set_type = set_type
self.ontology = get_ontology_slots(dataset_name) self.ontology = get_ontology_slots(dataset_name)
values = [get_values_from_data(dataset, set_type) for dataset in self.dataset_dicts] values = [get_values_from_data(dataset, set_type) for dataset in self.dataset_dicts]
self.ontology = ontology_add_values(self.ontology, combine_value_sets(values), set_type) self.ontology = ontology_add_values(self.ontology, combine_value_sets(values), set_type)
self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts)) self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts))
tokenizer.set_setsumbt_ontology(self.ontology)
self.ontology_embeddings = ontology_encoder.get_slot_candidate_embeddings()
if train_ratio != 1.0: if train_ratio != 1.0:
for dataset_args_ in dataset_args: for dataset_args_ in dataset_args:
dataset_args_['dial_ids_order'] = seed dataset_args_['dial_ids_order'] = seed
...@@ -282,7 +103,7 @@ class UnifiedFormatDataset(Dataset): ...@@ -282,7 +103,7 @@ class UnifiedFormatDataset(Dataset):
data = [] data = []
for idx, data_ in enumerate(data_list): for idx, data_ in enumerate(data_list):
data += extract_dialogues(data_, dataset_args[idx]["dataset_name"]) data += extract_dialogues(data_, dataset_args[idx]["dataset_name"])
self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len) self.features = tokenizer.encode(data, max_turns, max_seq_len)
def __getitem__(self, index: int) -> dict: def __getitem__(self, index: int) -> dict:
""" """
...@@ -350,14 +171,15 @@ class UnifiedFormatDataset(Dataset): ...@@ -350,14 +171,15 @@ class UnifiedFormatDataset(Dataset):
if self.features[label] is not None} if self.features[label] is not None}
@classmethod @classmethod
def from_datadict(cls, data: dict, ontology: dict): def from_datadict(cls, set_type: str, data: dict, ontology: dict, ontology_embeddings: dict):
return cls(None, None, None, data=data, ontology=ontology) return cls(None, set_type, None, None, data=data, ontology=ontology, ontology_embeddings=ontology_embeddings)
def get_dataloader(dataset_name: str, def get_dataloader(dataset_name: str,
set_type: str, set_type: str,
batch_size: int, batch_size: int,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
ontology_encoder,
max_turns: int = 12, max_turns: int = 12,
max_seq_len: int = 64, max_seq_len: int = 64,
device='cpu', device='cpu',
...@@ -372,6 +194,7 @@ def get_dataloader(dataset_name: str, ...@@ -372,6 +194,7 @@ def get_dataloader(dataset_name: str,
set_type (str): Subset of the dataset to load (train, validation or test) set_type (str): Subset of the dataset to load (train, validation or test)
batch_size (int): Batch size for the dataloader batch_size (int): Batch size for the dataloader
tokenizer (transformers tokenizer): Tokenizer for the encoder model used tokenizer (transformers tokenizer): Tokenizer for the encoder model used
ontology_encoder (OntologyEncoder): Ontology encoder object
max_turns (int): Maximum numbers of turns in a dialogue max_turns (int): Maximum numbers of turns in a dialogue
max_seq_len (int): Maximum number of tokens in a dialogue turn max_seq_len (int): Maximum number of tokens in a dialogue turn
device (torch device): Device to map data to device (torch device): Device to map data to
...@@ -382,8 +205,8 @@ def get_dataloader(dataset_name: str, ...@@ -382,8 +205,8 @@ def get_dataloader(dataset_name: str,
Returns: Returns:
loader (torch dataloader): Dataloader to train and evaluate the setsumbt model loader (torch dataloader): Dataloader to train and evaluate the setsumbt model
''' '''
data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio, data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, ontology_encoder, max_turns, max_seq_len,
seed=seed) train_ratio=train_ratio, seed=seed)
data.to(device) data.to(device)
if resampled_size: if resampled_size:
...@@ -418,6 +241,7 @@ def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader: ...@@ -418,6 +241,7 @@ def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader:
return loader return loader
def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader: def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader:
""" """
Sample a subset of the dialogues in a dataloader Sample a subset of the dialogues in a dataloader
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
"""Convlab3 Unified dataset data processing utilities""" """Convlab3 Unified dataset data processing utilities"""
import numpy import numpy
import pdb
from convlab.util import load_ontology, load_dst_data, load_nlu_data from convlab.util import load_ontology, load_dst_data, load_nlu_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
def get_ontology_slots(dataset_name: str) -> dict: def get_ontology_slots(dataset_name: str) -> dict:
...@@ -424,6 +423,7 @@ class IdTensor: ...@@ -424,6 +423,7 @@ class IdTensor:
def extract_dialogues(data: list, dataset_name: str) -> list: def extract_dialogues(data: list, dataset_name: str) -> list:
""" """
Extract all dialogues from dataset Extract all dialogues from dataset
Args: Args:
data (list): List of all dialogues in a subset of the data data (list): List of all dialogues in a subset of the data
dataset_name (str): Name of the dataset to which the dialogues belongs dataset_name (str): Name of the dataset to which the dialogues belongs
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Get ensemble predictions and build distillation dataloaders"""
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os
import json
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size
from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT
from convlab.dst.setsumbt.modeling import training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader:
"""
Build dataloader from ensemble prediction data
Args:
data: Dictionary of ensemble predictions
ontology: Data ontology
set_type: Data subset (train/validation/test)
batch_size: Number of dialogues per batch
Returns:
loader: Data loader object
"""
data = flatten_data(data)
data = do_label_padding(data)
data = UnifiedFormatDataset.from_datadict(data, ontology)
if set_type == 'train':
sampler = RandomSampler(data)
else:
sampler = SequentialSampler(data)
loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return loader
def do_label_padding(data: dict) -> dict:
"""
Add padding to the ensemble predictions (used as labels in distillation)
Args:
data: Dictionary of ensemble predictions
Returns:
data: Padded ensemble predictions
"""
if 'attention_mask' in data:
dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
else:
dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0)
for key in data:
if key not in ['input_ids', 'attention_mask', 'token_type_ids']:
data[key][dialogs, turns] = -1
return data
def flatten_data(data: dict) -> dict:
"""
Map data to flattened feature format used in training
Args:
data: Ensemble prediction data
Returns:
data: Flattened ensemble prediction data
"""
data_new = dict()
for label, feats in data.items():
if type(feats) == dict:
for label_, feats_ in feats.items():
data_new[label + '-' + label_] = feats_
else:
data_new[label] = feats
return data_new
def get_ensemble_distributions(args):
"""
Load data and get ensemble predictions
Args:
args: Runtime arguments
"""
device = DEVICE
model = EnsembleSetSUMBT.from_pretrained(args.model_path)
model = model.to(device)
print('Model Loaded!')
dataloader = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
database = os.path.join(args.model_path, 'database', f'{args.set_type}.db')
dataloader = torch.load(dataloader)
database = torch.load(database)
if dataloader.batch_size != args.batch_size:
dataloader = change_batch_size(dataloader, args.batch_size)
training.set_ontology_embeddings(model, database)
print('Environment set up.')
input_ids = []
token_type_ids = []
attention_mask = []
state_labels = {slot: [] for slot in model.informable_slot_ids}
request_labels = {slot: [] for slot in model.requestable_slot_ids}
active_domain_labels = {domain: [] for domain in model.domain_ids}
general_act_labels = []
is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None
belief_state = {slot: [] for slot in model.informable_slot_ids}
request_probs = {slot: [] for slot in model.requestable_slot_ids}
active_domain_probs = {domain: [] for domain in model.domain_ids}
general_act_probs = []
model.eval()
for batch in tqdm(dataloader, desc='Batch:'):
ids = batch['input_ids']
tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None
mask = batch['attention_mask'] if 'attention_mask' in batch else None
if 'is_noisy' in batch:
is_noisy.append(batch['is_noisy'])
input_ids.append(ids)
token_type_ids.append(tt_ids)
attention_mask.append(mask)
ids = ids.to(device)
tt_ids = tt_ids.to(device) if tt_ids is not None else None
mask = mask.to(device) if mask is not None else None
for slot in state_labels:
state_labels[slot].append(batch['state_labels-' + slot])
if model.config.predict_actions:
for slot in request_labels:
request_labels[slot].append(batch['request_labels-' + slot])
for domain in active_domain_labels:
active_domain_labels[domain].append(batch['active_domain_labels-' + domain])
general_act_labels.append(batch['general_act_labels'])
with torch.no_grad():
p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction)
for slot in belief_state:
belief_state[slot].append(p[slot].cpu())
if model.config.predict_actions:
for slot in request_probs:
request_probs[slot].append(p_req[slot].cpu())
for domain in active_domain_probs:
active_domain_probs[domain].append(p_dom[domain].cpu())
general_act_probs.append(p_gen.cpu())
input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None
token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None
attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None
is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None
state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()}
if model.config.predict_actions:
request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()}
active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()}
general_act_labels = torch.cat(general_act_labels, 0)
belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()}
if model.config.predict_actions:
request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()}
active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()}
general_act_probs = torch.cat(general_act_probs, 0)
data = {'input_ids': input_ids}
if token_type_ids is not None:
data['token_type_ids'] = token_type_ids
if attention_mask is not None:
data['attention_mask'] = attention_mask
if is_noisy is not None:
data['is_noisy'] = is_noisy
data['state_labels'] = state_labels
data['belief_state'] = belief_state
if model.config.predict_actions:
data['request_labels'] = request_labels
data['active_domain_labels'] = active_domain_labels
data['general_act_labels'] = general_act_labels
data['request_probs'] = request_probs
data['active_domain_probs'] = active_domain_probs
data['general_act_probs'] = general_act_probs
file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
torch.save(data, file)
def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str):
"""
Convert ensemble predictions to predictions file format.
Args:
model_path: Path to ensemble location.
set_type: Evaluation dataset (train/dev/test).
"""
data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data"))
# Get oracle labels
if 'request_probs' in data:
data_new = {'state_labels': data['state_labels'],
'request_labels': data['request_labels'],
'active_domain_labels': data['active_domain_labels'],
'general_act_labels': data['general_act_labels']}
else:
data_new = {'state_labels': data['state_labels']}
# Marginalising across ensemble distributions
data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()}
if 'request_probs' in data:
data_new['request_probs'] = {slot: distribution.mean(-1)
for slot, distribution in data['request_probs'].items()}
data_new['active_domain_probs'] = {domain: distribution.mean(-1)
for domain, distribution in data['active_domain_probs'].items()}
data_new['general_act_probs'] = data['general_act_probs'].mean(-2)
# Save predictions file
predictions_dir = os.path.join(model_path, 'predictions')
if not os.path.exists(predictions_dir):
os.mkdir(predictions_dir)
torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions"))
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_path', type=str)
parser.add_argument('--set_type', type=str)
parser.add_argument('--batch_size', type=int, default=3)
parser.add_argument('--reduction', type=str, default='none')
parser.add_argument('--get_ensemble_distributions', action='store_true')
parser.add_argument('--convert_distributions_to_predictions', action='store_true')
parser.add_argument('--build_dataloaders', action='store_true')
args = parser.parse_args()
if args.get_ensemble_distributions:
get_ensemble_distributions(args)
if args.convert_distributions_to_predictions:
ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type)
if args.build_dataloaders:
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
data = torch.load(path)
reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
ontology = json.load(reader)
reader.close()
loader = get_loader(data, ontology, args.set_type, args.batch_size)
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
torch.save(loader, path)
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run SetSUMBT Calibration"""
import logging
import os
import torch
from transformers import (BertModel, BertConfig, BertTokenizer,
RobertaModel, RobertaConfig, RobertaTokenizer)
from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
from convlab.dst.setsumbt.dataset import unified_format
from convlab.dst.setsumbt.dataset import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, update_args
from convlab.dst.setsumbt.modeling import evaluation_utils
from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc
from convlab.dst.setsumbt.modeling import training
# Available model
MODELS = {
'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
}
def main(args=None, config=None):
# Get arguments
if args is None:
args, config = get_args(MODELS)
if args.model_type in MODELS:
SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
else:
raise NameError('NotImplemented')
# Set up output directory
OUTPUT_DIR = args.output_dir
args.output_dir = OUTPUT_DIR
if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
# Set pretrained model path to the trained checkpoint
paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path)
else:
paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
if paths:
paths = paths[0]
args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path)
args = update_args(args, config)
# Create logger
global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
fh = logging.FileHandler(args.logging_path)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
# Get device
if torch.cuda.is_available() and args.n_gpu > 0:
device = torch.device('cuda')
else:
device = torch.device('cpu')
args.n_gpu = 0
if args.n_gpu == 0:
args.fp16 = False
# Set up model training/evaluation
evaluation_utils.set_seed(args)
# Perform tasks
if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')):
pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
state_labels = pred['state_labels']
belief_states = pred['belief_states']
if 'request_labels' in pred:
request_labels = pred['request_labels']
request_probs = pred['request_probs']
active_domain_labels = pred['active_domain_labels']
active_domain_probs = pred['active_domain_probs']
general_act_labels = pred['general_act_labels']
general_act_probs = pred['general_act_probs']
else:
request_probs = None
del pred
else:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if test_dataloader.batch_size != args.test_batch_size:
test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
else:
tokenizer = Tokenizer(config.candidate_embedding_model_name)
test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
else:
encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
'test', args, tokenizer, encoder)
# Initialise Model
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = model.to(device)
training.set_ontology_embeddings(model, test_slots)
belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader)
state_labels = belief_states[1]
request_probs = belief_states[2]
request_labels = belief_states[3]
active_domain_probs = belief_states[4]
active_domain_labels = belief_states[5]
general_act_probs = belief_states[6]
general_act_labels = belief_states[7]
belief_states = belief_states[0]
out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs,
'request_labels': request_labels, 'active_domain_probs': active_domain_probs,
'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs,
'general_act_labels': general_act_labels}
torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions'))
# Calculate calibration metrics
jg = jg_ece(belief_states, state_labels, 10)
logger.info('Joint Goal ECE: %f' % jg)
jg_acc = 0.0
padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
padding = (padding == len(state_labels))
padding = padding.reshape(-1)
for slot in belief_states:
p_ = belief_states[slot]
gold = state_labels[slot]
pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1)
acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad]
acc = torch.tensor(acc).float()
jg_acc += acc
n_turns = jg_acc.size(0)
jg_acc = sum((jg_acc / len(belief_states)).int()).float()
jg_acc /= n_turns
logger.info(f'Joint Goal Accuracy: {jg_acc}')
l2 = l2_acc(belief_states, state_labels, remove_belief=False)
logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
l2 = l2_acc(belief_states, state_labels, remove_belief=True)
logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}')
padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0
padding = (padding == len(state_labels))
padding = padding.reshape(-1)
tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0
for slot in belief_states:
p_ = belief_states[slot]
gold = state_labels[slot].reshape(-1)
p_ = p_.reshape(-1, p_.size(-1))
p_ = p_[~padding].argmax(-1)
gold = gold[~padding]
tp += (p_ == gold)[gold != 0].int().sum().item()
fp += (p_ != 0)[gold == 0].int().sum().item()
fp += (p_ != gold)[gold != 0].int().sum().item()
fp -= (p_ == 0)[gold != 0].int().sum().item()
fn += (p_ == 0)[gold != 0].int().sum().item()
tn += (p_ == 0)[gold == 0].int().sum().item()
n += p_.size(0)
acc = (tp + tn) / n
prec = tp / (tp + fp)
rec = tp / (tp + fn)
f1 = 2 * (prec * rec) / (prec + rec)
logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}")
if request_probs is not None:
tp, fp, fn = 0.0, 0.0, 0.0
for slot in request_probs:
p = request_probs[slot]
l = request_labels[slot]
tp += (p.round().int() * (l == 1)).reshape(-1).float()
fp += (p.round().int() * (l == 0)).reshape(-1).float()
fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
tp /= len(request_probs)
fp /= len(request_probs)
fn /= len(request_probs)
f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
logger.info('Request F1 Score: %f' % f1.item())
for slot in request_probs:
p = request_probs[slot]
p = p.unsqueeze(-1)
p = torch.cat((1 - p, p), -1)
request_probs[slot] = p
jg = jg_ece(request_probs, request_labels, 10)
logger.info('Request Joint Goal ECE: %f' % jg)
tp, fp, fn = 0.0, 0.0, 0.0
for dom in active_domain_probs:
p = active_domain_probs[dom]
l = active_domain_labels[dom]
tp += (p.round().int() * (l == 1)).reshape(-1).float()
fp += (p.round().int() * (l == 0)).reshape(-1).float()
fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float()
tp /= len(active_domain_probs)
fp /= len(active_domain_probs)
fn /= len(active_domain_probs)
f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum()))
logger.info('Domain F1 Score: %f' % f1.item())
for dom in active_domain_probs:
p = active_domain_probs[dom]
p = p.unsqueeze(-1)
p = torch.cat((1 - p, p), -1)
active_domain_probs[dom] = p
jg = jg_ece(active_domain_probs, active_domain_labels, 10)
logger.info('Domain Joint Goal ECE: %f' % jg)
tp = ((general_act_probs.argmax(-1) > 0) *
(general_act_labels > 0)).reshape(-1).float().sum()
fp = ((general_act_probs.argmax(-1) > 0) *
(general_act_labels == 0)).reshape(-1).float().sum()
fn = ((general_act_probs.argmax(-1) == 0) *
(general_act_labels > 0)).reshape(-1).float().sum()
f1 = tp / (tp + 0.5 * (fp + fn))
logger.info('General Act F1 Score: %f' % f1.item())
err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)),
general_act_labels.reshape(-1), 10)
logger.info('General Act ECE: %f' % err)
for slot in request_probs:
p = request_probs[slot].unsqueeze(-1)
request_probs[slot] = torch.cat((1 - p, p), -1)
l2 = l2_acc(request_probs, request_labels, remove_belief=False)
logger.info(f'Model L2 Norm Request Accuracy: {l2}')
l2 = l2_acc(request_probs, request_labels, remove_belief=True)
logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}')
for slot in active_domain_probs:
p = active_domain_probs[slot].unsqueeze(-1)
active_domain_probs[slot] = torch.cat((1 - p, p), -1)
l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False)
logger.info(f'Model L2 Norm Domain Accuracy: {l2}')
l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True)
logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
general_act_labels = {'general': general_act_labels}
general_act_probs = {'general': general_act_probs}
l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
logger.info(f'Model L2 Norm General Act Accuracy: {l2}')
l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}')
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run SetSUMBT training/eval"""
import logging
import os
from shutil import copy2 as copy
import json
from copy import deepcopy
import pdb
import torch
import transformers
from transformers import (BertModel, BertConfig, BertTokenizer,
RobertaModel, RobertaConfig, RobertaTokenizer)
from tensorboardX import SummaryWriter
from tqdm import tqdm
from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT
from convlab.dst.setsumbt.dataset import unified_format
from convlab.dst.setsumbt.modeling import training
from convlab.dst.setsumbt.dataset import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, update_args
from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
from convlab.util.custom_util import model_downloader
# Available model
MODELS = {
'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
}
def main(args=None, config=None):
# Get arguments
if args is None:
args, config = get_args(MODELS)
if args.model_type in MODELS:
SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
else:
raise NameError('NotImplemented')
# Set up output directory
OUTPUT_DIR = args.output_dir
if not os.path.exists(OUTPUT_DIR):
if "http" not in OUTPUT_DIR:
os.makedirs(OUTPUT_DIR)
os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
else:
# Get path /.../convlab/dst/setsumbt/multiwoz/models
download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
download_path = os.path.join(download_path, 'models')
if not os.path.exists(download_path):
os.mkdir(download_path)
model_downloader(download_path, OUTPUT_DIR)
# Downloadable model path format http://.../model_name.zip
OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '')
OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR)
args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
args.output_dir = OUTPUT_DIR
# Set pretrained model path to the trained checkpoint
paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path,
local_files_only=args.transformers_local_files_only)
else:
paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
if paths:
paths = paths[0]
args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path,
local_files_only=args.transformers_local_files_only)
args = update_args(args, config)
# Create TensorboardX writer
tb_writer = SummaryWriter(logdir=args.tensorboard_path)
# Create logger
global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
fh = logging.FileHandler(args.logging_path)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
# Get device
if torch.cuda.is_available() and args.n_gpu > 0:
device = torch.device('cuda')
else:
device = torch.device('cpu')
args.n_gpu = 0
if args.n_gpu == 0:
args.fp16 = False
# Initialise Model
transformers.utils.logging.set_verbosity_info()
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config,
local_files_only=args.transformers_local_files_only)
model = model.to(device)
# Create Tokenizer and embedding model for Data Loaders and ontology
encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name,
local_files_only=args.transformers_local_files_only)
tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config,
local_files_only=args.transformers_local_files_only)
# Set up model training/evaluation
training.set_logger(logger, tb_writer)
training.set_seed(args)
embeddings.set_seed(args)
transformers.utils.logging.set_verbosity_error()
if args.ensemble_size > 1:
# Build all dataloaders
train_dataloader = unified_format.get_dataloader(args.dataset,
'train',
args.train_batch_size,
tokenizer,
args.max_dialogue_len,
args.max_turn_len,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
dev_dataloader = unified_format.get_dataloader(args.dataset,
'validation',
args.dev_batch_size,
tokenizer,
args.max_dialogue_len,
args.max_turn_len,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
test_dataloader = unified_format.get_dataloader(args.dataset,
'test',
args.test_batch_size,
tokenizer,
args.max_dialogue_len,
args.max_turn_len,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder)
embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder)
embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder)
setup_ensemble(OUTPUT_DIR, args.ensemble_size)
logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
for _ in tqdm(range(args.ensemble_size))]
logger.info('Dataloaders built.')
for i, loader in enumerate(dataloaders):
path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, 'dataloaders', 'train.dataloader')
torch.save(loader, path)
logger.info('Dataloaders saved.')
# Do not perform standard training after ensemble setup is created
return 0
# Perform tasks
# TRAINING
if args.do_train:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
if train_dataloader.batch_size != args.train_batch_size:
train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size)
else:
if args.data_sampling_size <= 0:
args.data_sampling_size = None
train_dataloader = unified_format.get_dataloader(args.dataset,
'train',
args.train_batch_size,
tokenizer,
args.max_dialogue_len,
config.max_turn_len,
resampled_size=args.data_sampling_size,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db'))
else:
train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology,
'train', args, tokenizer, encoder)
# Get development set batch loaders= and ontology embeddings
if args.do_eval:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if dev_dataloader.batch_size != args.dev_batch_size:
dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
else:
dev_dataloader = unified_format.get_dataloader(args.dataset,
'validation',
args.dev_batch_size,
tokenizer,
args.max_dialogue_len,
config.max_turn_len)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
else:
dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
'dev', args, tokenizer, encoder)
else:
dev_dataloader = None
dev_slots = None
# Load model ontology
training.set_ontology_embeddings(model, train_slots)
# TRAINING !!!!!!!!!!!!!!!!!!
training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots)
# Copy final best model to the output dir
checkpoints = os.listdir(OUTPUT_DIR)
checkpoints = [p for p in checkpoints if 'checkpoint' in p]
checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin'))
copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json'))
# Load best model for evaluation
model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
model = model.to(device)
# Evaluation on the development set
if args.do_eval:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if dev_dataloader.batch_size != args.dev_batch_size:
dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size)
else:
dev_dataloader = unified_format.get_dataloader(args.dataset,
'validation',
args.dev_batch_size,
tokenizer,
args.max_dialogue_len,
config.max_turn_len)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
else:
dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
'dev', args, tokenizer, encoder)
# Load model ontology
training.set_ontology_embeddings(model, dev_slots)
# EVALUATION
jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader)
training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
# Evaluation on the test set
if args.do_test:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if test_dataloader.batch_size != args.test_batch_size:
test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size)
else:
test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
else:
test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
'test', args, tokenizer, encoder)
# Load model ontology
training.set_ontology_embeddings(model, test_slots)
# TESTING
jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader,
return_eval_output=True)
if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w')
json.dump(output, writer)
writer.close()
training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1)
tb_writer.close()
if __name__ == "__main__":
main()
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os
from tqdm import tqdm
from convlab.util import load_dataset
from convlab.util import load_dst_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
def extract_data(dataset_names: str) -> list:
dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')]
data = []
for dataset_dict in dataset_dicts:
dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False)
for dial in dataset['test']:
data.append(dial)
return data
def clean_state(state):
clean_state = dict()
for domain, subset in state.items():
clean_state[domain] = {}
for slot, value in subset.items():
# Remove pipe separated values
value = value.split('|')
# Map values using value_map
for old, new in VALUE_MAP.items():
value = [val.replace(old, new) for val in value]
value = '|'.join(value)
# Map dontcare to "do not care" and empty to 'none'
value = value.replace('dontcare', 'do not care')
value = value if value else 'none'
# Map quantity values to the integer quantity value
if 'people' in slot or 'duration' in slot or 'stay' in slot:
try:
if value not in ['do not care', 'none']:
value = int(value)
value = str(value) if value < 10 else QUANTITIES[-1]
except:
value = value
# Map time values to the most appropriate value in the standard time set
elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
try:
if value not in ['do not care', 'none']:
# Strip after/before from time value
value = value.replace('after ', '').replace('before ', '')
# Extract hours and minutes from different possible formats
if ':' not in value and len(value) == 4:
h, m = value[:2], value[2:]
elif len(value) == 1:
h = int(value)
m = 0
elif 'pm' in value:
h = int(value.replace('pm', '')) + 12
m = 0
elif 'am' in value:
h = int(value.replace('pm', ''))
m = 0
elif ':' in value:
h, m = value.split(':')
elif ';' in value:
h, m = value.split(';')
# Map to closest 5 minutes
if int(m) % 5 != 0:
m = round(int(m) / 5) * 5
h = int(h)
if m == 60:
m = 0
h += 1
if h >= 24:
h -= 24
# Set in standard 24 hour format
h, m = int(h), int(m)
value = '%02i:%02i' % (h, m)
except:
value = value
# Map boolean slots to yes/no value
elif 'parking' in slot or 'internet' in slot:
if value not in ['do not care', 'none']:
if value == 'free':
value = 'yes'
elif True in [v in value.lower() for v in ['yes', 'no']]:
value = [v for v in ['yes', 'no'] if v in value][0]
value = value if value != 'none' else ''
clean_state[domain][slot] = value
return clean_state
def extract_states(data):
states_data = {}
for dial in data:
states = []
for turn in dial['turns']:
if 'state' in turn:
state = clean_state(turn['state'])
states.append(state)
states_data[dial['dialogue_id']] = states
return states_data
def get_golden_state(prediction, data):
state = data[prediction['dial_idx']][prediction['utt_idx']]
pred = prediction['predictions']['state']
pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '')
for slot in state[domain]} for domain in state}
prediction['state'] = state
prediction['predictions']['state'] = pred
return prediction
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
parser.add_argument('--model_path', type=str, help='Path to model dir')
args = parser.parse_args()
data = extract_data(args.dataset_name)
data = extract_states(data)
reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r')
predictions = json.load(reader)
reader.close()
predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)]
writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w')
json.dump(predictions, writer)
writer.close()
from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss
from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Uncertainty evaluation metrics for dialogue belief tracking"""
import torch
def fill_bins(n_bins: int, probs: torch.Tensor) -> list:
"""
Function to split observations into bins based on predictive probabilities
Args:
n_bins (int): Number of bins
probs (Tensor): Predictive probabilities for the observations
Returns:
bins (list): List of observation ids for each bin
"""
assert probs.dim() == 2
probs = probs.max(-1)[0]
step = 1.0 / n_bins
bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
bins = []
for b in range(n_bins):
lower, upper = bin_ranges[b], bin_ranges[b + 1]
if b == 0:
ids = torch.where((probs >= lower) * (probs <= upper))[0]
else:
ids = torch.where((probs > lower) * (probs <= upper))[0]
bins.append(ids)
return bins
def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor:
"""
Compute the confidence score within each bin
Args:
bins (list): List of observation ids for each bin
probs (Tensor): Predictive probabilities for the observations
Returns:
scores (Tensor): Average confidence score within each bin
"""
probs = probs.max(-1)[0]
scores = []
for b in bins:
if b is not None:
scores.append(probs[b].mean())
else:
scores.append(-1)
scores = torch.tensor(scores)
return scores
def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Compute the accuracy score for observations in each bin
Args:
bins (list): List of observation ids for each bin
probs (Tensor): Predictive probabilities for the observations
y_true (Tensor): Labels for the observations
Returns:
acc (Tensor): Accuracies for the observations in each bin
"""
y_pred = probs.argmax(-1)
acc = []
for b in bins:
if b is not None:
p = y_pred[b]
acc_ = (p == y_true[b]).float()
acc_ = acc_[y_true[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
return acc
def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float:
"""
Expected calibration error calculation
Args:
probs (Tensor): Predictive probabilities for the observations
y_true (Tensor): Labels for the observations
n_bins (int): Number of bins
Returns:
ece (float): Expected calibration error
"""
bins = fill_bins(n_bins, probs)
scores = bin_confidence(bins, probs)
acc = bin_accuracy(bins, probs, y_true)
n = probs.size(0)
bk = torch.tensor([b.size(0) for b in bins])
ece = torch.abs(scores - acc) * bk / n
ece = ece[acc >= 0.0]
ece = ece.sum().item()
return ece
def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float:
"""
Joint goal expected calibration error calculation
Args:
belief_state (dict): Belief state probabilities for the dialogue turns
y_true (dict): Labels for the state in dialogue turns
n_bins (int): Number of bins
Returns:
ece (float): Joint goal expected calibration error
"""
y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()}
goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
goal_acc = sum([goal_acc[slot] for slot in goal_acc])
goal_acc = (goal_acc == len(y_true)).int()
# Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state
scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()]
scores = torch.cat(scores, 0).min(0)[0]
bins = fill_bins(n_bins, scores.unsqueeze(-1))
conf = bin_confidence(bins, scores.unsqueeze(-1))
slot = [s for s in y_true][0]
acc = []
for b in bins:
if b is not None:
acc_ = goal_acc[b]
acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.float().mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0)
bk = torch.tensor([b.size(0) for b in bins])
ece = torch.abs(conf - acc) * bk / n
ece = ece[acc >= 0.0]
ece = ece.sum().item()
return ece
def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float:
"""
Compute L2 Error of belief state prediction
Args:
belief_state (dict): Belief state probabilities for the dialogue turns
labels (dict): Labels for the state in dialogue turns
remove_belief (bool): Convert belief state to dialogue state
Returns:
err (float): L2 Error of belief state prediction
"""
# Get ids used for removing padding turns.
padding = labels[list(labels.keys())[0]].reshape(-1)
padding = torch.where(padding != -1)[0]
state = []
labs = []
for slot, bs in belief_state.items():
# Predictive Distribution
bs = bs.reshape(-1, bs.size(-1)).cuda()
# Replace distribution by a 1 hot prediction
if remove_belief:
bs_ = torch.zeros(bs.shape).float().cuda()
bs_[range(bs.size(0)), bs.argmax(-1)] = 1.0
bs = bs_
del bs_
# Remove padding turns
lab = labels[slot].reshape(-1).cuda()
bs = bs[padding]
lab = lab[padding]
# Target distribution
y = torch.zeros(bs.shape).cuda()
y[range(y.size(0)), lab] = 1.0
state.append(bs)
labs.append(y)
# Concatenate all slots into a single belief state
state = torch.cat(state, -1)
labs = torch.cat(labs, -1)
# Calculate L2-Error for each turn
err = torch.sqrt(((labs - state) ** 2).sum(-1))
return err.mean()
from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from transformers import BertConfig, RobertaConfig
from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT
from convlab.dst.setsumbt.modeling.setsumbt_nbt import BertSetSUMBT, RobertaSetSUMBT, EnsembleSetSUMBT
from convlab.dst.setsumbt.modeling.ontology_encoder import OntologyEncoder
from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler
from convlab.dst.setsumbt.modeling.trainer import SetSUMBTTrainer
from convlab.dst.setsumbt.modeling.tokenization import SetSUMBTTokenizer
class BertSetSUMBTTokenizer(SetSUMBTTokenizer('bert')): pass
class RobertaSetSUMBTTokenizer(SetSUMBTTokenizer('roberta')): pass
SetSUMBTModels = {
'bert': (BertSetSUMBT, OntologyEncoder('bert'), BertConfig, BertSetSUMBTTokenizer),
'roberta': (RobertaSetSUMBT, OntologyEncoder('roberta'), RobertaConfig, RobertaSetSUMBTTokenizer),
'ensemble': (EnsembleSetSUMBT, None, None, None)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment