Skip to content
Snippets Groups Projects
Unverified Commit 350f36c4 authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Setsumbt updates

parent ffb3dc42
No related branches found
No related tags found
No related merge requests found
Showing
with 2158 additions and 1513 deletions
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT SetSUMBT"""
import torch
from torch.autograd import Variable
from transformers import BertModel, BertPreTrainedModel
from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
class BertSetSUMBT(BertPreTrainedModel):
def __init__(self, config):
super(BertSetSUMBT, self).__init__(config)
self.config = config
# Turn Encoder
self.bert = BertModel(config)
if config.freeze_encoder:
for p in self.bert.parameters():
p.requires_grad = False
self.setsumbt = SetSUMBTHead(config)
self.add_slot_candidates = self.setsumbt.add_slot_candidates
self.add_value_candidates = self.setsumbt.add_value_candidates
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
hidden_state: torch.Tensor = None,
state_labels: torch.Tensor = None,
request_labels: torch.Tensor = None,
active_domain_labels: torch.Tensor = None,
general_act_labels: torch.Tensor = None,
get_turn_pooled_representation: bool = False,
calculate_state_mutual_info: bool = False):
"""
Args:
input_ids: Input token ids
attention_mask: Input padding mask
token_type_ids: Token type indicator
hidden_state: Latent internal dialogue belief state
state_labels: Dialogue state labels
request_labels: User request action labels
active_domain_labels: Current active domain labels
general_act_labels: General user action labels
get_turn_pooled_representation: Return pooled representation of the current dialogue turn
calculate_state_mutual_info: Return mutual information in the dialogue state
Returns:
out: Tuple containing loss, predictive distributions, model statistics and state mutual information
"""
# Encode Dialogues
batch_size, dialogue_size, turn_size = input_ids.size()
input_ids = input_ids.reshape(-1, turn_size)
token_type_ids = token_type_ids.reshape(-1, turn_size)
attention_mask = attention_mask.reshape(-1, turn_size)
bert_output = self.bert(input_ids, token_type_ids, attention_mask)
attention_mask = attention_mask.float().unsqueeze(2)
attention_mask = attention_mask.repeat((1, 1, bert_output.last_hidden_state.size(-1)))
turn_embeddings = bert_output.last_hidden_state * attention_mask
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
if get_turn_pooled_representation:
return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
batch_size, dialogue_size, hidden_state, state_labels,
request_labels, active_domain_labels, general_act_labels,
calculate_state_mutual_info) + (bert_output.pooler_output,)
return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
general_act_labels, calculate_state_mutual_info)
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discriminative models calibration"""
import random
import os
import torch
import numpy as np
from torch.distributions import Categorical
from torch.nn.functional import kl_div
from torch.nn import Module
from tqdm import tqdm
# Load logger and tensorboard summary writer
def set_logger(logger_, tb_writer_):
global logger, tb_writer
logger = logger_
tb_writer = tb_writer_
# Set seeds
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
logger.info('Seed set to %d.' % args.seed)
def build_train_loaders(args, tokenizer, dataset):
dataloaders = [dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
args.max_turn_len, resampled_size=args.data_sampling_size)
for _ in range(args.ensemble_size)]
return dataloaders
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation Utilities"""
import random
import torch
import numpy as np
from tqdm import tqdm
def set_seed(args):
"""
Set random seeds
Args:
args (Arguments class): Arguments class containing seed and number of gpus to use
"""
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple:
"""
Get model predictions
Args:
args: Runtime arguments
model: SetSUMBT Model
device: Torch device
dataloader: Dataloader containing eval data
"""
model.eval()
belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids}
request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids}
general_act_probs = []
state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids}
request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids}
general_act_labels = []
epoch_iterator = tqdm(dataloader, desc="Iteration")
for step, batch in enumerate(epoch_iterator):
with torch.no_grad():
input_ids = batch['input_ids'].to(device)
token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask)
for slot in belief_states:
p_ = p[slot]
labs = batch['state_labels-' + slot].to(device)
belief_states[slot].append(p_)
state_labels[slot].append(labs)
if p_req is not None:
for slot in request_probs:
p_ = p_req[slot]
labs = batch['request_labels-' + slot].to(device)
request_probs[slot].append(p_)
request_labels[slot].append(labs)
for domain in active_domain_probs:
p_ = p_dom[domain]
labs = batch['active_domain_labels-' + domain].to(device)
active_domain_probs[domain].append(p_)
active_domain_labels[domain].append(labs)
general_act_probs.append(p_gen)
general_act_labels.append(batch['general_act_labels'].to(device))
for slot in belief_states:
belief_states[slot] = torch.cat(belief_states[slot], 0)
state_labels[slot] = torch.cat(state_labels[slot], 0)
if p_req is not None:
for slot in request_probs:
request_probs[slot] = torch.cat(request_probs[slot], 0)
request_labels[slot] = torch.cat(request_labels[slot], 0)
for domain in active_domain_probs:
active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0)
active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0)
general_act_probs = torch.cat(general_act_probs, 0)
general_act_labels = torch.cat(general_act_labels, 0)
else:
request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4
general_act_probs, general_act_labels = [None] * 2
out = (belief_states, state_labels, request_probs, request_labels)
out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels)
return out
# -*- coding: utf-8 -*-
# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions for SetSUMBT"""
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from convlab.dst.setsumbt.modeling.loss.bayesian_matching import (BayesianMatchingLoss,
BinaryBayesianMatchingLoss)
from convlab.dst.setsumbt.modeling.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss
from convlab.dst.setsumbt.modeling.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss
from convlab.dst.setsumbt.modeling.loss.endd_loss import (RKLDirichletMediatorLoss,
BinaryRKLDirichletMediatorLoss)
LOSS_MAP = {
'crossentropy': {'non-binary': CrossEntropyLoss,
'binary': BCEWithLogitsLoss,
'args': list()},
'bayesianmatching': {'non-binary': BayesianMatchingLoss,
'binary': BinaryBayesianMatchingLoss,
'args': ['kl_scaling_factor']},
'labelsmoothing': {'non-binary': LabelSmoothingLoss,
'binary': BinaryLabelSmoothingLoss,
'args': ['label_smoothing']},
'distillation': {'non-binary': KLDistillationLoss,
'binary': BinaryKLDistillationLoss,
'args': ['ensemble_smoothing']},
'distribution_distillation': {'non-binary': RKLDirichletMediatorLoss,
'binary': BinaryRKLDirichletMediatorLoss,
'args': []}
}
def load(loss_function, binary=False):
"""
Load loss function
Args:
loss_function (str): Loss function name
binary (bool): Whether to use binary loss function
Returns:
torch.nn.Module: Loss function
"""
assert loss_function in LOSS_MAP
args_list = LOSS_MAP[loss_function]['args']
loss_function = LOSS_MAP[loss_function]['binary' if binary else 'non-binary']
def __init__(ignore_index=-1, **kwargs):
args = {'ignore_index': ignore_index} if loss_function != BCEWithLogitsLoss else dict()
for arg, val in kwargs.items():
if arg in args_list:
args[arg] = val
return loss_function(**args)
return __init__
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -23,15 +23,15 @@ from torch.nn import Module ...@@ -23,15 +23,15 @@ from torch.nn import Module
class BayesianMatchingLoss(Module): class BayesianMatchingLoss(Module):
"""Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: def __init__(self, kl_scaling_factor: float = 0.001, ignore_index: int = -1) -> Module:
""" """
Args: Args:
lamb (float): Weighting factor for the KL Divergence component kl_scaling_factor (float): Weighting factor for the KL Divergence component
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
""" """
super(BayesianMatchingLoss, self).__init__() super(BayesianMatchingLoss, self).__init__()
self.lamb = lamb self.lamb = kl_scaling_factor
self.ignore_index = ignore_index self.ignore_index = ignore_index
def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
...@@ -88,13 +88,13 @@ class BayesianMatchingLoss(Module): ...@@ -88,13 +88,13 @@ class BayesianMatchingLoss(Module):
class BinaryBayesianMatchingLoss(BayesianMatchingLoss): class BinaryBayesianMatchingLoss(BayesianMatchingLoss):
"""Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation"""
def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: def __init__(self, kl_scaling_factor: float = 0.001, ignore_index: int = -1) -> Module:
""" """
Args: Args:
lamb (float): Weighting factor for the KL Divergence component kl_scaling_factor (float): Weighting factor for the KL Divergence component
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
""" """
super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index) super(BinaryBayesianMatchingLoss, self).__init__(kl_scaling_factor, ignore_index)
def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor:
""" """
......
# -*- coding: utf-8 -*-
# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ensemble Distribution Distillation Loss Function (see https://arxiv.org/pdf/2105.06987.pdf for details)"""
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.functional import kl_div from torch.nn.functional import kl_div
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""KL Divergence Ensemble Distillation loss""" """KL Divergence Ensemble Distillation loss (See https://arxiv.org/pdf/1503.02531.pdf for details)"""
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -23,7 +23,7 @@ from torch.nn.functional import kl_div ...@@ -23,7 +23,7 @@ from torch.nn.functional import kl_div
class KLDistillationLoss(Module): class KLDistillationLoss(Module):
"""Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: def __init__(self, ensemble_smoothing: float = 1e-4, ignore_index: int = -1) -> Module:
""" """
Args: Args:
lamb (float): Target smoothing parameter lamb (float): Target smoothing parameter
...@@ -31,7 +31,7 @@ class KLDistillationLoss(Module): ...@@ -31,7 +31,7 @@ class KLDistillationLoss(Module):
""" """
super(KLDistillationLoss, self).__init__() super(KLDistillationLoss, self).__init__()
self.lamb = lamb self.lamb = ensemble_smoothing
self.ignore_index = ignore_index self.ignore_index = ignore_index
def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
...@@ -71,13 +71,13 @@ class KLDistillationLoss(Module): ...@@ -71,13 +71,13 @@ class KLDistillationLoss(Module):
class BinaryKLDistillationLoss(KLDistillationLoss): class BinaryKLDistillationLoss(KLDistillationLoss):
"""Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation"""
def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: def __init__(self, ensemble_smoothing: float = 1e-4, ignore_index: int = -1) -> Module:
""" """
Args: Args:
lamb (float): Target smoothing parameter lamb (float): Target smoothing parameter
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
""" """
super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index) super(BinaryKLDistillationLoss, self).__init__(ensemble_smoothing, ignore_index)
def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
""" """
...@@ -101,4 +101,4 @@ class BinaryKLDistillationLoss(KLDistillationLoss): ...@@ -101,4 +101,4 @@ class BinaryKLDistillationLoss(KLDistillationLoss):
targets = targets.unsqueeze(-1) targets = targets.unsqueeze(-1)
targets = torch.cat((1 - targets, targets), -1) targets = torch.cat((1 - targets, targets), -1)
return super().forward(input, targets, temp) return super().forward(inputs, targets, temp)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
from torch.nn import Softmax, Module, CrossEntropyLoss from torch.nn import Module
from torch.nn.functional import kl_div from torch.nn.functional import kl_div
......
# -*- coding: utf-8 -*-
# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ontology Encoder Model"""
import random
from copy import deepcopy
import torch
from transformers import RobertaModel, BertModel
import numpy as np
from tqdm import tqdm
PARENT_CLASSES = {'bert': BertModel,
'roberta': RobertaModel}
def OntologyEncoder(parent_name: str):
"""
Return the Ontology Encoder model based on the parent transformer model.
Args:
parent_name (str): Name of the parent transformer model
Returns:
OntologyEncoder (class): Ontology Encoder model
"""
parent_class = PARENT_CLASSES.get(parent_name.lower())
class OntologyEncoder(parent_class):
"""Ontology Encoder model based on parent transformer model"""
def __init__(self, config, args, tokenizer):
"""
Initialize Ontology Encoder model.
Args:
config (transformers.configuration_utils.PretrainedConfig): Configuration of the transformer model
args (argparse.Namespace): Arguments
tokenizer (transformers.tokenization_utils_base.PreTrainedTokenizer): Tokenizer
Returns:
OntologyEncoder (class): Ontology Encoder model
"""
super().__init__(config)
# Set random seeds
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
self.args = args
self.config = config
self.tokenizer = tokenizer
def _encode_candidates(self, candidates: list) -> torch.tensor:
"""
Embed candidates
Args:
candidates (list): List of candidate descriptions
Returns:
feats (torch.tensor): Embeddings of the candidate descriptions
"""
# Tokenize candidate descriptions
feats = [self.tokenizer.encode_plus(val, add_special_tokens=True, max_length=self.args.max_candidate_len,
padding='max_length', truncation='longest_first')
for val in candidates]
# Encode tokenized descriptions
with torch.no_grad():
feats = {key: torch.tensor([f[key] for f in feats]).to(self.device) for key in feats[0]}
embedded_feats = self(**feats) # [num_candidates, max_candidate_len, hidden_dim]
# Reduce/pool descriptions embeddings if required
if self.args.set_similarity:
feats = embedded_feats.last_hidden_state.detach().cpu() #[num_candidates, max_candidate_len, hidden_dim]
elif self.args.candidate_pooling == 'cls':
feats = embedded_feats.pooler_output.detach().cpu() # [num_candidates, hidden_dim]
elif self.args.candidate_pooling == "mean":
feats = embedded_feats.last_hidden_state.detach().cpu()
feats = feats.sum(1)
feats = torch.nn.functional.layer_norm(feats, feats.size())
feats = feats.detach().cpu() # [num_candidates, hidden_dim]
return feats
def get_slot_candidate_embeddings(self):
"""
Get embeddings for slots and candidates
Args:
set_type (str): Subset of the dataset being used (train/validation/test)
save_to_file (bool): Indication of whether to save information to file
Returns:
slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
"""
# Set model to eval mode
self.eval()
slots = dict()
for domain, subset in tqdm(self.tokenizer.ontology.items(), desc='Domains'):
for slot, slot_info in tqdm(subset.items(), desc='Slots'):
# Get description or use "domain-slot"
if self.args.use_descriptions:
desc = slot_info['description']
else:
desc = f"{domain}-{slot}"
# Encode domain-slot pair description
slot_emb = self._encode_candidates([desc])[0]
# Obtain possible value set and discard requestable value
values = deepcopy(slot_info['possible_values'])
is_requestable = False
if '?' in values:
is_requestable = True
values.remove('?')
# Encode value candidates
if values:
feats = self._encode_candidates(values)
else:
feats = None
# Store domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot
slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable)
return slots
return OntologyEncoder
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RoBERTa SetSUMBT"""
import torch
from transformers import RobertaModel, RobertaPreTrainedModel
from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead
class RobertaSetSUMBT(RobertaPreTrainedModel):
"""Roberta based SetSUMBT model"""
def __init__(self, config):
"""
Args:
config (configuration): Model configuration class
"""
super(RobertaSetSUMBT, self).__init__(config)
self.config = config
# Turn Encoder
self.roberta = RobertaModel(config)
if config.freeze_encoder:
for p in self.roberta.parameters():
p.requires_grad = False
self.setsumbt = SetSUMBTHead(config)
self.add_slot_candidates = self.setsumbt.add_slot_candidates
self.add_value_candidates = self.setsumbt.add_value_candidates
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
hidden_state: torch.Tensor = None,
state_labels: torch.Tensor = None,
request_labels: torch.Tensor = None,
active_domain_labels: torch.Tensor = None,
general_act_labels: torch.Tensor = None,
get_turn_pooled_representation: bool = False,
calculate_state_mutual_info: bool = False):
"""
Args:
input_ids: Input token ids
attention_mask: Input padding mask
token_type_ids: Token type indicator
hidden_state: Latent internal dialogue belief state
state_labels: Dialogue state labels
request_labels: User request action labels
active_domain_labels: Current active domain labels
general_act_labels: General user action labels
get_turn_pooled_representation: Return pooled representation of the current dialogue turn
calculate_state_mutual_info: Return mutual information in the dialogue state
Returns:
out: Tuple containing loss, predictive distributions, model statistics and state mutual information
"""
if token_type_ids is not None:
token_type_ids = None
# Encode Dialogues
batch_size, dialogue_size, turn_size = input_ids.size()
input_ids = input_ids.reshape(-1, turn_size)
attention_mask = attention_mask.reshape(-1, turn_size)
roberta_output = self.roberta(input_ids, attention_mask)
# Apply mask and reshape the dialogue turn token embeddings
attention_mask = attention_mask.float().unsqueeze(2)
attention_mask = attention_mask.repeat((1, 1, roberta_output.last_hidden_state.size(-1)))
turn_embeddings = roberta_output.last_hidden_state * attention_mask
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
if get_turn_pooled_representation:
return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
batch_size, dialogue_size, hidden_state, state_labels,
request_labels, active_domain_labels, general_act_labels,
calculate_state_mutual_info) + (roberta_output.pooler_output,)
return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
general_act_labels, calculate_state_mutual_info)
This diff is collapsed.
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,21 +13,162 @@ ...@@ -13,21 +13,162 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Ensemble SetSUMBT""" """SetSUMBT Models"""
import os import os
from shutil import copy2 as copy from copy import deepcopy
import torch import torch
from torch.nn import Module from torch.nn import Module
from transformers import RobertaConfig, BertConfig from transformers import (BertModel, BertPreTrainedModel, BertConfig,
RobertaModel, RobertaPreTrainedModel, RobertaConfig)
from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead, SetSUMBTOutput
from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
class BertSetSUMBT(BertPreTrainedModel):
"""Bert based SetSUMBT model"""
def __init__(self, config):
"""
Args:
config (configuration): Model configuration class
"""
super(BertSetSUMBT, self).__init__(config)
self.config = config
# Turn Encoder
self.bert = BertModel(config)
if config.freeze_encoder:
for p in self.bert.parameters():
p.requires_grad = False
self.setsumbt = SetSUMBTHead(config)
self.add_slot_candidates = self.setsumbt.add_slot_candidates
self.add_value_candidates = self.setsumbt.add_value_candidates
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
hidden_state: torch.Tensor = None,
state_labels: torch.Tensor = None,
request_labels: torch.Tensor = None,
active_domain_labels: torch.Tensor = None,
general_act_labels: torch.Tensor = None,
get_turn_pooled_representation: bool = False,
calculate_state_mutual_info: bool = False):
"""
Args:
input_ids: Input token ids
attention_mask: Input padding mask
token_type_ids: Token type indicator
hidden_state: Latent internal dialogue belief state
state_labels: Dialogue state labels
request_labels: User request action labels
active_domain_labels: Current active domain labels
general_act_labels: General user action labels
get_turn_pooled_representation: Return pooled representation of the current dialogue turn
calculate_state_mutual_info: Return mutual information in the dialogue state
Returns:
out: Tuple containing loss, predictive distributions, model statistics and state mutual information
"""
# Encode Dialogues
batch_size, dialogue_size, turn_size = input_ids.size()
input_ids = input_ids.reshape(-1, turn_size)
token_type_ids = token_type_ids.reshape(-1, turn_size)
attention_mask = attention_mask.reshape(-1, turn_size)
bert_output = self.bert(input_ids, token_type_ids, attention_mask)
attention_mask = attention_mask.float().unsqueeze(2)
attention_mask = attention_mask.repeat((1, 1, bert_output.last_hidden_state.size(-1)))
turn_embeddings = bert_output.last_hidden_state * attention_mask
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
output = self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
batch_size, dialogue_size, hidden_state, state_labels,
request_labels, active_domain_labels, general_act_labels,
calculate_state_mutual_info)
output.turn_pooled_representation = bert_output.pooler_output if get_turn_pooled_representation else None
return output
class RobertaSetSUMBT(RobertaPreTrainedModel):
"""Roberta based SetSUMBT model"""
def __init__(self, config):
"""
Args:
config (configuration): Model configuration class
"""
super(RobertaSetSUMBT, self).__init__(config)
self.config = config
# Turn Encoder
self.roberta = RobertaModel(config)
if config.freeze_encoder:
for p in self.roberta.parameters():
p.requires_grad = False
self.setsumbt = SetSUMBTHead(config)
self.add_slot_candidates = self.setsumbt.add_slot_candidates
self.add_value_candidates = self.setsumbt.add_value_candidates
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
hidden_state: torch.Tensor = None,
state_labels: torch.Tensor = None,
request_labels: torch.Tensor = None,
active_domain_labels: torch.Tensor = None,
general_act_labels: torch.Tensor = None,
get_turn_pooled_representation: bool = False,
calculate_state_mutual_info: bool = False):
"""
Args:
input_ids: Input token ids
attention_mask: Input padding mask
token_type_ids: Token type indicator
hidden_state: Latent internal dialogue belief state
state_labels: Dialogue state labels
request_labels: User request action labels
active_domain_labels: Current active domain labels
general_act_labels: General user action labels
get_turn_pooled_representation: Return pooled representation of the current dialogue turn
calculate_state_mutual_info: Return mutual information in the dialogue state
Returns:
out: Tuple containing loss, predictive distributions, model statistics and state mutual information
"""
if token_type_ids is not None:
token_type_ids = None
# Encode Dialogues
batch_size, dialogue_size, turn_size = input_ids.size()
input_ids = input_ids.reshape(-1, turn_size)
attention_mask = attention_mask.reshape(-1, turn_size)
roberta_output = self.roberta(input_ids, attention_mask)
# Apply mask and reshape the dialogue turn token embeddings
attention_mask = attention_mask.float().unsqueeze(2)
attention_mask = attention_mask.repeat((1, 1, roberta_output.last_hidden_state.size(-1)))
turn_embeddings = roberta_output.last_hidden_state * attention_mask
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
output = self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
batch_size, dialogue_size, hidden_state, state_labels,
request_labels, active_domain_labels, general_act_labels,
calculate_state_mutual_info)
output.turn_pooled_representation = roberta_output.pooler_output if get_turn_pooled_representation else None
return output
MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT}
class EnsembleSetSUMBT(Module): class EnsembleSetSUMBT(Module):
"""Ensemble SetSUMBT Model for joint ensemble prediction""" """Ensemble SetSUMBT Model for joint ensemble prediction"""
...@@ -42,7 +183,18 @@ class EnsembleSetSUMBT(Module): ...@@ -42,7 +183,18 @@ class EnsembleSetSUMBT(Module):
# Initialise ensemble members # Initialise ensemble members
model_cls = MODELS[self.config.model_type] model_cls = MODELS[self.config.model_type]
for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
setattr(self, attr, model_cls(config)) setattr(self, attr, model_cls(self.get_clean_config(config)))
@staticmethod
def get_clean_config(config):
config = deepcopy(config)
config.slot_ids = dict()
config.requestable_slot_ids = dict()
config.informable_slot_ids = dict()
config.domain_ids = dict()
config.num_values = dict()
return config
def _load(self, path: str): def _load(self, path: str):
""" """
...@@ -52,7 +204,8 @@ class EnsembleSetSUMBT(Module): ...@@ -52,7 +204,8 @@ class EnsembleSetSUMBT(Module):
""" """
for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
idx = attr.split('_', 1)[-1] idx = attr.split('_', 1)[-1]
state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin')) state_dict = torch.load(os.path.join(self._get_checkpoint_path(path, idx), 'pytorch_model.bin'))
state_dict = {key: itm for key, itm in state_dict.items() if '_value_embeddings' not in key}
getattr(self, attr).load_state_dict(state_dict) getattr(self, attr).load_state_dict(state_dict)
def add_slot_candidates(self, slot_candidates: tuple): def add_slot_candidates(self, slot_candidates: tuple):
...@@ -66,9 +219,7 @@ class EnsembleSetSUMBT(Module): ...@@ -66,9 +219,7 @@ class EnsembleSetSUMBT(Module):
""" """
for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
getattr(self, attr).add_slot_candidates(slot_candidates) getattr(self, attr).add_slot_candidates(slot_candidates)
self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids self.setsumbt = self.model_0.setsumbt
self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids
self.domain_ids = self.model_0.setsumbt.domain_ids
def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
""" """
...@@ -86,7 +237,8 @@ class EnsembleSetSUMBT(Module): ...@@ -86,7 +237,8 @@ class EnsembleSetSUMBT(Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None, token_type_ids: torch.Tensor = None,
reduction: str = 'mean') -> tuple: reduction: str = 'mean',
**kwargs) -> tuple:
""" """
Args: Args:
input_ids: Input token ids input_ids: Input token ids
...@@ -97,23 +249,28 @@ class EnsembleSetSUMBT(Module): ...@@ -97,23 +249,28 @@ class EnsembleSetSUMBT(Module):
Returns: Returns:
""" """
belief_state_probs = {slot: [] for slot in self.informable_slot_ids} belief_state_probs = {slot: [] for slot in self.setsumbt.config.informable_slot_ids}
request_probs = {slot: [] for slot in self.requestable_slot_ids} request_probs = {slot: [] for slot in self.setsumbt.config.requestable_slot_ids}
active_domain_probs = {dom: [] for dom in self.domain_ids} active_domain_probs = {dom: [] for dom in self.setsumbt.config.domain_ids}
general_act_probs = [] general_act_probs = []
loss = 0.0 if 'state_labels' in kwargs else None
for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
# Prediction from each ensemble member # Prediction from each ensemble member
b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids, with torch.no_grad():
_out = getattr(self, attr)(input_ids=input_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask) attention_mask=attention_mask,
**kwargs)
if loss is not None:
loss += _out.loss
for slot in belief_state_probs: for slot in belief_state_probs:
belief_state_probs[slot].append(b[slot].unsqueeze(-2)) belief_state_probs[slot].append(_out.belief_state[slot].unsqueeze(-2).detach().cpu())
if self.config.predict_actions: if self.config.predict_actions:
for slot in request_probs: for slot in request_probs:
request_probs[slot].append(r[slot].unsqueeze(-1)) request_probs[slot].append(_out.request_probabilities[slot].unsqueeze(-1).detach().cpu())
for dom in active_domain_probs: for dom in active_domain_probs:
active_domain_probs[dom].append(d[dom].unsqueeze(-1)) active_domain_probs[dom].append(_out.active_domain_probabilities[dom].unsqueeze(-1).detach().cpu())
general_act_probs.append(g.unsqueeze(-2)) general_act_probs.append(_out.general_act_probabilities.unsqueeze(-2).detach().cpu())
belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()} belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
if self.config.predict_actions: if self.config.predict_actions:
...@@ -134,15 +291,41 @@ class EnsembleSetSUMBT(Module): ...@@ -134,15 +291,41 @@ class EnsembleSetSUMBT(Module):
elif reduction != 'none': elif reduction != 'none':
raise (NameError('Not Implemented!')) raise (NameError('Not Implemented!'))
return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _ if loss is not None:
loss /= self.config.ensemble_size
output = SetSUMBTOutput(loss=loss,
belief_state=belief_state_probs,
request_probabilities=request_probs,
active_domain_probabilities=active_domain_probs,
general_act_probabilities=general_act_probs)
return output
@staticmethod
def _get_checkpoint_path(path: str, idx: int):
"""
Get checkpoint path for ensemble member
Args:
path: Location of ensemble
idx: Ensemble member index
Returns:
Checkpoint path
"""
checkpoints = os.listdir(os.path.join(path, f'ens-{idx}'))
checkpoints = [int(p.split('-', 1)[-1]) for p in checkpoints if 'checkpoint-' in p]
checkpoint = f"checkpoint-{max(checkpoints)}"
return os.path.join(path, f'ens-{idx}', checkpoint)
@classmethod @classmethod
def from_pretrained(cls, path): def from_pretrained(cls, path, config=None):
config_path = os.path.join(path, 'ens-0', 'config.json') config_path = os.path.join(cls._get_checkpoint_path(path, 0), 'config.json')
if not os.path.exists(config_path): if not os.path.exists(config_path):
raise (NameError('Could not find config.json in model path.')) raise (NameError('Could not find config.json in model path.'))
if config is None:
try: try:
config = RobertaConfig.from_pretrained(config_path) config = RobertaConfig.from_pretrained(config_path)
except: except:
...@@ -154,27 +337,3 @@ class EnsembleSetSUMBT(Module): ...@@ -154,27 +337,3 @@ class EnsembleSetSUMBT(Module):
model._load(path) model._load(path)
return model return model
def setup_ensemble(model_path: str, ensemble_size: int):
"""
Setup ensemble model directory structure.
Args:
model_path: Path to ensemble model directory
ensemble_size: Number of ensemble members
"""
for i in range(ensemble_size):
path = os.path.join(model_path, f'ens-{i}')
if not os.path.exists(path):
os.mkdir(path)
os.mkdir(os.path.join(path, 'dataloaders'))
os.mkdir(os.path.join(path, 'database'))
# Add development set dataloader to each ensemble member directory
for set_type in ['dev']:
copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'),
os.path.join(path, 'dataloaders', f'{set_type}.dataloader'))
# Add training and development set ontologies to each ensemble member directory
for set_type in ['train', 'dev']:
copy(os.path.join(model_path, 'database', f'{set_type}.db'),
os.path.join(path, 'database', f'{set_type}.db'))
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""Linear Temperature Scheduler Class""" """Linear Temperature Scheduler Class"""
# Temp scheduler class for ensemble distillation
class LinearTemperatureScheduler: class LinearTemperatureScheduler:
""" """
Temperature scheduler object used for distribution temperature scheduling in distillation Temperature scheduler object used for distribution temperature scheduling in distillation
......
# -*- coding: utf-8 -*-
# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SetSUMBT Tokenizer"""
import json
import os
import torch
from transformers import RobertaTokenizer, BertTokenizer
from tqdm import tqdm
from convlab.dst.setsumbt.datasets.utils import IdTensor
PARENT_CLASSES = {'bert': BertTokenizer,
'roberta': RobertaTokenizer}
def SetSUMBTTokenizer(parent_name):
"""SetSUMBT Tokenizer Class Factory"""
parent_class = PARENT_CLASSES.get(parent_name.lower())
class SetSUMBTTokenizer(parent_class):
"""SetSUMBT Tokenizer Class"""
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
**kwargs,
):
"""
Initialize the tokenizer.
Args:
vocab_file (str): Path to the vocabulary file.
merges_file (str): Path to the merges file.
errors (str): Error handling for the tokenizer.
bos_token (str): Beginning of sentence token.
eos_token (str): End of sentence token.
sep_token (str): Separator token.
cls_token (str): Classification token.
unk_token (str): Unknown token.
pad_token (str): Padding token.
mask_token (str): Masking token.
add_prefix_space (bool): Whether to add a space before the first token.
**kwargs: Additional arguments for the tokenizer.
"""
# Load ontology and tokenizer vocab
with open(vocab_file, 'r', encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
vocab_handle.close()
self.ontology = self.encoder['SETSUMBT_ONTOLOGY'] if 'SETSUMBT_ONTOLOGY' in self.encoder else dict()
self.encoder = {k: v for k, v in self.encoder.items() if 'SETSUMBT_ONTOLOGY' not in k}
vocab_dir = os.path.dirname(vocab_file)
vocab_file = os.path.basename(vocab_file).split('.')
vocab_file = vocab_file[0] + "_base." + vocab_file[-1]
vocab_file = os.path.join(vocab_dir, vocab_file)
with open(vocab_file, 'w', encoding="utf-8") as vocab_handle:
json.dump(self.encoder, vocab_handle)
vocab_handle.close()
super().__init__(vocab_file, merges_file, errors, bos_token, eos_token, sep_token, cls_token, unk_token,
pad_token, mask_token, add_prefix_space, **kwargs)
def set_setsumbt_ontology(self, ontology):
"""
Set the ontology for the tokenizer.
Args:
ontology (dict): The dialogue system ontology to use.
"""
self.ontology = ontology
def save_vocabulary(self, save_directory: str, filename_prefix: str = None) -> tuple:
"""
Save the tokenizer vocabulary and merges files to a directory.
Args:
save_directory (str): Directory to which to save.
filename_prefix (str): Optional prefix to add to the files.
Returns:
vocab_file (str): Path to the saved vocabulary file.
merge_file (str): Path to the saved merges file.
"""
self.encoder['SETSUMBT_ONTOLOGY'] = self.ontology
vocab_file, merge_file = super().save_vocabulary(save_directory, filename_prefix)
self.encoder = {k: v for k, v in self.encoder.items() if 'SETSUMBT_ONTOLOGY' not in k}
return vocab_file, merge_file
def decode_state(self, belief_state, request_probs=None, active_domain_probs=None, general_act_probs=None):
"""
Decode a belief state, request, active domain and general action distributions into a dialogue state.
Args:
belief_state (dict): The belief state distributions.
request_probs (dict): The request distributions.
active_domain_probs (dict): The active domain distributions.
general_act_probs (dict): The general action distributions.
Returns:
dialogue_state (dict): The decoded dialogue state.
"""
dialogue_state = {domain: {slot: '' for slot, slot_info in domain_info.items()
if slot_info['possible_values'] != ["?"] and slot_info['possible_values']}
for domain, domain_info in self.ontology.items()}
for slot, probs in belief_state.items():
dom, slot = slot.split('-', 1)
val = self.ontology.get(dom, dict()).get(slot, dict()).get('possible_values', [])
val = val[probs.argmax().item()] if val else 'none'
if val != 'none':
if dom in dialogue_state:
if slot in dialogue_state[dom]:
dialogue_state[dom][slot] = val
request_acts = list()
if request_probs is not None:
request_acts = [slot for slot, p in request_probs.items() if p.item() > 0.5]
request_acts = [slot.split('-', 1) for slot in request_acts]
request_acts = [[dom, slt] for dom, slt in request_acts
if '?' in self.ontology.get(dom, dict()).get(slt, dict()).get('possible_values', [])]
request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
# Construct active domain set
active_domains = dict()
if active_domain_probs is not None:
active_domains = {dom: active_domain_probs.get(dom, torch.tensor(0.0)).item() > 0.5
for dom in self.ontology}
# Construct general domain action
general_acts = list()
if general_act_probs is not None:
general_acts = general_act_probs.argmax(-1).item()
general_acts = [[], ['bye'], ['thank']][general_acts]
general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
user_acts = request_acts + general_acts
dialogue_state = {'belief_state': dialogue_state,
'user_action': user_acts,
'active_domains': active_domains}
return dialogue_state
def decode_state_batch(self,
belief_state,
request_probs=None,
active_domain_probs=None,
general_act_probs=None,
dialogue_ids=None):
"""
Decode a batch of belief state, request, active domain and general action distributions.
Args:
belief_state (dict): The belief state distributions.
request_probs (dict): The request distributions.
active_domain_probs (dict): The active domain distributions.
general_act_probs (dict): The general action distributions.
dialogue_ids (list): The dialogue IDs.
Returns:
data (dict): The decoded dialogue states.
"""
data = dict()
slot_0 = [key for key in belief_state.keys()][0]
if dialogue_ids is None:
dialogue_ids = [["{:06d}".format(i) for i in range(belief_state[slot_0].size(0))]]
for dial_idx in range(belief_state[slot_0].size(0)):
dialogue = list()
for turn_idx in range(belief_state[slot_0].size(1)):
if belief_state[slot_0][dial_idx, turn_idx].sum() != 0.0:
belief = {slot: p[dial_idx, turn_idx] for slot, p in belief_state.items()}
req = {slot: p[dial_idx, turn_idx]
for slot, p in request_probs.items()} if request_probs is not None else None
dom = {dom: p[dial_idx, turn_idx]
for dom, p in active_domain_probs.items()} if active_domain_probs is not None else None
gen = general_act_probs[dial_idx, turn_idx] if general_act_probs is not None else None
state = self.decode_state(belief, req, dom, gen)
dialogue.append(state)
data[dialogue_ids[0][dial_idx]] = dialogue
return data
def encode(self, dialogues: list, max_turns: int = 12, max_seq_len: int = 64) -> dict:
"""
Convert dialogue examples to model input features and labels
Args:
dialogues (list): List of all extracted dialogues
max_turns (int): Maximum numbers of turns in a dialogue
max_seq_len (int): Maximum number of tokens in a dialogue turn
Returns:
features (dict): All inputs and labels required to train the model
"""
features = dict()
# Get encoder input for system, user utterance pairs
input_feats = []
if len(dialogues) > 5:
iterator = tqdm(dialogues)
else:
iterator = dialogues
for dial in iterator:
dial_feats = []
for turn in dial:
if len(turn['system_utterance']) == 0:
usr = turn['user_utterance']
dial_feats.append(super().encode_plus(usr, add_special_tokens=True, max_length=max_seq_len,
padding='max_length', truncation='longest_first'))
else:
usr = turn['user_utterance']
sys = turn['system_utterance']
dial_feats.append(super().encode_plus(usr, sys, add_special_tokens=True,
max_length=max_seq_len, padding='max_length',
truncation='longest_first'))
# Truncate
if len(dial_feats) >= max_turns:
break
input_feats.append(dial_feats)
del dial_feats
# Perform turn level padding
if 'dialogue_id' in dialogues[0][0]:
dial_ids = list()
for dial in dialogues:
_ids = [turn['dialogue_id'] for turn in dial][:max_turns]
_ids += [''] * (max_turns - len(_ids))
dial_ids.append(_ids)
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]:
token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
else:
token_type_ids = None
if 'attention_mask' in input_feats[0][0]:
attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
else:
attention_mask = None
del input_feats
# Create torch data tensors
if 'dialogue_id' in dialogues[0][0]:
features['dialogue_ids'] = IdTensor(dial_ids)
features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
del input_ids, token_type_ids, attention_mask
# Extract all informable and requestable slots from the ontology
informable_slots = [f"{domain}-{slot}" for domain in self.ontology for slot in self.ontology[domain]
if self.ontology[domain][slot]['possible_values']
and self.ontology[domain][slot]['possible_values'] != ['?']]
requestable_slots = [f"{domain}-{slot}" for domain in self.ontology for slot in self.ontology[domain]
if '?' in self.ontology[domain][slot]['possible_values']]
# Extract a list of domains from the ontology slots
domains = [domain for domain in self.ontology]
# Create slot labels
if 'state' in dialogues[0][0]:
for domslot in tqdm(informable_slots):
labels = []
for dial in dialogues:
labs = []
for turn in dial:
value = [v for d, substate in turn['state'].items() for s, v in substate.items()
if f'{d}-{s}' == domslot]
domain, slot = domslot.split('-', 1)
if turn['dataset_name'] in self.ontology[domain][slot]['dataset_names']:
value = value[0] if value else 'none'
else:
value = -1
if value in self.ontology[domain][slot]['possible_values'] and value != -1:
value = self.ontology[domain][slot]['possible_values'].index(value)
else:
value = -1 # If value is not in ontology then we do not penalise the model
labs.append(value)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['state_labels-' + domslot] = labels
# Create requestable slot labels
if 'dialogue_acts' in dialogues[0][0]:
for domslot in tqdm(requestable_slots):
labels = []
for dial in dialogues:
labs = []
for turn in dial:
domain, slot = domslot.split('-', 1)
if turn['dataset_name'] in self.ontology[domain][slot]['dataset_names']:
acts = [act['intent'] for act in turn['dialogue_acts']
if act['domain'] == domain and act['slot'] == slot]
if acts:
act_ = acts[0]
if act_ == 'request':
labs.append(1)
else:
labs.append(0)
else:
labs.append(0)
else:
labs.append(-1)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['request_labels-' + domslot] = labels
# General act labels (1-goodbye, 2-thank you)
labels = []
for dial in tqdm(dialogues):
labs = []
for turn in dial:
acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']]
if acts:
if 'bye' in acts:
labs.append(1)
else:
labs.append(2)
else:
labs.append(0)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['general_act_labels'] = labels
# Create active domain labels
if 'active_domains' in dialogues[0][0]:
for domain in tqdm(domains):
labels = []
for dial in dialogues:
labs = []
for turn in dial:
possible_domains = list()
for dom in self.ontology:
for slt in self.ontology[dom]:
if turn['dataset_name'] in self.ontology[dom][slt]['dataset_names']:
possible_domains.append(dom)
if domain in turn['active_domains']:
labs.append(1)
elif domain in possible_domains:
labs.append(0)
else:
labs.append(-1)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['active_domain_labels-' + domain] = labels
try:
del labels
except:
labels = None
return features
return SetSUMBTTokenizer
This diff is collapsed.
This diff is collapsed.
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Predict dataset user action using SetSUMBT Model"""
from copy import deepcopy
import os
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from convlab.util.custom_util import flatten_acts as flatten
from convlab.util import load_dataset, load_policy_data
from convlab.dst.setsumbt import SetSUMBTTracker
def flatten_acts(acts: dict) -> list:
"""
Flatten dictionary actions.
Args:
acts: Dictionary acts
Returns:
flat_acts: Flattened actions
"""
acts = flatten(acts)
flat_acts = []
for intent, domain, slot, value in acts:
flat_acts.append([intent,
domain,
slot if slot != 'none' else '',
value.lower() if value != 'none' else ''])
return flat_acts
def get_user_actions(context: list, system_acts: list) -> list:
"""
Extract user actions from the data.
Args:
context: Previous dialogue turns.
system_acts: List of flattened system actions.
Returns:
user_acts: List of flattened user actions.
"""
user_acts = context[-1]['dialogue_acts']
user_acts = flatten_acts(user_acts)
if len(context) == 3:
prev_state = context[-3]['state']
cur_state = context[-1]['state']
for domain, substate in cur_state.items():
for slot, value in substate.items():
if prev_state[domain][slot] != value:
act = ['inform', domain, slot, value]
if act not in user_acts and act not in system_acts:
user_acts.append(act)
return user_acts
def extract_dataset(dataset: str = 'multiwoz21') -> list:
"""
Extract acts and utterances from the dataset.
Args:
dataset: Dataset name
Returns:
data: Extracted data
"""
data = load_dataset(dataset_name=dataset)
raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test']
dialogue = list()
data = list()
for turn in raw_data:
state = dict()
state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else ''
state['utterance'] = turn['context'][-1]['utterance']
state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {}
state['system_actions'] = flatten_acts(state['system_actions'])
state['user_actions'] = get_user_actions(turn['context'], state['system_actions'])
dialogue.append(state)
if turn['terminated']:
data.append(dialogue)
dialogue = list()
return data
def unflatten_acts(acts: list) -> dict:
"""
Convert acts from flat list format to dict format.
Args:
acts: List of flat actions.
Returns:
unflat_acts: Dictionary of acts.
"""
binary_acts = []
cat_acts = []
for intent, domain, slot, value in acts:
include = True if (domain == 'general') or (slot != 'none') else False
if include and (value == '' or value == 'none' or intent == 'request'):
binary_acts.append({'intent': intent,
'domain': domain,
'slot': slot if slot != 'none' else ''})
elif include:
cat_acts.append({'intent': intent,
'domain': domain,
'slot': slot if slot != 'none' else '',
'value': value})
unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()}
return unflat_acts
def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list:
"""
Predict the user actions using the SetSUMBT Tracker.
Args:
data: List of dialogues.
tracker: SetSUMBT Tracker
Returns:
predict_result: List of turns containing predictions and true user actions.
"""
tracker.init_session()
predict_result = []
for dial_idx, dialogue in enumerate(data):
for turn_idx, state in enumerate(dialogue):
sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx}
tracker.state['history'].append(['sys', state['system_utterance']])
predicted_state = deepcopy(tracker.update(state['utterance']))
tracker.state['history'].append(['usr', state['utterance']])
tracker.state['system_action'] = state['system_actions']
sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])}
sample['dialogue_acts'] = unflatten_acts(state['user_actions'])
predict_result.append(sample)
tracker.init_session()
return predict_result
if __name__ =="__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
parser.add_argument('--model_path', type=str, help='Path to model dir')
args = parser.parse_args()
dataset = extract_dataset(args.dataset_name)
tracker = SetSUMBTTracker(args.model_path)
predict_results = predict_user_acts(dataset, tracker)
with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer:
json.dump(predict_results, writer, indent=2)
writer.close()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,12 +13,23 @@ ...@@ -13,12 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Run""" """Run SetSUMBT belief tracker training and evaluation."""
from transformers import BertConfig, RobertaConfig import logging
import os
from shutil import copy2 as copy
from copy import deepcopy
from convlab.dst.setsumbt.utils import get_args import torch
import transformers
from transformers import BertConfig, RobertaConfig
from tensorboardX import SummaryWriter
from tqdm import tqdm
from convlab.dst.setsumbt.modeling import SetSUMBTModels, SetSUMBTTrainer
from convlab.dst.setsumbt.datasets import (get_dataloader, change_batch_size, dataloader_sample_dialogues,
get_distillation_dataloader)
from convlab.dst.setsumbt.utils import get_args, update_args, setup_ensemble
MODELS = { MODELS = {
'bert': (BertConfig, "BertTokenizer"), 'bert': (BertConfig, "BertTokenizer"),
...@@ -27,15 +38,277 @@ MODELS = { ...@@ -27,15 +38,277 @@ MODELS = {
def main(): def main():
# Get arguments
args, config = get_args(MODELS) args, config = get_args(MODELS)
if args.run_nbt: if args.model_type in SetSUMBTModels:
from convlab.dst.setsumbt.do.nbt import main SetSumbtModel, OntologyEncoderModel, ConfigClass, Tokenizer = SetSUMBTModels[args.model_type]
main(args, config) if args.ensemble:
if args.run_evaluation: SetSumbtModel, _, _, _ = SetSUMBTModels['ensemble']
from convlab.dst.setsumbt.do.evaluate import main else:
main(args, config) raise NameError('NotImplemented')
# Set up output directory
OUTPUT_DIR = args.output_dir
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
args.output_dir = OUTPUT_DIR
# Set pretrained model path to the trained checkpoint
paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path)
elif 'ens-0' in paths:
paths = [p for p in os.listdir(os.path.join(args.output_dir, 'ens-0')) if 'checkpoint-' in p]
if paths:
args.model_name_or_path = os.path.join(args.output_dir)
config = ConfigClass.from_pretrained(os.path.join(args.model_name_or_path, 'ens-0', paths[0]))
else:
paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
if paths:
paths = paths[0]
args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path)
args = update_args(args, config)
# Create TensorboardX writer
tb_writer = SummaryWriter(logdir=args.tensorboard_path)
# Create logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y')
fh = logging.FileHandler(args.logging_path)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
# Get device
if torch.cuda.is_available() and args.n_gpu > 0:
device = torch.device('cuda')
else:
device = torch.device('cpu')
args.n_gpu = 0
if args.n_gpu == 0:
args.fp16 = False
# Initialise Model
transformers.utils.logging.set_verbosity_info()
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = model.to(device)
if args.ensemble:
args.model_name_or_path = model._get_checkpoint_path(args.model_name_or_path, 0)
# Create Tokenizer and embedding model for Data Loaders and ontology
tokenizer = Tokenizer.from_pretrained(args.model_name_or_path)
encoder = OntologyEncoderModel.from_pretrained(config.candidate_embedding_model_name,
args=args, tokenizer=tokenizer)
transformers.utils.logging.set_verbosity_error()
if args.do_ensemble_setup:
# Build all dataloaders
train_dataloader = get_dataloader(args.dataset,
'train',
args.train_batch_size,
tokenizer,
encoder,
args.max_dialogue_len,
args.max_turn_len,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
dev_dataloader = get_dataloader(args.dataset,
'validation',
args.dev_batch_size,
tokenizer,
encoder,
args.max_dialogue_len,
args.max_turn_len,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
test_dataloader = get_dataloader(args.dataset,
'test',
args.test_batch_size,
tokenizer,
encoder,
args.max_dialogue_len,
args.max_turn_len,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
setup_ensemble(OUTPUT_DIR, args.ensemble_size)
logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
dataloaders = [dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
for _ in tqdm(range(args.ensemble_size))]
logger.info('Dataloaders built.')
for i, loader in enumerate(dataloaders):
path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, 'dataloaders', 'train.dataloader')
torch.save(loader, path)
logger.info('Dataloaders saved.')
# Do not perform standard training after ensemble setup is created
return 0
# Perform tasks
# TRAINING
if args.do_train:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
if train_dataloader.batch_size != args.train_batch_size:
train_dataloader = change_batch_size(train_dataloader, args.train_batch_size)
else:
if args.data_sampling_size <= 0:
args.data_sampling_size = None
if 'distillation' not in config.loss_function:
train_dataloader = get_dataloader(args.dataset,
'train',
args.train_batch_size,
tokenizer,
encoder,
args.max_dialogue_len,
config.max_turn_len,
resampled_size=args.data_sampling_size,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
else:
loader_args = {"ensemble_path": args.ensemble_model_path,
"set_type": "train",
"batch_size": args.train_batch_size,
"reduction": "mean" if config.loss_function == 'distillation' else "none"}
train_dataloader = get_distillation_dataloader(**loader_args)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
# Get development set batch loaders= and ontology embeddings
if args.do_eval:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if dev_dataloader.batch_size != args.dev_batch_size:
dev_dataloader = change_batch_size(dev_dataloader, args.dev_batch_size)
else:
if 'distillation' not in config.loss_function:
dev_dataloader = get_dataloader(args.dataset,
'validation',
args.dev_batch_size,
tokenizer,
encoder,
args.max_dialogue_len,
config.max_turn_len)
else:
loader_args = {"ensemble_path": args.ensemble_model_path,
"set_type": "dev",
"batch_size": args.dev_batch_size,
"reduction": "mean" if config.loss_function == 'distillation' else "none"}
dev_dataloader = get_distillation_dataloader(**loader_args)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
else:
dev_dataloader = None
# TRAINING !!!!!!!!!!!!!!!!!!
trainer = SetSUMBTTrainer(args, model, tokenizer, train_dataloader, dev_dataloader, logger, tb_writer,
device)
trainer.train()
# Copy final best model to the output dir
checkpoints = os.listdir(OUTPUT_DIR)
checkpoints = [p for p in checkpoints if 'checkpoint' in p]
checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints])
best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}')
files = ['pytorch_model.bin', 'config.json', 'merges.txt', 'special_tokens_map.json',
'tokenizer_config.json', 'vocab.json']
for file in files:
copy(os.path.join(best_checkpoint, file), os.path.join(OUTPUT_DIR, file))
# Load best model for evaluation
tokenizer = Tokenizer.from_pretrained(OUTPUT_DIR)
model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
model = model.to(device)
# Evaluation on the training set
if args.do_eval_trainset:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
if train_dataloader.batch_size != args.train_batch_size:
train_dataloader = change_batch_size(train_dataloader, args.train_batch_size)
else:
train_dataloader = get_dataloader(args.dataset, 'train', args.train_batch_size, tokenizer,
encoder, args.max_dialogue_len, config.max_turn_len)
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
# EVALUATION
trainer = SetSUMBTTrainer(args, model, tokenizer, None, train_dataloader, logger, tb_writer, device)
trainer.eval_mode(load_slots=True)
if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'train.data') if args.ensemble else None
metrics = trainer.evaluate(save_pred_dist_path=save_pred_dist_path)
trainer.log_info(metrics, logging_stage='dev')
# Evaluation on the development set
if args.do_eval:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if dev_dataloader.batch_size != args.dev_batch_size:
dev_dataloader = change_batch_size(dev_dataloader, args.dev_batch_size)
else:
dev_dataloader = get_dataloader(args.dataset, 'validation', args.dev_batch_size, tokenizer,
encoder, args.max_dialogue_len, config.max_turn_len)
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
# EVALUATION
trainer = SetSUMBTTrainer(args, model, tokenizer, None, dev_dataloader, logger, tb_writer, device)
trainer.eval_mode(load_slots=True)
if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'dev.data') if args.ensemble else None
metrics = trainer.evaluate(save_eval_path=os.path.join(OUTPUT_DIR, 'predictions', 'dev.json'),
save_pred_dist_path=save_pred_dist_path)
trainer.log_info(metrics, logging_stage='dev')
# Evaluation on the test set
if args.do_test:
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if test_dataloader.batch_size != args.test_batch_size:
test_dataloader = change_batch_size(test_dataloader, args.test_batch_size)
else:
test_dataloader = get_dataloader(args.dataset, 'test', args.test_batch_size, tokenizer,
encoder, args.max_dialogue_len, config.max_turn_len)
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
trainer = SetSUMBTTrainer(args, model, tokenizer, None, test_dataloader, logger, tb_writer, device)
trainer.eval_mode(load_slots=True)
# TESTING
if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')):
os.mkdir(os.path.join(OUTPUT_DIR, 'predictions'))
save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'test.data') if args.ensemble else None
metrics = trainer.evaluate(save_eval_path=os.path.join(OUTPUT_DIR, 'predictions', 'test.json'),
save_pred_dist_path=save_pred_dist_path, draw_calibration_diagram=True)
trainer.log_info(metrics, logging_stage='test')
# Save final model for inference
if not args.ensemble:
trainer.model.save_pretrained(OUTPUT_DIR)
trainer.tokenizer.save_pretrained(OUTPUT_DIR)
tb_writer.close()
if __name__ == "__main__": if __name__ == "__main__":
......
This diff is collapsed.
from convlab.dst.setsumbt.utils.configuration import get_args, update_args, clear_checkpoints
from convlab.dst.setsumbt.utils.ensemble import setup_ensemble, EnsembleAggregator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment