diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py index 9492faa9c9a20d1c476819bb995900ca71d56607..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/convlab/dst/setsumbt/__init__.py +++ b/convlab/dst/setsumbt/__init__.py @@ -1 +0,0 @@ -from convlab.dst.setsumbt.tracker import SetSUMBTTracker \ No newline at end of file diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py deleted file mode 100644 index a41f280d3349164a2a67333d0ab176a37cbe50ea..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/calibration_plots.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Calibration Plot plotting script""" - -import os -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - -import torch -from matplotlib import pyplot as plt - - -def main(): - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--data_dir', help='Location of the belief states', required=True) - parser.add_argument('--output', help='Output image path', default='calibration_plot.png') - parser.add_argument('--n_bins', help='Number of bins', default=10, type=int) - args = parser.parse_args() - - if torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') - path = args.data_dir - - models = os.listdir(path) - models = [os.path.join(path, model, 'test.predictions') for model in models] - - fig = plt.figure(figsize=(14,8)) - font=20 - plt.tick_params(labelsize=font-2) - linestyle = ['-', ':', (0, (3, 5, 1, 5)), '-.', (0, (5, 10))] - for i, model in enumerate(models): - conf, acc = get_calibration(model, device, n_bins=args.n_bins) - name = model.split('/')[-2].strip() - print(name, conf, acc) - plt.plot(conf, acc, label=name, linestyle=linestyle[i], linewidth=3) - - plt.plot(torch.tensor([0,1]), torch.tensor([0,1]), linestyle='--', color='black', linewidth=3) - plt.xlabel('Confidence', fontsize=font) - plt.ylabel('Joint Goal Accuracy', fontsize=font) - plt.legend(fontsize=font) - - plt.savefig(args.output) - - -def get_calibration(path, device, n_bins=10, temperature=1.00): - probs = torch.load(path, map_location=device) - y_true = probs['state_labels'] - probs = probs['belief_states'] - - y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs} - goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred} - goal_acc = sum([goal_acc[slot] for slot in goal_acc]) - goal_acc = (goal_acc == len(y_true)).int() - - scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs] - scores = torch.cat(scores, 0).min(0)[0] - - step = 1.0 / float(n_bins) - bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) - bins = [] - for b in range(n_bins): - lower, upper = bin_ranges[b], bin_ranges[b + 1] - if b == 0: - ids = torch.where((scores >= lower) * (scores <= upper))[0] - else: - ids = torch.where((scores > lower) * (scores <= upper))[0] - bins.append(ids) - - conf = [0.0] - for b in bins: - if b.size(0) > 0: - l = scores[b] - conf.append(l.mean()) - else: - conf.append(-1) - conf = torch.tensor(conf) - - slot = [s for s in y_true][0] - acc = [0.0] - for b in bins: - if b.size(0) > 0: - acc_ = goal_acc[b] - acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0] - if acc_.size(0) >= 0: - acc.append(acc_.float().mean()) - else: - acc.append(-1) - else: - acc.append(-1) - acc = torch.tensor(acc) - - conf = conf[acc != -1] - acc = acc[acc != -1] - - return conf, acc - - -if __name__ == '__main__': - main() diff --git a/convlab/dst/setsumbt/configs/end2_setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/end2_setsumbt_multiwoz21.json new file mode 100644 index 0000000000000000000000000000000000000000..185d397c86a2c9cffc9abe1e826414acf84f4e1b --- /dev/null +++ b/convlab/dst/setsumbt/configs/end2_setsumbt_multiwoz21.json @@ -0,0 +1,11 @@ +{ + "model_type": "SetSUMBT", + "dataset": "multiwoz21", + "no_action_prediction": false, + "loss_function": "distribution_distillation", + "model_name_or_path": "roberta-base", + "candidate_embedding_model_name": "roberta-base", + "train_batch_size": 3, + "dev_batch_size": 12, + "test_batch_size": 16 +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/configs/ensemble_setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/ensemble_setsumbt_multiwoz21.json new file mode 100644 index 0000000000000000000000000000000000000000..cde2a77e25e68ad48f3d1c10ae3e965fc2b1b6dd --- /dev/null +++ b/convlab/dst/setsumbt/configs/ensemble_setsumbt_multiwoz21.json @@ -0,0 +1,12 @@ +{ + "model_type": "Ensemble-SetSUMBT", + "dataset": "multiwoz21", + "ensemble_size": 5, + "data_sampling_size": 7500, + "no_action_prediction": false, + "model_name_or_path": "roberta-base", + "candidate_embedding_model_name": "roberta-base", + "train_batch_size": 3, + "dev_batch_size": 3, + "test_batch_size": 3 +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json index 57a245518aae0a111f6220b1a088943b8b64ee4c..9463243b4ab22b2e096ec266f2b6922a584c5ad6 100644 --- a/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json +++ b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json @@ -4,9 +4,7 @@ "no_action_prediction": true, "model_name_or_path": "roberta-base", "candidate_embedding_model_name": "roberta-base", - "transformers_local_files_only": false, "train_batch_size": 3, - "dev_batch_size": 16, - "test_batch_size": 16, - "run_nbt": true + "dev_batch_size": 12, + "test_batch_size": 16 } \ No newline at end of file diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py deleted file mode 100644 index 17b1f93b3b39f95827cf6c09e8826383cd00b805..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size -from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py deleted file mode 100644 index ce150a61077ad61ab9d7af2ae3537971ae925f55..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/ontology.py +++ /dev/null @@ -1,134 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Create Ontology Embeddings""" - -import json -import os -import random -from copy import deepcopy - -import torch -import numpy as np -from tqdm import tqdm - - -def set_seed(args): - """ - Set random seeds - - Args: - args (Arguments class): Arguments class containing seed and number of gpus to use - """ - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - - -def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor: - """ - Embed candidates - - Args: - candidates (list): List of candidate descriptions - args (argument class): Runtime arguments - tokenizer (transformers Tokenizer): Tokenizer for the embedding_model - embedding_model (transformer Model): Transformer model for embedding candidate descriptions - - Returns: - feats (torch.tensor): Embeddings of the candidate descriptions - """ - # Tokenize candidate descriptions - feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len, - padding='max_length', truncation='longest_first') - for val in candidates] - - # Encode tokenized descriptions - with torch.no_grad(): - feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]} - embedded_feats = embedding_model(**feats) # [num_candidates, max_candidate_len, hidden_dim] - - # Reduce/pool descriptions embeddings if required - if args.set_similarity: - feats = embedded_feats.last_hidden_state.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim] - elif args.candidate_pooling == 'cls': - feats = embedded_feats.pooler_output.detach().cpu() # [num_candidates, hidden_dim] - elif args.candidate_pooling == "mean": - feats = embedded_feats.last_hidden_state.detach().cpu() - feats = feats.sum(1) - feats = torch.nn.functional.layer_norm(feats, feats.size()) - feats = feats.detach().cpu() # [num_candidates, hidden_dim] - - return feats - - -def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True): - """ - Get embeddings for slots and candidates - - Args: - ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets - set_type (str): Subset of the dataset being used (train/validation/test) - args (argument class): Runtime arguments - tokenizer (transformers Tokenizer): Tokenizer for the embedding_model - embedding_model (transformer Model): Transormer model for embedding candidate descriptions - save_to_file (bool): Indication of whether to save information to file - - Returns: - slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot - """ - # Set model to eval mode - embedding_model.eval() - - slots = dict() - for domain, subset in tqdm(ontology.items(), desc='Domains'): - for slot, slot_info in tqdm(subset.items(), desc='Slots'): - # Get description or use "domain-slot" - if args.use_descriptions: - desc = slot_info['description'] - else: - desc = f"{domain}-{slot}" - - # Encode domain-slot pair description - slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0] - - # Obtain possible value set and discard requestable value - values = deepcopy(slot_info['possible_values']) - is_requestable = False - if '?' in values: - is_requestable = True - values.remove('?') - - # Encode value candidates - if values: - feats = encode_candidates(values, args, tokenizer, embedding_model) - else: - feats = None - - # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot - slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable) - - # Dump tensors and ontology for use in training and evaluation - if save_to_file: - writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type) - torch.save(slots, writer) - - writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w') - json.dump(ontology, writer, indent=2) - writer.close() - - return slots diff --git a/convlab/dst/setsumbt/datasets/__init__.py b/convlab/dst/setsumbt/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f78b84c53dabfea8fac71a5d7f99ef1a7ac69dd0 --- /dev/null +++ b/convlab/dst/setsumbt/datasets/__init__.py @@ -0,0 +1,4 @@ +from convlab.dst.setsumbt.datasets.unified_format import get_dataloader, change_batch_size, dataloader_sample_dialogues +from convlab.dst.setsumbt.datasets.metrics import (JointGoalAccuracy, BeliefStateUncertainty, + ActPredictionAccuracy, Metrics) +from convlab.dst.setsumbt.datasets.distillation import get_dataloader as get_distillation_dataloader diff --git a/convlab/dst/setsumbt/datasets/distillation.py b/convlab/dst/setsumbt/datasets/distillation.py new file mode 100644 index 0000000000000000000000000000000000000000..50697582cb220d6e009d21dbc57b03caa37841b3 --- /dev/null +++ b/convlab/dst/setsumbt/datasets/distillation.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Get ensemble predictions and build distillation dataloaders""" + +import os + +import torch +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + +from convlab.dst.setsumbt.datasets.unified_format import UnifiedFormatDataset +from convlab.dst.setsumbt.datasets.utils import IdTensor + +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def get_dataloader(ensemble_path:str, set_type: str = 'train', batch_size: int = 3, + reduction: str = 'none') -> DataLoader: + """ + Get dataloader for distillation of ensemble. + + Args: + ensemble_path: Path to ensemble model and predictive distributions. + set_type: Dataset split to load. + batch_size: Batch size. + reduction: Reduction to apply to ensemble predictive distributions. + + Returns: + loader: Dataloader for distillation. + """ + # Load data and predictions from ensemble + path = os.path.join(ensemble_path, 'dataloaders', f"{set_type}.dataloader") + dataset = torch.load(path).dataset + + path = os.path.join(ensemble_path, 'predictions', f"{set_type}.data") + data = torch.load(path) + + dialogue_ids = data.pop('dialogue_ids') + + # Preprocess data + data = reduce_data(data, reduction=reduction) + data = flatten_data(data) + data = do_label_padding(data) + + # Build dataset and dataloader + data = UnifiedFormatDataset.from_datadict(set_type=set_type if set_type != 'dev' else 'validation', + data=data, + ontology=dataset.ontology, + ontology_embeddings=dataset.ontology_embeddings) + data.features['dialogue_ids'] = IdTensor(dialogue_ids) + + if set_type == 'train': + sampler = RandomSampler(data) + else: + sampler = SequentialSampler(data) + + loader = DataLoader(data, sampler=sampler, batch_size=batch_size) + return loader + + +def reduce_data(data: dict, reduction: str = 'none') -> dict: + """ + Reduce ensemble predictive distributions. + + Args: + data: Dictionary of ensemble predictive distributions. + reduction: Reduction to apply to ensemble predictive distributions. + + Returns: + data: Reduced ensemble predictive distributions. + """ + if reduction == 'mean': + data['belief_state'] = {slot: probs.mean(-2) for slot, probs in data['belief_state'].items()} + if 'request_probabilities' in data: + data['request_probabilities'] = {slot: probs.mean(-1) + for slot, probs in data['request_probabilities'].items()} + data['active_domain_probabilities'] = {domain: probs.mean(-1) + for domain, probs in data['active_domain_probabilities'].items()} + data['general_act_probabilities'] = data['general_act_probabilities'].mean(-2) + return data + + +def do_label_padding(data: dict) -> dict: + """ + Add padding to the ensemble predictions (used as labels in distillation) + + Args: + data: Dictionary of ensemble predictions + + Returns: + data: Padded ensemble predictions + """ + if 'attention_mask' in data: + dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0) + else: + dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0) + + for key in data: + if key not in ['input_ids', 'attention_mask', 'token_type_ids']: + data[key][dialogs, turns] = -1 + + return data + + +def flatten_data(data: dict) -> dict: + """ + Map data to flattened feature format used in training + + Args: + data: Ensemble prediction data + + Returns: + data: Flattened ensemble prediction data + """ + data_new = dict() + for label, feats in data.items(): + if type(feats) == dict: + for label_, feats_ in feats.items(): + data_new[label + '-' + label_] = feats_ + else: + data_new[label] = feats + + return data_new diff --git a/convlab/dst/setsumbt/datasets/metrics.py b/convlab/dst/setsumbt/datasets/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..690d75baeaee4ef6ec5c1d963e8238db0044c57b --- /dev/null +++ b/convlab/dst/setsumbt/datasets/metrics.py @@ -0,0 +1,566 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Metrics for DST models.""" + +import json +import os + +import torch +from transformers.utils import ModelOutput +from matplotlib import pyplot as plt + +from convlab.util import load_dataset +from convlab.util import load_dst_data +from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, QUANTITIES + + +class Metrics(ModelOutput): + """Metrics for DST models.""" + def __add__(self, other): + """Add two metrics objects.""" + for key, itm in other.items(): + assert key not in self + self[key] = itm + return self + + def compute_score(self, **weights): + """ + Compute the score for the metrics object. + + Args: + request (float): The weight for the request F1 score. + active_domain (float): The weight for the active domain F1 score. + general_act (float): The weight for the general act F1 score. + """ + assert 'joint_goal_accuracy' in self + self.score = 0.0 + if 'request_f1' in self and 'request' in weights: + self.score += self.request_f1 * weights['request'] + if 'active_domain_f1' in self and 'active_domain' in weights: + self.score += self.active_domain_f1 * weights['active_domain'] + if 'general_act_f1' in self and 'general_act' in weights: + self.score += self.general_act_f1 * weights['general_act'] + self.score += self.joint_goal_accuracy + + def __gt__(self, other): + """Compare two metrics objects.""" + assert isinstance(other, Metrics) + + if self.joint_goal_accuracy > other.joint_goal_accuracy: + return True + elif 'score' in self and 'score' in other and self.score > other.score: + return True + elif self.training_loss < other.training_loss: + return True + else: + return False + +class JointGoalAccuracy: + """Joint goal accuracy metric.""" + + def __init__(self, dataset_names, validation_split='test'): + """ + Initialize the joint goal accuracy metric. + + Args: + dataset_names (str): The name of the dataset(s) to use for computing the metric. + validation_split (str): The split of the dataset to use for computing the metric. + """ + self.dataset_names = [name for name in dataset_names.split('+')] + self.validation_split = validation_split + self._extract_data() + self._extract_states() + self._init_session() + + def _extract_data(self): + """Extract the data from the dataset.""" + dataset_dicts = [load_dataset(dataset_name=name) for name in self.dataset_names] + self.golden_states = dict() + for dataset_dict in dataset_dicts: + dataset = load_dst_data(dataset_dict, data_split=self.validation_split, speaker='all', dialogue_acts=True, + split_to_turn=False) + for dial in dataset[self.validation_split]: + self.golden_states[dial['dialogue_id']] = dial['turns'] + + @staticmethod + def _clean_state(state): + """ + Clean the state to remove pipe separated values and map values to the standard set. + + Args: + state (dict): The state to clean. + + Returns: + dict: The cleaned state. + """ + clean_state = dict() + for domain, subset in state.items(): + clean_state[domain] = {} + for slot, value in subset.items(): + value = value.split('|') + + # Map values using value_map + for old, new in VALUE_MAP.items(): + value = [val.replace(old, new) for val in value] + value = '|'.join(value) + + # Map dontcare to "do not care" and empty to 'none' + value = value.replace('dontcare', 'do not care') + value = value if value else 'none' + + # Map quantity values to the integer quantity value + if 'people' in slot or 'duration' in slot or 'stay' in slot: + try: + if value not in ['do not care', 'none']: + value = int(value) + value = str(value) if value < 10 else QUANTITIES[-1] + except: + value = value + # Map time values to the most appropriate value in the standard time set + elif 'time' in slot or 'leave' in slot or 'arrive' in slot: + try: + if value not in ['do not care', 'none']: + # Strip after/before from time value + value = value.replace('after ', '').replace('before ', '') + # Extract hours and minutes from different possible formats + if ':' not in value and len(value) == 4: + h, m = value[:2], value[2:] + elif len(value) == 1: + h = int(value) + m = 0 + elif 'pm' in value: + h = int(value.replace('pm', '')) + 12 + m = 0 + elif 'am' in value: + h = int(value.replace('pm', '')) + m = 0 + elif ':' in value: + h, m = value.split(':') + elif ';' in value: + h, m = value.split(';') + # Map to closest 5 minutes + if int(m) % 5 != 0: + m = round(int(m) / 5) * 5 + h = int(h) + if m == 60: + m = 0 + h += 1 + if h >= 24: + h -= 24 + # Set in standard 24 hour format + h, m = int(h), int(m) + value = '%02i:%02i' % (h, m) + except: + value = value + # Map boolean slots to yes/no value + elif 'parking' in slot or 'internet' in slot: + if value not in ['do not care', 'none']: + if value == 'free': + value = 'yes' + elif True in [v in value.lower() for v in ['yes', 'no']]: + value = [v for v in ['yes', 'no'] if v in value][0] + + value = value if value != 'none' else '' + + clean_state[domain][slot] = value + + return clean_state + + def _extract_states(self): + """Extract the states from the dataset.""" + for dial_id, dial in self.golden_states.items(): + states = list() + for turn in dial: + if 'state' in turn: + state = self._clean_state(turn['state']) + states.append(state) + self.golden_states[dial_id] = states + + def _init_session(self): + """Initialize the session.""" + self.samples = dict() + + def add_dialogues(self, predictions): + """ + Add dialogues to the metric. + + Args: + predictions (dict): Dictionary of predicted dialogue belief states. + """ + for dial_id, dialogue in predictions.items(): + for turn_id, turn in enumerate(dialogue): + if dial_id in self.golden_states: + sample = {'dialogue_id': dial_id, + 'turn_id': turn_id, + 'state': self.golden_states[dial_id][turn_id], + 'predictions': turn['belief_state']} + self.samples[f"{dial_id}_{turn_id}"] = sample + + def save_dialogues(self, path): + """ + Save the dialogues and predictions to a file. + + Args: + path (str): The path to save the dialogues to. + """ + dialogues = list() + for idx, turn in self.samples.items(): + predictions = dict() + for domain in turn['state']: + predictions[domain] = dict() + for slot in turn['state'][domain]: + predictions[domain][slot] = turn['predictions'].get(domain, dict()).get(slot, '') + dialogues.append({'dialogue_id': turn['dialogue_id'], + 'turn_id': turn['turn_id'], + 'state': turn['state'], + 'predictions': {'state': predictions}}) + + with open(path, 'w') as writer: + json.dump(dialogues, writer, indent=2) + writer.close() + + def evaluate(self): + """Evaluate the metric.""" + assert len(self.samples) > 0 + metrics = {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0, 'Correct': 0, 'N': 0} + for dial_id, sample in self.samples.items(): + correct = True + for domain in sample['state']: + for slot, values in sample['state'][domain].items(): + metrics['N'] += 1 + if domain not in sample['predictions'] or slot not in sample['predictions'][domain]: + predict_values = '' + else: + predict_values = ''.join(sample['predictions'][domain][slot].split()).lower() + if len(values) > 0: + if len(predict_values) > 0: + values = [''.join(value.split()).lower() for value in values.split('|')] + predict_values = [''.join(value.split()).lower() for value in predict_values.split('|')] + if any([value in values for value in predict_values]): + metrics['TP'] += 1 + else: + correct = False + metrics['FP'] += 1 + else: + metrics['FN'] += 1 + correct = False + else: + if len(predict_values) > 0: + metrics['FP'] += 1 + correct = False + else: + metrics['TN'] += 1 + + metrics['Correct'] += int(correct) + + TP = metrics.pop('TP') + FP = metrics.pop('FP') + FN = metrics.pop('FN') + TN = metrics.pop('TN') + Correct = metrics.pop('Correct') + N = metrics.pop('N') + precision = 1.0 * TP / (TP + FP) if TP + FP else 0. + recall = 1.0 * TP / (TP + FN) if TP + FN else 0. + f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0. + slot_accuracy = (TP + TN) / N + joint_goal_accuracy = Correct / len(self.samples) + + metrics = Metrics(joint_goal_accuracy=joint_goal_accuracy * 100., + slot_accuracy=slot_accuracy * 100., + slot_f1=f1 * 100., + slot_precision=precision * 100., + slot_recall=recall * 100.) + + return metrics + + +class BeliefStateUncertainty: + """Compute the uncertainty of the belief state predictions.""" + + def __init__(self, n_confidence_bins=10): + """ + Initialize the metric. + + Args: + n_confidence_bins (int): Number of confidence bins. + """ + self._init_session() + self.n_confidence_bins = n_confidence_bins + + def _init_session(self): + """Initialize the session.""" + self.samples = {'belief_state': dict(), + 'golden_state': dict()} + self.bin_info = {'confidence': None, + 'accuracy': None} + + def add_dialogues(self, predictions, labels): + """ + Add dialogues to the metric. + + Args: + predictions (dict): Dictionary of predicted dialogue belief states. + labels (dict): Dictionary of golden dialogue belief states. + """ + for slot, probs in predictions.items(): + if slot not in self.samples['belief_state']: + self.samples['belief_state'][slot] = probs.reshape(-1, probs.size(-1)).cpu() + self.samples['golden_state'][slot] = labels[slot].reshape(-1).cpu() + else: + self.samples['belief_state'][slot] = torch.cat((self.samples['belief_state'][slot], + probs.reshape(-1, probs.size(-1)).cpu()), 0) + self.samples['golden_state'][slot] = torch.cat((self.samples['golden_state'][slot], + labels[slot].reshape(-1).cpu()), 0) + + def _fill_bins(self, probs: torch.Tensor) -> list: + """ + Fill the bins with the relevant observation ids. + + Args: + probs (Tensor): Predicted probabilities. + + Returns: + list: List of bins. + """ + assert probs.dim() == 2 + probs = probs.max(-1)[0] + + step = 1.0 / self.n_confidence_bins + bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) + bins = [] + # Compute the bin ranges + for b in range(self.n_confidence_bins): + lower, upper = bin_ranges[b], bin_ranges[b + 1] + if b == 0: + ids = torch.where((probs >= lower) * (probs <= upper))[0] + else: + ids = torch.where((probs > lower) * (probs <= upper))[0] + bins.append(ids) + + return bins + + @staticmethod + def _bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor: + """ + Compute the average confidence score for each bin. + + Args: + bins (list): List of confidence bins. + probs (Tensor): Predicted probabilities. + + Returns: + scores: Confidence score for each bin. + """ + probs = probs.max(-1)[0] + + scores = [] + for b in bins: + if b is not None: + scores.append(probs[b].mean()) + else: + scores.append(-1) + scores = torch.tensor(scores) + return scores + + def _jg_ece(self) -> float: + """Compute the joint goal Expected Calibration Error.""" + y_pred = {slot: probs.argmax(-1) for slot, probs in self.samples['belief_state'].items()} + goal_acc = [(y_pred[slot] == y_true).int() for slot, y_true in self.samples['golden_state'].items()] + goal_acc = (sum(goal_acc) / len(goal_acc)).int() + + # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state + scores = [probs.max(-1)[0].unsqueeze(1) for slot, probs in self.samples["belief_state"].items()] + scores = torch.cat(scores, 1).min(1)[0] + + bins = self._fill_bins(scores.unsqueeze(-1)) + conf = self._bin_confidence(bins, scores.unsqueeze(-1)) + + slot_0 = list(self.samples['golden_state'].keys())[0] + acc = [] + for b in bins: + if b is not None: + acc_ = goal_acc[b] + acc_ = acc_[self.samples['golden_state'][slot_0][b] >= 0] + if acc_.size(0) >= 0: + acc.append(acc_.float().mean()) + else: + acc.append(-1) + else: + acc.append(-1) + acc = torch.tensor(acc) + + self.bin_info['confidence'] = conf + self.bin_info['accuracy'] = acc + + n = self.samples["belief_state"][slot_0].size(0) + bk = torch.tensor([b.size(0) for b in bins]) + + ece = torch.abs(conf - acc) * bk / n + ece = ece[acc >= 0.0] + ece = ece.sum().item() + + return ece + + def draw_calibration_diagram(self, save_path: str, validation_split=None): + """ + Draw the calibration diagram. + + Args: + save_path (str): Path to save the calibration diagram. + validation_split (str): Validation split. + """ + if self.bin_info['confidence'] is None: + self._jg_ece() + + acc = self.bin_info['accuracy'] + conf = self.bin_info['confidence'] + conf = conf[acc >= 0.0] + acc = acc[acc >= 0.0] + + fig = plt.figure(figsize=(14,8)) + font = 20 + plt.tick_params(labelsize=font - 2) + linestyle = '-' + + plt.plot(torch.tensor([0, 1]), torch.tensor([0, 1]), linestyle='--', color='black', linewidth=3) + plt.plot(conf, acc, linestyle=linestyle, color='red', linewidth=3) + plt.xlabel('Confidence', fontsize=font) + plt.ylabel('Joint Goal Accuracy', fontsize=font) + + path = validation_split + '_calibration_diagram.json' if validation_split else 'calibration_diagram.json' + path = os.path.join(save_path, 'predictions', path) + with open(path, 'w') as f: + json.dump({'confidence': conf.tolist(), 'accuracy': acc.tolist()}, f) + + path = validation_split + '_calibration_diagram.png' if validation_split else 'calibration_diagram.png' + path = os.path.join(save_path, path) + plt.savefig(path) + + def _l2_err(self, remove_belief: bool = False) -> float: + """ + Compute the L2 error between the predicted and target distribution. + + Args: + remove_belief (bool): Remove the belief state and replace it with a 1 hot prediction. + + Returns: + l2_err: L2 error between the predicted and target distribution. + """ + # Get ids used for removing padding turns. + slot_0 = list(self.samples['golden_state'].keys())[0] + padding = torch.where(self.samples['golden_state'][slot_0] != -1)[0] + + distributions = [] + labels = [] + for slot, probs in self.samples['belief_state'].items(): + # Replace distribution by a 1 hot prediction + if remove_belief: + probs_ = torch.zeros(probs.shape).float() + probs_[range(probs.size(0)), probs.argmax(-1)] = 1.0 + probs = probs_ + del probs_ + # Remove padding turns + lab = self.samples['golden_state'][slot] + probs = probs[padding] + lab = lab[padding] + + # Target distribution + y = torch.zeros(probs.shape) + y[range(y.size(0)), lab] = 1.0 + + distributions.append(probs) + labels.append(y) + + # Concatenate all slots into a single belief state + distributions = torch.cat(distributions, -1) + labels = torch.cat(labels, -1) + + # Calculate L2-Error for each turn + err = torch.sqrt(((labels - distributions) ** 2).sum(-1)) + return err.mean().item() + + def evaluate(self): + """Evaluate the metrics.""" + l2_err = self._l2_err(remove_belief=False) + binary_l2_err = self._l2_err(remove_belief=True) + l2_err_ratio = (binary_l2_err - l2_err) / binary_l2_err + metrics = Metrics( + joint_goal_ece=self._jg_ece() * 100., + joint_l2_error=l2_err, + joint_l2_error_ratio=l2_err_ratio * 100. + ) + return metrics + + +class ActPredictionAccuracy: + """Calculate the accuracy of the action predictions.""" + + def __init__(self, act_type, binary=False): + """ + Args: + act_type (str): Type of action to evaluate. + binary (bool): Whether the action is binary or multilabel. + """ + self.act_type = act_type + self.binary = binary + self._init_session() + + def _init_session(self): + """Initialize the session.""" + self.samples = {'predictions': dict(), + 'labels': dict()} + + def add_dialogues(self, predictions, labels): + """ + Add dialogues to the session. + + Args: + predictions (dict): Action predictions. + labels (dict): Action labels. + """ + for slot, probs in predictions.items(): + if slot in labels: + pred = probs.cpu().argmax(-1).reshape(-1) if not self.binary else probs.cpu().round().int().reshape(-1) + if slot not in self.samples['predictions']: + self.samples['predictions'][slot] = pred + self.samples['labels'][slot] = labels[slot].reshape(-1).cpu() + else: + self.samples['predictions'][slot] = torch.cat((self.samples['predictions'][slot], pred), 0) + self.samples['labels'][slot] = torch.cat((self.samples['labels'][slot], + labels[slot].reshape(-1).cpu()), 0) + + def evaluate(self): + """Evaluate the metrics.""" + metrics = {'TP': 0, 'FP': 0, 'FN': 0, 'Correct': 0, 'N': 0} + for slot, pred in self.samples['predictions'].items(): + metrics['N'] += pred.size(0) + metrics['Correct'] += (pred == self.samples['labels'][slot]).sum() + tp = (pred > 0) * (self.samples['labels'][slot] > 0) * (pred == self.samples['labels'][slot]) + metrics['TP'] += tp.sum() + metrics['FP'] += ((pred > 0) * (self.samples['labels'][slot] == 0)).sum() + metrics['FN'] += ((pred == 0) * (self.samples['labels'][slot] > 0)).sum() + + TP = metrics.pop('TP') + FP = metrics.pop('FP') + FN = metrics.pop('FN') + Correct = metrics.pop('Correct') + N = metrics.pop('N') + precision = 1.0 * TP / (TP + FP) if TP + FP else 0. + recall = 1.0 * TP / (TP + FN) if TP + FN else 0. + f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0. + + metrics = {f'{self.act_type}_f1': f1 * 100.} + return Metrics(**metrics) diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/datasets/unified_format.py similarity index 50% rename from convlab/dst/setsumbt/dataset/unified_format.py rename to convlab/dst/setsumbt/datasets/unified_format.py index 55483e0f4e3404e96d817395c53dd9a6fcd57c3e..68a371f7a8f54c8258b20e584abba6a9894a9e17 100644 --- a/convlab/dst/setsumbt/dataset/unified_format.py +++ b/convlab/dst/setsumbt/datasets/unified_format.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,260 +14,81 @@ # See the License for the specific language governing permissions and # limitations under the License. """Convlab3 Unified Format Dialogue Datasets""" -import pdb -from copy import deepcopy import torch import transformers from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from transformers.tokenization_utils import PreTrainedTokenizer -from tqdm import tqdm from convlab.util import load_dataset -from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, - get_values_from_data, ontology_add_requestable_slots, - get_requestable_slots, load_dst_data, extract_dialogues, - combine_value_sets, IdTensor) +from convlab.dst.setsumbt.datasets.utils import (get_ontology_slots, ontology_add_values, + get_values_from_data, ontology_add_requestable_slots, + get_requestable_slots, load_dst_data, extract_dialogues, + combine_value_sets) transformers.logging.set_verbosity_error() -def convert_examples_to_features(data: list, - ontology: dict, - tokenizer: PreTrainedTokenizer, - max_turns: int = 12, - max_seq_len: int = 64) -> dict: - """ - Convert dialogue examples to model input features and labels - - Args: - data (list): List of all extracted dialogues - ontology (dict): Ontology dictionary containing slots, slot descriptions and - possible value sets including requests - tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used - max_turns (int): Maximum numbers of turns in a dialogue - max_seq_len (int): Maximum number of tokens in a dialogue turn - - Returns: - features (dict): All inputs and labels required to train the model - """ - features = dict() - ontology = deepcopy(ontology) - - # Get encoder input for system, user utterance pairs - input_feats = [] - for dial in tqdm(data): - dial_feats = [] - for turn in dial: - if len(turn['system_utterance']) == 0: - usr = turn['user_utterance'] - dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True, - max_length=max_seq_len, padding='max_length', - truncation='longest_first')) - else: - usr = turn['user_utterance'] - sys = turn['system_utterance'] - dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True, - max_length=max_seq_len, padding='max_length', - truncation='longest_first')) - # Truncate - if len(dial_feats) >= max_turns: - break - input_feats.append(dial_feats) - del dial_feats - - # Perform turn level padding - dial_ids = list() - for dial in data: - _ids = [turn['dialogue_id'] for turn in dial][:max_turns] - _ids += [''] * (max_turns - len(_ids)) - dial_ids.append(_ids) - input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) - for dial in input_feats] - if 'token_type_ids' in input_feats[0][0]: - token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) - for dial in input_feats] - else: - token_type_ids = None - if 'attention_mask' in input_feats[0][0]: - attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) - for dial in input_feats] - else: - attention_mask = None - del input_feats - - # Create torch data tensors - features['dialogue_ids'] = IdTensor(dial_ids) - features['input_ids'] = torch.tensor(input_ids) - features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None - features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None - del input_ids, token_type_ids, attention_mask - - # Extract all informable and requestable slots from the ontology - informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain] - if ontology[domain][slot]['possible_values'] - and ontology[domain][slot]['possible_values'] != ['?']] - requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain] - if '?' in ontology[domain][slot]['possible_values']] - for slot in requestable_slots: - domain, slot = slot.split('-', 1) - ontology[domain][slot]['possible_values'].remove('?') - - # Extract a list of domains from the ontology slots - domains = list(set(informable_slots + requestable_slots)) - domains = list(set([slot.split('-', 1)[0] for slot in domains])) - - # Create slot labels - for domslot in tqdm(informable_slots): - labels = [] - for dial in data: - labs = [] - for turn in dial: - value = [v for d, substate in turn['state'].items() for s, v in substate.items() - if f'{d}-{s}' == domslot] - domain, slot = domslot.split('-', 1) - if turn['dataset_name'] in ontology[domain][slot]['dataset_names']: - value = value[0] if value else 'none' - else: - value = -1 - if value in ontology[domain][slot]['possible_values'] and value != -1: - value = ontology[domain][slot]['possible_values'].index(value) - else: - value = -1 # If value is not in ontology then we do not penalise the model - labs.append(value) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['state_labels-' + domslot] = labels - - # Create requestable slot labels - for domslot in tqdm(requestable_slots): - labels = [] - for dial in data: - labs = [] - for turn in dial: - domain, slot = domslot.split('-', 1) - if turn['dataset_name'] in ontology[domain][slot]['dataset_names']: - acts = [act['intent'] for act in turn['dialogue_acts'] - if act['domain'] == domain and act['slot'] == slot] - if acts: - act_ = acts[0] - if act_ == 'request': - labs.append(1) - else: - labs.append(0) - else: - labs.append(0) - else: - labs.append(-1) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['request_labels-' + domslot] = labels - - # General act labels (1-goodbye, 2-thank you) - labels = [] - for dial in tqdm(data): - labs = [] - for turn in dial: - acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']] - if acts: - if 'bye' in acts: - labs.append(1) - else: - labs.append(2) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['general_act_labels'] = labels - - # Create active domain labels - for domain in tqdm(domains): - labels = [] - for dial in data: - labs = [] - for turn in dial: - possible_domains = list() - for dom in ontology: - for slt in ontology[dom]: - if turn['dataset_name'] in ontology[dom][slt]['dataset_names']: - possible_domains.append(dom) - - if domain in turn['active_domains']: - labs.append(1) - elif domain in possible_domains: - labs.append(0) - else: - labs.append(-1) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['active_domain_labels-' + domain] = labels - - del labels - - return features - - class UnifiedFormatDataset(Dataset): """ Class for preprocessing, and storing data easily from the Convlab3 unified format. Attributes: - dataset_dict (dict): Dictionary containing all the data in dataset + set_type (str): Subset of the dataset to load (train, validation or test) + dataset_dicts (dict): Dictionary containing all the data in dataset ontology (dict): Set of all domain-slot-value triplets in the ontology of the model + ontology_embeddings (dict): Set of all domain-slot-value triplets in the ontology of the model features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model """ def __init__(self, dataset_name: str, set_type: str, tokenizer: PreTrainedTokenizer, + ontology_encoder, max_turns: int = 12, max_seq_len: int = 64, train_ratio: float = 1.0, seed: int = 0, data: dict = None, - ontology: dict = None): + ontology: dict = None, + ontology_embeddings: dict = None): """ Args: dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +) set_type (str): Subset of the dataset to load (train, validation or test) tokenizer (transformers tokenizer): Tokenizer for the encoder model used + ontology_encoder (transformers model): Ontology encoder model max_turns (int): Maximum numbers of turns in a dialogue max_seq_len (int): Maximum number of tokens in a dialogue turn train_ratio (float): Fraction of training data to use during training seed (int): Seed governing random order of ids for subsampling data (dict): Dataset features for loading from dict ontology (dict): Ontology dict for loading from dict + ontology_embeddings (dict): Ontology embeddings for loading from dict """ + # Load data from dict if provided if data is not None: + self.set_type = set_type self.ontology = ontology + self.ontology_embeddings = ontology_embeddings self.features = data + # Load data from dataset if data is not provided else: if '+' in dataset_name: dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')] else: dataset_args = [{"dataset_name": dataset_name}] self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args] + self.set_type = set_type + self.ontology = get_ontology_slots(dataset_name) values = [get_values_from_data(dataset, set_type) for dataset in self.dataset_dicts] self.ontology = ontology_add_values(self.ontology, combine_value_sets(values), set_type) self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts)) + tokenizer.set_setsumbt_ontology(self.ontology) + self.ontology_embeddings = ontology_encoder.get_slot_candidate_embeddings() + if train_ratio != 1.0: for dataset_args_ in dataset_args: dataset_args_['dial_ids_order'] = seed @@ -282,7 +103,7 @@ class UnifiedFormatDataset(Dataset): data = [] for idx, data_ in enumerate(data_list): data += extract_dialogues(data_, dataset_args[idx]["dataset_name"]) - self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len) + self.features = tokenizer.encode(data, max_turns, max_seq_len) def __getitem__(self, index: int) -> dict: """ @@ -350,14 +171,15 @@ class UnifiedFormatDataset(Dataset): if self.features[label] is not None} @classmethod - def from_datadict(cls, data: dict, ontology: dict): - return cls(None, None, None, data=data, ontology=ontology) + def from_datadict(cls, set_type: str, data: dict, ontology: dict, ontology_embeddings: dict): + return cls(None, set_type, None, None, data=data, ontology=ontology, ontology_embeddings=ontology_embeddings) def get_dataloader(dataset_name: str, set_type: str, batch_size: int, tokenizer: PreTrainedTokenizer, + ontology_encoder, max_turns: int = 12, max_seq_len: int = 64, device='cpu', @@ -372,6 +194,7 @@ def get_dataloader(dataset_name: str, set_type (str): Subset of the dataset to load (train, validation or test) batch_size (int): Batch size for the dataloader tokenizer (transformers tokenizer): Tokenizer for the encoder model used + ontology_encoder (OntologyEncoder): Ontology encoder object max_turns (int): Maximum numbers of turns in a dialogue max_seq_len (int): Maximum number of tokens in a dialogue turn device (torch device): Device to map data to @@ -382,8 +205,8 @@ def get_dataloader(dataset_name: str, Returns: loader (torch dataloader): Dataloader to train and evaluate the setsumbt model ''' - data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio, - seed=seed) + data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, ontology_encoder, max_turns, max_seq_len, + train_ratio=train_ratio, seed=seed) data.to(device) if resampled_size: @@ -418,6 +241,7 @@ def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader: return loader + def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader: """ Sample a subset of the dialogues in a dataloader diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/datasets/utils.py similarity index 99% rename from convlab/dst/setsumbt/dataset/utils.py rename to convlab/dst/setsumbt/datasets/utils.py index 96773d6b9b181925b3004e4971e440d9c7720bfb..f227a569c5cfc782dda1fbedb61b5afbffa17cf5 100644 --- a/convlab/dst/setsumbt/dataset/utils.py +++ b/convlab/dst/setsumbt/datasets/utils.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,10 +16,9 @@ """Convlab3 Unified dataset data processing utilities""" import numpy -import pdb from convlab.util import load_ontology, load_dst_data, load_nlu_data -from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME +from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME def get_ontology_slots(dataset_name: str) -> dict: @@ -424,6 +423,7 @@ class IdTensor: def extract_dialogues(data: list, dataset_name: str) -> list: """ Extract all dialogues from dataset + Args: data (list): List of all dialogues in a subset of the data dataset_name (str): Name of the dataset to which the dialogues belongs diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/datasets/value_maps.py similarity index 96% rename from convlab/dst/setsumbt/dataset/value_maps.py rename to convlab/dst/setsumbt/datasets/value_maps.py index 619600a7b0a57096918058ff117aa2ca5aac864a..d4ef64a0e21839e3ecade16ded4c03aea738fa98 100644 --- a/convlab/dst/setsumbt/dataset/value_maps.py +++ b/convlab/dst/setsumbt/datasets/value_maps.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -47,4 +47,4 @@ DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buse # Generic value sets for quantity and time slots QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more'] TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)] -TIME = ['%02i:%02i' % t for l in TIME for t in l] \ No newline at end of file +TIME = ['%02i:%02i' % t for l in TIME for t in l] diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py deleted file mode 100644 index 2279e22265ea417ebe9a13e63837a625f858e73d..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/distillation_setup.py +++ /dev/null @@ -1,277 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Get ensemble predictions and build distillation dataloaders""" - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -import os -import json - -import torch -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from tqdm import tqdm - -from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size -from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT -from convlab.dst.setsumbt.modeling import training - -DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' - - -def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader: - """ - Build dataloader from ensemble prediction data - - Args: - data: Dictionary of ensemble predictions - ontology: Data ontology - set_type: Data subset (train/validation/test) - batch_size: Number of dialogues per batch - - Returns: - loader: Data loader object - """ - data = flatten_data(data) - data = do_label_padding(data) - data = UnifiedFormatDataset.from_datadict(data, ontology) - if set_type == 'train': - sampler = RandomSampler(data) - else: - sampler = SequentialSampler(data) - - loader = DataLoader(data, sampler=sampler, batch_size=batch_size) - return loader - - -def do_label_padding(data: dict) -> dict: - """ - Add padding to the ensemble predictions (used as labels in distillation) - - Args: - data: Dictionary of ensemble predictions - - Returns: - data: Padded ensemble predictions - """ - if 'attention_mask' in data: - dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0) - else: - dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0) - - for key in data: - if key not in ['input_ids', 'attention_mask', 'token_type_ids']: - data[key][dialogs, turns] = -1 - - return data - - -def flatten_data(data: dict) -> dict: - """ - Map data to flattened feature format used in training - Args: - data: Ensemble prediction data - - Returns: - data: Flattened ensemble prediction data - """ - data_new = dict() - for label, feats in data.items(): - if type(feats) == dict: - for label_, feats_ in feats.items(): - data_new[label + '-' + label_] = feats_ - else: - data_new[label] = feats - - return data_new - - -def get_ensemble_distributions(args): - """ - Load data and get ensemble predictions - Args: - args: Runtime arguments - """ - device = DEVICE - - model = EnsembleSetSUMBT.from_pretrained(args.model_path) - model = model.to(device) - - print('Model Loaded!') - - dataloader = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader') - database = os.path.join(args.model_path, 'database', f'{args.set_type}.db') - - dataloader = torch.load(dataloader) - database = torch.load(database) - - if dataloader.batch_size != args.batch_size: - dataloader = change_batch_size(dataloader, args.batch_size) - - training.set_ontology_embeddings(model, database) - - print('Environment set up.') - - input_ids = [] - token_type_ids = [] - attention_mask = [] - state_labels = {slot: [] for slot in model.informable_slot_ids} - request_labels = {slot: [] for slot in model.requestable_slot_ids} - active_domain_labels = {domain: [] for domain in model.domain_ids} - general_act_labels = [] - - is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None - - belief_state = {slot: [] for slot in model.informable_slot_ids} - request_probs = {slot: [] for slot in model.requestable_slot_ids} - active_domain_probs = {domain: [] for domain in model.domain_ids} - general_act_probs = [] - model.eval() - for batch in tqdm(dataloader, desc='Batch:'): - ids = batch['input_ids'] - tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None - mask = batch['attention_mask'] if 'attention_mask' in batch else None - - if 'is_noisy' in batch: - is_noisy.append(batch['is_noisy']) - - input_ids.append(ids) - token_type_ids.append(tt_ids) - attention_mask.append(mask) - - ids = ids.to(device) - tt_ids = tt_ids.to(device) if tt_ids is not None else None - mask = mask.to(device) if mask is not None else None - - for slot in state_labels: - state_labels[slot].append(batch['state_labels-' + slot]) - if model.config.predict_actions: - for slot in request_labels: - request_labels[slot].append(batch['request_labels-' + slot]) - for domain in active_domain_labels: - active_domain_labels[domain].append(batch['active_domain_labels-' + domain]) - general_act_labels.append(batch['general_act_labels']) - - with torch.no_grad(): - p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction) - - for slot in belief_state: - belief_state[slot].append(p[slot].cpu()) - if model.config.predict_actions: - for slot in request_probs: - request_probs[slot].append(p_req[slot].cpu()) - for domain in active_domain_probs: - active_domain_probs[domain].append(p_dom[domain].cpu()) - general_act_probs.append(p_gen.cpu()) - - input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None - token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None - attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None - is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None - - state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()} - if model.config.predict_actions: - request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()} - active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()} - general_act_labels = torch.cat(general_act_labels, 0) - - belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()} - if model.config.predict_actions: - request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()} - active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()} - general_act_probs = torch.cat(general_act_probs, 0) - - data = {'input_ids': input_ids} - if token_type_ids is not None: - data['token_type_ids'] = token_type_ids - if attention_mask is not None: - data['attention_mask'] = attention_mask - if is_noisy is not None: - data['is_noisy'] = is_noisy - data['state_labels'] = state_labels - data['belief_state'] = belief_state - if model.config.predict_actions: - data['request_labels'] = request_labels - data['active_domain_labels'] = active_domain_labels - data['general_act_labels'] = general_act_labels - data['request_probs'] = request_probs - data['active_domain_probs'] = active_domain_probs - data['general_act_probs'] = general_act_probs - - file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') - torch.save(data, file) - - -def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str): - """ - Convert ensemble predictions to predictions file format. - - Args: - model_path: Path to ensemble location. - set_type: Evaluation dataset (train/dev/test). - """ - data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data")) - - # Get oracle labels - if 'request_probs' in data: - data_new = {'state_labels': data['state_labels'], - 'request_labels': data['request_labels'], - 'active_domain_labels': data['active_domain_labels'], - 'general_act_labels': data['general_act_labels']} - else: - data_new = {'state_labels': data['state_labels']} - - # Marginalising across ensemble distributions - data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()} - if 'request_probs' in data: - data_new['request_probs'] = {slot: distribution.mean(-1) - for slot, distribution in data['request_probs'].items()} - data_new['active_domain_probs'] = {domain: distribution.mean(-1) - for domain, distribution in data['active_domain_probs'].items()} - data_new['general_act_probs'] = data['general_act_probs'].mean(-2) - - # Save predictions file - predictions_dir = os.path.join(model_path, 'predictions') - if not os.path.exists(predictions_dir): - os.mkdir(predictions_dir) - torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions")) - - -if __name__ == "__main__": - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--model_path', type=str) - parser.add_argument('--set_type', type=str) - parser.add_argument('--batch_size', type=int, default=3) - parser.add_argument('--reduction', type=str, default='none') - parser.add_argument('--get_ensemble_distributions', action='store_true') - parser.add_argument('--convert_distributions_to_predictions', action='store_true') - parser.add_argument('--build_dataloaders', action='store_true') - args = parser.parse_args() - - if args.get_ensemble_distributions: - get_ensemble_distributions(args) - if args.convert_distributions_to_predictions: - ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type) - if args.build_dataloaders: - path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') - data = torch.load(path) - - reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r') - ontology = json.load(reader) - reader.close() - - loader = get_loader(data, ontology, args.set_type, args.batch_size) - - path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader') - torch.save(loader, path) diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py deleted file mode 100644 index 2fe351b3d5c2af187da58ffcc46e8184013bbcdb..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/do/evaluate.py +++ /dev/null @@ -1,296 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Run SetSUMBT Calibration""" - -import logging -import os - -import torch -from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer) - -from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT -from convlab.dst.setsumbt.dataset import unified_format -from convlab.dst.setsumbt.dataset import ontology as embeddings -from convlab.dst.setsumbt.utils import get_args, update_args -from convlab.dst.setsumbt.modeling import evaluation_utils -from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc -from convlab.dst.setsumbt.modeling import training - - -# Available model -MODELS = { - 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), - 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) -} - - -def main(args=None, config=None): - # Get arguments - if args is None: - args, config = get_args(MODELS) - - if args.model_type in MODELS: - SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] - else: - raise NameError('NotImplemented') - - # Set up output directory - OUTPUT_DIR = args.output_dir - args.output_dir = OUTPUT_DIR - if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): - os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) - - # Set pretrained model path to the trained checkpoint - paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] - if paths: - paths = paths[0] - args.model_name_or_path = paths - config = ConfigClass.from_pretrained(args.model_name_or_path) - - args = update_args(args, config) - - # Create logger - global logger - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - - formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') - - fh = logging.FileHandler(args.logging_path) - fh.setLevel(logging.INFO) - fh.setFormatter(formatter) - logger.addHandler(fh) - - # Get device - if torch.cuda.is_available() and args.n_gpu > 0: - device = torch.device('cuda') - else: - device = torch.device('cpu') - args.n_gpu = 0 - - if args.n_gpu == 0: - args.fp16 = False - - # Set up model training/evaluation - evaluation_utils.set_seed(args) - - # Perform tasks - if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')): - pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) - state_labels = pred['state_labels'] - belief_states = pred['belief_states'] - if 'request_labels' in pred: - request_labels = pred['request_labels'] - request_probs = pred['request_probs'] - active_domain_labels = pred['active_domain_labels'] - active_domain_probs = pred['active_domain_probs'] - general_act_labels = pred['general_act_labels'] - general_act_probs = pred['general_act_probs'] - else: - request_probs = None - del pred - else: - # Get training batch loaders and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size != args.test_batch_size: - test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size) - else: - tokenizer = Tokenizer(config.candidate_embedding_model_name) - test_dataloader = unified_format.get_dataloader(args.dataset, 'test', - args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db')) - else: - encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) - test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, - 'test', args, tokenizer, encoder) - - # Initialise Model - model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) - model = model.to(device) - - training.set_ontology_embeddings(model, test_slots) - - belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader) - state_labels = belief_states[1] - request_probs = belief_states[2] - request_labels = belief_states[3] - active_domain_probs = belief_states[4] - active_domain_labels = belief_states[5] - general_act_probs = belief_states[6] - general_act_labels = belief_states[7] - belief_states = belief_states[0] - out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs, - 'request_labels': request_labels, 'active_domain_probs': active_domain_probs, - 'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs, - 'general_act_labels': general_act_labels} - torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) - - # Calculate calibration metrics - jg = jg_ece(belief_states, state_labels, 10) - logger.info('Joint Goal ECE: %f' % jg) - - jg_acc = 0.0 - padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0 - padding = (padding == len(state_labels)) - padding = padding.reshape(-1) - for slot in belief_states: - p_ = belief_states[slot] - gold = state_labels[slot] - - pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1) - acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad] - acc = torch.tensor(acc).float() - - jg_acc += acc - - n_turns = jg_acc.size(0) - jg_acc = sum((jg_acc / len(belief_states)).int()).float() - - jg_acc /= n_turns - - logger.info(f'Joint Goal Accuracy: {jg_acc}') - - l2 = l2_acc(belief_states, state_labels, remove_belief=False) - logger.info(f'Model L2 Norm Goal Accuracy: {l2}') - l2 = l2_acc(belief_states, state_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}') - - padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0 - padding = (padding == len(state_labels)) - padding = padding.reshape(-1) - - tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0 - for slot in belief_states: - p_ = belief_states[slot] - gold = state_labels[slot].reshape(-1) - p_ = p_.reshape(-1, p_.size(-1)) - - p_ = p_[~padding].argmax(-1) - gold = gold[~padding] - - tp += (p_ == gold)[gold != 0].int().sum().item() - fp += (p_ != 0)[gold == 0].int().sum().item() - fp += (p_ != gold)[gold != 0].int().sum().item() - fp -= (p_ == 0)[gold != 0].int().sum().item() - fn += (p_ == 0)[gold != 0].int().sum().item() - tn += (p_ == 0)[gold == 0].int().sum().item() - n += p_.size(0) - - acc = (tp + tn) / n - prec = tp / (tp + fp) - rec = tp / (tp + fn) - f1 = 2 * (prec * rec) / (prec + rec) - - logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}") - - if request_probs is not None: - tp, fp, fn = 0.0, 0.0, 0.0 - for slot in request_probs: - p = request_probs[slot] - l = request_labels[slot] - - tp += (p.round().int() * (l == 1)).reshape(-1).float() - fp += (p.round().int() * (l == 0)).reshape(-1).float() - fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() - tp /= len(request_probs) - fp /= len(request_probs) - fn /= len(request_probs) - f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) - logger.info('Request F1 Score: %f' % f1.item()) - - for slot in request_probs: - p = request_probs[slot] - p = p.unsqueeze(-1) - p = torch.cat((1 - p, p), -1) - request_probs[slot] = p - jg = jg_ece(request_probs, request_labels, 10) - logger.info('Request Joint Goal ECE: %f' % jg) - - tp, fp, fn = 0.0, 0.0, 0.0 - for dom in active_domain_probs: - p = active_domain_probs[dom] - l = active_domain_labels[dom] - - tp += (p.round().int() * (l == 1)).reshape(-1).float() - fp += (p.round().int() * (l == 0)).reshape(-1).float() - fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() - tp /= len(active_domain_probs) - fp /= len(active_domain_probs) - fn /= len(active_domain_probs) - f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) - logger.info('Domain F1 Score: %f' % f1.item()) - - for dom in active_domain_probs: - p = active_domain_probs[dom] - p = p.unsqueeze(-1) - p = torch.cat((1 - p, p), -1) - active_domain_probs[dom] = p - jg = jg_ece(active_domain_probs, active_domain_labels, 10) - logger.info('Domain Joint Goal ECE: %f' % jg) - - tp = ((general_act_probs.argmax(-1) > 0) * - (general_act_labels > 0)).reshape(-1).float().sum() - fp = ((general_act_probs.argmax(-1) > 0) * - (general_act_labels == 0)).reshape(-1).float().sum() - fn = ((general_act_probs.argmax(-1) == 0) * - (general_act_labels > 0)).reshape(-1).float().sum() - f1 = tp / (tp + 0.5 * (fp + fn)) - logger.info('General Act F1 Score: %f' % f1.item()) - - err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)), - general_act_labels.reshape(-1), 10) - logger.info('General Act ECE: %f' % err) - - for slot in request_probs: - p = request_probs[slot].unsqueeze(-1) - request_probs[slot] = torch.cat((1 - p, p), -1) - - l2 = l2_acc(request_probs, request_labels, remove_belief=False) - logger.info(f'Model L2 Norm Request Accuracy: {l2}') - l2 = l2_acc(request_probs, request_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}') - - for slot in active_domain_probs: - p = active_domain_probs[slot].unsqueeze(-1) - active_domain_probs[slot] = torch.cat((1 - p, p), -1) - - l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False) - logger.info(f'Model L2 Norm Domain Accuracy: {l2}') - l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}') - - general_act_labels = {'general': general_act_labels} - general_act_probs = {'general': general_act_probs} - - l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False) - logger.info(f'Model L2 Norm General Act Accuracy: {l2}') - l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False) - logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}') - - -if __name__ == "__main__": - main() diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py deleted file mode 100644 index 21949e728aa03d261dbb901e64fbb73bfd662d13..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/do/nbt.py +++ /dev/null @@ -1,328 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Run SetSUMBT training/eval""" - -import logging -import os -from shutil import copy2 as copy -import json -from copy import deepcopy -import pdb - -import torch -import transformers -from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer) -from tensorboardX import SummaryWriter -from tqdm import tqdm - -from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT -from convlab.dst.setsumbt.dataset import unified_format -from convlab.dst.setsumbt.modeling import training -from convlab.dst.setsumbt.dataset import ontology as embeddings -from convlab.dst.setsumbt.utils import get_args, update_args -from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble -from convlab.util.custom_util import model_downloader - - -# Available model -MODELS = { - 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), - 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) -} - - -def main(args=None, config=None): - # Get arguments - if args is None: - args, config = get_args(MODELS) - - if args.model_type in MODELS: - SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] - else: - raise NameError('NotImplemented') - - # Set up output directory - OUTPUT_DIR = args.output_dir - - if not os.path.exists(OUTPUT_DIR): - if "http" not in OUTPUT_DIR: - os.makedirs(OUTPUT_DIR) - os.mkdir(os.path.join(OUTPUT_DIR, 'database')) - os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) - else: - # Get path /.../convlab/dst/setsumbt/multiwoz/models - download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - download_path = os.path.join(download_path, 'models') - if not os.path.exists(download_path): - os.mkdir(download_path) - model_downloader(download_path, OUTPUT_DIR) - # Downloadable model path format http://.../model_name.zip - OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '') - OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR) - - args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) - args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) - os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) - args.output_dir = OUTPUT_DIR - - # Set pretrained model path to the trained checkpoint - paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path, - local_files_only=args.transformers_local_files_only) - else: - paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] - if paths: - paths = paths[0] - args.model_name_or_path = paths - config = ConfigClass.from_pretrained(args.model_name_or_path, - local_files_only=args.transformers_local_files_only) - - args = update_args(args, config) - - # Create TensorboardX writer - tb_writer = SummaryWriter(logdir=args.tensorboard_path) - - # Create logger - global logger - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - - formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') - - fh = logging.FileHandler(args.logging_path) - fh.setLevel(logging.INFO) - fh.setFormatter(formatter) - logger.addHandler(fh) - - # Get device - if torch.cuda.is_available() and args.n_gpu > 0: - device = torch.device('cuda') - else: - device = torch.device('cpu') - args.n_gpu = 0 - - if args.n_gpu == 0: - args.fp16 = False - - # Initialise Model - transformers.utils.logging.set_verbosity_info() - model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config, - local_files_only=args.transformers_local_files_only) - model = model.to(device) - - # Create Tokenizer and embedding model for Data Loaders and ontology - encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name, - local_files_only=args.transformers_local_files_only) - tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config, - local_files_only=args.transformers_local_files_only) - - # Set up model training/evaluation - training.set_logger(logger, tb_writer) - training.set_seed(args) - embeddings.set_seed(args) - - transformers.utils.logging.set_verbosity_error() - if args.ensemble_size > 1: - # Build all dataloaders - train_dataloader = unified_format.get_dataloader(args.dataset, - 'train', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - dev_dataloader = unified_format.get_dataloader(args.dataset, - 'validation', - args.dev_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - test_dataloader = unified_format.get_dataloader(args.dataset, - 'test', - args.test_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder) - embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder) - embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder) - - setup_ensemble(OUTPUT_DIR, args.ensemble_size) - - logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.') - dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size) - for _ in tqdm(range(args.ensemble_size))] - logger.info('Dataloaders built.') - - for i, loader in enumerate(dataloaders): - path = os.path.join(OUTPUT_DIR, 'ens-%i' % i) - if not os.path.exists(path): - os.mkdir(path) - path = os.path.join(path, 'dataloaders', 'train.dataloader') - torch.save(loader, path) - logger.info('Dataloaders saved.') - - # Do not perform standard training after ensemble setup is created - return 0 - - # Perform tasks - # TRAINING - if args.do_train: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')): - train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - if train_dataloader.batch_size != args.train_batch_size: - train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size) - else: - if args.data_sampling_size <= 0: - args.data_sampling_size = None - train_dataloader = unified_format.get_dataloader(args.dataset, - 'train', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - config.max_turn_len, - resampled_size=args.data_sampling_size, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - - # Get training batch loaders and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')): - train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db')) - else: - train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, - 'train', args, tokenizer, encoder) - - # Get development set batch loaders= and ontology embeddings - if args.do_eval: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): - dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - if dev_dataloader.batch_size != args.dev_batch_size: - dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size) - else: - dev_dataloader = unified_format.get_dataloader(args.dataset, - 'validation', - args.dev_batch_size, - tokenizer, - args.max_dialogue_len, - config.max_turn_len) - torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): - dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db')) - else: - dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, - 'dev', args, tokenizer, encoder) - else: - dev_dataloader = None - dev_slots = None - - # Load model ontology - training.set_ontology_embeddings(model, train_slots) - - # TRAINING !!!!!!!!!!!!!!!!!! - training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots) - - # Copy final best model to the output dir - checkpoints = os.listdir(OUTPUT_DIR) - checkpoints = [p for p in checkpoints if 'checkpoint' in p] - checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints]) - best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}') - copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin')) - copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json')) - - # Load best model for evaluation - model = SetSumbtModel.from_pretrained(OUTPUT_DIR) - model = model.to(device) - - # Evaluation on the development set - if args.do_eval: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): - dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - if dev_dataloader.batch_size != args.dev_batch_size: - dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size) - else: - dev_dataloader = unified_format.get_dataloader(args.dataset, - 'validation', - args.dev_batch_size, - tokenizer, - args.max_dialogue_len, - config.max_turn_len) - torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): - dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db')) - else: - dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, - 'dev', args, tokenizer, encoder) - - # Load model ontology - training.set_ontology_embeddings(model, dev_slots) - - # EVALUATION - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader) - training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) - - # Evaluation on the test set - if args.do_test: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size != args.test_batch_size: - test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size) - else: - test_dataloader = unified_format.get_dataloader(args.dataset, 'test', - args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db')) - else: - test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, - 'test', args, tokenizer, encoder) - - # Load model ontology - training.set_ontology_embeddings(model, test_slots) - - # TESTING - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader, - return_eval_output=True) - - if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): - os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) - writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w') - json.dump(output, writer) - writer.close() - - training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) - - tb_writer.close() - - -if __name__ == "__main__": - main() diff --git a/convlab/dst/setsumbt/get_golden_labels.py b/convlab/dst/setsumbt/get_golden_labels.py deleted file mode 100644 index 7fb2841d0d503181119c791a7046fd7e0025d236..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/get_golden_labels.py +++ /dev/null @@ -1,138 +0,0 @@ -import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -import os - -from tqdm import tqdm - -from convlab.util import load_dataset -from convlab.util import load_dst_data -from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME - - -def extract_data(dataset_names: str) -> list: - dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')] - data = [] - for dataset_dict in dataset_dicts: - dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False) - for dial in dataset['test']: - data.append(dial) - - return data - -def clean_state(state): - clean_state = dict() - for domain, subset in state.items(): - clean_state[domain] = {} - for slot, value in subset.items(): - # Remove pipe separated values - value = value.split('|') - - # Map values using value_map - for old, new in VALUE_MAP.items(): - value = [val.replace(old, new) for val in value] - value = '|'.join(value) - - # Map dontcare to "do not care" and empty to 'none' - value = value.replace('dontcare', 'do not care') - value = value if value else 'none' - - # Map quantity values to the integer quantity value - if 'people' in slot or 'duration' in slot or 'stay' in slot: - try: - if value not in ['do not care', 'none']: - value = int(value) - value = str(value) if value < 10 else QUANTITIES[-1] - except: - value = value - # Map time values to the most appropriate value in the standard time set - elif 'time' in slot or 'leave' in slot or 'arrive' in slot: - try: - if value not in ['do not care', 'none']: - # Strip after/before from time value - value = value.replace('after ', '').replace('before ', '') - # Extract hours and minutes from different possible formats - if ':' not in value and len(value) == 4: - h, m = value[:2], value[2:] - elif len(value) == 1: - h = int(value) - m = 0 - elif 'pm' in value: - h = int(value.replace('pm', '')) + 12 - m = 0 - elif 'am' in value: - h = int(value.replace('pm', '')) - m = 0 - elif ':' in value: - h, m = value.split(':') - elif ';' in value: - h, m = value.split(';') - # Map to closest 5 minutes - if int(m) % 5 != 0: - m = round(int(m) / 5) * 5 - h = int(h) - if m == 60: - m = 0 - h += 1 - if h >= 24: - h -= 24 - # Set in standard 24 hour format - h, m = int(h), int(m) - value = '%02i:%02i' % (h, m) - except: - value = value - # Map boolean slots to yes/no value - elif 'parking' in slot or 'internet' in slot: - if value not in ['do not care', 'none']: - if value == 'free': - value = 'yes' - elif True in [v in value.lower() for v in ['yes', 'no']]: - value = [v for v in ['yes', 'no'] if v in value][0] - - value = value if value != 'none' else '' - - clean_state[domain][slot] = value - - return clean_state - -def extract_states(data): - states_data = {} - for dial in data: - states = [] - for turn in dial['turns']: - if 'state' in turn: - state = clean_state(turn['state']) - states.append(state) - states_data[dial['dialogue_id']] = states - - return states_data - - -def get_golden_state(prediction, data): - state = data[prediction['dial_idx']][prediction['utt_idx']] - pred = prediction['predictions']['state'] - pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '') - for slot in state[domain]} for domain in state} - prediction['state'] = state - prediction['predictions']['state'] = pred - - return prediction - - -if __name__ == "__main__": - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21") - parser.add_argument('--model_path', type=str, help='Path to model dir') - args = parser.parse_args() - - data = extract_data(args.dataset_name) - data = extract_states(data) - - reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r') - predictions = json.load(reader) - reader.close() - - predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)] - - writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w') - json.dump(predictions, writer) - writer.close() diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py deleted file mode 100644 index 475f7646126ea03b630efcbbc688f86c5a8ec16e..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/loss/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss -from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss -from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss -from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss diff --git a/convlab/dst/setsumbt/loss/uncertainty_measures.py b/convlab/dst/setsumbt/loss/uncertainty_measures.py deleted file mode 100644 index 87c89dd31c724cc7d599230c6d4a15faee9b680e..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/loss/uncertainty_measures.py +++ /dev/null @@ -1,222 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Uncertainty evaluation metrics for dialogue belief tracking""" - -import torch - - -def fill_bins(n_bins: int, probs: torch.Tensor) -> list: - """ - Function to split observations into bins based on predictive probabilities - - Args: - n_bins (int): Number of bins - probs (Tensor): Predictive probabilities for the observations - - Returns: - bins (list): List of observation ids for each bin - """ - assert probs.dim() == 2 - probs = probs.max(-1)[0] - - step = 1.0 / n_bins - bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) - bins = [] - for b in range(n_bins): - lower, upper = bin_ranges[b], bin_ranges[b + 1] - if b == 0: - ids = torch.where((probs >= lower) * (probs <= upper))[0] - else: - ids = torch.where((probs > lower) * (probs <= upper))[0] - bins.append(ids) - return bins - - -def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor: - """ - Compute the confidence score within each bin - - Args: - bins (list): List of observation ids for each bin - probs (Tensor): Predictive probabilities for the observations - - Returns: - scores (Tensor): Average confidence score within each bin - """ - probs = probs.max(-1)[0] - - scores = [] - for b in bins: - if b is not None: - scores.append(probs[b].mean()) - else: - scores.append(-1) - scores = torch.tensor(scores) - return scores - - -def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - """ - Compute the accuracy score for observations in each bin - - Args: - bins (list): List of observation ids for each bin - probs (Tensor): Predictive probabilities for the observations - y_true (Tensor): Labels for the observations - - Returns: - acc (Tensor): Accuracies for the observations in each bin - """ - y_pred = probs.argmax(-1) - - acc = [] - for b in bins: - if b is not None: - p = y_pred[b] - acc_ = (p == y_true[b]).float() - acc_ = acc_[y_true[b] >= 0] - if acc_.size(0) >= 0: - acc.append(acc_.mean()) - else: - acc.append(-1) - else: - acc.append(-1) - acc = torch.tensor(acc) - return acc - - -def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float: - """ - Expected calibration error calculation - - Args: - probs (Tensor): Predictive probabilities for the observations - y_true (Tensor): Labels for the observations - n_bins (int): Number of bins - - Returns: - ece (float): Expected calibration error - """ - bins = fill_bins(n_bins, probs) - - scores = bin_confidence(bins, probs) - acc = bin_accuracy(bins, probs, y_true) - - n = probs.size(0) - bk = torch.tensor([b.size(0) for b in bins]) - - ece = torch.abs(scores - acc) * bk / n - ece = ece[acc >= 0.0] - ece = ece.sum().item() - - return ece - - -def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float: - """ - Joint goal expected calibration error calculation - - Args: - belief_state (dict): Belief state probabilities for the dialogue turns - y_true (dict): Labels for the state in dialogue turns - n_bins (int): Number of bins - - Returns: - ece (float): Joint goal expected calibration error - """ - y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()} - goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred} - goal_acc = sum([goal_acc[slot] for slot in goal_acc]) - goal_acc = (goal_acc == len(y_true)).int() - - # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state - scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()] - scores = torch.cat(scores, 0).min(0)[0] - - bins = fill_bins(n_bins, scores.unsqueeze(-1)) - - conf = bin_confidence(bins, scores.unsqueeze(-1)) - - slot = [s for s in y_true][0] - acc = [] - for b in bins: - if b is not None: - acc_ = goal_acc[b] - acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0] - if acc_.size(0) >= 0: - acc.append(acc_.float().mean()) - else: - acc.append(-1) - else: - acc.append(-1) - acc = torch.tensor(acc) - - n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0) - bk = torch.tensor([b.size(0) for b in bins]) - - ece = torch.abs(conf - acc) * bk / n - ece = ece[acc >= 0.0] - ece = ece.sum().item() - - return ece - - -def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float: - """ - Compute L2 Error of belief state prediction - - Args: - belief_state (dict): Belief state probabilities for the dialogue turns - labels (dict): Labels for the state in dialogue turns - remove_belief (bool): Convert belief state to dialogue state - - Returns: - err (float): L2 Error of belief state prediction - """ - # Get ids used for removing padding turns. - padding = labels[list(labels.keys())[0]].reshape(-1) - padding = torch.where(padding != -1)[0] - - state = [] - labs = [] - for slot, bs in belief_state.items(): - # Predictive Distribution - bs = bs.reshape(-1, bs.size(-1)).cuda() - # Replace distribution by a 1 hot prediction - if remove_belief: - bs_ = torch.zeros(bs.shape).float().cuda() - bs_[range(bs.size(0)), bs.argmax(-1)] = 1.0 - bs = bs_ - del bs_ - # Remove padding turns - lab = labels[slot].reshape(-1).cuda() - bs = bs[padding] - lab = lab[padding] - - # Target distribution - y = torch.zeros(bs.shape).cuda() - y[range(y.size(0)), lab] = 1.0 - - state.append(bs) - labs.append(y) - - # Concatenate all slots into a single belief state - state = torch.cat(state, -1) - labs = torch.cat(labs, -1) - - # Calculate L2-Error for each turn - err = torch.sqrt(((labs - state) ** 2).sum(-1)) - return err.mean() diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py index 59f1439948421ac365e4602b7800c94d3b8b32dd..502db2810b33262ab7edb40412a93dfcb7ba0786 100644 --- a/convlab/dst/setsumbt/modeling/__init__.py +++ b/convlab/dst/setsumbt/modeling/__init__.py @@ -1,5 +1,16 @@ -from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT -from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT -from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT +from transformers import BertConfig, RobertaConfig +from convlab.dst.setsumbt.modeling.setsumbt_nbt import BertSetSUMBT, RobertaSetSUMBT, EnsembleSetSUMBT +from convlab.dst.setsumbt.modeling.ontology_encoder import OntologyEncoder from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler +from convlab.dst.setsumbt.modeling.trainer import SetSUMBTTrainer +from convlab.dst.setsumbt.modeling.tokenization import SetSUMBTTokenizer + +class BertSetSUMBTTokenizer(SetSUMBTTokenizer('bert')): pass +class RobertaSetSUMBTTokenizer(SetSUMBTTokenizer('roberta')): pass + +SetSUMBTModels = { + 'bert': (BertSetSUMBT, OntologyEncoder('bert'), BertConfig, BertSetSUMBTTokenizer), + 'roberta': (RobertaSetSUMBT, OntologyEncoder('roberta'), RobertaConfig, RobertaSetSUMBTTokenizer), + 'ensemble': (EnsembleSetSUMBT, None, None, None) +} diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py deleted file mode 100644 index 6762fb3891b4720c3889d8c0809b8791f3bf7633..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/bert_nbt.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""BERT SetSUMBT""" - -import torch -from torch.autograd import Variable -from transformers import BertModel, BertPreTrainedModel - -from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead - - -class BertSetSUMBT(BertPreTrainedModel): - - def __init__(self, config): - super(BertSetSUMBT, self).__init__(config) - self.config = config - - # Turn Encoder - self.bert = BertModel(config) - if config.freeze_encoder: - for p in self.bert.parameters(): - p.requires_grad = False - - self.setsumbt = SetSUMBTHead(config) - self.add_slot_candidates = self.setsumbt.add_slot_candidates - self.add_value_candidates = self.setsumbt.add_value_candidates - - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - hidden_state: torch.Tensor = None, - state_labels: torch.Tensor = None, - request_labels: torch.Tensor = None, - active_domain_labels: torch.Tensor = None, - general_act_labels: torch.Tensor = None, - get_turn_pooled_representation: bool = False, - calculate_state_mutual_info: bool = False): - """ - Args: - input_ids: Input token ids - attention_mask: Input padding mask - token_type_ids: Token type indicator - hidden_state: Latent internal dialogue belief state - state_labels: Dialogue state labels - request_labels: User request action labels - active_domain_labels: Current active domain labels - general_act_labels: General user action labels - get_turn_pooled_representation: Return pooled representation of the current dialogue turn - calculate_state_mutual_info: Return mutual information in the dialogue state - - Returns: - out: Tuple containing loss, predictive distributions, model statistics and state mutual information - """ - - # Encode Dialogues - batch_size, dialogue_size, turn_size = input_ids.size() - input_ids = input_ids.reshape(-1, turn_size) - token_type_ids = token_type_ids.reshape(-1, turn_size) - attention_mask = attention_mask.reshape(-1, turn_size) - - bert_output = self.bert(input_ids, token_type_ids, attention_mask) - - attention_mask = attention_mask.float().unsqueeze(2) - attention_mask = attention_mask.repeat((1, 1, bert_output.last_hidden_state.size(-1))) - turn_embeddings = bert_output.last_hidden_state * attention_mask - turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) - - if get_turn_pooled_representation: - return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, - batch_size, dialogue_size, hidden_state, state_labels, - request_labels, active_domain_labels, general_act_labels, - calculate_state_mutual_info) + (bert_output.pooler_output,) - return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, - dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels, - general_act_labels, calculate_state_mutual_info) diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py deleted file mode 100644 index 6d3d8035a4d6f47f2ea8551050ca8da682ea0376..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py +++ /dev/null @@ -1,180 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Ensemble SetSUMBT""" - -import os -from shutil import copy2 as copy - -import torch -from torch.nn import Module -from transformers import RobertaConfig, BertConfig - -from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT -from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT - -MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT} - - -class EnsembleSetSUMBT(Module): - """Ensemble SetSUMBT Model for joint ensemble prediction""" - - def __init__(self, config): - """ - Args: - config (configuration): Model configuration class - """ - super(EnsembleSetSUMBT, self).__init__() - self.config = config - - # Initialise ensemble members - model_cls = MODELS[self.config.model_type] - for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: - setattr(self, attr, model_cls(config)) - - def _load(self, path: str): - """ - Load parameters - Args: - path: Location of model parameters - """ - for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: - idx = attr.split('_', 1)[-1] - state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin')) - getattr(self, attr).load_state_dict(state_dict) - - def add_slot_candidates(self, slot_candidates: tuple): - """ - Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings - and a request indicator, if the informable value embeddings is None the slot is not informable and if - the request indicator is false the slot is not requestable. - - Args: - slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator - """ - for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: - getattr(self, attr).add_slot_candidates(slot_candidates) - self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids - self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids - self.domain_ids = self.model_0.setsumbt.domain_ids - - def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): - """ - Add value candidates for a slot - - Args: - slot: Slot name - value_candidates: Value candidate embeddings - replace: If true existing value candidates are replaced - """ - for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: - getattr(self, attr).add_value_candidates(slot, value_candidates, replace) - - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - reduction: str = 'mean') -> tuple: - """ - Args: - input_ids: Input token ids - attention_mask: Input padding mask - token_type_ids: Token type indicator - reduction: Reduction of ensemble member predictive distributions (mean, none) - - Returns: - - """ - belief_state_probs = {slot: [] for slot in self.informable_slot_ids} - request_probs = {slot: [] for slot in self.requestable_slot_ids} - active_domain_probs = {dom: [] for dom in self.domain_ids} - general_act_probs = [] - for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: - # Prediction from each ensemble member - b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask) - for slot in belief_state_probs: - belief_state_probs[slot].append(b[slot].unsqueeze(-2)) - if self.config.predict_actions: - for slot in request_probs: - request_probs[slot].append(r[slot].unsqueeze(-1)) - for dom in active_domain_probs: - active_domain_probs[dom].append(d[dom].unsqueeze(-1)) - general_act_probs.append(g.unsqueeze(-2)) - - belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()} - if self.config.predict_actions: - request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()} - active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()} - general_act_probs = torch.cat(general_act_probs, -2) - else: - request_probs = {} - active_domain_probs = {} - general_act_probs = torch.tensor(0.0) - - # Apply reduction of ensemble to single posterior - if reduction == 'mean': - belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()} - request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()} - active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()} - general_act_probs = general_act_probs.mean(-2) - elif reduction != 'none': - raise(NameError('Not Implemented!')) - - return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _ - - - @classmethod - def from_pretrained(cls, path): - config_path = os.path.join(path, 'ens-0', 'config.json') - if not os.path.exists(config_path): - raise(NameError('Could not find config.json in model path.')) - - try: - config = RobertaConfig.from_pretrained(config_path) - except: - config = BertConfig.from_pretrained(config_path) - - config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir]) - - model = cls(config) - model._load(path) - - return model - - -def setup_ensemble(model_path: str, ensemble_size: int): - """ - Setup ensemble model directory structure. - - Args: - model_path: Path to ensemble model directory - ensemble_size: Number of ensemble members - """ - for i in range(ensemble_size): - path = os.path.join(model_path, f'ens-{i}') - if not os.path.exists(path): - os.mkdir(path) - os.mkdir(os.path.join(path, 'dataloaders')) - os.mkdir(os.path.join(path, 'database')) - # Add development set dataloader to each ensemble member directory - for set_type in ['dev']: - copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'), - os.path.join(path, 'dataloaders', f'{set_type}.dataloader')) - # Add training and development set ontologies to each ensemble member directory - for set_type in ['train', 'dev']: - copy(os.path.join(model_path, 'database', f'{set_type}.db'), - os.path.join(path, 'database', f'{set_type}.db')) diff --git a/convlab/dst/setsumbt/modeling/ensemble_utils.py b/convlab/dst/setsumbt/modeling/ensemble_utils.py deleted file mode 100644 index 19f6abf81a4070b9498310adfab93d50f5a692f5..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/ensemble_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Discriminative models calibration""" - -import random -import os - -import torch -import numpy as np -from torch.distributions import Categorical -from torch.nn.functional import kl_div -from torch.nn import Module -from tqdm import tqdm - - -# Load logger and tensorboard summary writer -def set_logger(logger_, tb_writer_): - global logger, tb_writer - logger = logger_ - tb_writer = tb_writer_ - - -# Set seeds -def set_seed(args): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - logger.info('Seed set to %d.' % args.seed) - - -def build_train_loaders(args, tokenizer, dataset): - dataloaders = [dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len, - args.max_turn_len, resampled_size=args.data_sampling_size) - for _ in range(args.ensemble_size)] - return dataloaders diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py deleted file mode 100644 index c73d4b6d32a485a2cf2b5948dbd6a9a4d7f346cb..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/evaluation_utils.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Evaluation Utilities""" - -import random - -import torch -import numpy as np -from tqdm import tqdm - - -def set_seed(args): - """ - Set random seeds - - Args: - args (Arguments class): Arguments class containing seed and number of gpus to use - """ - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - - -def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple: - """ - Get model predictions - - Args: - args: Runtime arguments - model: SetSUMBT Model - device: Torch device - dataloader: Dataloader containing eval data - """ - model.eval() - - belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids} - request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids} - active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids} - general_act_probs = [] - state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids} - request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids} - active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids} - general_act_labels = [] - epoch_iterator = tqdm(dataloader, desc="Iteration") - for step, batch in enumerate(epoch_iterator): - with torch.no_grad(): - input_ids = batch['input_ids'].to(device) - token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None - - p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask) - - for slot in belief_states: - p_ = p[slot] - labs = batch['state_labels-' + slot].to(device) - - belief_states[slot].append(p_) - state_labels[slot].append(labs) - - if p_req is not None: - for slot in request_probs: - p_ = p_req[slot] - labs = batch['request_labels-' + slot].to(device) - - request_probs[slot].append(p_) - request_labels[slot].append(labs) - - for domain in active_domain_probs: - p_ = p_dom[domain] - labs = batch['active_domain_labels-' + domain].to(device) - - active_domain_probs[domain].append(p_) - active_domain_labels[domain].append(labs) - - general_act_probs.append(p_gen) - general_act_labels.append(batch['general_act_labels'].to(device)) - - for slot in belief_states: - belief_states[slot] = torch.cat(belief_states[slot], 0) - state_labels[slot] = torch.cat(state_labels[slot], 0) - if p_req is not None: - for slot in request_probs: - request_probs[slot] = torch.cat(request_probs[slot], 0) - request_labels[slot] = torch.cat(request_labels[slot], 0) - for domain in active_domain_probs: - active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0) - active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0) - general_act_probs = torch.cat(general_act_probs, 0) - general_act_labels = torch.cat(general_act_labels, 0) - else: - request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4 - general_act_probs, general_act_labels = [None] * 2 - - out = (belief_states, state_labels, request_probs, request_labels) - out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels) - return out diff --git a/convlab/dst/setsumbt/modeling/loss/__init__.py b/convlab/dst/setsumbt/modeling/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a5f76cdccf266f8ac702ad8a58e223d97fbc2c --- /dev/null +++ b/convlab/dst/setsumbt/modeling/loss/__init__.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions for SetSUMBT""" + +from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss + +from convlab.dst.setsumbt.modeling.loss.bayesian_matching import (BayesianMatchingLoss, + BinaryBayesianMatchingLoss) +from convlab.dst.setsumbt.modeling.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss +from convlab.dst.setsumbt.modeling.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss +from convlab.dst.setsumbt.modeling.loss.endd_loss import (RKLDirichletMediatorLoss, + BinaryRKLDirichletMediatorLoss) + +LOSS_MAP = { + 'crossentropy': {'non-binary': CrossEntropyLoss, + 'binary': BCEWithLogitsLoss, + 'args': list()}, + 'bayesianmatching': {'non-binary': BayesianMatchingLoss, + 'binary': BinaryBayesianMatchingLoss, + 'args': ['kl_scaling_factor']}, + 'labelsmoothing': {'non-binary': LabelSmoothingLoss, + 'binary': BinaryLabelSmoothingLoss, + 'args': ['label_smoothing']}, + 'distillation': {'non-binary': KLDistillationLoss, + 'binary': BinaryKLDistillationLoss, + 'args': ['ensemble_smoothing']}, + 'distribution_distillation': {'non-binary': RKLDirichletMediatorLoss, + 'binary': BinaryRKLDirichletMediatorLoss, + 'args': []} +} + +def load(loss_function, binary=False): + """ + Load loss function + + Args: + loss_function (str): Loss function name + binary (bool): Whether to use binary loss function + + Returns: + torch.nn.Module: Loss function + """ + assert loss_function in LOSS_MAP + args_list = LOSS_MAP[loss_function]['args'] + loss_function = LOSS_MAP[loss_function]['binary' if binary else 'non-binary'] + + def __init__(ignore_index=-1, **kwargs): + args = {'ignore_index': ignore_index} if loss_function != BCEWithLogitsLoss else dict() + for arg, val in kwargs.items(): + if arg in args_list: + args[arg] = val + + return loss_function(**args) + + return __init__ diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/modeling/loss/bayesian_matching.py similarity index 87% rename from convlab/dst/setsumbt/loss/bayesian_matching.py rename to convlab/dst/setsumbt/modeling/loss/bayesian_matching.py index 3e91444d60afeeb6e2ca54192dd2283810fc5135..66e37e6eddcc99457535564a6f49b4c11190a49c 100644 --- a/convlab/dst/setsumbt/loss/bayesian_matching.py +++ b/convlab/dst/setsumbt/modeling/loss/bayesian_matching.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,15 +23,15 @@ from torch.nn import Module class BayesianMatchingLoss(Module): """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" - def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: + def __init__(self, kl_scaling_factor: float = 0.001, ignore_index: int = -1) -> Module: """ Args: - lamb (float): Weighting factor for the KL Divergence component + kl_scaling_factor (float): Weighting factor for the KL Divergence component ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ super(BayesianMatchingLoss, self).__init__() - self.lamb = lamb + self.lamb = kl_scaling_factor self.ignore_index = ignore_index def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: @@ -46,7 +46,7 @@ class BayesianMatchingLoss(Module): """ # Assert input sizes assert inputs.dim() == 2 # Observations, predictive distribution - assert labels.dim() == 1 # Label for each observation + assert labels.dim() == 1 # Label for each observation assert labels.size(0) == inputs.size(0) # Equal number of observation # Confirm predictive distribution dimension @@ -88,13 +88,13 @@ class BayesianMatchingLoss(Module): class BinaryBayesianMatchingLoss(BayesianMatchingLoss): """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" - def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: + def __init__(self, kl_scaling_factor: float = 0.001, ignore_index: int = -1) -> Module: """ Args: - lamb (float): Weighting factor for the KL Divergence component + kl_scaling_factor (float): Weighting factor for the KL Divergence component ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ - super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index) + super(BinaryBayesianMatchingLoss, self).__init__(kl_scaling_factor, ignore_index) def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: """ diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/modeling/loss/endd_loss.py similarity index 93% rename from convlab/dst/setsumbt/loss/endd_loss.py rename to convlab/dst/setsumbt/modeling/loss/endd_loss.py index 9bd794bf4569f54f5896e1e88ed1edeadc0fe1e2..f979cc66b668473b8b620cfae1885eef9717c049 100644 --- a/convlab/dst/setsumbt/loss/endd_loss.py +++ b/convlab/dst/setsumbt/modeling/loss/endd_loss.py @@ -1,3 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensemble Distribution Distillation Loss Function (see https://arxiv.org/pdf/2105.06987.pdf for details)""" + import torch from torch.nn import Module from torch.nn.functional import kl_div diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/modeling/loss/kl_distillation.py similarity index 88% rename from convlab/dst/setsumbt/loss/kl_distillation.py rename to convlab/dst/setsumbt/modeling/loss/kl_distillation.py index 9aee234ab68054f2b4a83d6feb5e453384d89e94..6f3971aaad98482a7136ee2507322161cba3ffd0 100644 --- a/convlab/dst/setsumbt/loss/kl_distillation.py +++ b/convlab/dst/setsumbt/modeling/loss/kl_distillation.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""KL Divergence Ensemble Distillation loss""" +"""KL Divergence Ensemble Distillation loss (See https://arxiv.org/pdf/1503.02531.pdf for details)""" import torch from torch.nn import Module @@ -23,7 +23,7 @@ from torch.nn.functional import kl_div class KLDistillationLoss(Module): """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" - def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: + def __init__(self, ensemble_smoothing: float = 1e-4, ignore_index: int = -1) -> Module: """ Args: lamb (float): Target smoothing parameter @@ -31,7 +31,7 @@ class KLDistillationLoss(Module): """ super(KLDistillationLoss, self).__init__() - self.lamb = lamb + self.lamb = ensemble_smoothing self.ignore_index = ignore_index def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: @@ -71,13 +71,13 @@ class KLDistillationLoss(Module): class BinaryKLDistillationLoss(KLDistillationLoss): """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" - def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: + def __init__(self, ensemble_smoothing: float = 1e-4, ignore_index: int = -1) -> Module: """ Args: lamb (float): Target smoothing parameter ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ - super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index) + super(BinaryKLDistillationLoss, self).__init__(ensemble_smoothing, ignore_index) def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: """ @@ -101,4 +101,4 @@ class BinaryKLDistillationLoss(KLDistillationLoss): targets = targets.unsqueeze(-1) targets = torch.cat((1 - targets, targets), -1) - return super().forward(input, targets, temp) + return super().forward(inputs, targets, temp) diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/modeling/loss/labelsmoothing.py similarity index 98% rename from convlab/dst/setsumbt/loss/labelsmoothing.py rename to convlab/dst/setsumbt/modeling/loss/labelsmoothing.py index 61d4b353303451eac7eb09592bdb2c5200328250..61a2eeeef3d47c9f9df5275e0316184b6048626e 100644 --- a/convlab/dst/setsumbt/loss/labelsmoothing.py +++ b/convlab/dst/setsumbt/modeling/loss/labelsmoothing.py @@ -17,7 +17,7 @@ import torch -from torch.nn import Softmax, Module, CrossEntropyLoss +from torch.nn import Module from torch.nn.functional import kl_div diff --git a/convlab/dst/setsumbt/modeling/ontology_encoder.py b/convlab/dst/setsumbt/modeling/ontology_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c68d131b584e20a86c8b37ee43a10e7ebce0dc --- /dev/null +++ b/convlab/dst/setsumbt/modeling/ontology_encoder.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ontology Encoder Model""" + +import random +from copy import deepcopy + +import torch +from transformers import RobertaModel, BertModel +import numpy as np +from tqdm import tqdm + +PARENT_CLASSES = {'bert': BertModel, + 'roberta': RobertaModel} + + +def OntologyEncoder(parent_name: str): + """ + Return the Ontology Encoder model based on the parent transformer model. + + Args: + parent_name (str): Name of the parent transformer model + + Returns: + OntologyEncoder (class): Ontology Encoder model + """ + parent_class = PARENT_CLASSES.get(parent_name.lower()) + + class OntologyEncoder(parent_class): + """Ontology Encoder model based on parent transformer model""" + def __init__(self, config, args, tokenizer): + """ + Initialize Ontology Encoder model. + + Args: + config (transformers.configuration_utils.PretrainedConfig): Configuration of the transformer model + args (argparse.Namespace): Arguments + tokenizer (transformers.tokenization_utils_base.PreTrainedTokenizer): Tokenizer + + Returns: + OntologyEncoder (class): Ontology Encoder model + """ + super().__init__(config) + + # Set random seeds + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + self.args = args + self.config = config + self.tokenizer = tokenizer + + def _encode_candidates(self, candidates: list) -> torch.tensor: + """ + Embed candidates + + Args: + candidates (list): List of candidate descriptions + + Returns: + feats (torch.tensor): Embeddings of the candidate descriptions + """ + # Tokenize candidate descriptions + feats = [self.tokenizer.encode_plus(val, add_special_tokens=True, max_length=self.args.max_candidate_len, + padding='max_length', truncation='longest_first') + for val in candidates] + + # Encode tokenized descriptions + with torch.no_grad(): + feats = {key: torch.tensor([f[key] for f in feats]).to(self.device) for key in feats[0]} + embedded_feats = self(**feats) # [num_candidates, max_candidate_len, hidden_dim] + + # Reduce/pool descriptions embeddings if required + if self.args.set_similarity: + feats = embedded_feats.last_hidden_state.detach().cpu() #[num_candidates, max_candidate_len, hidden_dim] + elif self.args.candidate_pooling == 'cls': + feats = embedded_feats.pooler_output.detach().cpu() # [num_candidates, hidden_dim] + elif self.args.candidate_pooling == "mean": + feats = embedded_feats.last_hidden_state.detach().cpu() + feats = feats.sum(1) + feats = torch.nn.functional.layer_norm(feats, feats.size()) + feats = feats.detach().cpu() # [num_candidates, hidden_dim] + + return feats + + def get_slot_candidate_embeddings(self): + """ + Get embeddings for slots and candidates + + Args: + set_type (str): Subset of the dataset being used (train/validation/test) + save_to_file (bool): Indication of whether to save information to file + + Returns: + slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot + """ + # Set model to eval mode + self.eval() + + slots = dict() + for domain, subset in tqdm(self.tokenizer.ontology.items(), desc='Domains'): + for slot, slot_info in tqdm(subset.items(), desc='Slots'): + # Get description or use "domain-slot" + if self.args.use_descriptions: + desc = slot_info['description'] + else: + desc = f"{domain}-{slot}" + + # Encode domain-slot pair description + slot_emb = self._encode_candidates([desc])[0] + + # Obtain possible value set and discard requestable value + values = deepcopy(slot_info['possible_values']) + is_requestable = False + if '?' in values: + is_requestable = True + values.remove('?') + + # Encode value candidates + if values: + feats = self._encode_candidates(values) + else: + feats = None + + # Store domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot + slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable) + + return slots + + return OntologyEncoder diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py deleted file mode 100644 index f72d17fafa50553434b6d4dcd20b8e53d143892f..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/roberta_nbt.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""RoBERTa SetSUMBT""" - -import torch -from transformers import RobertaModel, RobertaPreTrainedModel - -from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead - - -class RobertaSetSUMBT(RobertaPreTrainedModel): - """Roberta based SetSUMBT model""" - - def __init__(self, config): - """ - Args: - config (configuration): Model configuration class - """ - super(RobertaSetSUMBT, self).__init__(config) - self.config = config - - # Turn Encoder - self.roberta = RobertaModel(config) - if config.freeze_encoder: - for p in self.roberta.parameters(): - p.requires_grad = False - - self.setsumbt = SetSUMBTHead(config) - self.add_slot_candidates = self.setsumbt.add_slot_candidates - self.add_value_candidates = self.setsumbt.add_value_candidates - - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - hidden_state: torch.Tensor = None, - state_labels: torch.Tensor = None, - request_labels: torch.Tensor = None, - active_domain_labels: torch.Tensor = None, - general_act_labels: torch.Tensor = None, - get_turn_pooled_representation: bool = False, - calculate_state_mutual_info: bool = False): - """ - Args: - input_ids: Input token ids - attention_mask: Input padding mask - token_type_ids: Token type indicator - hidden_state: Latent internal dialogue belief state - state_labels: Dialogue state labels - request_labels: User request action labels - active_domain_labels: Current active domain labels - general_act_labels: General user action labels - get_turn_pooled_representation: Return pooled representation of the current dialogue turn - calculate_state_mutual_info: Return mutual information in the dialogue state - - Returns: - out: Tuple containing loss, predictive distributions, model statistics and state mutual information - """ - if token_type_ids is not None: - token_type_ids = None - - # Encode Dialogues - batch_size, dialogue_size, turn_size = input_ids.size() - input_ids = input_ids.reshape(-1, turn_size) - attention_mask = attention_mask.reshape(-1, turn_size) - - roberta_output = self.roberta(input_ids, attention_mask) - - # Apply mask and reshape the dialogue turn token embeddings - attention_mask = attention_mask.float().unsqueeze(2) - attention_mask = attention_mask.repeat((1, 1, roberta_output.last_hidden_state.size(-1))) - turn_embeddings = roberta_output.last_hidden_state * attention_mask - turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) - - if get_turn_pooled_representation: - return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, - batch_size, dialogue_size, hidden_state, state_labels, - request_labels, active_domain_labels, general_act_labels, - calculate_state_mutual_info) + (roberta_output.pooler_output,) - return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, - dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels, - general_act_labels, calculate_state_mutual_info) diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py index 4b67e35c3e8d2e0eaa5abcff2376d89ff67ae3cc..19f7408e95a03cec7bf76665da815eae85e216e7 100644 --- a/convlab/dst/setsumbt/modeling/setsumbt.py +++ b/convlab/dst/setsumbt/modeling/setsumbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,26 +16,22 @@ """SetSUMBT Prediction Head""" import torch -from torch.autograd import Variable from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout, - CosineSimilarity, CrossEntropyLoss, PairwiseDistance, - Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss) + CosineSimilarity, PairwiseDistance, Sequential, ReLU, Conv1d, GELU, Parameter) from torch.nn.init import (xavier_normal_, constant_) +from transformers.utils import ModelOutput -from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss, - KLDistillationLoss, BinaryKLDistillationLoss, - LabelSmoothingLoss, BinaryLabelSmoothingLoss, - RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss) +from convlab.dst.setsumbt.modeling import loss class SlotUtteranceMatching(Module): - """Slot Utterance matching attention based information extractor""" + """Slot Utterance Matching module for information extraction from utterances""" def __init__(self, hidden_size: int = 768, attention_heads: int = 12): """ Args: - hidden_size (int): Dimension of token embeddings - attention_heads (int): Number of attention heads to use in attention module + hidden_size: Hidden size of the transformer + attention_heads: Number of attention heads """ super(SlotUtteranceMatching, self).__init__() @@ -47,12 +43,12 @@ class SlotUtteranceMatching(Module): slot_embeddings: torch.Tensor) -> torch.Tensor: """ Args: - turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size] - attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size] - slot_embeddings: Embeddings for each token in the slot descriptions + turn_embeddings: Turn level embeddings for the dialogue + attention_mask: Mask for the attention related to turn embeddings + slot_embeddings: Slot level embeddings for the dialogue Returns: - hidden: Information extracted from turn related to slot descriptions + hidden: Turn level embeddings for the dialogue conditioned on the slot embeddings """ turn_embeddings = turn_embeddings.transpose(0, 1) @@ -69,7 +65,7 @@ class SlotUtteranceMatching(Module): class RecurrentNeuralBeliefTracker(Module): - """Recurrent latent neural belief tracking module""" + """Recurrent Neural Belief Tracker module for tracking the latent dialogue state""" def __init__(self, nbt_type: str = 'gru', @@ -80,15 +76,16 @@ class RecurrentNeuralBeliefTracker(Module): dropout_rate: float = 0.3): """ Args: - nbt_type: Type of recurrent neural network (gru/lstm) - rnn_zero_init: Use zero initialised state for the RNN - input_size: Embedding size of the inputs + nbt_type: Type of recurrent neural network to use (lstm/gru) + rnn_zero_init: Whether to initialise the hidden state of the RNN to zero + input_size: Input embedding size hidden_size: Hidden size of the RNN - hidden_layers: Number of RNN Layers + hidden_layers: Number of hidden layers of the RNN dropout_rate: Dropout rate """ super(RecurrentNeuralBeliefTracker, self).__init__() + # Initialise Initial Belief State Layer if rnn_zero_init: self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate)) else: @@ -126,12 +123,12 @@ class RecurrentNeuralBeliefTracker(Module): def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor: """ Args: - inputs: Latent turn level information - hidden_state: Latent internal belief state + inputs: Input embeddings + hidden_state: Hidden state of the RNN Returns: - belief_embedding: Belief state embeddings - context: Latent internal belief state + belief_embedding: Latent belief state embeddings + context: Hidden state of the RNN """ self.nbt.flatten_parameters() if hidden_state is None: @@ -155,13 +152,13 @@ class RecurrentNeuralBeliefTracker(Module): class SetPooler(Module): - """Token set pooler""" + """Set Pooler module for pooling the set of token embeddings""" def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768): """ Args: - pooling_strategy: Type of set pooler (cnn/dan/mean) - hidden_size: Token embedding size + pooling_strategy: Pooling strategy to use (mean/cnn/dan) + hidden_size: Hidden size of the set of token embeddings """ super(SetPooler, self).__init__() @@ -172,14 +169,14 @@ class SetPooler(Module): elif pooling_strategy == 'dan': self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size)) - def forward(self, inputs, attention_mask): + def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor): """ Args: - inputs: Token set embeddings - attention_mask: Padding mask for the set of tokens + inputs: Set of token embeddings + attention_mask: Attention mask for the set of token embeddings Returns: - + hidden: Pooled embeddings """ if self.pooling_strategy == "mean": hidden = inputs.sum(1) / attention_mask.sum(1) @@ -192,13 +189,25 @@ class SetPooler(Module): return hidden +class SetSUMBTOutput(ModelOutput): + """SetSUMBT Output class""" + loss = None + belief_state = None + request_probabilities = None + active_domain_probabilities = None + general_act_probabilities = None + hidden_state = None + belief_state_summary = None + belief_state_mutual_information = None + + class SetSUMBTHead(Module): """SetSUMBT Prediction Head for Language Models""" def __init__(self, config): """ Args: - config (configuration): Model configuration class + config: Model configuration """ super(SetSUMBTHead, self).__init__() self.config = config @@ -214,11 +223,24 @@ class SetSUMBTHead(Module): self.set_pooler = SetPooler(config.set_pooling, config.hidden_size) # Model ontology placeholders - self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False) - self.slot_ids = dict() - self.requestable_slot_ids = dict() - self.informable_slot_ids = dict() - self.domain_ids = dict() + if not hasattr(self.config, 'num_slots'): + self.config.num_slots = 1 + self.slot_embeddings = Parameter(torch.zeros(self.config.num_slots, self.config.max_candidate_len, + self.config.hidden_size), requires_grad=False) + if not hasattr(self.config, 'slot_ids'): + self.config.slot_ids = dict() + self.config.requestable_slot_ids = dict() + self.config.informable_slot_ids = dict() + self.config.domain_ids = dict() + if not hasattr(self.config, 'num_values'): + self.config.num_values = dict() + for slot in self.config.slot_ids: + if slot not in self.config.num_values: + self.config.num_values[slot] = 1 + setattr(self, slot + '_value_embeddings', Parameter(torch.zeros(self.config.num_values[slot], + self.config.max_candidate_len, + self.config.hidden_size), + requires_grad=False)) # Matching network similarity measure if config.distance_measure == 'cosine': @@ -229,19 +251,12 @@ class SetSUMBTHead(Module): raise NameError('NotImplemented') # User goal prediction loss function - if config.loss_function == 'crossentropy': - self.loss = CrossEntropyLoss(ignore_index=-1) - elif config.loss_function == 'bayesianmatching': - self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - elif config.loss_function == 'labelsmoothing': - self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) - elif config.loss_function == 'distillation': - self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - self.temp = 1.0 - elif config.loss_function == 'distribution_distillation': - self.loss = RKLDirichletMediatorLoss(ignore_index=-1) - else: - raise NameError('NotImplemented') + loss_args = {'ignore_index': -1, + 'kl_scaling_factor': config.to_dict().get('kl_scaling_factor', 0.0), + 'label_smoothing': config.to_dict().get('label_smoothing', 0.0), + 'ensemble_smoothing': config.to_dict().get('ensemble_smoothing', 0.0)} + self.loss = loss.load(config.loss_function)(**loss_args) + self.temp = 1.0 # Intent and domain prediction heads if config.predict_actions: @@ -253,26 +268,10 @@ class SetSUMBTHead(Module): self.request_weight = float(self.config.user_request_loss_weight) self.general_act_weight = float(self.config.user_general_act_loss_weight) self.active_domain_weight = float(self.config.active_domain_loss_weight) - if config.loss_function == 'crossentropy': - self.request_loss = BCEWithLogitsLoss() - self.general_act_loss = CrossEntropyLoss(ignore_index=-1) - self.active_domain_loss = BCEWithLogitsLoss() - elif config.loss_function == 'labelsmoothing': - self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) - self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) - self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) - elif config.loss_function == 'bayesianmatching': - self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - elif config.loss_function == 'distillation': - self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - elif config.loss_function == 'distribution_distillation': - self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1) - self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1) - self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1) + + self.request_loss = loss.load(config.loss_function, binary=True)(**loss_args) + self.general_act_loss = loss.load(config.loss_function)(**loss_args) + self.active_domain_loss = loss.load(config.loss_function, binary=True)(**loss_args) def add_slot_candidates(self, slot_candidates: tuple): """ @@ -281,7 +280,7 @@ class SetSUMBTHead(Module): the request indicator is false the slot is not requestable. Args: - slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator + slot_candidates: Tuples of slot embedding, informable value embeddings and request indicator """ if self.slot_embeddings.size(0) != 0: embeddings = self.slot_embeddings.detach() @@ -289,28 +288,33 @@ class SetSUMBTHead(Module): embeddings = torch.zeros(0) for slot in slot_candidates: - if slot in self.slot_ids: - index = self.slot_ids[slot] + if slot in self.config.slot_ids: + index = self.config.slot_ids[slot] embeddings[index, :] = slot_candidates[slot][0] else: index = embeddings.size(0) emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) embeddings = torch.cat((embeddings, emb), 0) - self.slot_ids[slot] = index - setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) + self.config.slot_ids[slot] = index + self.config.num_values[slot] = 1 + setattr(self, slot + '_value_embeddings', Parameter(torch.zeros(self.config.num_values[slot], + self.config.max_candidate_len, + self.config.hidden_size), + requires_grad=False)) # Add slot to relevant requestable and informable slot lists if slot_candidates[slot][2]: - self.requestable_slot_ids[slot] = index + self.config.requestable_slot_ids[slot] = index if slot_candidates[slot][1] is not None: - self.informable_slot_ids[slot] = index + self.config.informable_slot_ids[slot] = index domain = slot.split('-', 1)[0] - if domain not in self.domain_ids: - self.domain_ids[domain] = [] - self.domain_ids[domain].append(index) - self.domain_ids[domain] = list(set(self.domain_ids[domain])) + if domain not in self.config.domain_ids: + self.config.domain_ids[domain] = [] + self.config.domain_ids[domain].append(index) + self.config.domain_ids[domain] = list(set(self.config.domain_ids[domain])) - self.slot_embeddings = Variable(embeddings, requires_grad=False) + self.config.num_slots = embeddings.size(0) + self.slot_embeddings = Parameter(embeddings, requires_grad=False) def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): """ @@ -319,7 +323,7 @@ class SetSUMBTHead(Module): Args: slot: Slot name value_candidates: Value candidate embeddings - replace: If true existing value candidates are replaced + replace: Replace existing value candidates """ embeddings = getattr(self, slot + '_value_embeddings') @@ -328,7 +332,8 @@ class SetSUMBTHead(Module): else: embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0) - setattr(self, slot + '_value_embeddings', embeddings) + self.config.num_values[slot] = embeddings.size(0) + setattr(self, slot + '_value_embeddings', Parameter(embeddings, requires_grad=False)) def forward(self, turn_embeddings: torch.Tensor, @@ -344,20 +349,20 @@ class SetSUMBTHead(Module): calculate_state_mutual_info: bool = False): """ Args: - turn_embeddings: Token embeddings in the current turn - turn_pooled_representation: Pooled representation of the current dialogue turn - attention_mask: Padding mask for the current dialogue turn - batch_size: Number of dialogues in the batch - dialogue_size: Number of turns in each dialogue - hidden_state: Latent internal dialogue belief state - state_labels: Dialogue state labels - request_labels: User request action labels - active_domain_labels: Current active domain labels - general_act_labels: General user action labels - calculate_state_mutual_info: Return mutual information in the dialogue state + turn_embeddings: Turn embeddings for dialogue turns + turn_pooled_representation: Turn pooled representation for dialogue turns + attention_mask: Attention mask for dialogue turns + batch_size: Batch size + dialogue_size: Number of turns in dialogue + hidden_state: RNN Hidden state / Latent Belief State for dialogue turns + state_labels: State labels for dialogue turns + request_labels: Request labels for dialogue turns + active_domain_labels: Active domain labels for dialogue turns + general_act_labels: General action labels for dialogue turns + calculate_state_mutual_info: Calculate state mutual information Returns: - out: Tuple containing loss, predictive distributions, model statistics and state mutual information + output: Model output containing loss, state, request, active domain predictions, etc. """ hidden_size = turn_embeddings.size(-1) # Initialise loss @@ -432,7 +437,7 @@ class SetSUMBTHead(Module): if self.config.predict_actions: # User request prediction request_probs = dict() - for slot, slot_id in self.requestable_slot_ids.items(): + for slot, slot_id in self.config.requestable_slot_ids.items(): request_logits = self.request_gate(belief_embedding[:, :, slot_id, :]) # Store output probabilities @@ -441,7 +446,7 @@ class SetSUMBTHead(Module): request_logits[batches, dialogues] = 0.0 request_probs[slot] = torch.sigmoid(request_logits) - if request_labels is not None: + if request_labels is not None and slot in request_labels: # Compute request gate loss request_logits = request_logits.reshape(-1) if self.config.loss_function == 'distillation': @@ -457,7 +462,7 @@ class SetSUMBTHead(Module): # Active domain prediction active_domain_probs = dict() - for domain, slot_ids in self.domain_ids.items(): + for domain, slot_ids in self.config.domain_ids.items(): belief = belief_embedding[:, :, slot_ids, :] if len(slot_ids) > 1: # SqrtN reduction across all slots within a domain @@ -490,7 +495,7 @@ class SetSUMBTHead(Module): belief_state_probs = dict() belief_state_mutual_info = dict() belief_state_stats = dict() - for slot, slot_id in self.informable_slot_ids.items(): + for slot, slot_id in self.config.informable_slot_ids.items(): # Get slot belief embedding and value candidates candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device) belief = belief_embedding[:, :, slot_id, :] @@ -556,9 +561,17 @@ class SetSUMBTHead(Module): loss += self.loss(logits, state_labels[slot].reshape(-1)) # Return model outputs - out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state + output = SetSUMBTOutput(belief_state=belief_state_probs, + request_probabilities=request_probs, + active_domain_probabilities=active_domain_probs, + general_act_probabilities=general_act_probs, + hidden_state=hidden_state, + loss=None, + belief_state_summary=None, + belief_state_mutual_information=None) if state_labels is not None or request_labels is not None: - out = (loss,) + out + (belief_state_stats,) + output.loss = loss + output.belief_state_summary = belief_state_stats if calculate_state_mutual_info: - out = out + (belief_state_mutual_info,) - return out + output.belief_state_mutual_information = belief_state_mutual_info + return output diff --git a/convlab/dst/setsumbt/modeling/setsumbt_nbt.py b/convlab/dst/setsumbt/modeling/setsumbt_nbt.py new file mode 100644 index 0000000000000000000000000000000000000000..a03983e75580fdadb05225b4e69f89ee5b813759 --- /dev/null +++ b/convlab/dst/setsumbt/modeling/setsumbt_nbt.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SetSUMBT Models""" + +import os +from copy import deepcopy + +import torch +from torch.nn import Module +from transformers import (BertModel, BertPreTrainedModel, BertConfig, + RobertaModel, RobertaPreTrainedModel, RobertaConfig) + +from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead, SetSUMBTOutput + + +class BertSetSUMBT(BertPreTrainedModel): + """Bert based SetSUMBT model""" + + def __init__(self, config): + """ + Args: + config (configuration): Model configuration class + """ + super(BertSetSUMBT, self).__init__(config) + self.config = config + + # Turn Encoder + self.bert = BertModel(config) + if config.freeze_encoder: + for p in self.bert.parameters(): + p.requires_grad = False + + self.setsumbt = SetSUMBTHead(config) + self.add_slot_candidates = self.setsumbt.add_slot_candidates + self.add_value_candidates = self.setsumbt.add_value_candidates + + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + hidden_state: torch.Tensor = None, + state_labels: torch.Tensor = None, + request_labels: torch.Tensor = None, + active_domain_labels: torch.Tensor = None, + general_act_labels: torch.Tensor = None, + get_turn_pooled_representation: bool = False, + calculate_state_mutual_info: bool = False): + """ + Args: + input_ids: Input token ids + attention_mask: Input padding mask + token_type_ids: Token type indicator + hidden_state: Latent internal dialogue belief state + state_labels: Dialogue state labels + request_labels: User request action labels + active_domain_labels: Current active domain labels + general_act_labels: General user action labels + get_turn_pooled_representation: Return pooled representation of the current dialogue turn + calculate_state_mutual_info: Return mutual information in the dialogue state + + Returns: + out: Tuple containing loss, predictive distributions, model statistics and state mutual information + """ + + # Encode Dialogues + batch_size, dialogue_size, turn_size = input_ids.size() + input_ids = input_ids.reshape(-1, turn_size) + token_type_ids = token_type_ids.reshape(-1, turn_size) + attention_mask = attention_mask.reshape(-1, turn_size) + + bert_output = self.bert(input_ids, token_type_ids, attention_mask) + + attention_mask = attention_mask.float().unsqueeze(2) + attention_mask = attention_mask.repeat((1, 1, bert_output.last_hidden_state.size(-1))) + turn_embeddings = bert_output.last_hidden_state * attention_mask + turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) + + output = self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, + batch_size, dialogue_size, hidden_state, state_labels, + request_labels, active_domain_labels, general_act_labels, + calculate_state_mutual_info) + output.turn_pooled_representation = bert_output.pooler_output if get_turn_pooled_representation else None + return output + + +class RobertaSetSUMBT(RobertaPreTrainedModel): + """Roberta based SetSUMBT model""" + + def __init__(self, config): + """ + Args: + config (configuration): Model configuration class + """ + super(RobertaSetSUMBT, self).__init__(config) + self.config = config + + # Turn Encoder + self.roberta = RobertaModel(config) + if config.freeze_encoder: + for p in self.roberta.parameters(): + p.requires_grad = False + + self.setsumbt = SetSUMBTHead(config) + self.add_slot_candidates = self.setsumbt.add_slot_candidates + self.add_value_candidates = self.setsumbt.add_value_candidates + + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + hidden_state: torch.Tensor = None, + state_labels: torch.Tensor = None, + request_labels: torch.Tensor = None, + active_domain_labels: torch.Tensor = None, + general_act_labels: torch.Tensor = None, + get_turn_pooled_representation: bool = False, + calculate_state_mutual_info: bool = False): + """ + Args: + input_ids: Input token ids + attention_mask: Input padding mask + token_type_ids: Token type indicator + hidden_state: Latent internal dialogue belief state + state_labels: Dialogue state labels + request_labels: User request action labels + active_domain_labels: Current active domain labels + general_act_labels: General user action labels + get_turn_pooled_representation: Return pooled representation of the current dialogue turn + calculate_state_mutual_info: Return mutual information in the dialogue state + + Returns: + out: Tuple containing loss, predictive distributions, model statistics and state mutual information + """ + if token_type_ids is not None: + token_type_ids = None + + # Encode Dialogues + batch_size, dialogue_size, turn_size = input_ids.size() + input_ids = input_ids.reshape(-1, turn_size) + attention_mask = attention_mask.reshape(-1, turn_size) + + roberta_output = self.roberta(input_ids, attention_mask) + + # Apply mask and reshape the dialogue turn token embeddings + attention_mask = attention_mask.float().unsqueeze(2) + attention_mask = attention_mask.repeat((1, 1, roberta_output.last_hidden_state.size(-1))) + turn_embeddings = roberta_output.last_hidden_state * attention_mask + turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) + + output = self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, + batch_size, dialogue_size, hidden_state, state_labels, + request_labels, active_domain_labels, general_act_labels, + calculate_state_mutual_info) + output.turn_pooled_representation = roberta_output.pooler_output if get_turn_pooled_representation else None + return output + + +MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT} +class EnsembleSetSUMBT(Module): + """Ensemble SetSUMBT Model for joint ensemble prediction""" + + def __init__(self, config): + """ + Args: + config (configuration): Model configuration class + """ + super(EnsembleSetSUMBT, self).__init__() + self.config = config + + # Initialise ensemble members + model_cls = MODELS[self.config.model_type] + for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: + setattr(self, attr, model_cls(self.get_clean_config(config))) + + @staticmethod + def get_clean_config(config): + config = deepcopy(config) + config.slot_ids = dict() + config.requestable_slot_ids = dict() + config.informable_slot_ids = dict() + config.domain_ids = dict() + config.num_values = dict() + + return config + + def _load(self, path: str): + """ + Load parameters + Args: + path: Location of model parameters + """ + for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: + idx = attr.split('_', 1)[-1] + state_dict = torch.load(os.path.join(self._get_checkpoint_path(path, idx), 'pytorch_model.bin')) + state_dict = {key: itm for key, itm in state_dict.items() if '_value_embeddings' not in key} + getattr(self, attr).load_state_dict(state_dict) + + def add_slot_candidates(self, slot_candidates: tuple): + """ + Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings + and a request indicator, if the informable value embeddings is None the slot is not informable and if + the request indicator is false the slot is not requestable. + + Args: + slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator + """ + for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: + getattr(self, attr).add_slot_candidates(slot_candidates) + self.setsumbt = self.model_0.setsumbt + + def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): + """ + Add value candidates for a slot + + Args: + slot: Slot name + value_candidates: Value candidate embeddings + replace: If true existing value candidates are replaced + """ + for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: + getattr(self, attr).add_value_candidates(slot, value_candidates, replace) + + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + reduction: str = 'mean', + **kwargs) -> tuple: + """ + Args: + input_ids: Input token ids + attention_mask: Input padding mask + token_type_ids: Token type indicator + reduction: Reduction of ensemble member predictive distributions (mean, none) + + Returns: + + """ + belief_state_probs = {slot: [] for slot in self.setsumbt.config.informable_slot_ids} + request_probs = {slot: [] for slot in self.setsumbt.config.requestable_slot_ids} + active_domain_probs = {dom: [] for dom in self.setsumbt.config.domain_ids} + general_act_probs = [] + loss = 0.0 if 'state_labels' in kwargs else None + for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: + # Prediction from each ensemble member + with torch.no_grad(): + _out = getattr(self, attr)(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + **kwargs) + if loss is not None: + loss += _out.loss + for slot in belief_state_probs: + belief_state_probs[slot].append(_out.belief_state[slot].unsqueeze(-2).detach().cpu()) + if self.config.predict_actions: + for slot in request_probs: + request_probs[slot].append(_out.request_probabilities[slot].unsqueeze(-1).detach().cpu()) + for dom in active_domain_probs: + active_domain_probs[dom].append(_out.active_domain_probabilities[dom].unsqueeze(-1).detach().cpu()) + general_act_probs.append(_out.general_act_probabilities.unsqueeze(-2).detach().cpu()) + + belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()} + if self.config.predict_actions: + request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()} + active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()} + general_act_probs = torch.cat(general_act_probs, -2) + else: + request_probs = {} + active_domain_probs = {} + general_act_probs = torch.tensor(0.0) + + # Apply reduction of ensemble to single posterior + if reduction == 'mean': + belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()} + request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()} + active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()} + general_act_probs = general_act_probs.mean(-2) + elif reduction != 'none': + raise (NameError('Not Implemented!')) + + if loss is not None: + loss /= self.config.ensemble_size + + output = SetSUMBTOutput(loss=loss, + belief_state=belief_state_probs, + request_probabilities=request_probs, + active_domain_probabilities=active_domain_probs, + general_act_probabilities=general_act_probs) + + return output + + @staticmethod + def _get_checkpoint_path(path: str, idx: int): + """ + Get checkpoint path for ensemble member + Args: + path: Location of ensemble + idx: Ensemble member index + + Returns: + Checkpoint path + """ + + checkpoints = os.listdir(os.path.join(path, f'ens-{idx}')) + checkpoints = [int(p.split('-', 1)[-1]) for p in checkpoints if 'checkpoint-' in p] + checkpoint = f"checkpoint-{max(checkpoints)}" + return os.path.join(path, f'ens-{idx}', checkpoint) + + @classmethod + def from_pretrained(cls, path, config=None): + config_path = os.path.join(cls._get_checkpoint_path(path, 0), 'config.json') + if not os.path.exists(config_path): + raise (NameError('Could not find config.json in model path.')) + + if config is None: + try: + config = RobertaConfig.from_pretrained(config_path) + except: + config = BertConfig.from_pretrained(config_path) + + config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir]) + + model = cls(config) + model._load(path) + + return model diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py index 654e83c5d1ad9dc908213cca8967a84893395b04..549ae76a13af09b681eb0d1fb893bdff361fb014 100644 --- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py +++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,7 +16,6 @@ """Linear Temperature Scheduler Class""" -# Temp scheduler class for ensemble distillation class LinearTemperatureScheduler: """ Temperature scheduler object used for distribution temperature scheduling in distillation diff --git a/convlab/dst/setsumbt/modeling/tokenization.py b/convlab/dst/setsumbt/modeling/tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..dbee3baa5188023a2c07dca96ae02ddfe5c7d298 --- /dev/null +++ b/convlab/dst/setsumbt/modeling/tokenization.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SetSUMBT Tokenizer""" + +import json +import os + +import torch +from transformers import RobertaTokenizer, BertTokenizer +from tqdm import tqdm + +from convlab.dst.setsumbt.datasets.utils import IdTensor + +PARENT_CLASSES = {'bert': BertTokenizer, + 'roberta': RobertaTokenizer} + + +def SetSUMBTTokenizer(parent_name): + """SetSUMBT Tokenizer Class Factory""" + parent_class = PARENT_CLASSES.get(parent_name.lower()) + + class SetSUMBTTokenizer(parent_class): + """SetSUMBT Tokenizer Class""" + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="<s>", + eos_token="</s>", + sep_token="</s>", + cls_token="<s>", + unk_token="<unk>", + pad_token="<pad>", + mask_token="<mask>", + add_prefix_space=False, + **kwargs, + ): + """ + Initialize the tokenizer. + + Args: + vocab_file (str): Path to the vocabulary file. + merges_file (str): Path to the merges file. + errors (str): Error handling for the tokenizer. + bos_token (str): Beginning of sentence token. + eos_token (str): End of sentence token. + sep_token (str): Separator token. + cls_token (str): Classification token. + unk_token (str): Unknown token. + pad_token (str): Padding token. + mask_token (str): Masking token. + add_prefix_space (bool): Whether to add a space before the first token. + **kwargs: Additional arguments for the tokenizer. + """ + + # Load ontology and tokenizer vocab + with open(vocab_file, 'r', encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + vocab_handle.close() + self.ontology = self.encoder['SETSUMBT_ONTOLOGY'] if 'SETSUMBT_ONTOLOGY' in self.encoder else dict() + self.encoder = {k: v for k, v in self.encoder.items() if 'SETSUMBT_ONTOLOGY' not in k} + vocab_dir = os.path.dirname(vocab_file) + vocab_file = os.path.basename(vocab_file).split('.') + vocab_file = vocab_file[0] + "_base." + vocab_file[-1] + vocab_file = os.path.join(vocab_dir, vocab_file) + with open(vocab_file, 'w', encoding="utf-8") as vocab_handle: + json.dump(self.encoder, vocab_handle) + vocab_handle.close() + + super().__init__(vocab_file, merges_file, errors, bos_token, eos_token, sep_token, cls_token, unk_token, + pad_token, mask_token, add_prefix_space, **kwargs) + + def set_setsumbt_ontology(self, ontology): + """ + Set the ontology for the tokenizer. + + Args: + ontology (dict): The dialogue system ontology to use. + """ + self.ontology = ontology + + def save_vocabulary(self, save_directory: str, filename_prefix: str = None) -> tuple: + """ + Save the tokenizer vocabulary and merges files to a directory. + + Args: + save_directory (str): Directory to which to save. + filename_prefix (str): Optional prefix to add to the files. + + Returns: + vocab_file (str): Path to the saved vocabulary file. + merge_file (str): Path to the saved merges file. + """ + self.encoder['SETSUMBT_ONTOLOGY'] = self.ontology + vocab_file, merge_file = super().save_vocabulary(save_directory, filename_prefix) + self.encoder = {k: v for k, v in self.encoder.items() if 'SETSUMBT_ONTOLOGY' not in k} + + return vocab_file, merge_file + + def decode_state(self, belief_state, request_probs=None, active_domain_probs=None, general_act_probs=None): + """ + Decode a belief state, request, active domain and general action distributions into a dialogue state. + + Args: + belief_state (dict): The belief state distributions. + request_probs (dict): The request distributions. + active_domain_probs (dict): The active domain distributions. + general_act_probs (dict): The general action distributions. + + Returns: + dialogue_state (dict): The decoded dialogue state. + """ + dialogue_state = {domain: {slot: '' for slot, slot_info in domain_info.items() + if slot_info['possible_values'] != ["?"] and slot_info['possible_values']} + for domain, domain_info in self.ontology.items()} + + for slot, probs in belief_state.items(): + dom, slot = slot.split('-', 1) + val = self.ontology.get(dom, dict()).get(slot, dict()).get('possible_values', []) + val = val[probs.argmax().item()] if val else 'none' + if val != 'none': + if dom in dialogue_state: + if slot in dialogue_state[dom]: + dialogue_state[dom][slot] = val + + request_acts = list() + if request_probs is not None: + request_acts = [slot for slot, p in request_probs.items() if p.item() > 0.5] + request_acts = [slot.split('-', 1) for slot in request_acts] + request_acts = [[dom, slt] for dom, slt in request_acts + if '?' in self.ontology.get(dom, dict()).get(slt, dict()).get('possible_values', [])] + request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + + # Construct active domain set + active_domains = dict() + if active_domain_probs is not None: + active_domains = {dom: active_domain_probs.get(dom, torch.tensor(0.0)).item() > 0.5 + for dom in self.ontology} + + # Construct general domain action + general_acts = list() + if general_act_probs is not None: + general_acts = general_act_probs.argmax(-1).item() + general_acts = [[], ['bye'], ['thank']][general_acts] + general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + + user_acts = request_acts + general_acts + dialogue_state = {'belief_state': dialogue_state, + 'user_action': user_acts, + 'active_domains': active_domains} + + return dialogue_state + + def decode_state_batch(self, + belief_state, + request_probs=None, + active_domain_probs=None, + general_act_probs=None, + dialogue_ids=None): + """ + Decode a batch of belief state, request, active domain and general action distributions. + + Args: + belief_state (dict): The belief state distributions. + request_probs (dict): The request distributions. + active_domain_probs (dict): The active domain distributions. + general_act_probs (dict): The general action distributions. + dialogue_ids (list): The dialogue IDs. + + Returns: + data (dict): The decoded dialogue states. + """ + + data = dict() + slot_0 = [key for key in belief_state.keys()][0] + + if dialogue_ids is None: + dialogue_ids = [["{:06d}".format(i) for i in range(belief_state[slot_0].size(0))]] + + for dial_idx in range(belief_state[slot_0].size(0)): + dialogue = list() + for turn_idx in range(belief_state[slot_0].size(1)): + if belief_state[slot_0][dial_idx, turn_idx].sum() != 0.0: + belief = {slot: p[dial_idx, turn_idx] for slot, p in belief_state.items()} + req = {slot: p[dial_idx, turn_idx] + for slot, p in request_probs.items()} if request_probs is not None else None + dom = {dom: p[dial_idx, turn_idx] + for dom, p in active_domain_probs.items()} if active_domain_probs is not None else None + gen = general_act_probs[dial_idx, turn_idx] if general_act_probs is not None else None + + state = self.decode_state(belief, req, dom, gen) + dialogue.append(state) + data[dialogue_ids[0][dial_idx]] = dialogue + + return data + + def encode(self, dialogues: list, max_turns: int = 12, max_seq_len: int = 64) -> dict: + """ + Convert dialogue examples to model input features and labels + + Args: + dialogues (list): List of all extracted dialogues + max_turns (int): Maximum numbers of turns in a dialogue + max_seq_len (int): Maximum number of tokens in a dialogue turn + + Returns: + features (dict): All inputs and labels required to train the model + """ + features = dict() + + # Get encoder input for system, user utterance pairs + input_feats = [] + if len(dialogues) > 5: + iterator = tqdm(dialogues) + else: + iterator = dialogues + for dial in iterator: + dial_feats = [] + for turn in dial: + if len(turn['system_utterance']) == 0: + usr = turn['user_utterance'] + dial_feats.append(super().encode_plus(usr, add_special_tokens=True, max_length=max_seq_len, + padding='max_length', truncation='longest_first')) + else: + usr = turn['user_utterance'] + sys = turn['system_utterance'] + dial_feats.append(super().encode_plus(usr, sys, add_special_tokens=True, + max_length=max_seq_len, padding='max_length', + truncation='longest_first')) + # Truncate + if len(dial_feats) >= max_turns: + break + input_feats.append(dial_feats) + del dial_feats + + # Perform turn level padding + if 'dialogue_id' in dialogues[0][0]: + dial_ids = list() + for dial in dialogues: + _ids = [turn['dialogue_id'] for turn in dial][:max_turns] + _ids += [''] * (max_turns - len(_ids)) + dial_ids.append(_ids) + input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) + for dial in input_feats] + if 'token_type_ids' in input_feats[0][0]: + token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) + for dial in input_feats] + else: + token_type_ids = None + if 'attention_mask' in input_feats[0][0]: + attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) + for dial in input_feats] + else: + attention_mask = None + del input_feats + + # Create torch data tensors + if 'dialogue_id' in dialogues[0][0]: + features['dialogue_ids'] = IdTensor(dial_ids) + features['input_ids'] = torch.tensor(input_ids) + features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None + features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None + del input_ids, token_type_ids, attention_mask + + # Extract all informable and requestable slots from the ontology + informable_slots = [f"{domain}-{slot}" for domain in self.ontology for slot in self.ontology[domain] + if self.ontology[domain][slot]['possible_values'] + and self.ontology[domain][slot]['possible_values'] != ['?']] + requestable_slots = [f"{domain}-{slot}" for domain in self.ontology for slot in self.ontology[domain] + if '?' in self.ontology[domain][slot]['possible_values']] + + # Extract a list of domains from the ontology slots + domains = [domain for domain in self.ontology] + + # Create slot labels + if 'state' in dialogues[0][0]: + for domslot in tqdm(informable_slots): + labels = [] + for dial in dialogues: + labs = [] + for turn in dial: + value = [v for d, substate in turn['state'].items() for s, v in substate.items() + if f'{d}-{s}' == domslot] + domain, slot = domslot.split('-', 1) + if turn['dataset_name'] in self.ontology[domain][slot]['dataset_names']: + value = value[0] if value else 'none' + else: + value = -1 + if value in self.ontology[domain][slot]['possible_values'] and value != -1: + value = self.ontology[domain][slot]['possible_values'].index(value) + else: + value = -1 # If value is not in ontology then we do not penalise the model + labs.append(value) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['state_labels-' + domslot] = labels + + # Create requestable slot labels + if 'dialogue_acts' in dialogues[0][0]: + for domslot in tqdm(requestable_slots): + labels = [] + for dial in dialogues: + labs = [] + for turn in dial: + domain, slot = domslot.split('-', 1) + if turn['dataset_name'] in self.ontology[domain][slot]['dataset_names']: + acts = [act['intent'] for act in turn['dialogue_acts'] + if act['domain'] == domain and act['slot'] == slot] + if acts: + act_ = acts[0] + if act_ == 'request': + labs.append(1) + else: + labs.append(0) + else: + labs.append(0) + else: + labs.append(-1) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['request_labels-' + domslot] = labels + + # General act labels (1-goodbye, 2-thank you) + labels = [] + for dial in tqdm(dialogues): + labs = [] + for turn in dial: + acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']] + if acts: + if 'bye' in acts: + labs.append(1) + else: + labs.append(2) + else: + labs.append(0) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['general_act_labels'] = labels + + # Create active domain labels + if 'active_domains' in dialogues[0][0]: + for domain in tqdm(domains): + labels = [] + for dial in dialogues: + labs = [] + for turn in dial: + possible_domains = list() + for dom in self.ontology: + for slt in self.ontology[dom]: + if turn['dataset_name'] in self.ontology[dom][slt]['dataset_names']: + possible_domains.append(dom) + + if domain in turn['active_domains']: + labs.append(1) + elif domain in possible_domains: + labs.append(0) + else: + labs.append(-1) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['active_domain_labels-' + domain] = labels + + try: + del labels + except: + labels = None + + return features + + return SetSUMBTTokenizer diff --git a/convlab/dst/setsumbt/modeling/trainer.py b/convlab/dst/setsumbt/modeling/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..50f0ac9d2e6acbe9e8ffa6634f47d395cece9c0d --- /dev/null +++ b/convlab/dst/setsumbt/modeling/trainer.py @@ -0,0 +1,681 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SetSUMBT Trainer Class""" + +import random +import os +from copy import deepcopy +import pdb + +import torch +from torch.nn import DataParallel +import numpy as np +from transformers import get_linear_schedule_with_warmup +from torch.optim import AdamW +from tqdm import tqdm, trange +try: + from apex import amp +except ModuleNotFoundError: + print('Apex not used') + +from convlab.dst.setsumbt.utils import clear_checkpoints +from convlab.dst.setsumbt.datasets import JointGoalAccuracy, BeliefStateUncertainty, ActPredictionAccuracy, Metrics +from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler +from convlab.dst.setsumbt.utils import EnsembleAggregator + + +class SetSUMBTTrainer: + """Trainer class for SetSUMBT Model""" + + def __init__(self, + args, + model, + tokenizer, + train_dataloader, + validation_dataloader, + logger, + tb_writer, + device='cpu'): + """ + Initialise the trainer class. + + Args: + args (argparse.Namespace): Arguments passed to the script + model (torch.nn.Module): SetSUMBT to be trained/evaluated + tokenizer (transformers.PreTrainedTokenizer): Tokenizer used to encode the data + train_dataloader (torch.utils.data.DataLoader): Dataloader for training data + validation_dataloader (torch.utils.data.DataLoader): Dataloader for validation data + logger (logging.Logger): Logger to log training progress + tb_writer (tensorboardX.SummaryWriter): Tensorboard writer to log training progress + device (str): Device to use for training + """ + self.args = args + self.model = model + self.tokenizer = tokenizer + self.train_dataloader = train_dataloader + self.validation_dataloader = validation_dataloader + self.logger = logger + self.tb_writer = tb_writer + self.device = device + + # Initialise metrics + if self.validation_dataloader is not None: + self.joint_goal_accuracy = JointGoalAccuracy(self.args.dataset, validation_dataloader.dataset.set_type) + self.belief_state_uncertainty_metrics = BeliefStateUncertainty() + self.ensemble_aggregator = EnsembleAggregator() + if self.args.predict_actions: + self.request_accuracy = ActPredictionAccuracy('request', binary=True) + self.active_domain_accuracy = ActPredictionAccuracy('active_domain', binary=True) + self.general_act_accuracy = ActPredictionAccuracy('general_act', binary=False) + + self._set_seed() + + if train_dataloader is not None: + self.training_mode(load_slots=True) + self._configure_optimiser() + self._configure_schedulers() + + # Set up fp16 and multi gpu usage + if self.args.fp16: + self.model, self.optimizer = amp.initialize(self.model, self.optimizer, + opt_level=self.args.fp16_opt_level) + if self.args.n_gpu > 1: + self.model = DataParallel(self.model) + + # Initialise training parameters + self.best_model = Metrics(joint_goal_accuracy=0.0, + training_loss=np.inf) + self.global_step = 0 + self.epochs_trained = 0 + self.steps_trained_in_current_epoch = 0 + + logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}") + + def _configure_optimiser(self): + """Configure the optimiser for training.""" + assert self.train_dataloader is not None + # Group weight decay and no decay parameters in the model + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) + and 'value_embeddings' not in n], + "weight_decay": self.args.weight_decay, + "lr": self.args.learning_rate + }, + { + "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) + and 'value_embeddings' not in n], + "weight_decay": 0.0, + "lr": self.args.learning_rate + }, + ] + + # Initialise the optimizer + self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate) + + def _configure_schedulers(self): + """Configure the learning rate and temperature schedulers for training.""" + assert self.train_dataloader is not None + # Calculate the total number of training steps to be performed + if self.args.max_training_steps > 0: + self.args.num_train_epochs = (len(self.train_dataloader) // self.args.gradient_accumulation_steps) + 1 + self.args.num_train_epochs = self.args.max_training_steps // self.args.num_train_epochs + else: + self.args.max_training_steps = len(self.train_dataloader) // self.args.gradient_accumulation_steps + self.args.max_training_steps *= self.args.num_train_epochs + + if self.args.save_steps <= 0: + self.args.save_steps = len(self.train_dataloader) // self.args.gradient_accumulation_steps + + # Initialise linear lr scheduler + self.args.num_warmup_steps = int(self.args.max_training_steps * self.args.warmup_proportion) + self.lr_scheduler = get_linear_schedule_with_warmup(self.optimizer, + num_warmup_steps=self.args.num_warmup_steps, + num_training_steps=self.args.max_training_steps) + + # Initialise distillation temp scheduler + if self.model.config.loss_function in ['distillation']: + self.temp_scheduler = LinearTemperatureScheduler(total_steps=self.args.max_training_steps, + base_temp=self.args.annealing_base_temp, + cycle_len=self.args.annealing_cycle_len) + else: + self.temp_scheduler = None + + def _set_seed(self): + """Set the seed for reproducibility.""" + random.seed(self.args.seed) + np.random.seed(self.args.seed) + torch.manual_seed(self.args.seed) + if self.args.n_gpu > 0: + torch.cuda.manual_seed_all(self.args.seed) + self.logger.info('Seed set to %d.' % self.args.seed) + + @staticmethod + def _set_ontology_embeddings(model, slots, load_slots=True): + """ + Set the ontology embeddings for the model. + + Args: + model (torch.nn.Module): Model to set the ontology embeddings for. + slots (dict): Dictionary of slot names and their corresponding information. + load_slots (bool): Whether to load/reload the slot embeddings. + """ + # Get slot and value embeddings + values = {slot: slots[slot][1] for slot in slots} + + # Load model ontology + if load_slots: + slots = {slot: embs for slot, embs in slots.items()} + model.add_slot_candidates(slots) + try: + informable_slot_ids = model.setsumbt.config.informable_slot_ids + except AttributeError: + informable_slot_ids = model.config.informable_slot_ids + for slot in informable_slot_ids: + model.add_value_candidates(slot, values[slot], replace=True) + + def set_ontology_embeddings(self, slots, load_slots=True): + """ + Set the ontology embeddings for the model. + + Args: + slots (dict): Dictionary of slot names and their corresponding information. + load_slots (bool): Whether to load/reload the slot embeddings. + """ + self._set_ontology_embeddings(self.model, slots, load_slots=load_slots) + + def load_state(self): + """Load the model, optimiser and schedulers state from a previous run.""" + if os.path.isfile(os.path.join(self.args.model_name_or_path, 'optimizer.pt')): + self.logger.info("Optimizer loaded from previous run.") + self.optimizer.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, 'optimizer.pt'))) + self.lr_scheduler.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, 'lr_scheduler.pt'))) + if self.temp_scheduler is not None: + self.temp_scheduler.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, + 'temp_scheduler.pt'))) + if self.args.fp16 and os.path.isfile(os.path.join(self.args.model_name_or_path, 'amp.pt')): + self.logger.info("FP16 Apex Amp loaded from previous run.") + amp.load_state_dict(torch.load(os.path.join(self.args.model_name_or_path, 'amp.pt'))) + + # Evaluate initialised model + if self.args.do_eval: + self.eval_mode() + metrics = self.evaluate(is_train=True) + self.training_mode() + + best_model = metrics + best_model.training_loss = np.inf + + def save_state(self): + """Save the model, optimiser and schedulers state for future runs.""" + output_dir = os.path.join(self.args.output_dir, f"checkpoint-{self.global_step}") + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + self.tokenizer.save_pretrained(output_dir) + if self.args.n_gpu > 1: + self.model.module.save_pretrained(output_dir) + else: + self.model.save_pretrained(output_dir) + + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "lr_scheduler.pt")) + if self.temp_scheduler is not None: + torch.save(self.temp_scheduler.state_dict(), os.path.join(output_dir, 'temp_scheduler.pt')) + if self.args.fp16: + torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt")) + + # Remove older training checkpoints + clear_checkpoints(self.args.output_dir, self.args.keep_models) + + def training_mode(self, load_slots=False): + """ + Set the model and trainer to training mode. + + Args: + load_slots (bool): Whether to load/reload the slot embeddings. + """ + assert self.train_dataloader is not None + self.model.train() + self.tokenizer.set_setsumbt_ontology(self.train_dataloader.dataset.ontology) + self.model.zero_grad() + self.set_ontology_embeddings(self.train_dataloader.dataset.ontology_embeddings, load_slots=load_slots) + + def eval_mode(self, load_slots=False): + """ + Set the model and trainer to evaluation mode. + + Args: + load_slots (bool): Whether to load/reload the slot embeddings. + """ + self.model.eval() + self.model.zero_grad() + self.tokenizer.set_setsumbt_ontology(self.validation_dataloader.dataset.ontology) + self.set_ontology_embeddings(self.validation_dataloader.dataset.ontology_embeddings, load_slots=load_slots) + + def log_info(self, metrics, logging_stage='update'): + """ + Log information about the training/evaluation. + + Args: + metrics (Metrics): Metrics object containing the relevant information. + logging_stage (str): The stage of the training/evaluation to log. + """ + if logging_stage == "update": + info = f"{self.global_step} steps complete, " + info += f"Loss since last update: {metrics.training_loss}." + self.logger.info(info) + self.logger.info("Validation set statistics:") + elif logging_stage == 'training_complete': + self.logger.info("Training Complete.") + self.logger.info("Validation set statistics:") + elif logging_stage == 'dev': + self.logger.info("Validation set statistics:") + self.logger.info(f"\tLoss: {metrics.validation_loss}") + elif logging_stage == 'test': + self.logger.info("Test set statistics:") + self.logger.info(f"\tLoss: {metrics.validation_loss}") + self.logger.info(f"\tJoint Goal Accuracy: {metrics.joint_goal_accuracy}") + self.logger.info(f"\tGoal Slot F1 Score: {metrics.slot_f1}") + self.logger.info(f"\tGoal Slot Precision: {metrics.slot_precision}") + self.logger.info(f"\tGoal Slot Recall: {metrics.slot_recall}") + self.logger.info(f"\tJoint Goal ECE: {metrics.joint_goal_ece}") + self.logger.info(f"\tJoint Goal L2-Error: {metrics.joint_l2_error}") + self.logger.info(f"\tJoint Goal L2-Error Ratio: {metrics.joint_l2_error_ratio}") + if 'request_f1' in metrics: + self.logger.info(f"\tRequest Action F1 Score: {metrics.request_f1}") + self.logger.info(f"\tActive Domain F1 Score: {metrics.active_domain_f1}") + self.logger.info(f"\tGeneral Action F1 Score: {metrics.general_act_f1}") + + # Log to tensorboard + if logging_stage == "update": + self.tb_writer.add_scalar('JointGoalAccuracy/Dev', metrics.joint_goal_accuracy, self.global_step) + self.tb_writer.add_scalar('SlotAccuracy/Dev', metrics.slot_accuracy, self.global_step) + self.tb_writer.add_scalar('SlotF1/Dev', metrics.slot_f1, self.global_step) + self.tb_writer.add_scalar('SlotPrecision/Dev', metrics.slot_precision, self.global_step) + self.tb_writer.add_scalar('JointGoalECE/Dev', metrics.joint_goal_ece, self.global_step) + self.tb_writer.add_scalar('JointGoalL2ErrorRatio/Dev', metrics.joint_l2_error_ratio, self.global_step) + if 'request_f1' in metrics: + self.tb_writer.add_scalar('RequestF1Score/Dev', metrics.request_f1, self.global_step) + self.tb_writer.add_scalar('ActiveDomainF1Score/Dev', metrics.active_domain_f1, self.global_step) + self.tb_writer.add_scalar('GeneralActionF1Score/Dev', metrics.general_act_f1, self.global_step) + self.tb_writer.add_scalar('Loss/Dev', metrics.validation_loss, self.global_step) + + if 'belief_state_summary' in metrics: + for slot, stats_slot in metrics.belief_state_summary.items(): + for key, item in stats_slot.items(): + self.tb_writer.add_scalar(f'{key}_{slot}/Dev', item, self.global_step) + + def get_input_dict(self, batch: dict) -> dict: + """ + Get the input dictionary for the model. + + Args: + batch (dict): The batch of data to be passed to the model. + + Returns: + input_dict (dict): The input dictionary for the model. + """ + input_dict = dict() + + # Add the input ids, token type ids, and attention mask + input_dict['input_ids'] = batch['input_ids'].to(self.device) + input_dict['token_type_ids'] = batch['token_type_ids'].to(self.device) if 'token_type_ids' in batch else None + input_dict['attention_mask'] = batch['attention_mask'].to(self.device) if 'attention_mask' in batch else None + + # Add the labels + if any('belief_state' in key for key in batch): + input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(self.device) + for slot in self.model.setsumbt.config.informable_slot_ids + if ('belief_state-' + slot) in batch} + if self.args.predict_actions: + input_dict['request_labels'] = {slot: batch['request_probabilities-' + slot].to(self.device) + for slot in self.model.setsumbt.config.requestable_slot_ids + if ('request_probabilities-' + slot) in batch} + input_dict['active_domain_labels'] = {domain: batch['active_domain_probabilities-' + domain].to(self.device) + for domain in self.model.setsumbt.config.domain_ids + if ('active_domain_probabilities-' + domain) in batch} + input_dict['general_act_labels'] = batch['general_act_probabilities'].to(self.device) + else: + input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(self.device) + for slot in self.model.setsumbt.config.informable_slot_ids + if ('state_labels-' + slot) in batch} + if self.args.predict_actions: + input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(self.device) + for slot in self.model.setsumbt.config.requestable_slot_ids + if ('request_labels-' + slot) in batch} + input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(self.device) + for domain in self.model.setsumbt.config.domain_ids + if ('active_domain_labels-' + domain) in batch} + input_dict['general_act_labels'] = batch['general_act_labels'].to(self.device) + + return input_dict + + def train(self): + """Train the SetSUMBT model.""" + # Set the model to training mode + self.training_mode(load_slots=True) + self.load_state() + + # Log training set up + self.logger.info("***** Running training *****") + self.logger.info(f"\tNum Batches = {len(self.train_dataloader)}") + self.logger.info(f"\tNum Epochs = {self.args.num_train_epochs}") + self.logger.info(f"\tGradient Accumulation steps = {self.args.gradient_accumulation_steps}") + self.logger.info(f"\tTotal optimization steps = {self.args.max_training_steps}") + + # Check if continuing training from a checkpoint + if os.path.exists(self.args.model_name_or_path): + try: + # set global_step to gobal_step of last saved checkpoint from model path + checkpoint_suffix = self.args.model_name_or_path.split("-")[-1].split("/")[0] + self.global_step = int(checkpoint_suffix) + self.epochs_trained = len(self.train_dataloader) // self.args.gradient_accumulation_steps + self.steps_trained_in_current_epoch = self.global_step % self.epochs_trained + self.epochs_trained = self.global_step // self.epochs_trained + + self.logger.info("\tContinuing training from checkpoint, will skip to saved global_step") + self.logger.info(f"\tContinuing training from epoch {self.epochs_trained}") + self.logger.info(f"\tContinuing training from global step {self.global_step}") + self.logger.info(f"\tWill skip the first {self.steps_trained_in_current_epoch} steps in the first epoch") + except ValueError: + self.logger.info(f"\tStarting fine-tuning.") + + # Prepare iterator for training + tr_loss, logging_loss = 0.0, 0.0 + train_iterator = trange(self.epochs_trained, int(self.args.num_train_epochs), desc="Epoch") + + steps_since_last_update = 0 + # Perform training + for e in train_iterator: + epoch_iterator = tqdm(self.train_dataloader, desc="Iteration") + # Iterate over all batches + for step, batch in enumerate(epoch_iterator): + # Skip batches already trained on + if step < self.steps_trained_in_current_epoch: + continue + + # Extract all label dictionaries from the batch + input_dict = self.get_input_dict(batch) + + # Set up temperature scaling for the model + if self.temp_scheduler is not None: + self.model.setsumbt.temp = self.temp_scheduler.temp() + + # Forward pass to obtain loss + output = self.model(**input_dict) + + if self.args.n_gpu > 1: + output.loss = output.loss.mean() + + # Update step + if step % self.args.gradient_accumulation_steps == 0: + output.loss = output.loss / self.args.gradient_accumulation_steps + if self.temp_scheduler is not None: + self.tb_writer.add_scalar('Temp', self.temp_scheduler.temp(), self.global_step) + self.tb_writer.add_scalar('Loss/train', output.loss, self.global_step) + # Backpropogate accumulated loss + if self.args.fp16: + with amp.scale_loss(output.loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) + self.tb_writer.add_scalar('Scaled_Loss/train', scaled_loss, self.global_step) + else: + output.loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) + + # Get learning rate + self.tb_writer.add_scalar('LearningRate', self.optimizer.param_groups[0]['lr'], self.global_step) + + if output.belief_state_summary: + for slot, stats_slot in output.belief_state_summary.items(): + for key, item in stats_slot.items(): + self.tb_writer.add_scalar(f'{key}_{slot}/Train', item, self.global_step) + + # Update model parameters + self.optimizer.step() + self.lr_scheduler.step() + self.model.zero_grad() + if self.temp_scheduler is not None: + self.temp_scheduler.step() + + tr_loss += output.loss.float().item() + epoch_iterator.set_postfix(loss=output.loss.float().item()) + self.global_step += 1 + + # Save model checkpoint + if self.global_step % self.args.save_steps == 0: + logging_loss = tr_loss - logging_loss + + # Evaluate model + if self.args.do_eval: + self.eval_mode() + metrics = self.evaluate(is_train=True) + metrics.training_loss = logging_loss / self.args.save_steps + # Log model eval information + self.log_info(metrics) + self.training_mode() + else: + metrics = Metrics(training_loss=logging_loss / self.args.save_steps, + joint_goal_accuracy=0.0) + self.log_info(metrics) + + logging_loss = tr_loss + + try: + # Compute the score of the best model + self.best_model.compute_score(request=self.model.config.user_request_loss_weight, + active_domain=self.model.config.active_domain_loss_weight, + general_act=self.model.config.user_general_act_loss_weight) + + # Compute the score of the current model + metrics.compute_score(request=self.model.config.user_request_loss_weight, + active_domain=self.model.config.active_domain_loss_weight, + general_act=self.model.config.user_general_act_loss_weight) + except AttributeError: + self.best_model.compute_score() + metrics.compute_score() + + metrics.training_loss = tr_loss / self.global_step + + if metrics > self.best_model: + steps_since_last_update = 0 + self.logger.info('Model saved.') + self.best_model = deepcopy(metrics) + + self.save_state() + else: + steps_since_last_update += 1 + self.logger.info('Model not saved.') + + # Stop training after max training steps or if the model has not updated for too long + if self.args.max_training_steps > 0 and self.global_step > self.args.max_training_steps: + epoch_iterator.close() + break + if self.args.patience > 0 and steps_since_last_update >= self.args.patience: + epoch_iterator.close() + break + + self.steps_trained_in_current_epoch = 0 + self.logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / self.global_step}') + + if self.args.max_training_steps > 0 and self.global_step > self.args.max_training_steps: + train_iterator.close() + break + if self.args.patience > 0 and steps_since_last_update >= self.args.patience: + train_iterator.close() + self.logger.info(f'Model has not improved for at least {self.args.patience} steps. Training stopped!') + break + + # Evaluate final model + if self.args.do_eval: + self.eval_mode() + metrics = self.evaluate(is_train=True) + metrics.training_loss = tr_loss / self.global_step + self.log_info(metrics, logging_stage='training_complete') + else: + self.logger.info('Training complete!') + + # Store final model + try: + self.best_model.compute_score(request=self.model.config.user_request_loss_weight, + active_domain=self.model.config.active_domain_loss_weight, + general_act=self.model.config.user_general_act_loss_weight) + + metrics.compute_score(request=self.model.config.user_request_loss_weight, + active_domain=self.model.config.active_domain_loss_weight, + general_act=self.model.config.user_general_act_loss_weight) + except AttributeError: + self.best_model.compute_score() + metrics.compute_score() + + metrics.training_loss = tr_loss / self.global_step + + if metrics > self.best_model: + self.logger.info('Final model saved.') + self.save_state() + else: + self.logger.info('Final model not saved, as it is not the best performing model.') + + def evaluate(self, save_eval_path=None, is_train=False, save_pred_dist_path=None, draw_calibration_diagram=False): + """ + Evaluates the model on the validation set. + + Args: + save_eval_path (str): Path to save the evaluation results. + is_train (bool): Whether the evaluation is performed during training. + save_pred_dist_path (str): Path to save the predicted distribution. + draw_calibration_diagram (bool): Whether to draw the calibration diagram. + Returns: + Metrics: The evaluation metrics. + """ + save_eval_path = None if is_train else save_eval_path + save_pred_dist_path = None if is_train else save_pred_dist_path + draw_calibration_diagram = False if is_train else draw_calibration_diagram + if not is_train: + self.logger.info("***** Running evaluation *****") + self.logger.info(" Num Batches = %d", len(self.validation_dataloader)) + + eval_loss = 0.0 + belief_state_summary = dict() + self.joint_goal_accuracy._init_session() + self.belief_state_uncertainty_metrics._init_session() + self.eval_mode(load_slots=True) + + if not is_train: + epoch_iterator = tqdm(self.validation_dataloader, desc="Iteration") + else: + epoch_iterator = self.validation_dataloader + for batch in epoch_iterator: + with torch.no_grad(): + input_dict = self.get_input_dict(batch) + if not is_train and 'distillation' in self.model.config.loss_function: + input_dict = {key: input_dict[key] for key in ['input_ids', 'attention_mask', 'token_type_ids']} + if self.args.ensemble and save_pred_dist_path is not None: + input_dict['reduction'] = 'none' + output = self.model(**input_dict) + output.loss = output.loss if output.loss is not None else 0.0 + + eval_loss += output.loss + + if self.args.ensemble and save_pred_dist_path is not None: + self.ensemble_aggregator.add_batch(input_dict, output, batch['dialogue_ids']) + output.belief_state = {slot: probs.mean(-2) for slot, probs in output.belief_state.items()} + if self.args.predict_actions: + output.request_probabilities = {slot: probs.mean(-1) + for slot, probs in output.request_probabilities.items()} + output.active_domain_probabilities = {domain: probs.mean(-1) + for domain, probs in output.active_domain_probabilities.items()} + output.general_act_probabilities = output.general_act_probabilities.mean(-2) + + # Accumulate belief state summary across batches + if output.belief_state_summary is not None: + for slot, slot_summary in output.belief_state_summary.items(): + if slot not in belief_state_summary: + belief_state_summary[slot] = dict() + for key, item in slot_summary.items(): + if key not in belief_state_summary[slot]: + belief_state_summary[slot][key] = item + else: + if 'min' in key: + belief_state_summary[slot][key] = min(belief_state_summary[slot][key], item) + elif 'max' in key: + belief_state_summary[slot][key] = max(belief_state_summary[slot][key], item) + elif 'mean' in key: + belief_state_summary[slot][key] = (belief_state_summary[slot][key] + item) / 2 + + slot_0 = [slot for slot in input_dict['state_labels'].keys()] if 'state_labels' in input_dict else list() + slot_0 = slot_0[0] if slot_0 else None + if slot_0 is not None: + pad_dials, pad_turns = torch.where(input_dict['input_ids'][:, :, 0] == -1) + if len(input_dict['state_labels'][slot_0].size()) == 4: + for slot in input_dict['state_labels']: + input_dict['state_labels'][slot] = input_dict['state_labels'][slot].mean(-2).argmax(-1) + input_dict['state_labels'][slot][pad_dials, pad_turns] = -1 + if self.args.predict_actions: + for slot in input_dict['request_labels']: + input_dict['request_labels'][slot] = input_dict['request_labels'][slot].mean(-1).round().int() + input_dict['request_labels'][slot][pad_dials, pad_turns] = -1 + for domain in input_dict['active_domain_labels']: + input_dict['active_domain_labels'][domain] = input_dict['active_domain_labels'][domain].mean(-1).round().int() + input_dict['active_domain_labels'][domain][pad_dials, pad_turns] = -1 + input_dict['general_act_labels'] = input_dict['general_act_labels'].mean(-2).argmax(-1) + input_dict['general_act_labels'][pad_dials, pad_turns] = -1 + else: + input_dict = self.get_input_dict(batch) + + # Add batch to metrics + self.belief_state_uncertainty_metrics.add_dialogues(output.belief_state, input_dict['state_labels']) + + predicted_states = self.tokenizer.decode_state_batch(output.belief_state, + output.request_probabilities, + output.active_domain_probabilities, + output.general_act_probabilities, + batch['dialogue_ids']) + + self.joint_goal_accuracy.add_dialogues(predicted_states) + + if self.args.predict_actions: + self.request_accuracy.add_dialogues(output.request_probabilities, input_dict['request_labels']) + self.active_domain_accuracy.add_dialogues(output.active_domain_probabilities, + input_dict['active_domain_labels']) + self.general_act_accuracy.add_dialogues({'gen': output.general_act_probabilities}, + {'gen': input_dict['general_act_labels']}) + + # Compute metrics + metrics = self.joint_goal_accuracy.evaluate() + metrics += self.belief_state_uncertainty_metrics.evaluate() + if self.args.predict_actions: + metrics += self.request_accuracy.evaluate() + metrics += self.active_domain_accuracy.evaluate() + metrics += self.general_act_accuracy.evaluate() + metrics.validation_loss = eval_loss + if belief_state_summary: + metrics.belief_state_summary = belief_state_summary + + # Save model predictions + if save_eval_path is not None: + self.joint_goal_accuracy.save_dialogues(save_eval_path) + if save_pred_dist_path is not None: + self.ensemble_aggregator.save(save_pred_dist_path) + if draw_calibration_diagram: + self.belief_state_uncertainty_metrics.draw_calibration_diagram( + save_path=self.args.output_dir, + validation_split=self.joint_goal_accuracy.validation_split + ) + + return metrics diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py deleted file mode 100644 index 590b2ac7372b26262625d08691a8528ffddd82d2..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/training.py +++ /dev/null @@ -1,715 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Training and evaluation utils""" - -import random -import os -import logging -from copy import deepcopy - -import torch -from torch.nn import DataParallel -from torch.distributions import Categorical -import numpy as np -from transformers import get_linear_schedule_with_warmup -from torch.optim import AdamW -from tqdm import tqdm, trange -try: - from apex import amp -except: - print('Apex not used') - -from convlab.dst.setsumbt.utils import clear_checkpoints -from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler - - -# Load logger and tensorboard summary writer -def set_logger(logger_, tb_writer_): - global logger, tb_writer - logger = logger_ - tb_writer = tb_writer_ - - -# Set seeds -def set_seed(args): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - logger.info('Seed set to %d.' % args.seed) - - -def set_ontology_embeddings(model, slots, load_slots=True): - # Get slot and value embeddings - values = {slot: slots[slot][1] for slot in slots} - - # Load model ontology - if load_slots: - slots = {slot: embs for slot, embs in slots.items()} - model.add_slot_candidates(slots) - try: - informable_slot_ids = model.setsumbt.informable_slot_ids - except: - informable_slot_ids = model.informable_slot_ids - for slot in informable_slot_ids: - model.add_value_candidates(slot, values[slot], replace=True) - - -def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None): - """ - Log training statistics. - - Args: - global_step: Number of global training steps completed - loss: Training loss - jg_acc: Joint goal accuracy - sl_acc: Slot accuracy - req_f1: Request prediction F1 score - dom_f1: Active domain prediction F1 score - gen_f1: General action prediction F1 score - stats: Uncertainty measure statistics of model - """ - if type(global_step) == int: - info = f"{global_step} steps complete, " - info += f"Loss since last update: {loss}. Validation set stats: " - elif global_step == 'training_complete': - info = f"Training Complete, " - info += f"Validation set stats: " - elif global_step == 'dev': - info = f"Validation set stats: Loss: {loss}, " - elif global_step == 'test': - info = f"Test set stats: Loss: {loss}, " - info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, " - if req_f1 is not None: - info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, " - info += f"General Action F1 Score: {gen_f1}" - logger.info(info) - - if type(global_step) == int: - tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) - tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) - if req_f1 is not None: - tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step) - tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step) - tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step) - tb_writer.add_scalar('Loss/Dev', loss, global_step) - - if stats: - for slot, stats_slot in stats.items(): - for key, item in stats_slot.items(): - tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step) - - -def get_input_dict(batch: dict, - predict_actions: bool, - model_informable_slot_ids: list, - model_requestable_slot_ids: list = None, - model_domain_ids: list = None, - device = 'cpu') -> dict: - """ - Produce model input arguments - - Args: - batch: Batch of data from the dataloader - predict_actions: Model should predict user actions if set true - model_informable_slot_ids: List of model dialogue state slots - model_requestable_slot_ids: List of model requestable slots - model_domain_ids: List of model domains - device: Current torch device in use - - Returns: - input_dict: Dictrionary containing model inputs for the batch - """ - input_dict = dict() - - input_dict['input_ids'] = batch['input_ids'].to(device) - input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None - - if any('belief_state' in key for key in batch): - input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device) - for slot in model_informable_slot_ids - if ('belief_state-' + slot) in batch} - if predict_actions: - input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device) - for slot in model_requestable_slot_ids - if ('request_probs-' + slot) in batch} - input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device) - for domain in model_domain_ids - if ('active_domain_probs-' + domain) in batch} - input_dict['general_act_labels'] = batch['general_act_probs'].to(device) - else: - input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device) - for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch} - if predict_actions: - input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device) - for slot in model_requestable_slot_ids - if ('request_labels-' + slot) in batch} - input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device) - for domain in model_domain_ids - if ('active_domain_labels-' + domain) in batch} - input_dict['general_act_labels'] = batch['general_act_labels'].to(device) - - return input_dict - - -def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict): - """ - Train the SetSUMBT model. - - Args: - args: Runtime arguments - model: SetSUMBT Model instance to train - device: Torch device to use during training - train_dataloader: Dataloader containing the training data - dev_dataloader: Dataloader containing the validation set data - slots: Model ontology used for training - slots_dev: Model ontology used for evaluating on the validation set - """ - - # Calculate the total number of training steps to be performed - if args.max_training_steps > 0: - t_total = args.max_training_steps - args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1 - args.num_train_epochs = args.max_training_steps // args.num_train_epochs - else: - t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs - - if args.save_steps <= 0: - args.save_steps = len(train_dataloader) // args.gradient_accumulation_steps - - # Group weight decay and no decay parameters in the model - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - "lr": args.learning_rate - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - "lr": args.learning_rate - }, - ] - - # Initialise the optimizer - optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - - # Initialise linear lr scheduler - num_warmup_steps = int(t_total * args.warmup_proportion) - scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, - num_training_steps=t_total) - - # Initialise distillation temp scheduler - if model.config.loss_function in ['distillation']: - temp_scheduler = TemperatureScheduler(total_steps=t_total, base_temp=args.annealing_base_temp, - cycle_len=args.annealing_cycle_len) - else: - temp_scheduler = None - - # Set up fp16 and multi gpu usage - if args.fp16: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - if args.n_gpu > 1: - model = DataParallel(model) - - # Load optimizer checkpoint if available - best_model = {'joint goal accuracy': 0.0, - 'request f1 score': 0.0, - 'active domain f1 score': 0.0, - 'general act f1 score': 0.0, - 'train loss': np.inf} - if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')): - logger.info("Optimizer loaded from previous run.") - optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt'))) - scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt'))) - if temp_scheduler is not None: - temp_scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'temp_scheduler.pt'))) - if args.fp16 and os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')): - logger.info("FP16 Apex Amp loaded from previous run.") - amp.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'amp.pt'))) - - # Evaluate initialised model - if args.do_eval: - # Set up model for evaluation - model.eval() - set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True) - - # Set model back to training mode - model.train() - model.zero_grad() - set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False) - else: - jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0 - - best_model['joint goal accuracy'] = jg_acc - best_model['request f1 score'] = req_f1 - best_model['active domain f1 score'] = dom_f1 - best_model['general act f1 score'] = gen_f1 - - # Log training set up - logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}") - logger.info("***** Running training *****") - logger.info(f" Num Batches = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {t_total}") - - # Initialise training parameters - global_step = 0 - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - - # Check if continuing training from a checkpoint - if os.path.exists(args.model_name_or_path): - try: - # set global_step to gobal_step of last saved checkpoint from model path - checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] - global_step = int(checkpoint_suffix) - epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) - steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) - - logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(f" Continuing training from epoch {epochs_trained}") - logger.info(f" Continuing training from global step {global_step}") - logger.info(f" Will skip the first {steps_trained_in_current_epoch} steps in the first epoch") - except ValueError: - logger.info(f" Starting fine-tuning.") - - # Prepare model for training - tr_loss, logging_loss = 0.0, 0.0 - model.train() - model.zero_grad() - train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch") - - steps_since_last_update = 0 - # Perform training - for e in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration") - # Iterate over all batches - for step, batch in enumerate(epoch_iterator): - # Skip batches already trained on - if step < steps_trained_in_current_epoch: - continue - - # Extract all label dictionaries from the batch - input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids, - model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device) - - # Set up temperature scaling for the model - if temp_scheduler is not None: - model.setsumbt.temp = temp_scheduler.temp() - - # Forward pass to obtain loss - loss, _, _, _, _, _, stats = model(**input_dict) - - if args.n_gpu > 1: - loss = loss.mean() - - # Update step - if step % args.gradient_accumulation_steps == 0: - loss = loss / args.gradient_accumulation_steps - if temp_scheduler is not None: - tb_writer.add_scalar('Temp', temp_scheduler.temp(), global_step) - tb_writer.add_scalar('Loss/train', loss, global_step) - # Backpropogate accumulated loss - if args.fp16: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) - tb_writer.add_scalar('Scaled_Loss/train', scaled_loss, global_step) - else: - loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) - - # Get learning rate - lr = optimizer.param_groups[0]['lr'] - tb_writer.add_scalar('LearningRate', lr, global_step) - - if stats: - for slot, stats_slot in stats.items(): - for key, item in stats_slot.items(): - tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step) - - # Update model parameters - optimizer.step() - scheduler.step() - model.zero_grad() - - if temp_scheduler is not None: - temp_scheduler.step() - - tr_loss += loss.float().item() - epoch_iterator.set_postfix(loss=loss.float().item()) - global_step += 1 - - # Save model checkpoint - if global_step % args.save_steps == 0: - logging_loss = tr_loss - logging_loss - - # Evaluate model - if args.do_eval: - # Set up model for evaluation - model.eval() - set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader, - is_train=True) - # Log model eval information - log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats) - - # Set model back to training mode - model.train() - model.zero_grad() - set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False) - else: - log_info(global_step, logging_loss / args.save_steps) - - logging_loss = tr_loss - - # Compute the score of the best model - try: - best_score = best_model['request f1 score'] * model.config.user_request_loss_weight - best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight - best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight - except AttributeError: - best_score = 0.0 - best_score += best_model['joint goal accuracy'] - - # Compute the score of the current model - try: - current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0 - current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0 - current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0 - except AttributeError: - current_score = 0.0 - current_score += jg_acc - - # Decide whether to update the model - if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0: - update = True - elif current_score > best_score and current_score > 0.0: - update = True - elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0: - update = True - else: - update = False - - if update: - steps_since_last_update = 0 - logger.info('Model saved.') - best_model['joint goal accuracy'] = jg_acc - if req_f1: - best_model['request f1 score'] = req_f1 - best_model['active domain f1 score'] = dom_f1 - best_model['general act f1 score'] = gen_f1 - best_model['train loss'] = tr_loss / global_step - - output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}") - if not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - - if args.n_gpu > 1: - model.module.save_pretrained(output_dir) - else: - model.save_pretrained(output_dir) - - torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) - torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - if temp_scheduler is not None: - torch.save(temp_scheduler.state_dict(), os.path.join(output_dir, 'temp_scheduler.pt')) - if args.fp16: - torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt")) - - # Remove older training checkpoints - clear_checkpoints(args.output_dir, args.keep_models) - else: - steps_since_last_update += 1 - logger.info('Model not saved.') - - # Stop training after max training steps or if the model has not updated for too long - if args.max_training_steps > 0 and global_step > args.max_training_steps: - epoch_iterator.close() - break - if args.patience > 0 and steps_since_last_update >= args.patience: - epoch_iterator.close() - break - - steps_trained_in_current_epoch = 0 - logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}') - - if args.max_training_steps > 0 and global_step > args.max_training_steps: - train_iterator.close() - break - if args.patience > 0 and steps_since_last_update >= args.patience: - train_iterator.close() - logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!') - break - - # Evaluate final model - if args.do_eval: - model.eval() - set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader, - is_train=True) - - log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) - else: - logger.info('Training complete!') - - # Store final model - try: - best_score = best_model['request f1 score'] * model.config.user_request_loss_weight - best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight - best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight - except AttributeError: - best_score = 0.0 - best_score += best_model['joint goal accuracy'] - try: - current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0 - current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0 - current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0 - except AttributeError: - current_score = 0.0 - current_score += jg_acc - if best_model['joint goal accuracy'] < jg_acc and jg_acc > 0.0: - update = True - elif current_score > best_score and current_score > 0.0: - update = True - elif best_model['train loss'] > (tr_loss / global_step) and best_model['joint goal accuracy'] == 0.0: - update = True - else: - update = False - - if update: - logger.info('Final model saved.') - output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - if args.n_gpu > 1: - model.module.save_pretrained(output_dir) - else: - model.save_pretrained(output_dir) - - torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) - torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - if temp_scheduler is not None: - torch.save(temp_scheduler.state_dict(), os.path.join(output_dir, 'temp_scheduler.pt')) - if args.fp16: - torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt")) - clear_checkpoints(args.output_dir) - else: - logger.info('Final model not saved, as it is not the best performing model.') - - -def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False): - """ - Evaluate model - - Args: - args: Runtime arguments - model: SetSUMBT model instance - device: Torch device in use - dataloader: Dataloader of data to evaluate on - return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format - is_train: If true model is training and no logging is performed - - Returns: - out: Evaluated model statistics - """ - return_eval_output = False if is_train else return_eval_output - if not is_train: - logger.info("***** Running evaluation *****") - logger.info(" Num Batches = %d", len(dataloader)) - - tr_loss = 0.0 - model.eval() - if return_eval_output: - ontology = dataloader.dataset.ontology - - accuracy_jg = [] - accuracy_sl = [] - truepos_req, falsepos_req, falseneg_req = [], [], [] - truepos_dom, falsepos_dom, falseneg_dom = [], [], [] - truepos_gen, falsepos_gen, falseneg_gen = [], [], [] - turns = [] - if return_eval_output: - evaluation_output = [] - epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader - for batch in epoch_iterator: - with torch.no_grad(): - input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids, - model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device) - - loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict) - - jg_acc = 0.0 - num_inform_slots = 0.0 - req_acc = 0.0 - req_tp, req_fp, req_fn = 0.0, 0.0, 0.0 - dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0 - dom_acc = 0.0 - - if return_eval_output: - eval_output_batch = [] - for dial_id, dial in enumerate(input_dict['input_ids']): - for turn_id, turn in enumerate(dial): - if turn.sum() != 0: - eval_output_batch.append({'dial_idx': dial_id, - 'utt_idx': turn_id, - 'state': dict(), - 'predictions': {'state': dict()} - }) - - for slot in model.setsumbt.informable_slot_ids: - p_ = p[slot] - state_labels = batch['state_labels-' + slot].to(device) - - if return_eval_output: - prediction = p_.argmax(-1) - - for sample in eval_output_batch: - dom, slt = slot.split('-', 1) - lab = state_labels[sample['dial_idx']][sample['utt_idx']].item() - lab = ontology[dom][slt]['possible_values'][lab] if lab != -1 else 'NOT_IN_ONTOLOGY' - pred = prediction[sample['dial_idx']][sample['utt_idx']].item() - pred = ontology[dom][slt]['possible_values'][pred] - - if dom not in sample['state']: - sample['state'][dom] = dict() - sample['predictions']['state'][dom] = dict() - - sample['state'][dom][slt] = lab if lab != 'none' else '' - sample['predictions']['state'][dom][slt] = pred if pred != 'none' else '' - - if args.temp_scaling > 0.0: - p_ = torch.log(p_ + 1e-10) / args.temp_scaling - p_ = torch.softmax(p_, -1) - else: - p_ = torch.log(p_ + 1e-10) / 1.0 - p_ = torch.softmax(p_, -1) - - acc = (p_.argmax(-1) == state_labels).reshape(-1).float() - - jg_acc += acc - num_inform_slots += (state_labels != -1).float().reshape(-1) - - if return_eval_output: - for sample in eval_output_batch: - sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']] - evaluation_output.append(deepcopy(sample)) - eval_output_batch = [] - - if model.config.predict_actions: - for slot in model.setsumbt.requestable_slot_ids: - p_req_ = p_req[slot] - request_labels = batch['request_labels-' + slot].to(device) - - acc = (p_req_.round().int() == request_labels).reshape(-1).float() - tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float() - fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float() - fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float() - req_acc += acc - req_tp += tp - req_fp += fp - req_fn += fn - - domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch] - for domain in domains: - p_dom_ = p_dom[domain] - active_domain_labels = batch['active_domain_labels-' + domain].to(device) - - acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float() - tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float() - fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float() - fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float() - dom_acc += acc - dom_tp += tp - dom_fp += fp - dom_fn += fn - - general_act_labels = batch['general_act_labels'].to(device) - gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum() - gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum() - gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum() - else: - req_tp, req_fp, req_fn = None, None, None - dom_tp, dom_fp, dom_fn = None, None, None - gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) - - jg_acc = jg_acc[num_inform_slots > 0] - num_inform_slots = num_inform_slots[num_inform_slots > 0] - sl_acc = sum(jg_acc / num_inform_slots).float() - jg_acc = sum((jg_acc == num_inform_slots).int()).float() - if req_tp is not None and model.setsumbt.requestable_slot_ids: - req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float() - req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float() - req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float() - else: - req_tp, req_fp, req_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) - dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) - dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) - dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) - n_turns = num_inform_slots.size(0) - - accuracy_jg.append(jg_acc.item()) - accuracy_sl.append(sl_acc.item()) - truepos_req.append(req_tp.item()) - falsepos_req.append(req_fp.item()) - falseneg_req.append(req_fn.item()) - truepos_dom.append(dom_tp.item()) - falsepos_dom.append(dom_fp.item()) - falseneg_dom.append(dom_fn.item()) - truepos_gen.append(gen_tp.item()) - falsepos_gen.append(gen_fp.item()) - falseneg_gen.append(gen_fn.item()) - turns.append(n_turns) - tr_loss += loss.item() - - # Global accuracy reduction across batches - turns = sum(turns) - jg_acc = sum(accuracy_jg) / turns - sl_acc = sum(accuracy_sl) / turns - if model.config.predict_actions: - req_tp = sum(truepos_req) - req_fp = sum(falsepos_req) - req_fn = sum(falseneg_req) - req_f1 = req_tp + 0.5 * (req_fp + req_fn) - req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0 - dom_tp = sum(truepos_dom) - dom_fp = sum(falsepos_dom) - dom_fn = sum(falseneg_dom) - dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn) - dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0 - gen_tp = sum(truepos_gen) - gen_fp = sum(falsepos_gen) - gen_fn = sum(falseneg_gen) - gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn) - gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0 - else: - req_f1, dom_f1, gen_f1 = None, None, None - - if return_eval_output: - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output - if is_train: - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader) diff --git a/convlab/dst/setsumbt/predict_user_actions.py b/convlab/dst/setsumbt/predict_user_actions.py deleted file mode 100644 index 2c304a569cb5e29920332ed21c8f862dd00c1e48..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/predict_user_actions.py +++ /dev/null @@ -1,178 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Predict dataset user action using SetSUMBT Model""" - -from copy import deepcopy -import os -import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - -from convlab.util.custom_util import flatten_acts as flatten -from convlab.util import load_dataset, load_policy_data -from convlab.dst.setsumbt import SetSUMBTTracker - - -def flatten_acts(acts: dict) -> list: - """ - Flatten dictionary actions. - - Args: - acts: Dictionary acts - - Returns: - flat_acts: Flattened actions - """ - acts = flatten(acts) - flat_acts = [] - for intent, domain, slot, value in acts: - flat_acts.append([intent, - domain, - slot if slot != 'none' else '', - value.lower() if value != 'none' else '']) - - return flat_acts - - -def get_user_actions(context: list, system_acts: list) -> list: - """ - Extract user actions from the data. - - Args: - context: Previous dialogue turns. - system_acts: List of flattened system actions. - - Returns: - user_acts: List of flattened user actions. - """ - user_acts = context[-1]['dialogue_acts'] - user_acts = flatten_acts(user_acts) - if len(context) == 3: - prev_state = context[-3]['state'] - cur_state = context[-1]['state'] - for domain, substate in cur_state.items(): - for slot, value in substate.items(): - if prev_state[domain][slot] != value: - act = ['inform', domain, slot, value] - if act not in user_acts and act not in system_acts: - user_acts.append(act) - - return user_acts - - -def extract_dataset(dataset: str = 'multiwoz21') -> list: - """ - Extract acts and utterances from the dataset. - - Args: - dataset: Dataset name - - Returns: - data: Extracted data - """ - data = load_dataset(dataset_name=dataset) - raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test'] - - dialogue = list() - data = list() - for turn in raw_data: - state = dict() - state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else '' - state['utterance'] = turn['context'][-1]['utterance'] - state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {} - state['system_actions'] = flatten_acts(state['system_actions']) - state['user_actions'] = get_user_actions(turn['context'], state['system_actions']) - dialogue.append(state) - if turn['terminated']: - data.append(dialogue) - dialogue = list() - - return data - - -def unflatten_acts(acts: list) -> dict: - """ - Convert acts from flat list format to dict format. - - Args: - acts: List of flat actions. - - Returns: - unflat_acts: Dictionary of acts. - """ - binary_acts = [] - cat_acts = [] - for intent, domain, slot, value in acts: - include = True if (domain == 'general') or (slot != 'none') else False - if include and (value == '' or value == 'none' or intent == 'request'): - binary_acts.append({'intent': intent, - 'domain': domain, - 'slot': slot if slot != 'none' else ''}) - elif include: - cat_acts.append({'intent': intent, - 'domain': domain, - 'slot': slot if slot != 'none' else '', - 'value': value}) - - unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()} - - return unflat_acts - - -def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list: - """ - Predict the user actions using the SetSUMBT Tracker. - - Args: - data: List of dialogues. - tracker: SetSUMBT Tracker - - Returns: - predict_result: List of turns containing predictions and true user actions. - """ - tracker.init_session() - predict_result = [] - for dial_idx, dialogue in enumerate(data): - for turn_idx, state in enumerate(dialogue): - sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx} - - tracker.state['history'].append(['sys', state['system_utterance']]) - predicted_state = deepcopy(tracker.update(state['utterance'])) - tracker.state['history'].append(['usr', state['utterance']]) - tracker.state['system_action'] = state['system_actions'] - - sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])} - sample['dialogue_acts'] = unflatten_acts(state['user_actions']) - - predict_result.append(sample) - - tracker.init_session() - - return predict_result - - -if __name__ =="__main__": - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21") - parser.add_argument('--model_path', type=str, help='Path to model dir') - args = parser.parse_args() - - dataset = extract_dataset(args.dataset_name) - tracker = SetSUMBTTracker(args.model_path) - predict_results = predict_user_acts(dataset, tracker) - - with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer: - json.dump(predict_results, writer, indent=2) - writer.close() diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py index e45bf129f0c9f2c5c1fba01d4b5eb80e29a5a1f0..d017fd8e824884df11548d4b466fdcae8f97e925 100644 --- a/convlab/dst/setsumbt/run.py +++ b/convlab/dst/setsumbt/run.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,12 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Run""" +"""Run SetSUMBT belief tracker training and evaluation.""" -from transformers import BertConfig, RobertaConfig +import logging +import os +from shutil import copy2 as copy +from copy import deepcopy -from convlab.dst.setsumbt.utils import get_args +import torch +import transformers +from transformers import BertConfig, RobertaConfig +from tensorboardX import SummaryWriter +from tqdm import tqdm +from convlab.dst.setsumbt.modeling import SetSUMBTModels, SetSUMBTTrainer +from convlab.dst.setsumbt.datasets import (get_dataloader, change_batch_size, dataloader_sample_dialogues, + get_distillation_dataloader) +from convlab.dst.setsumbt.utils import get_args, update_args, setup_ensemble MODELS = { 'bert': (BertConfig, "BertTokenizer"), @@ -27,15 +38,277 @@ MODELS = { def main(): - # Get arguments args, config = get_args(MODELS) - if args.run_nbt: - from convlab.dst.setsumbt.do.nbt import main - main(args, config) - if args.run_evaluation: - from convlab.dst.setsumbt.do.evaluate import main - main(args, config) + if args.model_type in SetSUMBTModels: + SetSumbtModel, OntologyEncoderModel, ConfigClass, Tokenizer = SetSUMBTModels[args.model_type] + if args.ensemble: + SetSumbtModel, _, _, _ = SetSUMBTModels['ensemble'] + else: + raise NameError('NotImplemented') + + # Set up output directory + OUTPUT_DIR = args.output_dir + + if not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR) + os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) + args.output_dir = OUTPUT_DIR + + # Set pretrained model path to the trained checkpoint + paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] + if 'pytorch_model.bin' in paths and 'config.json' in paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + elif 'ens-0' in paths: + paths = [p for p in os.listdir(os.path.join(args.output_dir, 'ens-0')) if 'checkpoint-' in p] + if paths: + args.model_name_or_path = os.path.join(args.output_dir) + config = ConfigClass.from_pretrained(os.path.join(args.model_name_or_path, 'ens-0', paths[0])) + else: + paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] + if paths: + paths = paths[0] + args.model_name_or_path = paths + config = ConfigClass.from_pretrained(args.model_name_or_path) + + args = update_args(args, config) + + # Create TensorboardX writer + tb_writer = SummaryWriter(logdir=args.tensorboard_path) + + # Create logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') + + fh = logging.FileHandler(args.logging_path) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) + + # Get device + if torch.cuda.is_available() and args.n_gpu > 0: + device = torch.device('cuda') + else: + device = torch.device('cpu') + args.n_gpu = 0 + + if args.n_gpu == 0: + args.fp16 = False + + # Initialise Model + transformers.utils.logging.set_verbosity_info() + model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) + model = model.to(device) + + if args.ensemble: + args.model_name_or_path = model._get_checkpoint_path(args.model_name_or_path, 0) + + # Create Tokenizer and embedding model for Data Loaders and ontology + tokenizer = Tokenizer.from_pretrained(args.model_name_or_path) + encoder = OntologyEncoderModel.from_pretrained(config.candidate_embedding_model_name, + args=args, tokenizer=tokenizer) + + transformers.utils.logging.set_verbosity_error() + if args.do_ensemble_setup: + # Build all dataloaders + train_dataloader = get_dataloader(args.dataset, + 'train', + args.train_batch_size, + tokenizer, + encoder, + args.max_dialogue_len, + args.max_turn_len, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + dev_dataloader = get_dataloader(args.dataset, + 'validation', + args.dev_batch_size, + tokenizer, + encoder, + args.max_dialogue_len, + args.max_turn_len, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + test_dataloader = get_dataloader(args.dataset, + 'test', + args.test_batch_size, + tokenizer, + encoder, + args.max_dialogue_len, + args.max_turn_len, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + + setup_ensemble(OUTPUT_DIR, args.ensemble_size) + + logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.') + dataloaders = [dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size) + for _ in tqdm(range(args.ensemble_size))] + logger.info('Dataloaders built.') + + for i, loader in enumerate(dataloaders): + path = os.path.join(OUTPUT_DIR, 'ens-%i' % i) + if not os.path.exists(path): + os.mkdir(path) + path = os.path.join(path, 'dataloaders', 'train.dataloader') + torch.save(loader, path) + logger.info('Dataloaders saved.') + + # Do not perform standard training after ensemble setup is created + return 0 + + # Perform tasks + # TRAINING + if args.do_train: + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')): + train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + if train_dataloader.batch_size != args.train_batch_size: + train_dataloader = change_batch_size(train_dataloader, args.train_batch_size) + else: + if args.data_sampling_size <= 0: + args.data_sampling_size = None + if 'distillation' not in config.loss_function: + train_dataloader = get_dataloader(args.dataset, + 'train', + args.train_batch_size, + tokenizer, + encoder, + args.max_dialogue_len, + config.max_turn_len, + resampled_size=args.data_sampling_size, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + else: + loader_args = {"ensemble_path": args.ensemble_model_path, + "set_type": "train", + "batch_size": args.train_batch_size, + "reduction": "mean" if config.loss_function == 'distillation' else "none"} + train_dataloader = get_distillation_dataloader(**loader_args) + torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + + # Get development set batch loaders= and ontology embeddings + if args.do_eval: + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): + dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if dev_dataloader.batch_size != args.dev_batch_size: + dev_dataloader = change_batch_size(dev_dataloader, args.dev_batch_size) + else: + if 'distillation' not in config.loss_function: + dev_dataloader = get_dataloader(args.dataset, + 'validation', + args.dev_batch_size, + tokenizer, + encoder, + args.max_dialogue_len, + config.max_turn_len) + else: + loader_args = {"ensemble_path": args.ensemble_model_path, + "set_type": "dev", + "batch_size": args.dev_batch_size, + "reduction": "mean" if config.loss_function == 'distillation' else "none"} + dev_dataloader = get_distillation_dataloader(**loader_args) + torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + else: + dev_dataloader = None + + # TRAINING !!!!!!!!!!!!!!!!!! + trainer = SetSUMBTTrainer(args, model, tokenizer, train_dataloader, dev_dataloader, logger, tb_writer, + device) + trainer.train() + + # Copy final best model to the output dir + checkpoints = os.listdir(OUTPUT_DIR) + checkpoints = [p for p in checkpoints if 'checkpoint' in p] + checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints]) + best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}') + files = ['pytorch_model.bin', 'config.json', 'merges.txt', 'special_tokens_map.json', + 'tokenizer_config.json', 'vocab.json'] + for file in files: + copy(os.path.join(best_checkpoint, file), os.path.join(OUTPUT_DIR, file)) + + # Load best model for evaluation + tokenizer = Tokenizer.from_pretrained(OUTPUT_DIR) + model = SetSumbtModel.from_pretrained(OUTPUT_DIR) + model = model.to(device) + + # Evaluation on the training set + if args.do_eval_trainset: + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')): + train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + if train_dataloader.batch_size != args.train_batch_size: + train_dataloader = change_batch_size(train_dataloader, args.train_batch_size) + else: + train_dataloader = get_dataloader(args.dataset, 'train', args.train_batch_size, tokenizer, + encoder, args.max_dialogue_len, config.max_turn_len) + torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + + # EVALUATION + trainer = SetSUMBTTrainer(args, model, tokenizer, None, train_dataloader, logger, tb_writer, device) + trainer.eval_mode(load_slots=True) + + if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): + os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) + save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'train.data') if args.ensemble else None + metrics = trainer.evaluate(save_pred_dist_path=save_pred_dist_path) + trainer.log_info(metrics, logging_stage='dev') + + # Evaluation on the development set + if args.do_eval: + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): + dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if dev_dataloader.batch_size != args.dev_batch_size: + dev_dataloader = change_batch_size(dev_dataloader, args.dev_batch_size) + else: + dev_dataloader = get_dataloader(args.dataset, 'validation', args.dev_batch_size, tokenizer, + encoder, args.max_dialogue_len, config.max_turn_len) + torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + + # EVALUATION + trainer = SetSUMBTTrainer(args, model, tokenizer, None, dev_dataloader, logger, tb_writer, device) + trainer.eval_mode(load_slots=True) + + if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): + os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) + save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'dev.data') if args.ensemble else None + metrics = trainer.evaluate(save_eval_path=os.path.join(OUTPUT_DIR, 'predictions', 'dev.json'), + save_pred_dist_path=save_pred_dist_path) + trainer.log_info(metrics, logging_stage='dev') + + # Evaluation on the test set + if args.do_test: + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): + test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if test_dataloader.batch_size != args.test_batch_size: + test_dataloader = change_batch_size(test_dataloader, args.test_batch_size) + else: + test_dataloader = get_dataloader(args.dataset, 'test', args.test_batch_size, tokenizer, + encoder, args.max_dialogue_len, config.max_turn_len) + torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + + trainer = SetSUMBTTrainer(args, model, tokenizer, None, test_dataloader, logger, tb_writer, device) + trainer.eval_mode(load_slots=True) + + # TESTING + if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): + os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) + + save_pred_dist_path = os.path.join(OUTPUT_DIR, 'predictions', 'test.data') if args.ensemble else None + metrics = trainer.evaluate(save_eval_path=os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), + save_pred_dist_path=save_pred_dist_path, draw_calibration_diagram=True) + trainer.log_info(metrics, logging_stage='test') + + # Save final model for inference + if not args.ensemble: + trainer.model.save_pretrained(OUTPUT_DIR) + trainer.tokenizer.save_pretrained(OUTPUT_DIR) + + tb_writer.close() if __name__ == "__main__": diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index f56bbadc2f4d8fdca102b2bbc996acb0ae5a4a58..5126fd3439c4dd77a2f27720bdd68f7c0f7e947a 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -1,16 +1,28 @@ -import os -import json +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run SetSUMBT belief tracker training and evaluation.""" + import copy import logging import torch import transformers -from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer -from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT -from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings +from convlab.dst.setsumbt.modeling import SetSUMBTModels from convlab.dst.dst import DST -from convlab.util.custom_util import model_downloader USE_CUDA = torch.cuda.is_available() transformers.logging.set_verbosity_error() @@ -20,7 +32,7 @@ class SetSUMBTTracker(DST): """SetSUMBT Tracker object for Convlab dialogue system""" def __init__(self, - model_path: str = "", + model_name_or_path: str = "", model_type: str = "roberta", return_turn_pooled_representation: bool = False, return_confidence_scores: bool = False, @@ -30,7 +42,7 @@ class SetSUMBTTracker(DST): store_full_belief_state: bool = True): """ Args: - model_path: Model path or download URL + model_name_or_path: Path to pretrained model or name of pretrained model model_type: Transformer type (roberta/bert) return_turn_pooled_representation: If true a turn level pooled representation is returned return_confidence_scores: If true act confidence scores are included in the state @@ -42,7 +54,7 @@ class SetSUMBTTracker(DST): super(SetSUMBTTracker, self).__init__() self.model_type = model_type - self.model_path = model_path + self.model_name_or_path = model_name_or_path self.return_turn_pooled_representation = return_turn_pooled_representation self.return_confidence_scores = return_confidence_scores self.confidence_threshold = confidence_threshold @@ -53,41 +65,24 @@ class SetSUMBTTracker(DST): self.full_belief_state = {} self.info_dict = {} - # Download model if needed - if not os.path.exists(self.model_path): - # Get path /.../convlab/dst/setsumbt/multiwoz/models - download_path = os.path.dirname(os.path.abspath(__file__)) - download_path = os.path.join(download_path, 'models') - if not os.path.exists(download_path): - os.mkdir(download_path) - model_downloader(download_path, self.model_path) - # Downloadable model path format http://.../model_name.zip - self.model_path = self.model_path.split('/')[-1].replace('.zip', '') - self.model_path = os.path.join(download_path, self.model_path) + if self.model_type in SetSUMBTModels: + self.model, _, self.config, self.tokenizer = SetSUMBTModels[self.model_type] + else: + raise NameError('NotImplemented') # Select model type based on the encoder - if model_type == "roberta": - self.config = RobertaConfig.from_pretrained(self.model_path) - self.tokenizer = RobertaTokenizer - self.model = RobertaSetSUMBT - elif model_type == "bert": - self.config = BertConfig.from_pretrained(self.model_path) - self.tokenizer = BertTokenizer - self.model = BertSetSUMBT - else: - logging.debug("Name Error: Not Implemented") + self.config = self.config.from_pretrained(self.model_name_or_path) self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu') - self.load_weights() def load_weights(self): """Load model weights and model ontology""" logging.info('Loading SetSUMBT pretrained model.') - self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name) - logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.') - self.model = self.model.from_pretrained(self.model_path, config=self.config) - logging.info(f'Model loaded from {self.model_path}.') + self.tokenizer = self.tokenizer.from_pretrained(self.model_name_or_path) + logging.info(f'Model tokenizer loaded from {self.model_name_or_path}.') + self.model = self.model.from_pretrained(self.model_name_or_path, config=self.config) + logging.info(f'Model loaded from {self.model_name_or_path}.') # Transfer model to compute device and setup eval environment self.model = self.model.to(self.device) @@ -95,12 +90,7 @@ class SetSUMBTTracker(DST): logging.info(f'Model transferred to device: {self.device}') logging.info('Loading model ontology') - f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r') - self.ontology = json.load(f) - f.close() - - db = torch.load(os.path.join(self.model_path, 'database', 'test.db')) - set_ontology_embeddings(self.model, db) + self.ontology = self.tokenizer.ontology if self.return_confidence_scores: logging.info('Model returns user action and belief state confidence scores.') @@ -138,6 +128,7 @@ class SetSUMBTTracker(DST): return self.confidence_thresholds def init_session(self): + """Initialize dialogue state""" self.state = dict() self.state['belief_state'] = dict() self.state['booked'] = dict() @@ -157,21 +148,21 @@ class SetSUMBTTracker(DST): def update(self, user_act: str = '') -> dict: """ - Update user actions and dialogue and belief states. + Update dialogue state based on user utterance. Args: - user_act: + user_act: User utterance Returns: - + state: Dialogue state """ prev_state = self.state - _output = self.predict(self.get_features(user_act)) + outputs = self.predict(self.get_features(user_act)) # Format state entropy - if _output[5] is not None: + if outputs.state_entropy is not None: state_entropy = dict() - for slot, e in _output[5].items(): + for slot, e in outputs.state_entropy.items(): domain, slot = slot.split('-', 1) if domain not in state_entropy: state_entropy[domain] = dict() @@ -180,9 +171,9 @@ class SetSUMBTTracker(DST): state_entropy = None # Format state mutual information - if _output[6] is not None: + if outputs.belief_state_mutual_information is not None: state_mutual_info = dict() - for slot, mi in _output[6].items(): + for slot, mi in outputs.belief_state_mutual_information.items(): domain, slot = slot.split('-', 1) if domain not in state_mutual_info: state_mutual_info[domain] = dict() @@ -192,9 +183,9 @@ class SetSUMBTTracker(DST): # Format all confidence scores belief_state_confidence = None - if _output[4] is not None: + if outputs.confidence_scores is not None: belief_state_confidence = dict() - belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4] + belief_state_conf, request_probs, active_domain_probs, general_act_probs = outputs.confidence_scores for slot, p in belief_state_conf.items(): domain, slot = slot.split('-', 1) if domain not in belief_state_confidence: @@ -221,16 +212,16 @@ class SetSUMBTTracker(DST): belief_state_confidence['general']['none'] = general_act_probs # Get new domain activation actions - new_domains = [d for d, active in _output[1].items() if active] + new_domains = [d for d, active in outputs.state['active_domains'].items() if active] new_domains = [d for d in new_domains if not self.active_domains.get(d, False)] - self.active_domains = _output[1] + self.active_domains = outputs.state['active_domains'] - user_acts = _output[2] + user_acts = outputs.state['user_action'] for domain in new_domains: user_acts.append(['inform', domain, 'none', 'none']) new_belief_state = copy.deepcopy(prev_state['belief_state']) - for domain, substate in _output[0].items(): + for domain, substate in outputs.state['belief_state'].items(): for slot, value in substate.items(): value = '' if value == 'none' else value value = 'dontcare' if value == 'do not care' else value @@ -268,17 +259,17 @@ class SetSUMBTTracker(DST): user_acts = [act for act in user_acts if act not in new_state['system_action']] new_state['user_action'] = user_acts - if _output[3] is not None: - new_state['turn_pooled_representation'] = _output[3] + if outputs.turn_pooled_representation is not None: + new_state['turn_pooled_representation'] = outputs.turn_pooled_representation.reshape(-1) self.state = new_state - self.info_dict = copy.deepcopy(dict(new_state)) + self.info_dict['belief_state'] = copy.deepcopy(dict(new_state)) return self.state def predict(self, features: dict) -> tuple: """ - Model forward pass and prediction post processing. + Model forward pass and prediction post-processing. Args: features: Dictionary of model input features @@ -288,96 +279,51 @@ class SetSUMBTTracker(DST): """ state_mutual_info = None with torch.no_grad(): - turn_pooled_representation = None - if self.return_turn_pooled_representation: - _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], hidden_state=self.hidden_states, - get_turn_pooled_representation=True) - belief_state = _outputs[0] - request_probs = _outputs[1] - active_domain_probs = _outputs[2] - general_act_probs = _outputs[3] - self.hidden_states = _outputs[4] - turn_pooled_representation = _outputs[5] - elif self.return_belief_state_mutual_info: - _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], hidden_state=self.hidden_states, - get_turn_pooled_representation=True, calculate_state_mutual_info=True) - belief_state = _outputs[0] - request_probs = _outputs[1] - active_domain_probs = _outputs[2] - general_act_probs = _outputs[3] - self.hidden_states = _outputs[4] - state_mutual_info = _outputs[5] - else: - _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], hidden_state=self.hidden_states, - get_turn_pooled_representation=False) - belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs + features['hidden_state'] = self.hidden_states + features['get_turn_pooled_representation'] = self.return_turn_pooled_representation + features['calculate_state_mutual_info'] = self.return_belief_state_mutual_info + outputs = self.model(**features) + self.hidden_states = outputs.hidden_state # Convert belief state into dialog state - dialogue_state = dict() - for slot, probs in belief_state.items(): - dom, slot = slot.split('-', 1) - if dom not in dialogue_state: - dialogue_state[dom] = dict() - val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()] - if val != 'none': - dialogue_state[dom][slot] = val + state = self.tokenizer.decode_state_batch(outputs.belief_state, outputs.request_probabilities, + outputs.active_domain_probabilities, + outputs.general_act_probabilities) + state = state['000000'][0] if self.store_full_belief_state: - self.info_dict['belief_state_distributions'] = belief_state + self.info_dict['belief_state_distributions'] = outputs.belief_state if state_mutual_info is not None: - self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info + self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information # Obtain model output probabilities if self.return_confidence_scores: state_entropy = None if self.return_belief_state_entropy: - state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()} + state_entropy = {slot: probs[0, 0, :] for slot, probs in outputs.belief_state.items()} state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()} # Confidence score is the max probability across all not "none" values candidates. - belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()} - _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()} - _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()} - _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()} + belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in outputs.belief_state.items()} + _request_probs = {slot: p[0, 0].item() for slot, p in outputs.request_probabilities.items()} + _active_domain_probs = {domain: p[0, 0].item() for domain, p in outputs.active_domain_probabilities.items()} + _general_act_probs = {'bye': outputs.general_act_probabilities[0, 0, 1].item(), + 'thank': outputs.general_act_probabilities[0, 0, 2].item()} confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs) else: confidence_scores = None state_entropy = None - # Construct request action prediction - if request_probs is not None: - request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] - request_acts = [slot.split('-', 1) for slot in request_acts] - request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] - else: - request_acts = list() - - # Construct active domain set - if active_domain_probs is not None: - active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} - else: - active_domains = dict() - - # Construct general domain action - if general_act_probs is not None: - general_acts = general_act_probs[0, 0, :].argmax(-1).item() - general_acts = [[], ['bye'], ['thank']][general_acts] - general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] - else: - general_acts = list() + outputs.confidence_scores = confidence_scores + outputs.state_entropy = state_entropy + outputs.state = state + outputs.belief_state = None + return outputs - user_acts = request_acts + general_acts - - out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores) - out += (state_entropy, state_mutual_info) - return out - - def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor: + @staticmethod + def relative_entropy(probs: torch.Tensor) -> torch.Tensor: """ - Compute relative entrop for a probability distribution + Compute relative entropy for a probability distribution Args: probs: Probability distributions @@ -412,18 +358,17 @@ class SetSUMBTTracker(DST): else: system_act = '' + dialogue = [[{ + 'user_utterance': user_act, + 'system_utterance': system_act + }]] + # Tokenize dialog - features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, - max_length=self.config.max_turn_len, padding='max_length', - truncation='longest_first') - - input_ids = torch.tensor(features['input_ids']).reshape( - 1, 1, -1).to(self.device) if 'input_ids' in features else None - token_type_ids = torch.tensor(features['token_type_ids']).reshape( - 1, 1, -1).to(self.device) if 'token_type_ids' in features else None - attention_mask = torch.tensor(features['attention_mask']).reshape( - 1, 1, -1).to(self.device) if 'attention_mask' in features else None - features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask} + features = self.tokenizer.encode(dialogue, max_seq_len=self.config.max_turn_len, max_turns=1) + + for key in features: + if features[key] is not None: + features[key] = features[key].to(self.device) return features @@ -431,7 +376,7 @@ class SetSUMBTTracker(DST): # if __name__ == "__main__": # from convlab.policy.vector.vector_uncertainty import VectorUncertainty # # from convlab.policy.vector.vector_binary import VectorBinary -# tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42', +# tracker = SetSUMBTTracker(model_name_or_path='setsumbt_multiwoz21', # return_confidence_scores=True, confidence_threshold='auto', # return_belief_state_entropy=True) # vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds, diff --git a/convlab/dst/setsumbt/utils/__init__.py b/convlab/dst/setsumbt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a96b9f8878cc81ea1e30e77deaf37366ee6348bc --- /dev/null +++ b/convlab/dst/setsumbt/utils/__init__.py @@ -0,0 +1,2 @@ +from convlab.dst.setsumbt.utils.configuration import get_args, update_args, clear_checkpoints +from convlab.dst.setsumbt.utils.ensemble import setup_ensemble, EnsembleAggregator diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils/configuration.py similarity index 88% rename from convlab/dst/setsumbt/utils.py rename to convlab/dst/setsumbt/utils/configuration.py index ff374116a3f8e88e6219fdc8b134d40b0bee7caf..bd318a1bcbd76200d120e942f65a33294caab569 100644 --- a/convlab/dst/setsumbt/utils.py +++ b/convlab/dst/setsumbt/utils/configuration.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""SetSUMBT utils""" +"""SetSUMBT configuration utilities.""" import os import json @@ -25,6 +25,12 @@ from git import Repo def get_args(base_models: dict): + """ + Get arguments from command line and config file. + + Args: + base_models: Dictionary of base models to use for ensemble training + """ # Get arguments parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) @@ -46,7 +52,6 @@ def get_args(base_models: dict): parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int) parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12, type=int) - parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.') parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int) parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings', action='store_true') @@ -56,6 +61,7 @@ def get_args(base_models: dict): parser.add_argument('--output_dir', help='Output storage directory', default=None) parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta') parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None) + parser.add_argument('--ensemble_model_path', help='Path to ensemble model', default=None) parser.add_argument('--candidate_embedding_model_name', default=None, help='Name of the pretrained candidate embedding model.') parser.add_argument('--transformers_local_files_only', help='Use local files only for huggingface transformers', @@ -140,13 +146,11 @@ def get_args(base_models: dict): "See details at https://nvidia.github.io/apex/amp.html") # ACTIONS - parser.add_argument('--run_nbt', help='Run NBT script', action='store_true') - parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true') - - # RUN_NBT ACTIONS parser.add_argument('--do_train', help='Perform training', action='store_true') parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true') + parser.add_argument('--do_eval_trainset', help='Evaluate model on training data', action='store_true') parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true') + parser.add_argument('--do_ensemble_setup', help='Setup the dataloaders for ensemble training', action='store_true') args = parser.parse_args() if args.starting_config_name: @@ -162,10 +166,13 @@ def get_args(base_models: dict): # Setup default directories if not args.output_dir: - args.output_dir = os.path.dirname(os.path.abspath(__file__)) + args.output_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) args.output_dir = os.path.join(args.output_dir, 'models') - name = 'SetSUMBT' if args.set_similarity else 'SUMBT' + name = 'Ensemble-' if args.do_ensemble_setup else '' + name += 'EnD-' if args.loss_function == 'distillation' else '' + name += 'EnD2-' if args.loss_function == 'distribution_distillation' else '' + name += 'SetSUMBT' if args.set_similarity else 'SUMBT' name += '+ActPrediction' if args.predict_actions else '' name += '-' + args.dataset name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else '' @@ -230,13 +237,25 @@ def get_args(base_models: dict): def get_starting_config(args): - path = os.path.dirname(os.path.realpath(__file__)) + """ + Load a config file and update the args with the values from the config file. + + Args: + args: The args object to update. + """ + path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) path = os.path.join(path, 'configs', f"{args.starting_config_name}.json") reader = open(path, 'r') config = json.load(reader) reader.close() if "model_type" in config: + if 'ensemble-' in config["model_type"].lower(): + args.ensemble = True + config["model_type"] = config["model_type"].lower().replace('ensemble-', '') + else: + args.ensemble = False + if config["model_type"].lower() == 'setsumbt': config["model_type"] = 'roberta' config["no_set_similarity"] = False @@ -255,6 +274,7 @@ def get_starting_config(args): def get_git_info(): + """Get the git info of the current branch and commit hash""" repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) branch_name = repo.active_branch.name commit_hex = repo.head.object.hexsha @@ -264,17 +284,21 @@ def get_git_info(): def build_config(config_class, args): + """ + Build a config object from the args. + + Args: + config_class: The config class to use. + args: The args object to use. + + Returns: + The config object. + """ config = config_class.from_pretrained(args.model_name_or_path) config.code_version = get_git_info() - if not os.path.exists(args.model_name_or_path): - config.tokenizer_name = args.model_name_or_path - try: - config.tokenizer_name = config.tokenizer_name - except AttributeError: - config.tokenizer_name = args.model_name_or_path try: config.candidate_embedding_model_name = config.candidate_embedding_model_name - except: + except AttributeError: if args.candidate_embedding_model_name: config.candidate_embedding_model_name = args.candidate_embedding_model_name config.max_dialogue_len = args.max_dialogue_len @@ -302,8 +326,6 @@ def build_config(config_class, args): if config.loss_function == 'bayesianmatching': config.kl_scaling_factor = args.kl_scaling_factor config.prior_constant = args.prior_constant - if config.loss_function == 'inhibitedce': - config.inhibiting_factor = args.inhibiting_factor if config.loss_function == 'labelsmoothing': config.label_smoothing = args.label_smoothing if config.loss_function == 'distillation': @@ -320,6 +342,16 @@ def build_config(config_class, args): def update_args(args, config): + """ + Update the args with the values from the config file. + + Args: + args: The args object to update. + config: The config object to use. + + Returns: + The updated args object. + """ try: args.candidate_embedding_model_name = config.candidate_embedding_model_name except AttributeError: @@ -342,6 +374,13 @@ def update_args(args, config): def clear_checkpoints(path, topn=1): + """ + Clear all checkpoints except the top n. + + Args: + path: The path to the checkpoints. + topn: The number of checkpoints to keep. + """ checkpoints = os.listdir(path) checkpoints = [p for p in checkpoints if 'checkpoint' in p] checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints]) diff --git a/convlab/dst/setsumbt/utils/ensemble.py b/convlab/dst/setsumbt/utils/ensemble.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc59cbf4751c5e46b8efa4e86165e9eb4416d65 --- /dev/null +++ b/convlab/dst/setsumbt/utils/ensemble.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensemble setup ad inference utils.""" + +import os +from shutil import copy2 as copy + +import torch +import numpy as np + +def setup_ensemble(model_path: str, ensemble_size: int): + """ + Setup ensemble model directory structure. + + Args: + model_path: Path to ensemble model directory + ensemble_size: Number of ensemble members + """ + for i in range(ensemble_size): + path = os.path.join(model_path, f'ens-{i}') + if not os.path.exists(path): + os.mkdir(path) + os.mkdir(os.path.join(path, 'dataloaders')) + # Add development set dataloader to each ensemble member directory + for set_type in ['dev']: + copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'), + os.path.join(path, 'dataloaders', f'{set_type}.dataloader')) + + +class EnsembleAggregator: + """Aggregator for ensemble model outputs.""" + + def __init__(self): + self.init_session() + self.input_items = ['input_ids', 'attention_mask', 'token_type_ids'] + self.output_items = ['belief_state', 'request_probabilities', 'active_domain_probabilities', + 'general_act_probabilities'] + + def init_session(self): + """Initialize aggregator for new session.""" + self.features = dict() + + def add_batch(self, model_input: dict, model_output: dict, dialogue_ids=None): + """ + Add batch of model outputs to aggregator. + + Args: + model_input: Model input dictionary + model_output: Model output dictionary + dialogue_ids: List of dialogue ids + """ + for key in self.input_items: + if key in model_input: + if key not in self.features: + self.features[key] = list() + self.features[key].append(model_input[key]) + + for key in self.output_items: + if key in model_output: + if key not in self.features: + self.features[key] = list() + self.features[key].append(model_output[key]) + + if dialogue_ids is not None: + if 'dialogue_ids' not in self.features: + self.features['dialogue_ids'] = [np.array([list(itm) for itm in dialogue_ids]).T] + else: + self.features['dialogue_ids'].append(np.array([list(itm) for itm in dialogue_ids]).T) + + def _aggregate(self): + """Aggregate model outputs.""" + for key in self.features: + self.features[key] = self._aggregate_item(self.features[key]) + + @staticmethod + def _aggregate_item(item): + """ + Aggregate single model output item. + + Args: + item: Model output item + + Returns: + Aggregated model output item + """ + if item[0] is None: + return None + elif type(item[0]) == dict: + return {k: EnsembleAggregator._aggregate_item([i[k] for i in item]) for k in item[0]} + elif type(item[0]) == np.ndarray: + return np.concatenate(item, 0) + else: + return torch.cat(item, 0) + + def save(self, path): + """ + Save aggregated model outputs to file. + + Args: + path: Path to save file + """ + self._aggregate() + torch.save(self.features, path) diff --git a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json index a80c04c9656dbdb361de1ca74e3ca24db028b1cf..03220dd3d3331b9d8324e5800b580c044b225bb2 100644 --- a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json +++ b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json @@ -26,9 +26,9 @@ "nlu_sys": {}, "dst_sys": { "setsumbt-mul": { - "class_path": "convlab.dst.setsumbt.SetSUMBTTracker", + "class_path": "convlab.dst.setsumbt.tracker.SetSUMBTTracker", "ini_params": { - "model_path": "https://huggingface.co/ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2/resolve/main/SetSUMBT-nlu-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0.zip", + "model_path": "ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2", "return_confidence_scores": true, "return_belief_state_mutual_info": true } diff --git a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json index bf9211006b6e2623016acfec18573768f73558fd..36f2d46d0c2ae32ca65c04e303c3967a9a56e53a 100644 --- a/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json +++ b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json @@ -24,9 +24,9 @@ "nlu_sys": {}, "dst_sys": { "setsumbt-mul": { - "class_path": "convlab.dst.setsumbt.SetSUMBTTracker", + "class_path": "convlab.dst.setsumbt.tracker.SetSUMBTTracker", "ini_params": { - "model_path": "https://huggingface.co/ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2/resolve/main/SetSUMBT-nlu-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0.zip" + "model_name_or_path": "ConvLab/setsumbt-dst_nlu-multiwoz21-EnD2" } } },