From e4d716f52c9df83a47c63c2d70b0ee7b9d5ace45 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <vniekerk.carel@gmail.com> Date: Wed, 16 Nov 2022 11:41:47 +0100 Subject: [PATCH] Merge new setsumbt code into github repo copy --- convlab/dst/setsumbt/__init__.py | 1 + convlab/dst/setsumbt/calibration_plots.py | 12 +- convlab/dst/setsumbt/dataset/__init__.py | 2 + convlab/dst/setsumbt/dataset/ontology.py | 133 + .../dst/setsumbt/dataset/unified_format.py | 423 +++ convlab/dst/setsumbt/dataset/utils.py | 409 +++ convlab/dst/setsumbt/dataset/value_maps.py | 50 + convlab/dst/setsumbt/distillation_setup.py | 253 +- convlab/dst/setsumbt/do/calibration.py | 481 --- convlab/dst/setsumbt/do/evaluate.py | 296 ++ convlab/dst/setsumbt/do/nbt.py | 352 +- convlab/dst/setsumbt/loss/__init__.py | 4 + convlab/dst/setsumbt/loss/bayesian.py | 144 - .../dst/setsumbt/loss/bayesian_matching.py | 115 + convlab/dst/setsumbt/loss/distillation.py | 201 -- convlab/dst/setsumbt/loss/endd_loss.py | 314 +- convlab/dst/setsumbt/loss/kl_distillation.py | 104 + convlab/dst/setsumbt/loss/labelsmoothing.py | 115 +- .../loss/{ece.py => uncertainty_measures.py} | 136 +- convlab/dst/setsumbt/modeling/__init__.py | 4 +- convlab/dst/setsumbt/modeling/bert_nbt.py | 102 +- .../setsumbt/modeling/calibration_utils.py | 134 - convlab/dst/setsumbt/modeling/ensemble_nbt.py | 242 +- .../dst/setsumbt/modeling/evaluation_utils.py | 112 + convlab/dst/setsumbt/modeling/functional.py | 456 --- convlab/dst/setsumbt/modeling/roberta_nbt.py | 105 +- convlab/dst/setsumbt/modeling/setsumbt.py | 564 +++ .../modeling/temperature_scheduler.py | 68 +- convlab/dst/setsumbt/modeling/training.py | 663 ++-- convlab/dst/setsumbt/multiwoz/Tracker.py | 455 --- convlab/dst/setsumbt/multiwoz/__init__.py | 2 - .../setsumbt/multiwoz/dataset/mapping.pair | 83 - .../setsumbt/multiwoz/dataset/multiwoz21.py | 502 --- .../setsumbt/multiwoz/dataset/mwoz21_ont.json | 2990 ---------------- .../multiwoz/dataset/mwoz21_ont_request.json | 3128 ----------------- .../dataset/mwoz21_slot_descriptions.json | 57 - .../dst/setsumbt/multiwoz/dataset/ontology.py | 168 - .../dst/setsumbt/multiwoz/dataset/utils.py | 446 --- convlab/dst/setsumbt/predict_user_actions.py | 178 + convlab/dst/setsumbt/process_mwoz_data.py | 99 - convlab/dst/setsumbt/run.py | 4 +- convlab/dst/setsumbt/tracker.py | 446 +++ convlab/dst/setsumbt/utils.py | 234 +- convlab/policy/mle/loader.py | 37 +- convlab/policy/mle/train.py | 37 +- ...eline_config.json => setsumbt_config.json} | 30 +- convlab/policy/ppo/setsumbt_unc_config.json | 65 + convlab/policy/ppo/train.py | 12 +- convlab/policy/vector/dataset.py | 20 - convlab/policy/vector/vector_base.py | 13 +- convlab/policy/vector/vector_binary.py | 2 +- .../vector/vector_multiwoz_uncertainty.py | 238 -- convlab/policy/vector/vector_nodes.py | 56 +- convlab/policy/vector/vector_uncertainty.py | 166 + convlab/util/custom_util.py | 105 +- 55 files changed, 4524 insertions(+), 11044 deletions(-) create mode 100644 convlab/dst/setsumbt/dataset/__init__.py create mode 100644 convlab/dst/setsumbt/dataset/ontology.py create mode 100644 convlab/dst/setsumbt/dataset/unified_format.py create mode 100644 convlab/dst/setsumbt/dataset/utils.py create mode 100644 convlab/dst/setsumbt/dataset/value_maps.py delete mode 100644 convlab/dst/setsumbt/do/calibration.py create mode 100644 convlab/dst/setsumbt/do/evaluate.py create mode 100644 convlab/dst/setsumbt/loss/__init__.py delete mode 100644 convlab/dst/setsumbt/loss/bayesian.py create mode 100644 convlab/dst/setsumbt/loss/bayesian_matching.py delete mode 100644 convlab/dst/setsumbt/loss/distillation.py create mode 100644 convlab/dst/setsumbt/loss/kl_distillation.py rename convlab/dst/setsumbt/loss/{ece.py => uncertainty_measures.py} (50%) delete mode 100644 convlab/dst/setsumbt/modeling/calibration_utils.py create mode 100644 convlab/dst/setsumbt/modeling/evaluation_utils.py delete mode 100644 convlab/dst/setsumbt/modeling/functional.py create mode 100644 convlab/dst/setsumbt/modeling/setsumbt.py delete mode 100644 convlab/dst/setsumbt/multiwoz/Tracker.py delete mode 100644 convlab/dst/setsumbt/multiwoz/__init__.py delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mapping.pair delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/ontology.py delete mode 100644 convlab/dst/setsumbt/multiwoz/dataset/utils.py create mode 100644 convlab/dst/setsumbt/predict_user_actions.py delete mode 100755 convlab/dst/setsumbt/process_mwoz_data.py create mode 100644 convlab/dst/setsumbt/tracker.py rename convlab/policy/ppo/{setsumbt_end_baseline_config.json => setsumbt_config.json} (53%) create mode 100644 convlab/policy/ppo/setsumbt_unc_config.json delete mode 100644 convlab/policy/vector/vector_multiwoz_uncertainty.py create mode 100644 convlab/policy/vector/vector_uncertainty.py diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py index e69de29b..9492faa9 100644 --- a/convlab/dst/setsumbt/__init__.py +++ b/convlab/dst/setsumbt/__init__.py @@ -0,0 +1 @@ +from convlab.dst.setsumbt.tracker import SetSUMBTTracker \ No newline at end of file diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py index 379057e6..a41f280d 100644 --- a/convlab/dst/setsumbt/calibration_plots.py +++ b/convlab/dst/setsumbt/calibration_plots.py @@ -35,7 +35,7 @@ def main(): path = args.data_dir models = os.listdir(path) - models = [os.path.join(path, model, 'test.belief') for model in models] + models = [os.path.join(path, model, 'test.predictions') for model in models] fig = plt.figure(figsize=(14,8)) font=20 @@ -56,16 +56,16 @@ def main(): def get_calibration(path, device, n_bins=10, temperature=1.00): - logits = torch.load(path, map_location=device) - y_true = logits['labels'] - logits = logits['belief_states'] + probs = torch.load(path, map_location=device) + y_true = probs['state_labels'] + probs = probs['belief_states'] - y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits} + y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs} goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred} goal_acc = sum([goal_acc[slot] for slot in goal_acc]) goal_acc = (goal_acc == len(y_true)).int() - scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits] + scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs] scores = torch.cat(scores, 0).min(0)[0] step = 1.0 / float(n_bins) diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py new file mode 100644 index 00000000..17b1f93b --- /dev/null +++ b/convlab/dst/setsumbt/dataset/__init__.py @@ -0,0 +1,2 @@ +from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size +from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py new file mode 100644 index 00000000..81e20780 --- /dev/null +++ b/convlab/dst/setsumbt/dataset/ontology.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Create Ontology Embeddings""" + +import json +import os +import random +from copy import deepcopy + +import torch +import numpy as np + + +def set_seed(args): + """ + Set random seeds + + Args: + args (Arguments class): Arguments class containing seed and number of gpus to use + """ + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor: + """ + Embed candidates + + Args: + candidates (list): List of candidate descriptions + args (argument class): Runtime arguments + tokenizer (transformers Tokenizer): Tokenizer for the embedding_model + embedding_model (transformer Model): Transformer model for embedding candidate descriptions + + Returns: + feats (torch.tensor): Embeddings of the candidate descriptions + """ + # Tokenize candidate descriptions + feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len, + padding='max_length', truncation='longest_first') + for val in candidates] + + # Encode tokenized descriptions + with torch.no_grad(): + feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]} + embedded_feats = embedding_model(**feats) # [num_candidates, max_candidate_len, hidden_dim] + + # Reduce/pool descriptions embeddings if required + if args.set_similarity: + feats = embedded_feats.last_hidden_state.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim] + elif args.candidate_pooling == 'cls': + feats = embedded_feats.pooler_output.detach().cpu() # [num_candidates, hidden_dim] + elif args.candidate_pooling == "mean": + feats = embedded_feats.last_hidden_state.detach().cpu() + feats = feats.sum(1) + feats = torch.nn.functional.layer_norm(feats, feats.size()) + feats = feats.detach().cpu() # [num_candidates, hidden_dim] + + return feats + + +def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True): + """ + Get embeddings for slots and candidates + + Args: + ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets + set_type (str): Subset of the dataset being used (train/validation/test) + args (argument class): Runtime arguments + tokenizer (transformers Tokenizer): Tokenizer for the embedding_model + embedding_model (transformer Model): Transormer model for embedding candidate descriptions + save_to_file (bool): Indication of whether to save information to file + + Returns: + slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot + """ + # Set model to eval mode + embedding_model.eval() + + slots = dict() + for domain, subset in ontology.items(): + for slot, slot_info in subset.items(): + # Get description or use "domain-slot" + if args.use_descriptions: + desc = slot_info['description'] + else: + desc = f"{domain}-{slot}" + + # Encode domain-slot pair description + slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0] + + # Obtain possible value set and discard requestable value + values = deepcopy(slot_info['possible_values']) + is_requestable = False + if '?' in values: + is_requestable = True + values.remove('?') + + # Encode value candidates + if values: + feats = encode_candidates(values, args, tokenizer, embedding_model) + else: + feats = None + + # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot + slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable) + + # Dump tensors and ontology for use in training and evaluation + if save_to_file: + writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type) + torch.save(slots, writer) + + writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w') + json.dump(ontology, writer, indent=2) + writer.close() + + return slots diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py new file mode 100644 index 00000000..26b67268 --- /dev/null +++ b/convlab/dst/setsumbt/dataset/unified_format.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convlab3 Unified Format Dialogue Datasets""" + +from copy import deepcopy + +import torch +import transformers +from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler +from transformers.tokenization_utils import PreTrainedTokenizer +from tqdm import tqdm + +from convlab.util import load_dataset +from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, + get_values_from_data, ontology_add_requestable_slots, + get_requestable_slots, load_dst_data, extract_dialogues, + combine_value_sets) + +transformers.logging.set_verbosity_error() + + +def convert_examples_to_features(data: list, + ontology: dict, + tokenizer: PreTrainedTokenizer, + max_turns: int = 12, + max_seq_len: int = 64) -> dict: + """ + Convert dialogue examples to model input features and labels + + Args: + data (list): List of all extracted dialogues + ontology (dict): Ontology dictionary containing slots, slot descriptions and + possible value sets including requests + tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used + max_turns (int): Maximum numbers of turns in a dialogue + max_seq_len (int): Maximum number of tokens in a dialogue turn + + Returns: + features (dict): All inputs and labels required to train the model + """ + features = dict() + ontology = deepcopy(ontology) + + # Get encoder input for system, user utterance pairs + input_feats = [] + for dial in tqdm(data): + dial_feats = [] + for turn in dial: + if len(turn['system_utterance']) == 0: + usr = turn['user_utterance'] + dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True, + max_length=max_seq_len, padding='max_length', + truncation='longest_first')) + else: + usr = turn['user_utterance'] + sys = turn['system_utterance'] + dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True, + max_length=max_seq_len, padding='max_length', + truncation='longest_first')) + # Truncate + if len(dial_feats) >= max_turns: + break + input_feats.append(dial_feats) + del dial_feats + + # Perform turn level padding + input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) + for dial in input_feats] + if 'token_type_ids' in input_feats[0][0]: + token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) + for dial in input_feats] + else: + token_type_ids = None + if 'attention_mask' in input_feats[0][0]: + attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) + for dial in input_feats] + else: + attention_mask = None + del input_feats + + # Create torch data tensors + features['input_ids'] = torch.tensor(input_ids) + features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None + features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None + del input_ids, token_type_ids, attention_mask + + # Extract all informable and requestable slots from the ontology + informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain] + if ontology[domain][slot]['possible_values'] + and ontology[domain][slot]['possible_values'] != ['?']] + requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain] + if '?' in ontology[domain][slot]['possible_values']] + for slot in requestable_slots: + domain, slot = slot.split('-', 1) + ontology[domain][slot]['possible_values'].remove('?') + + # Extract a list of domains from the ontology slots + domains = list(set(informable_slots + requestable_slots)) + domains = list(set([slot.split('-', 1)[0] for slot in domains])) + + # Create slot labels + for domslot in tqdm(informable_slots): + labels = [] + for dial in data: + labs = [] + for turn in dial: + value = [v for d, substate in turn['state'].items() for s, v in substate.items() + if f'{d}-{s}' == domslot] + domain, slot = domslot.split('-', 1) + if turn['dataset_name'] in ontology[domain][slot]['dataset_names']: + value = value[0] if value else 'none' + else: + value = -1 + if value in ontology[domain][slot]['possible_values'] and value != -1: + value = ontology[domain][slot]['possible_values'].index(value) + else: + value = -1 # If value is not in ontology then we do not penalise the model + labs.append(value) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['state_labels-' + domslot] = labels + + # Create requestable slot labels + for domslot in tqdm(requestable_slots): + labels = [] + for dial in data: + labs = [] + for turn in dial: + domain, slot = domslot.split('-', 1) + if turn['dataset_name'] in ontology[domain][slot]['dataset_names']: + acts = [act['intent'] for act in turn['dialogue_acts'] + if act['domain'] == domain and act['slot'] == slot] + if acts: + act_ = acts[0] + if act_ == 'request': + labs.append(1) + else: + labs.append(0) + else: + labs.append(0) + else: + labs.append(-1) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['request_labels-' + domslot] = labels + + # General act labels (1-goodbye, 2-thank you) + labels = [] + for dial in tqdm(data): + labs = [] + for turn in dial: + acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']] + if acts: + if 'bye' in acts: + labs.append(1) + else: + labs.append(2) + else: + labs.append(0) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['general_act_labels'] = labels + + # Create active domain labels + for domain in tqdm(domains): + labels = [] + for dial in data: + labs = [] + for turn in dial: + possible_domains = list() + for dom in ontology: + for slt in ontology[dom]: + if turn['dataset_name'] in ontology[dom][slt]['dataset_names']: + possible_domains.append(dom) + + if domain in turn['active_domains']: + labs.append(1) + elif domain in possible_domains: + labs.append(0) + else: + labs.append(-1) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['active_domain_labels-' + domain] = labels + + del labels + + return features + + +class UnifiedFormatDataset(Dataset): + """ + Class for preprocessing, and storing data easily from the Convlab3 unified format. + + Attributes: + dataset_dict (dict): Dictionary containing all the data in dataset + ontology (dict): Set of all domain-slot-value triplets in the ontology of the model + features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model + """ + def __init__(self, + dataset_name: str, + set_type: str, + tokenizer: PreTrainedTokenizer, + max_turns: int = 12, + max_seq_len: int = 64, + train_ratio: float = 1.0, + seed: int = 0, + data: dict = None, + ontology: dict = None): + """ + Args: + dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +) + set_type (str): Subset of the dataset to load (train, validation or test) + tokenizer (transformers tokenizer): Tokenizer for the encoder model used + max_turns (int): Maximum numbers of turns in a dialogue + max_seq_len (int): Maximum number of tokens in a dialogue turn + train_ratio (float): Fraction of training data to use during training + seed (int): Seed governing random order of ids for subsampling + data (dict): Dataset features for loading from dict + ontology (dict): Ontology dict for loading from dict + """ + if data is not None: + self.ontology = ontology + self.features = data + else: + if '+' in dataset_name: + dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')] + else: + dataset_args = [{"dataset_name": dataset_name}] + self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args] + self.ontology = get_ontology_slots(dataset_name) + values = [get_values_from_data(dataset) for dataset in self.dataset_dicts] + self.ontology = ontology_add_values(self.ontology, combine_value_sets(values)) + self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts)) + + if train_ratio != 1.0: + for dataset_args_ in dataset_args: + dataset_args_['dial_ids_order'] = seed + dataset_args_['split2ratio'] = {'train': train_ratio, 'validation': train_ratio} + self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args] + + data = [load_dst_data(dataset_dict, data_split=set_type, speaker='all', + dialogue_acts=True, split_to_turn=False) + for dataset_dict in self.dataset_dicts] + data_list = [data_[set_type] for data_ in data] + + data = [] + for idx, data_ in enumerate(data_list): + data += extract_dialogues(data_, dataset_args[idx]["dataset_name"]) + self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len) + + def __getitem__(self, index: int) -> dict: + """ + Obtain dialogues with specific ids from dataset + + Args: + index (int/list/tensor): Index/indices of dialogues to get + + Returns: + features (dict): All inputs and labels required to train the model + """ + return {label: self.features[label][index] for label in self.features + if self.features[label] is not None} + + def __len__(self): + """ + Get number of dialogues in the dataset + + Returns: + len (int): Number of dialogues in the dataset object + """ + return self.features['input_ids'].size(0) + + def resample(self, size: int = None) -> Dataset: + """ + Resample subset of the dataset + + Args: + size (int): Number of dialogues to sample + + Returns: + self (Dataset): Dataset object + """ + # If no subset size is specified we resample a set with the same size as the full dataset + n_dialogues = self.__len__() + if not size: + size = n_dialogues + + dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) + self.features = self.__getitem__(dialogues) + + return self + + def to(self, device): + """ + Map all data to a device + + Args: + device (torch device): Device to map data to + """ + self.device = device + self.features = {label: self.features[label].to(device) for label in self.features + if self.features[label] is not None} + + @classmethod + def from_datadict(cls, data: dict, ontology: dict): + return cls(None, None, None, data=data, ontology=ontology) + + +def get_dataloader(dataset_name: str, + set_type: str, + batch_size: int, + tokenizer: PreTrainedTokenizer, + max_turns: int = 12, + max_seq_len: int = 64, + device='cpu', + resampled_size: int = None, + train_ratio: float = 1.0, + seed: int = 0) -> DataLoader: + ''' + Module to create torch dataloaders + + Args: + dataset_name (str): Name of the dataset to load + set_type (str): Subset of the dataset to load (train, validation or test) + batch_size (int): Batch size for the dataloader + tokenizer (transformers tokenizer): Tokenizer for the encoder model used + max_turns (int): Maximum numbers of turns in a dialogue + max_seq_len (int): Maximum number of tokens in a dialogue turn + device (torch device): Device to map data to + resampled_size (int): Number of dialogues to sample + train_ratio (float): Ratio of training data to use for training + seed (int): Seed governing random order of ids for subsampling + + Returns: + loader (torch dataloader): Dataloader to train and evaluate the setsumbt model + ''' + data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio, + seed=seed) + data.to(device) + + if resampled_size: + data = data.resample(resampled_size) + + if set_type in ['test', 'validation']: + sampler = SequentialSampler(data) + else: + sampler = RandomSampler(data) + loader = DataLoader(data, sampler=sampler, batch_size=batch_size) + + return loader + + +def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader: + """ + Change the batch size of a preloaded loader + + Args: + loader (DataLoader): Dataloader to train and evaluate the setsumbt model + batch_size (int): Batch size for the dataloader + + Returns: + loader (DataLoader): Dataloader to train and evaluate the setsumbt model + """ + + if 'SequentialSampler' in str(loader.sampler): + sampler = SequentialSampler(loader.dataset) + else: + sampler = RandomSampler(loader.dataset) + loader = DataLoader(loader.dataset, sampler=sampler, batch_size=batch_size) + + return loader + +def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader: + """ + Sample a subset of the dialogues in a dataloader + + Args: + loader (DataLoader): Dataloader to train and evaluate the setsumbt model + sample_size (int): Number of dialogues to sample + + Returns: + loader (DataLoader): Dataloader to train and evaluate the setsumbt model + """ + + dataset = loader.dataset.resample(sample_size) + + if 'SequentialSampler' in str(loader.sampler): + sampler = SequentialSampler(dataset) + else: + sampler = RandomSampler(dataset) + loader = DataLoader(loader.dataset, sampler=sampler, batch_size=loader.batch_size) + + return loader diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/dataset/utils.py new file mode 100644 index 00000000..088480c4 --- /dev/null +++ b/convlab/dst/setsumbt/dataset/utils.py @@ -0,0 +1,409 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convlab3 Unified dataset data processing utilities""" + +from convlab.util import load_ontology, load_dst_data, load_nlu_data +from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME + + +def get_ontology_slots(dataset_name: str) -> dict: + """ + Function to extract slots, slot descriptions and categorical slot values from the dataset ontology. + + Args: + dataset_name (str): Dataset name + + Returns: + ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values + """ + dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name] + ontology_slots = dict() + for dataset_name in dataset_names: + ontology = load_ontology(dataset_name) + domains = [domain for domain in ontology['domains'] if domain not in ['booking', 'general']] + for domain in domains: + domain_name = DOMAINS_MAP.get(domain, domain.lower()) + if domain_name not in ontology_slots: + ontology_slots[domain_name] = dict() + for slot, slot_info in ontology['domains'][domain]['slots'].items(): + if slot not in ontology_slots[domain_name]: + ontology_slots[domain_name][slot] = {'description': slot_info['description'], + 'possible_values': list(), + 'dataset_names': list()} + if slot_info['is_categorical']: + ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values'] + + ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values'])) + ontology_slots[domain_name][slot]['dataset_names'].append(dataset_name) + + return ontology_slots + + +def get_values_from_data(dataset: dict) -> dict: + """ + Function to extract slots, slot descriptions and categorical slot values from the dataset ontology. + + Args: + dataset (dict): Dataset dictionary obtained using the load_dataset function + + Returns: + value_sets (dict): Dictionary containing possible values obtained from dataset + """ + data = load_dst_data(dataset, data_split='all', speaker='user') + value_sets = {} + for set_type, dataset in data.items(): + for turn in dataset: + for domain, substate in turn['state'].items(): + domain_name = DOMAINS_MAP.get(domain, domain.lower()) + if domain not in value_sets: + value_sets[domain_name] = {} + for slot, value in substate.items(): + if slot not in value_sets[domain_name]: + value_sets[domain_name][slot] = [] + if value and value not in value_sets[domain_name][slot]: + value_sets[domain_name][slot].append(value) + + return clean_values(value_sets) + + +def combine_value_sets(value_sets: list) -> dict: + """ + Function to combine value sets extracted from different datasets + + Args: + value_sets (list): List of value sets extracted using the get_values_from_data function + + Returns: + value_set (dict): Dictionary containing possible values obtained from datasets + """ + value_set = value_sets[0] + for _value_set in value_sets[1:]: + for domain, domain_info in _value_set.items(): + for slot, possible_values in domain_info.items(): + if domain not in value_set: + value_set[domain] = dict() + if slot not in value_set[domain]: + value_set[domain][slot] = list() + value_set[domain][slot] += _value_set[domain][slot] + value_set[domain][slot] = list(set(value_set[domain][slot])) + + return value_set + + +def clean_values(value_sets: dict, value_map: dict = VALUE_MAP) -> dict: + """ + Function to clean up the possible value sets extracted from the states in the dataset + + Args: + value_sets (dict): Dictionary containing possible values obtained from dataset + value_map (dict): Label map to avoid duplication and typos in values + + Returns: + clean_vals (dict): Cleaned Dictionary containing possible values obtained from dataset + """ + clean_vals = {} + for domain, subset in value_sets.items(): + clean_vals[domain] = {} + for slot, values in subset.items(): + # Remove pipe separated values + values = list(set([val.split('|', 1)[0] for val in values])) + + # Map values using value_map + for old, new in value_map.items(): + values = list(set([val.replace(old, new) for val in values])) + + # Remove empty and dontcare from possible value sets + values = [val for val in values if val not in ['', 'dontcare']] + + # MultiWOZ specific value sets for quantity, time and boolean slots + if 'people' in slot or 'duration' in slot or 'stay' in slot: + values = QUANTITIES + elif 'time' in slot or 'leave' in slot or 'arrive' in slot: + values = TIME + elif 'parking' in slot or 'internet' in slot: + values = ['yes', 'no'] + + clean_vals[domain][slot] = values + + return clean_vals + + +def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict: + """ + Add value sets obtained from the dataset to the ontology + Args: + ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values + value_sets (dict): Cleaned Dictionary containing possible values obtained from dataset + + Returns: + ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and possible value sets + """ + ontology = {} + for domain in sorted(ontology_slots): + ontology[domain] = {} + for slot in sorted(ontology_slots[domain]): + if not ontology_slots[domain][slot]['possible_values']: + if domain in value_sets: + if slot in value_sets[domain]: + ontology_slots[domain][slot]['possible_values'] = value_sets[domain][slot] + if ontology_slots[domain][slot]['possible_values']: + values = sorted(ontology_slots[domain][slot]['possible_values']) + ontology_slots[domain][slot]['possible_values'] = ['none', 'do not care'] + values + + ontology[domain][slot] = ontology_slots[domain][slot] + + return ontology + + +def get_requestable_slots(datasets: list) -> dict: + """ + Function to get set of requestable slots from the dataset action labels. + Args: + dataset (dict): Dataset dictionary obtained using the load_dataset function + + Returns: + slots (dict): Dictionary containing requestable domain-slot pairs + """ + datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets] + + slots = {} + for data in datasets: + for set_type, subset in data.items(): + for turn in subset: + requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request'] + requests += [act for act in turn['dialogue_acts']['non-categorical'] if act['intent'] == 'request'] + requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request'] + requests = [(act['domain'], act['slot']) for act in requests] + for domain, slot in requests: + domain_name = DOMAINS_MAP.get(domain, domain.lower()) + if domain_name not in slots: + slots[domain_name] = [] + slots[domain_name].append(slot) + + slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()} + + return slots + + +def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict) -> dict: + """ + Add requestable slots obtained from the dataset to the ontology + Args: + ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values + requestable_slots (dict): Dictionary containing requestable domain-slot pairs + + Returns: + ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and + possible value sets including requests + """ + for domain in ontology_slots: + for slot in ontology_slots[domain]: + if domain in requestable_slots: + if slot in requestable_slots[domain]: + ontology_slots[domain][slot]['possible_values'].append('?') + + return ontology_slots + + +def extract_turns(dialogue: list, dataset_name: str) -> list: + """ + Extract the required information from the data provided by unified loader + Args: + dialogue (list): List of turns within a dialogue + dataset_name (str): Name of the dataset to which the dialogue belongs + + Returns: + turns (list): List of turns within a dialogue + """ + turns = [] + turn_info = {} + for turn in dialogue: + if turn['speaker'] == 'system': + turn_info['system_utterance'] = turn['utterance'] + + # System utterance in the first turn is always empty as conversation is initiated by the user + if turn['utt_idx'] == 1: + turn_info['system_utterance'] = '' + + if turn['speaker'] == 'user': + turn_info['user_utterance'] = turn['utterance'] + + # Inform acts not required by model + turn_info['dialogue_acts'] = [act for act in turn['dialogue_acts']['categorical'] + if act['intent'] not in ['inform']] + turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['non-categorical'] + if act['intent'] not in ['inform']] + turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['binary'] + if act['intent'] not in ['inform']] + + turn_info['state'] = turn['state'] + turn_info['dataset_name'] = dataset_name + + if 'system_utterance' in turn_info and 'user_utterance' in turn_info: + turns.append(turn_info) + turn_info = {} + + return turns + + +def clean_states(turns: list) -> list: + """ + Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology) + Args: + turns (list): List of turns within a dialogue + + Returns: + clean_turns (list): List of turns within a dialogue + """ + clean_turns = [] + for turn in turns: + clean_state = {} + clean_acts = [] + for act in turn['dialogue_acts']: + domain = act['domain'] + act['domain'] = DOMAINS_MAP.get(domain, domain.lower()) + clean_acts.append(act) + for domain, subset in turn['state'].items(): + domain_name = DOMAINS_MAP.get(domain, domain.lower()) + clean_state[domain_name] = {} + for slot, value in subset.items(): + # Remove pipe separated values + value = value.split('|', 1)[0] + + # Map values using value_map + for old, new in VALUE_MAP.items(): + value = value.replace(old, new) + + # Map dontcare to "do not care" and empty to 'none' + value = value.replace('dontcare', 'do not care') + value = value if value else 'none' + + # Map quantity values to the integer quantity value + if 'people' in slot or 'duration' in slot or 'stay' in slot: + try: + if value not in ['do not care', 'none']: + value = int(value) + value = str(value) if value < 10 else QUANTITIES[-1] + except: + value = value + # Map time values to the most appropriate value in the standard time set + elif 'time' in slot or 'leave' in slot or 'arrive' in slot: + try: + if value not in ['do not care', 'none']: + # Strip after/before from time value + value = value.replace('after ', '').replace('before ', '') + # Extract hours and minutes from different possible formats + if ':' not in value and len(value) == 4: + h, m = value[:2], value[2:] + elif len(value) == 1: + h = int(value) + m = 0 + elif 'pm' in value: + h = int(value.replace('pm', '')) + 12 + m = 0 + elif 'am' in value: + h = int(value.replace('pm', '')) + m = 0 + elif ':' in value: + h, m = value.split(':') + elif ';' in value: + h, m = value.split(';') + # Map to closest 5 minutes + if int(m) % 5 != 0: + m = round(int(m) / 5) * 5 + h = int(h) + if m == 60: + m = 0 + h += 1 + if h >= 24: + h -= 24 + # Set in standard 24 hour format + h, m = int(h), int(m) + value = '%02i:%02i' % (h, m) + except: + value = value + # Map boolean slots to yes/no value + elif 'parking' in slot or 'internet' in slot: + if value not in ['do not care', 'none']: + if value == 'free': + value = 'yes' + elif True in [v in value.lower() for v in ['yes', 'no']]: + value = [v for v in ['yes', 'no'] if v in value][0] + + clean_state[domain_name][slot] = value + turn['state'] = clean_state + turn['dialogue_acts'] = clean_acts + clean_turns.append(turn) + + return clean_turns + + +def get_active_domains(turns: list) -> list: + """ + Get active domains at each turn in a dialogue + Args: + turns (list): List of turns within a dialogue + + Returns: + turns (list): List of turns within a dialogue + """ + for turn_id in range(len(turns)): + # At first turn all domains with not none values in the state are active + if turn_id == 0: + domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none'] + domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']] + domains = [DOMAINS_MAP.get(domain, domain.lower()) for domain in domains] + turns[turn_id]['active_domains'] = list(set(domains)) + else: + # Use changes in domains to identify active domains + domains = [] + for domain, substate in turns[turn_id]['state'].items(): + domain_name = DOMAINS_MAP.get(domain, domain.lower()) + for slot, value in substate.items(): + if value != turns[turn_id - 1]['state'][domain][slot]: + val = value + else: + val = 'none' + if value == 'none': + val = 'none' + if val != 'none': + domains.append(domain_name) + # Add all domains activated by a user action + domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] + if act['domain'] in turns[turn_id]['state']] + turns[turn_id]['active_domains'] = list(set(domains)) + + return turns + + +def extract_dialogues(data: list, dataset_name: str) -> list: + """ + Extract all dialogues from dataset + Args: + data (list): List of all dialogues in a subset of the data + dataset_name (str): Name of the dataset to which the dialogues belongs + + Returns: + dialogues (list): List of all extracted dialogues + """ + dialogues = [] + for dial in data: + turns = extract_turns(dial['turns'], dataset_name) + turns = clean_states(turns) + turns = get_active_domains(turns) + dialogues.append(turns) + + return dialogues diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/dataset/value_maps.py new file mode 100644 index 00000000..619600a7 --- /dev/null +++ b/convlab/dst/setsumbt/dataset/value_maps.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convlab3 Unified dataset value maps""" + + +# MultiWOZ specific label map to avoid duplication and typos in values +VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast', + 'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott', + 'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge', + 'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel', + 'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european', + 'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"} + + +# Domain map for SGD and TM Data +DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus', + 'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events', + 'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights', + 'Flights_4': 'flights', 'Homes_1': 'homes', 'Homes_2': 'homes', 'Hotels_1': 'hotel', + 'Hotels_2': 'hotel', 'Hotels_3': 'hotel', 'Hotels_4': 'hotel', 'Media_1': 'media', + 'Media_2': 'media', 'Media_3': 'media', 'Messaging_1': 'messaging', 'Movies_1': 'movies', + 'Movies_2': 'movies', 'Movies_3': 'movies', 'Music_1': 'music', 'Music_2': 'music', 'Music_3': 'music', + 'Payment_1': 'payment', 'RentalCars_1': 'rentalcars', 'RentalCars_2': 'rentalcars', + 'RentalCars_3': 'rentalcars', 'Restaurants_1': 'restaurant', 'Restaurants_2': 'restaurant', + 'RideSharing_1': 'ridesharing', 'RideSharing_2': 'ridesharing', 'Services_1': 'services', + 'Services_2': 'services', 'Services_3': 'services', 'Services_4': 'services', 'Trains_1': 'train', + 'Travel_1': 'travel', 'Weather_1': 'weather', 'movie_ticket': 'movies', + 'restaurant_reservation': 'restaurant', 'coffee_ordering': 'coffee', 'pizza_ordering': 'takeout', + 'auto_repair': 'car_repairs', 'flights': 'flights', 'food-ordering': 'takeout', 'hotels': 'hotel', + 'movies': 'movies', 'music': 'music', 'restaurant-search': 'restaurant', 'sports': 'sports', + 'movie': 'movies'} + + +# Generic value sets for quantity and time slots +QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more'] +TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)] +TIME = ['%02i:%02i' % t for l in TIME for t in l] \ No newline at end of file diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py index e0d87bb9..2279e222 100644 --- a/convlab/dst/setsumbt/distillation_setup.py +++ b/convlab/dst/setsumbt/distillation_setup.py @@ -1,53 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Get ensemble predictions and build distillation dataloaders""" + from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser import os +import json import torch -import transformers -from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler -from transformers import RobertaConfig, BertConfig +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from tqdm import tqdm -import convlab -from convlab.dst.setsumbt.multiwoz.dataset.multiwoz21 import EnsembleMultiWoz21 +from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT +from convlab.dst.setsumbt.modeling import training -DEVICE = 'cuda' - - -def args_parser(): - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--model_path', type=str) - parser.add_argument('--model_type', type=str) - parser.add_argument('--set_type', type=str) - parser.add_argument('--batch_size', type=int) - parser.add_argument('--ensemble_size', type=int) - parser.add_argument('--reduction', type=str, default='mean') - parser.add_argument('--get_ensemble_distributions', action='store_true') - parser.add_argument('--build_dataloaders', action='store_true') - - return parser.parse_args() - - -def main(): - args = args_parser() +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' - if args.get_ensemble_distributions: - get_ensemble_distributions(args) - elif args.build_dataloaders: - path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') - data = torch.load(path) - loader = get_loader(data, args.set_type, args.batch_size) - path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader') - torch.save(loader, path) - else: - raise NameError("NotImplemented") +def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader: + """ + Build dataloader from ensemble prediction data + Args: + data: Dictionary of ensemble predictions + ontology: Data ontology + set_type: Data subset (train/validation/test) + batch_size: Number of dialogues per batch -def get_loader(data, set_type='train', batch_size=3): + Returns: + loader: Data loader object + """ data = flatten_data(data) data = do_label_padding(data) - data = EnsembleMultiWoz21(data) + data = UnifiedFormatDataset.from_datadict(data, ontology) if set_type == 'train': sampler = RandomSampler(data) else: @@ -57,7 +55,16 @@ def get_loader(data, set_type='train', batch_size=3): return loader -def do_label_padding(data): +def do_label_padding(data: dict) -> dict: + """ + Add padding to the ensemble predictions (used as labels in distillation) + + Args: + data: Dictionary of ensemble predictions + + Returns: + data: Padded ensemble predictions + """ if 'attention_mask' in data: dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0) else: @@ -70,13 +77,17 @@ def do_label_padding(data): return data -map_dict = {'belief_state': 'belief', 'greeting_act_belief': 'goodbye_belief', - 'state_labels': 'labels', 'request_labels': 'request', - 'domain_labels': 'active', 'greeting_labels': 'goodbye'} -def flatten_data(data): +def flatten_data(data: dict) -> dict: + """ + Map data to flattened feature format used in training + Args: + data: Ensemble prediction data + + Returns: + data: Flattened ensemble prediction data + """ data_new = dict() for label, feats in data.items(): - label = map_dict.get(label, label) if type(feats) == dict: for label_, feats_ in feats.items(): data_new[label + '-' + label_] = feats_ @@ -87,13 +98,11 @@ def flatten_data(data): def get_ensemble_distributions(args): - if args.model_type == 'roberta': - config = RobertaConfig - elif args.model_type == 'bert': - config = BertConfig - config = config.from_pretrained(args.model_path) - config.ensemble_size = args.ensemble_size - + """ + Load data and get ensemble predictions + Args: + args: Runtime arguments + """ device = DEVICE model = EnsembleSetSUMBT.from_pretrained(args.model_path) @@ -107,16 +116,10 @@ def get_ensemble_distributions(args): dataloader = torch.load(dataloader) database = torch.load(database) - # Get slot and value embeddings - slots = {slot: val for slot, val in database.items()} - values = {slot: val[1] for slot, val in database.items()} - del database + if dataloader.batch_size != args.batch_size: + dataloader = change_batch_size(dataloader, args.batch_size) - # Load model ontology - model.add_slot_candidates(slots) - for slot in model.informable_slot_ids: - model.add_value_candidates(slot, values[slot], replace=True) - del slots, values + training.set_ontology_embeddings(model, database) print('Environment set up.') @@ -125,18 +128,24 @@ def get_ensemble_distributions(args): attention_mask = [] state_labels = {slot: [] for slot in model.informable_slot_ids} request_labels = {slot: [] for slot in model.requestable_slot_ids} - domain_labels = {domain: [] for domain in model.domain_ids} - greeting_labels = [] + active_domain_labels = {domain: [] for domain in model.domain_ids} + general_act_labels = [] + + is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None + belief_state = {slot: [] for slot in model.informable_slot_ids} - request_belief = {slot: [] for slot in model.requestable_slot_ids} - domain_belief = {domain: [] for domain in model.domain_ids} - greeting_act_belief = [] + request_probs = {slot: [] for slot in model.requestable_slot_ids} + active_domain_probs = {domain: [] for domain in model.domain_ids} + general_act_probs = [] model.eval() for batch in tqdm(dataloader, desc='Batch:'): ids = batch['input_ids'] tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None mask = batch['attention_mask'] if 'attention_mask' in batch else None + if 'is_noisy' in batch: + is_noisy.append(batch['is_noisy']) + input_ids.append(ids) token_type_ids.append(tt_ids) attention_mask.append(mask) @@ -146,61 +155,123 @@ def get_ensemble_distributions(args): mask = mask.to(device) if mask is not None else None for slot in state_labels: - state_labels[slot].append(batch['labels-' + slot]) - if model.config.predict_intents: + state_labels[slot].append(batch['state_labels-' + slot]) + if model.config.predict_actions: for slot in request_labels: - request_labels[slot].append(batch['request-' + slot]) - for domain in domain_labels: - domain_labels[domain].append(batch['active-' + domain]) - greeting_labels.append(batch['goodbye']) + request_labels[slot].append(batch['request_labels-' + slot]) + for domain in active_domain_labels: + active_domain_labels[domain].append(batch['active_domain_labels-' + domain]) + general_act_labels.append(batch['general_act_labels']) with torch.no_grad(): - p, p_req, p_dom, p_bye, _ = model(ids, mask, tt_ids, - reduction=args.reduction) + p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction) for slot in belief_state: belief_state[slot].append(p[slot].cpu()) - if model.config.predict_intents: - for slot in request_belief: - request_belief[slot].append(p_req[slot].cpu()) - for domain in domain_belief: - domain_belief[domain].append(p_dom[domain].cpu()) - greeting_act_belief.append(p_bye.cpu()) + if model.config.predict_actions: + for slot in request_probs: + request_probs[slot].append(p_req[slot].cpu()) + for domain in active_domain_probs: + active_domain_probs[domain].append(p_dom[domain].cpu()) + general_act_probs.append(p_gen.cpu()) input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None + is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()} - if model.config.predict_intents: + if model.config.predict_actions: request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()} - domain_labels = {domain: torch.cat(l, 0) for domain, l in domain_labels.items()} - greeting_labels = torch.cat(greeting_labels, 0) + active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()} + general_act_labels = torch.cat(general_act_labels, 0) belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()} - if model.config.predict_intents: - request_belief = {slot: torch.cat(p, 0) for slot, p in request_belief.items()} - domain_belief = {domain: torch.cat(p, 0) for domain, p in domain_belief.items()} - greeting_act_belief = torch.cat(greeting_act_belief, 0) + if model.config.predict_actions: + request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()} + active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()} + general_act_probs = torch.cat(general_act_probs, 0) data = {'input_ids': input_ids} if token_type_ids is not None: data['token_type_ids'] = token_type_ids if attention_mask is not None: data['attention_mask'] = attention_mask + if is_noisy is not None: + data['is_noisy'] = is_noisy data['state_labels'] = state_labels data['belief_state'] = belief_state - if model.config.predict_intents: + if model.config.predict_actions: data['request_labels'] = request_labels - data['domain_labels'] = domain_labels - data['greeting_labels'] = greeting_labels - data['request_belief'] = request_belief - data['domain_belief'] = domain_belief - data['greeting_act_belief'] = greeting_act_belief + data['active_domain_labels'] = active_domain_labels + data['general_act_labels'] = general_act_labels + data['request_probs'] = request_probs + data['active_domain_probs'] = active_domain_probs + data['general_act_probs'] = general_act_probs file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') torch.save(data, file) +def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str): + """ + Convert ensemble predictions to predictions file format. + + Args: + model_path: Path to ensemble location. + set_type: Evaluation dataset (train/dev/test). + """ + data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data")) + + # Get oracle labels + if 'request_probs' in data: + data_new = {'state_labels': data['state_labels'], + 'request_labels': data['request_labels'], + 'active_domain_labels': data['active_domain_labels'], + 'general_act_labels': data['general_act_labels']} + else: + data_new = {'state_labels': data['state_labels']} + + # Marginalising across ensemble distributions + data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()} + if 'request_probs' in data: + data_new['request_probs'] = {slot: distribution.mean(-1) + for slot, distribution in data['request_probs'].items()} + data_new['active_domain_probs'] = {domain: distribution.mean(-1) + for domain, distribution in data['active_domain_probs'].items()} + data_new['general_act_probs'] = data['general_act_probs'].mean(-2) + + # Save predictions file + predictions_dir = os.path.join(model_path, 'predictions') + if not os.path.exists(predictions_dir): + os.mkdir(predictions_dir) + torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions")) + + if __name__ == "__main__": - main() + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--model_path', type=str) + parser.add_argument('--set_type', type=str) + parser.add_argument('--batch_size', type=int, default=3) + parser.add_argument('--reduction', type=str, default='none') + parser.add_argument('--get_ensemble_distributions', action='store_true') + parser.add_argument('--convert_distributions_to_predictions', action='store_true') + parser.add_argument('--build_dataloaders', action='store_true') + args = parser.parse_args() + + if args.get_ensemble_distributions: + get_ensemble_distributions(args) + if args.convert_distributions_to_predictions: + ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type) + if args.build_dataloaders: + path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') + data = torch.load(path) + + reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r') + ontology = json.load(reader) + reader.close() + + loader = get_loader(data, ontology, args.set_type, args.batch_size) + + path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader') + torch.save(loader, path) diff --git a/convlab/dst/setsumbt/do/calibration.py b/convlab/dst/setsumbt/do/calibration.py deleted file mode 100644 index 27ee058e..00000000 --- a/convlab/dst/setsumbt/do/calibration.py +++ /dev/null @@ -1,481 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Run SetSUMBT Calibration""" - -import logging -import random -import os -from shutil import copy2 as copy - -import torch -from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer, - AdamW, get_linear_schedule_with_warmup) -from tqdm import tqdm, trange -from tensorboardX import SummaryWriter -from torch.distributions import Categorical - -from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT -from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT -from convlab.dst.setsumbt.multiwoz import multiwoz21 -from convlab.dst.setsumbt.multiwoz import ontology as embeddings -from convlab.dst.setsumbt.utils import get_args, upload_local_directory_to_gcs, update_args -from convlab.dst.setsumbt.modeling import calibration_utils -from convlab.dst.setsumbt.modeling import ensemble_utils -from convlab.dst.setsumbt.loss.ece import ece, jg_ece, l2_acc - - -# Datasets -DATASETS = { - 'multiwoz21': multiwoz21 -} - -MODELS = { - 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), - 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) -} - - -def main(args=None, config=None): - # Get arguments - if args is None: - args, config = get_args(MODELS) - - # Select Dataset object - if args.dataset in DATASETS: - Dataset = DATASETS[args.dataset] - else: - raise NameError('NotImplemented') - - if args.model_type in MODELS: - SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] - else: - raise NameError('NotImplemented') - - # Set up output directory - OUTPUT_DIR = args.output_dir - if not os.path.exists(OUTPUT_DIR): - os.mkdir(OUTPUT_DIR) - args.output_dir = OUTPUT_DIR - if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): - os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) - - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - paths = [os.path.join(args.output_dir, p) - for p in paths if 'checkpoint-' in p] - if paths: - paths = paths[0] - args.model_name_or_path = paths - config = ConfigClass.from_pretrained(args.model_name_or_path) - - if args.ensemble_size > 0: - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - paths = [os.path.join(args.output_dir, p) - for p in paths if 'ensemble_' in p] - if paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - - args = update_args(args, config) - - # Set up data directory - DATA_DIR = args.data_dir - Dataset.set_datadir(DATA_DIR) - embeddings.set_datadir(DATA_DIR) - - if args.shrink_active_domains and args.dataset == 'multiwoz21': - Dataset.set_active_domains( - ['attraction', 'hotel', 'restaurant', 'taxi', 'train']) - - # Download and preprocess - Dataset.create_examples( - args.max_turn_len, args.predict_intents, args.force_processing) - - # Create logger - global logger - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - if 'stream' not in args.logging_path: - fh = logging.FileHandler(args.logging_path) - fh.setLevel(logging.INFO) - fh.setFormatter(formatter) - logger.addHandler(fh) - else: - ch = logging.StreamHandler() - ch.setLevel(level=logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - # Get device - if torch.cuda.is_available() and args.n_gpu > 0: - device = torch.device('cuda') - else: - device = torch.device('cpu') - args.n_gpu = 0 - - if args.n_gpu == 0: - args.fp16 = False - - # Set up model training/evaluation - calibration.set_logger(logger, None) - calibration.set_seed(args) - - if args.ensemble_size > 0: - ensemble.set_logger(logger, tb_writer) - ensemble_utils.set_seed(args) - - # Perform tasks - - if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')): - pred = torch.load(os.path.join( - OUTPUT_DIR, 'predictions', 'test.predictions')) - labels = pred['labels'] - belief_states = pred['belief_states'] - if 'request_labels' in pred: - request_labels = pred['request_labels'] - request_belief = pred['request_belief'] - domain_labels = pred['domain_labels'] - domain_belief = pred['domain_belief'] - greeting_labels = pred['greeting_labels'] - greeting_belief = pred['greeting_belief'] - else: - request_belief = None - del pred - elif args.ensemble_size > 0: - # Get training batch loaders and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'test.db')) - else: - # Create Tokenizer and embedding model for Data Loaders and ontology - encoder = CandidateEncoderModel.from_pretrained( - config.candidate_embedding_model_name) - tokenizer = Tokenizer(config.candidate_embedding_model_name) - embeddings.get_slot_candidate_embeddings( - 'test', args, tokenizer, encoder) - test_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'test.db')) - - exists = False - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size == args.test_batch_size: - exists = True - if not exists: - tokenizer = Tokenizer(config.candidate_embedding_model_name) - test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - config, models = ensemble.get_models( - args.model_name_or_path, device, ConfigClass, SetSumbtModel) - - belief_states, labels = ensemble_utils.get_predictions( - args, models, device, test_dataloader, test_slots) - torch.save({'belief_states': belief_states, 'labels': labels}, - os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) - else: - # Get training batch loaders and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'test.db')) - else: - # Create Tokenizer and embedding model for Data Loaders and ontology - encoder = CandidateEncoderModel.from_pretrained( - config.candidate_embedding_model_name) - tokenizer = Tokenizer(config.candidate_embedding_model_name) - embeddings.get_slot_candidate_embeddings( - 'test', args, tokenizer, encoder) - test_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'test.db')) - - exists = False - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size == args.test_batch_size: - exists = True - if not exists: - tokenizer = Tokenizer(config.candidate_embedding_model_name) - test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - # Initialise Model - model = SetSumbtModel.from_pretrained( - args.model_name_or_path, config=config) - model = model.to(device) - - # Get slot and value embeddings - slots = {slot: test_slots[slot] for slot in test_slots} - values = {slot: test_slots[slot][1] for slot in test_slots} - - # Load model ontology - model.add_slot_candidates(slots) - for slot in model.informable_slot_ids: - model.add_value_candidates(slot, values[slot], replace=True) - - belief_states = calibration.get_predictions( - args, model, device, test_dataloader) - belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states - out = {'belief_states': belief_states, 'labels': labels, - 'request_belief': request_belief, 'request_labels': request_labels, - 'domain_belief': domain_belief, 'domain_labels': domain_labels, - 'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels} - torch.save(out, os.path.join( - OUTPUT_DIR, 'predictions', 'test.predictions')) - - # err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10) - # for slot in belief_states] - # err = max(err) - # logger.info('ECE: %f' % err) - - # Calculate calibration metrics - - jg = jg_ece(belief_states, labels, 10) - logger.info('Joint Goal ECE: %f' % jg) - - binary_states = {} - for slot, p in belief_states.items(): - shp = p.shape - p = p.reshape(-1, p.size(-1)) - p_ = torch.ones(p.shape).to(p.device) * 1e-8 - p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 - binary_states[slot] = p_.reshape(shp) - jg = jg_ece(binary_states, labels, 10) - logger.info('Joint Goal Binary ECE: %f' % jg) - - bs = {slot: torch.cat((p[:, :, 0].unsqueeze(-1), p[:, :, 1:].max(-1) - [0].unsqueeze(-1)), -1) for slot, p in belief_states.items()} - ls = {} - for slot, l in labels.items(): - y = torch.zeros((l.size(0), l.size(1))).to(l.device) - dials, turns = torch.where(l > 0) - y[dials, turns] = 1.0 - dials, turns = torch.where(l < 0) - y[dials, turns] = -1.0 - ls[slot] = y - - jg = jg_ece(bs, ls, 10) - logger.info('Slot presence ECE: %f' % jg) - - binary_states = {} - for slot, p in bs.items(): - shp = p.shape - p = p.reshape(-1, p.size(-1)) - p_ = torch.ones(p.shape).to(p.device) * 1e-8 - p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 - binary_states[slot] = p_.reshape(shp) - jg = jg_ece(binary_states, ls, 10) - logger.info('Slot presence Binary ECE: %f' % jg) - - jg_acc = 0.0 - padding = torch.cat([item.unsqueeze(-1) - for _, item in labels.items()], -1).sum(-1) * -1.0 - padding = (padding == len(labels)) - padding = padding.reshape(-1) - for slot in belief_states: - topn = args.accuracy_topn - p_ = belief_states[slot] - gold = labels[slot] - - if p_.size(-1) <= topn: - topn = p_.size(-1) - 1 - if topn <= 0: - topn = 1 - - if topn > 1: - labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True) - labs = labs[:, :topn] - else: - labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1) - acc = [lab in s for lab, s, pad in zip( - gold.reshape(-1), labs, padding) if not pad] - acc = torch.tensor(acc).float() - - jg_acc += acc - - n_turns = jg_acc.size(0) - sl_acc = sum(jg_acc / len(belief_states)).float() - jg_acc = sum((jg_acc / len(belief_states)).int()).float() - - sl_acc /= n_turns - jg_acc /= n_turns - - logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc)) - - l2 = l2_acc(belief_states, labels, remove_belief=False) - logger.info(f'Model L2 Norm Goal Accuracy: {l2}') - l2 = l2_acc(belief_states, labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}') - - for slot in belief_states: - p = belief_states[slot] - p = p.reshape(-1, p.size(-1)) - p = torch.cat( - (p[:, 0].unsqueeze(-1), p[:, 1:].max(-1)[0].unsqueeze(-1)), -1) - belief_states[slot] = p - - l = labels[slot].reshape(-1) - l[l > 0] = 1 - labels[slot] = l - - f1 = 0.0 - for slot in belief_states: - prd = belief_states[slot].argmax(-1) - tp = ((prd == 1) * (labels[slot] == 1)).sum() - fp = ((prd == 1) * (labels[slot] == 0)).sum() - fn = ((prd == 0) * (labels[slot] == 1)).sum() - if tp > 0: - f1 += tp / (tp + 0.5 * (fp + fn)) - f1 /= len(belief_states) - logger.info(f'Trucated Goal F1 Score: {f1}') - - l2 = l2_acc(belief_states, labels, remove_belief=False) - logger.info(f'Model L2 Norm Trucated Goal Accuracy: {l2}') - l2 = l2_acc(belief_states, labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Trucated Goal Accuracy: {l2}') - - if request_belief is not None: - tp, fp, fn = 0.0, 0.0, 0.0 - for slot in request_belief: - p = request_belief[slot] - l = request_labels[slot] - - tp += (p.round().int() * (l == 1)).reshape(-1).float() - fp += (p.round().int() * (l == 0)).reshape(-1).float() - fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() - tp /= len(request_belief) - fp /= len(request_belief) - fn /= len(request_belief) - f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) - logger.info('Request F1 Score: %f' % f1.item()) - - for slot in request_belief: - p = request_belief[slot] - p = p.unsqueeze(-1) - p = torch.cat((1 - p, p), -1) - request_belief[slot] = p - jg = jg_ece(request_belief, request_labels, 10) - logger.info('Request Joint Goal ECE: %f' % jg) - - binary_states = {} - for slot, p in request_belief.items(): - shp = p.shape - p = p.reshape(-1, p.size(-1)) - p_ = torch.ones(p.shape).to(p.device) * 1e-8 - p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 - binary_states[slot] = p_.reshape(shp) - jg = jg_ece(binary_states, request_labels, 10) - logger.info('Request Joint Goal Binary ECE: %f' % jg) - - tp, fp, fn = 0.0, 0.0, 0.0 - for dom in domain_belief: - p = domain_belief[dom] - l = domain_labels[dom] - - tp += (p.round().int() * (l == 1)).reshape(-1).float() - fp += (p.round().int() * (l == 0)).reshape(-1).float() - fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() - tp /= len(domain_belief) - fp /= len(domain_belief) - fn /= len(domain_belief) - f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) - logger.info('Domain F1 Score: %f' % f1.item()) - - for dom in domain_belief: - p = domain_belief[dom] - p = p.unsqueeze(-1) - p = torch.cat((1 - p, p), -1) - domain_belief[dom] = p - jg = jg_ece(domain_belief, domain_labels, 10) - logger.info('Domain Joint Goal ECE: %f' % jg) - - binary_states = {} - for slot, p in domain_belief.items(): - shp = p.shape - p = p.reshape(-1, p.size(-1)) - p_ = torch.ones(p.shape).to(p.device) * 1e-8 - p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 - binary_states[slot] = p_.reshape(shp) - jg = jg_ece(binary_states, domain_labels, 10) - logger.info('Domain Joint Goal Binary ECE: %f' % jg) - - tp = ((greeting_belief.argmax(-1) > 0) * - (greeting_labels > 0)).reshape(-1).float().sum() - fp = ((greeting_belief.argmax(-1) > 0) * - (greeting_labels == 0)).reshape(-1).float().sum() - fn = ((greeting_belief.argmax(-1) == 0) * - (greeting_labels > 0)).reshape(-1).float().sum() - f1 = tp / (tp + 0.5 * (fp + fn)) - logger.info('Greeting F1 Score: %f' % f1.item()) - - err = ece(greeting_belief.reshape(-1, greeting_belief.size(-1)), - greeting_labels.reshape(-1), 10) - logger.info('Greetings ECE: %f' % err) - - greeting_belief = greeting_belief.reshape(-1, greeting_belief.size(-1)) - binary_states = torch.ones(greeting_belief.shape).to( - greeting_belief.device) * 1e-8 - binary_states[range(greeting_belief.size(0)), - greeting_belief.argmax(-1)] = 1.0 - 1e-8 - err = ece(binary_states, greeting_labels.reshape(-1), 10) - logger.info('Greetings Binary ECE: %f' % err) - - for slot in request_belief: - p = request_belief[slot].unsqueeze(-1) - request_belief[slot] = torch.cat((1 - p, p), -1) - - l2 = l2_acc(request_belief, request_labels, remove_belief=False) - logger.info(f'Model L2 Norm Request Accuracy: {l2}') - l2 = l2_acc(request_belief, request_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}') - - for slot in domain_belief: - p = domain_belief[slot].unsqueeze(-1) - domain_belief[slot] = torch.cat((1 - p, p), -1) - - l2 = l2_acc(domain_belief, domain_labels, remove_belief=False) - logger.info(f'Model L2 Norm Domain Accuracy: {l2}') - l2 = l2_acc(domain_belief, domain_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}') - - greeting_labels = {'bye': greeting_labels} - greeting_belief = {'bye': greeting_belief} - - l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False) - logger.info(f'Model L2 Norm Greeting Accuracy: {l2}') - l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False) - logger.info(f'Binary Model L2 Norm Greeting Accuracy: {l2}') - - -if __name__ == "__main__": - main() diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py new file mode 100644 index 00000000..2fe351b3 --- /dev/null +++ b/convlab/dst/setsumbt/do/evaluate.py @@ -0,0 +1,296 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run SetSUMBT Calibration""" + +import logging +import os + +import torch +from transformers import (BertModel, BertConfig, BertTokenizer, + RobertaModel, RobertaConfig, RobertaTokenizer) + +from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT +from convlab.dst.setsumbt.dataset import unified_format +from convlab.dst.setsumbt.dataset import ontology as embeddings +from convlab.dst.setsumbt.utils import get_args, update_args +from convlab.dst.setsumbt.modeling import evaluation_utils +from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc +from convlab.dst.setsumbt.modeling import training + + +# Available model +MODELS = { + 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), + 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) +} + + +def main(args=None, config=None): + # Get arguments + if args is None: + args, config = get_args(MODELS) + + if args.model_type in MODELS: + SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] + else: + raise NameError('NotImplemented') + + # Set up output directory + OUTPUT_DIR = args.output_dir + args.output_dir = OUTPUT_DIR + if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): + os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) + + # Set pretrained model path to the trained checkpoint + paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] + if 'pytorch_model.bin' in paths and 'config.json' in paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + else: + paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] + if paths: + paths = paths[0] + args.model_name_or_path = paths + config = ConfigClass.from_pretrained(args.model_name_or_path) + + args = update_args(args, config) + + # Create logger + global logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') + + fh = logging.FileHandler(args.logging_path) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) + + # Get device + if torch.cuda.is_available() and args.n_gpu > 0: + device = torch.device('cuda') + else: + device = torch.device('cpu') + args.n_gpu = 0 + + if args.n_gpu == 0: + args.fp16 = False + + # Set up model training/evaluation + evaluation_utils.set_seed(args) + + # Perform tasks + if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')): + pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) + state_labels = pred['state_labels'] + belief_states = pred['belief_states'] + if 'request_labels' in pred: + request_labels = pred['request_labels'] + request_probs = pred['request_probs'] + active_domain_labels = pred['active_domain_labels'] + active_domain_probs = pred['active_domain_probs'] + general_act_labels = pred['general_act_labels'] + general_act_probs = pred['general_act_probs'] + else: + request_probs = None + del pred + else: + # Get training batch loaders and ontology embeddings + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): + test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if test_dataloader.batch_size != args.test_batch_size: + test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size) + else: + tokenizer = Tokenizer(config.candidate_embedding_model_name) + test_dataloader = unified_format.get_dataloader(args.dataset, 'test', + args.test_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): + test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db')) + else: + encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) + test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, + 'test', args, tokenizer, encoder) + + # Initialise Model + model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) + model = model.to(device) + + training.set_ontology_embeddings(model, test_slots) + + belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader) + state_labels = belief_states[1] + request_probs = belief_states[2] + request_labels = belief_states[3] + active_domain_probs = belief_states[4] + active_domain_labels = belief_states[5] + general_act_probs = belief_states[6] + general_act_labels = belief_states[7] + belief_states = belief_states[0] + out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs, + 'request_labels': request_labels, 'active_domain_probs': active_domain_probs, + 'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs, + 'general_act_labels': general_act_labels} + torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) + + # Calculate calibration metrics + jg = jg_ece(belief_states, state_labels, 10) + logger.info('Joint Goal ECE: %f' % jg) + + jg_acc = 0.0 + padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0 + padding = (padding == len(state_labels)) + padding = padding.reshape(-1) + for slot in belief_states: + p_ = belief_states[slot] + gold = state_labels[slot] + + pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1) + acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad] + acc = torch.tensor(acc).float() + + jg_acc += acc + + n_turns = jg_acc.size(0) + jg_acc = sum((jg_acc / len(belief_states)).int()).float() + + jg_acc /= n_turns + + logger.info(f'Joint Goal Accuracy: {jg_acc}') + + l2 = l2_acc(belief_states, state_labels, remove_belief=False) + logger.info(f'Model L2 Norm Goal Accuracy: {l2}') + l2 = l2_acc(belief_states, state_labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}') + + padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0 + padding = (padding == len(state_labels)) + padding = padding.reshape(-1) + + tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0 + for slot in belief_states: + p_ = belief_states[slot] + gold = state_labels[slot].reshape(-1) + p_ = p_.reshape(-1, p_.size(-1)) + + p_ = p_[~padding].argmax(-1) + gold = gold[~padding] + + tp += (p_ == gold)[gold != 0].int().sum().item() + fp += (p_ != 0)[gold == 0].int().sum().item() + fp += (p_ != gold)[gold != 0].int().sum().item() + fp -= (p_ == 0)[gold != 0].int().sum().item() + fn += (p_ == 0)[gold != 0].int().sum().item() + tn += (p_ == 0)[gold == 0].int().sum().item() + n += p_.size(0) + + acc = (tp + tn) / n + prec = tp / (tp + fp) + rec = tp / (tp + fn) + f1 = 2 * (prec * rec) / (prec + rec) + + logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}") + + if request_probs is not None: + tp, fp, fn = 0.0, 0.0, 0.0 + for slot in request_probs: + p = request_probs[slot] + l = request_labels[slot] + + tp += (p.round().int() * (l == 1)).reshape(-1).float() + fp += (p.round().int() * (l == 0)).reshape(-1).float() + fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() + tp /= len(request_probs) + fp /= len(request_probs) + fn /= len(request_probs) + f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) + logger.info('Request F1 Score: %f' % f1.item()) + + for slot in request_probs: + p = request_probs[slot] + p = p.unsqueeze(-1) + p = torch.cat((1 - p, p), -1) + request_probs[slot] = p + jg = jg_ece(request_probs, request_labels, 10) + logger.info('Request Joint Goal ECE: %f' % jg) + + tp, fp, fn = 0.0, 0.0, 0.0 + for dom in active_domain_probs: + p = active_domain_probs[dom] + l = active_domain_labels[dom] + + tp += (p.round().int() * (l == 1)).reshape(-1).float() + fp += (p.round().int() * (l == 0)).reshape(-1).float() + fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() + tp /= len(active_domain_probs) + fp /= len(active_domain_probs) + fn /= len(active_domain_probs) + f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) + logger.info('Domain F1 Score: %f' % f1.item()) + + for dom in active_domain_probs: + p = active_domain_probs[dom] + p = p.unsqueeze(-1) + p = torch.cat((1 - p, p), -1) + active_domain_probs[dom] = p + jg = jg_ece(active_domain_probs, active_domain_labels, 10) + logger.info('Domain Joint Goal ECE: %f' % jg) + + tp = ((general_act_probs.argmax(-1) > 0) * + (general_act_labels > 0)).reshape(-1).float().sum() + fp = ((general_act_probs.argmax(-1) > 0) * + (general_act_labels == 0)).reshape(-1).float().sum() + fn = ((general_act_probs.argmax(-1) == 0) * + (general_act_labels > 0)).reshape(-1).float().sum() + f1 = tp / (tp + 0.5 * (fp + fn)) + logger.info('General Act F1 Score: %f' % f1.item()) + + err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)), + general_act_labels.reshape(-1), 10) + logger.info('General Act ECE: %f' % err) + + for slot in request_probs: + p = request_probs[slot].unsqueeze(-1) + request_probs[slot] = torch.cat((1 - p, p), -1) + + l2 = l2_acc(request_probs, request_labels, remove_belief=False) + logger.info(f'Model L2 Norm Request Accuracy: {l2}') + l2 = l2_acc(request_probs, request_labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}') + + for slot in active_domain_probs: + p = active_domain_probs[slot].unsqueeze(-1) + active_domain_probs[slot] = torch.cat((1 - p, p), -1) + + l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False) + logger.info(f'Model L2 Norm Domain Accuracy: {l2}') + l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}') + + general_act_labels = {'general': general_act_labels} + general_act_probs = {'general': general_act_probs} + + l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False) + logger.info(f'Model L2 Norm General Act Accuracy: {l2}') + l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False) + logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}') + + +if __name__ == "__main__": + main() diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index 821dca59..276d13f2 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,33 +16,27 @@ """Run SetSUMBT training/eval""" import logging -import random import os from shutil import copy2 as copy +import json +from copy import deepcopy import torch -from torch.nn import DataParallel +import transformers from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer, - AdamW, get_linear_schedule_with_warmup) -from tqdm import tqdm, trange -import numpy as np + RobertaModel, RobertaConfig, RobertaTokenizer) from tensorboardX import SummaryWriter +from tqdm import tqdm -from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT -from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT -from convlab.dst.setsumbt.multiwoz import multiwoz21 +from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT +from convlab.dst.setsumbt.dataset import unified_format from convlab.dst.setsumbt.modeling import training -from convlab.dst.setsumbt.multiwoz import ontology as embeddings +from convlab.dst.setsumbt.dataset import ontology as embeddings from convlab.dst.setsumbt.utils import get_args, update_args -from convlab.dst.setsumbt.modeling import ensemble_utils +from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble -# Datasets -DATASETS = { - 'multiwoz21': multiwoz21 -} - +# Available model MODELS = { 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) @@ -54,12 +48,6 @@ def main(args=None, config=None): if args is None: args, config = get_args(MODELS) - # Select Dataset object - if args.dataset in DATASETS: - Dataset = DATASETS[args.dataset] - else: - raise NameError('NotImplemented') - if args.model_type in MODELS: SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] else: @@ -74,53 +62,19 @@ def main(args=None, config=None): args.output_dir = OUTPUT_DIR # Set pretrained model path to the trained checkpoint - if args.do_train: - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - paths = [os.path.join(args.output_dir, p) - for p in paths if 'checkpoint-' in p] + paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] + if 'pytorch_model.bin' in paths and 'config.json' in paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + else: + paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] if paths: paths = paths[0] args.model_name_or_path = paths config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = os.listdir(args.output_dir) if os.path.exists( - args.output_dir) else [] - paths = [os.path.join(args.output_dir, p) - for p in paths if 'checkpoint-' in p] - if paths: - paths = paths[0] - args.model_name_or_path = paths - config = ConfigClass.from_pretrained(args.model_name_or_path) args = update_args(args, config) - # Set up data directory - DATA_DIR = args.data_dir - Dataset.set_datadir(DATA_DIR) - embeddings.set_datadir(DATA_DIR) - - # If use shrinked domains, remove bus and hospital domains from the training data and model ontology - if args.shrink_active_domains and args.dataset == 'multiwoz21': - Dataset.set_active_domains( - ['attraction', 'hotel', 'restaurant', 'taxi', 'train']) - - # Download and preprocess - Dataset.create_examples( - args.max_turn_len, args.predict_actions, args.force_processing) - # Create TensorboardX writer tb_writer = SummaryWriter(logdir=args.tensorboard_path) @@ -129,19 +83,12 @@ def main(args=None, config=None): logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') - if 'stream' not in args.logging_path: - fh = logging.FileHandler(args.logging_path) - fh.setLevel(logging.INFO) - fh.setFormatter(formatter) - logger.addHandler(fh) - else: - ch = logging.StreamHandler() - ch.setLevel(level=logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) + fh = logging.FileHandler(args.logging_path) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) # Get device if torch.cuda.is_available() and args.n_gpu > 0: @@ -154,14 +101,12 @@ def main(args=None, config=None): args.fp16 = False # Initialise Model - model = SetSumbtModel.from_pretrained( - args.model_name_or_path, config=config) + transformers.utils.logging.set_verbosity_info() + model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) model = model.to(device) # Create Tokenizer and embedding model for Data Loaders and ontology - encoder = model.roberta if args.model_type == 'roberta' else None - encoder = model.bert if args.model_type == 'bert' else encoder - + encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config) # Set up model training/evaluation @@ -169,88 +114,107 @@ def main(args=None, config=None): training.set_seed(args) embeddings.set_seed(args) + transformers.utils.logging.set_verbosity_error() if args.ensemble_size > 1: - ensemble_utils.set_logger(logger, tb_writer) - ensemble.set_seed(args) - logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size, - args.data_sampling_size)) - dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset) + # Build all dataloaders + train_dataloader = unified_format.get_dataloader(args.dataset, + 'train', + args.train_batch_size, + tokenizer, + args.max_dialogue_len, + args.max_turn_len, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + dev_dataloader = unified_format.get_dataloader(args.dataset, + 'validation', + args.dev_batch_size, + tokenizer, + args.max_dialogue_len, + args.max_turn_len, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + test_dataloader = unified_format.get_dataloader(args.dataset, + 'test', + args.test_batch_size, + tokenizer, + args.max_dialogue_len, + args.max_turn_len, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + + embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder) + embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder) + embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder) + + setup_ensemble(OUTPUT_DIR, args.ensemble_size) + + logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.') + dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size) + for _ in tqdm(range(args.ensemble_size))] logger.info('Dataloaders built.') + for i, loader in enumerate(dataloaders): - path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i) + path = os.path.join(OUTPUT_DIR, 'ens-%i' % i) if not os.path.exists(path): os.mkdir(path) - path = os.path.join(path, 'train.dataloader') + path = os.path.join(path, 'dataloaders', 'train.dataloader') torch.save(loader, path) logger.info('Dataloaders saved.') - train_slots = embeddings.get_slot_candidate_embeddings( - 'train', args, tokenizer, encoder) - dev_slots = embeddings.get_slot_candidate_embeddings( - 'dev', args, tokenizer, encoder) - test_slots = embeddings.get_slot_candidate_embeddings( - 'test', args, tokenizer, encoder) - - train_dataloader = Dataset.get_dataloader( - 'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len) - torch.save(dev_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - dev_dataloader = Dataset.get_dataloader( - 'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len) - torch.save(dev_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - test_dataloader = Dataset.get_dataloader( - 'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len) - torch.save(test_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - # Do not perform standard training after ensemble setup is created return 0 # Perform tasks # TRAINING if args.do_train: - # Get training batch loaders and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')): - train_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'train.db')) - else: - train_slots = embeddings.get_slot_candidate_embeddings( - 'train', args, tokenizer, encoder) - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): - dev_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'dev.db')) - else: - dev_slots = embeddings.get_slot_candidate_embeddings( - 'dev', args, tokenizer, encoder) - - exists = False if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')): - train_dataloader = torch.load(os.path.join( - OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - if train_dataloader.batch_size == args.train_batch_size: - exists = True - if not exists: + train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + if train_dataloader.batch_size != args.train_batch_size: + train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size) + else: if args.data_sampling_size <= 0: args.data_sampling_size = None - train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len, resampled_size=args.data_sampling_size) - torch.save(train_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + train_dataloader = unified_format.get_dataloader(args.dataset, + 'train', + args.train_batch_size, + tokenizer, + args.max_dialogue_len, + config.max_turn_len, + resampled_size=args.data_sampling_size, + train_ratio=args.dataset_train_ratio, + seed=args.seed) + torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + + # Get training batch loaders and ontology embeddings + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')): + train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db')) + else: + train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, + 'train', args, tokenizer, encoder) # Get development set batch loaders= and ontology embeddings if args.do_eval: - exists = False if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): - dev_dataloader = torch.load(os.path.join( - OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - if dev_dataloader.batch_size == args.dev_batch_size: - exists = True - if not exists: - dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(dev_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if dev_dataloader.batch_size != args.dev_batch_size: + dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size) + else: + dev_dataloader = unified_format.get_dataloader(args.dataset, + 'validation', + args.dev_batch_size, + tokenizer, + args.max_dialogue_len, + config.max_turn_len) + torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): + dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db')) + else: + dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, + 'dev', args, tokenizer, encoder) else: dev_dataloader = None dev_slots = None @@ -259,94 +223,80 @@ def main(args=None, config=None): training.set_ontology_embeddings(model, train_slots) # TRAINING !!!!!!!!!!!!!!!!!! - training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots, - embeddings=embeddings, tokenizer=tokenizer) + training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots) # Copy final best model to the output dir checkpoints = os.listdir(OUTPUT_DIR) checkpoints = [p for p in checkpoints if 'checkpoint' in p] checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints]) - best_checkpoint = checkpoints[-1] - best_checkpoint = os.path.join( - OUTPUT_DIR, f'checkpoint-{best_checkpoint}') - copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), - os.path.join(OUTPUT_DIR, 'pytorch_model.bin')) - copy(os.path.join(best_checkpoint, 'config.json'), - os.path.join(OUTPUT_DIR, 'config.json')) + best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}') + copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin')) + copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json')) # Load best model for evaluation - model = SumbtModel.from_pretrained(OUTPUT_DIR) + model = SetSumbtModel.from_pretrained(OUTPUT_DIR) model = model.to(device) # Evaluation on the development set if args.do_eval: - # Get development set batch loaders= and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): - dev_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'dev.db')) + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): + dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if dev_dataloader.batch_size != args.dev_batch_size: + dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size) else: - dev_slots = embeddings.get_slot_candidate_embeddings( - 'dev', args, tokenizer, encoder) + dev_dataloader = unified_format.get_dataloader(args.dataset, + 'validation', + args.dev_batch_size, + tokenizer, + args.max_dialogue_len, + config.max_turn_len) + torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - exists = False - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): - dev_dataloader = torch.load(os.path.join( - OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - if dev_dataloader.batch_size == args.dev_batch_size: - exists = True - if not exists: - dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(dev_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): + dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db')) + else: + dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, + 'dev', args, tokenizer, encoder) # Load model ontology training.set_ontology_embeddings(model, dev_slots) # EVALUATION - jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate( - args, model, device, dev_dataloader) - if req_f1: - logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f' - % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) - else: - logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f' - % (loss, jg_acc, sl_acc)) + jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader) + training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) # Evaluation on the test set if args.do_test: - # Get test set batch loaders= and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join( - OUTPUT_DIR, 'database', 'test.db')) + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): + test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if test_dataloader.batch_size != args.test_batch_size: + test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size) else: - test_slots = embeddings.get_slot_candidate_embeddings( - 'test', args, tokenizer, encoder) + test_dataloader = unified_format.get_dataloader(args.dataset, 'test', + args.test_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - exists = False - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size == args.test_batch_size: - exists = True - if not exists: - test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join( - OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): + test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db')) + else: + test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, + 'test', args, tokenizer, encoder) # Load model ontology training.set_ontology_embeddings(model, test_slots) # TESTING - jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate( - args, model, device, test_dataloader) - if req_f1: - logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f' - % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) - else: - logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f' - % (loss, jg_acc, sl_acc)) + jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader, + return_eval_output=True) + + if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): + os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) + writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w') + json.dump(output, writer) + writer.close() + + training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) tb_writer.close() diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py new file mode 100644 index 00000000..475f7646 --- /dev/null +++ b/convlab/dst/setsumbt/loss/__init__.py @@ -0,0 +1,4 @@ +from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss +from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss +from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss +from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss diff --git a/convlab/dst/setsumbt/loss/bayesian.py b/convlab/dst/setsumbt/loss/bayesian.py deleted file mode 100644 index e52d8d07..00000000 --- a/convlab/dst/setsumbt/loss/bayesian.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Bayesian Matching Activation and Loss Functions""" - -import torch -from torch import digamma, lgamma -from torch.nn import Module - - -# Inverse Linear activation function -def invlinear(x): - z = (1.0 / (1.0 - x)) * (x < 0) - z += (1.0 + x) * (x >= 0) - return z - -# Exponential activation function -def exponential(x): - return torch.exp(x) - - -# Dirichlet activation function for the model -def dirichlet(a): - p = exponential(a) - repeat_dim = (1,)*(len(p.shape)-1) + (p.size(-1),) - p = p / p.sum(-1).unsqueeze(-1).repeat(repeat_dim) - return p - - -# Pytorch BayesianMatchingLoss nn.Module -class BayesianMatchingLoss(Module): - - def __init__(self, lamb=0.01, ignore_index=-1): - super(BayesianMatchingLoss, self).__init__() - - self.lamb = lamb - self.ignore_index = ignore_index - - def forward(self, alpha, labels, prior=None): - # Assert input sizes - assert alpha.dim() == 2 # Observations, predictive distribution - assert labels.dim() == 1 # Label for each observation - assert labels.size(0) == alpha.size(0) # Equal number of observation - - # Confirm predictive distribution dimension - if labels.max() <= alpha.size(-1): - dimension = alpha.size(-1) - else: - raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1))) - - # Remove observations with no labels - if prior is not None: - prior = prior[labels != self.ignore_index] - alpha = exponential(alpha[labels != self.ignore_index]) - labels = labels[labels != self.ignore_index] - - # Initialise and reshape prior parameters - if prior is None: - prior = torch.ones(dimension) - prior = prior.to(alpha.device) - - # KL divergence term - lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1) - e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1))) - e = ((alpha - prior) * e).sum(-1) - kl = lb + e - kl *= self.lamb - del lb, e, prior - - # Expected log likelihood - expected_likelihood = digamma(alpha[range(labels.size(0)), labels]) - digamma(alpha.sum(1)) - del alpha, labels - - # Apply ELBO loss and mean reduction - loss = (kl - expected_likelihood).mean() - del kl, expected_likelihood - - return loss - - -# Pytorch BayesianMatchingLoss nn.Module -class BinaryBayesianMatchingLoss(Module): - - def __init__(self, lamb=0.01, ignore_index=-1): - super(BinaryBayesianMatchingLoss, self).__init__() - - self.lamb = lamb - self.ignore_index = ignore_index - - def forward(self, alpha, labels, prior=None): - # Assert input sizes - assert alpha.dim() == 1 # Observations, predictive distribution - assert labels.dim() == 1 # Label for each observation - assert labels.size(0) == alpha.size(0) # Equal number of observation - - # Confirm predictive distribution dimension - if labels.max() <= 2: - dimension = 2 - else: - raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1))) - - # Remove observations with no labels - if prior is not None: - prior = prior[labels != self.ignore_index] - alpha = alpha[labels != self.ignore_index] - alpha_sum = 1 + (1 / self.lamb) - alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1) - alpha = torch.cat((alpha_sum - alpha, alpha), 1) - labels = labels[labels != self.ignore_index] - - # Initialise and reshape prior parameters - if prior is None: - prior = torch.ones(dimension) - prior = prior.to(alpha.device) - - # KL divergence term - lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1) - e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1))) - e = ((alpha - prior) * e).sum(-1) - kl = lb + e - kl *= self.lamb - del lb, e, prior - - # Expected log likelihood - expected_likelihood = digamma(alpha[range(labels.size(0)), labels.long()]) - digamma(alpha.sum(1)) - del alpha, labels - - # Apply ELBO loss and mean reduction - loss = (kl - expected_likelihood).mean() - del kl, expected_likelihood - - return loss diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/loss/bayesian_matching.py new file mode 100644 index 00000000..3e91444d --- /dev/null +++ b/convlab/dst/setsumbt/loss/bayesian_matching.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bayesian Matching Activation and Loss Functions (see https://arxiv.org/pdf/2002.07965.pdf for details)""" + +import torch +from torch import digamma, lgamma +from torch.nn import Module + + +class BayesianMatchingLoss(Module): + """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" + + def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: + """ + Args: + lamb (float): Weighting factor for the KL Divergence component + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + """ + super(BayesianMatchingLoss, self).__init__() + + self.lamb = lamb + self.ignore_index = ignore_index + + def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: + """ + Args: + inputs (Tensor): Predictive distribution + labels (Tensor): Label indices + prior (Tensor): Prior distribution over label classes + + Returns: + loss (Tensor): Loss value + """ + # Assert input sizes + assert inputs.dim() == 2 # Observations, predictive distribution + assert labels.dim() == 1 # Label for each observation + assert labels.size(0) == inputs.size(0) # Equal number of observation + + # Confirm predictive distribution dimension + if labels.max() <= inputs.size(-1): + dimension = inputs.size(-1) + else: + raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.') + + # Remove observations to be ignored in loss calculation + if prior is not None: + prior = prior[labels != self.ignore_index] + inputs = torch.exp(inputs[labels != self.ignore_index]) + labels = labels[labels != self.ignore_index] + + # Initialise and reshape prior parameters + if prior is None: + prior = torch.ones(dimension).to(inputs.device) + prior = prior.to(inputs.device) + + # KL divergence term (divergence of predictive distribution from prior over label classes - regularisation term) + log_gamma_term = lgamma(inputs.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(inputs)).sum(-1) + div_term = digamma(inputs) - digamma(inputs.sum(-1)).unsqueeze(-1).repeat((1, inputs.size(-1))) + div_term = ((inputs - prior) * div_term).sum(-1) + kl_term = log_gamma_term + div_term + kl_term *= self.lamb + del log_gamma_term, div_term, prior + + # Expected log likelihood + expected_likelihood = digamma(inputs[range(labels.size(0)), labels]) - digamma(inputs.sum(-1)) + del inputs, labels + + # Apply ELBO loss and mean reduction + loss = (kl_term - expected_likelihood).mean() + del kl_term, expected_likelihood + + return loss + + +class BinaryBayesianMatchingLoss(BayesianMatchingLoss): + """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" + + def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: + """ + Args: + lamb (float): Weighting factor for the KL Divergence component + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + """ + super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index) + + def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: + """ + Args: + inputs (Tensor): Predictive distribution + labels (Tensor): Label indices + prior (Tensor): Prior distribution over label classes + + Returns: + loss (Tensor): Loss value + """ + + # Create 2D input dirichlet distribution + input_sum = 1 + (1 / self.lamb) + inputs = (torch.sigmoid(inputs) * input_sum).reshape(-1, 1) + inputs = torch.cat((input_sum - inputs, inputs), 1) + + return super().forward(inputs, labels, prior=prior) diff --git a/convlab/dst/setsumbt/loss/distillation.py b/convlab/dst/setsumbt/loss/distillation.py deleted file mode 100644 index 3cf13f10..00000000 --- a/convlab/dst/setsumbt/loss/distillation.py +++ /dev/null @@ -1,201 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Bayesian Matching Activation and Loss Functions""" - -import torch -from torch import lgamma, log -from torch.nn import Module -from torch.nn.functional import kl_div - -from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss - - -# Pytorch BayesianMatchingLoss nn.Module -class DistillationKL(Module): - - def __init__(self, lamb=1e-4, ignore_index=-1): - super(DistillationKL, self).__init__() - - self.lamb = lamb - self.ignore_index = ignore_index - - def forward(self, alpha, labels, temp=1.0): - # Assert input sizes - assert alpha.dim() == 2 # Observations, predictive distribution - assert labels.dim() == 2 # Label for each observation - assert labels.size(0) == alpha.size(0) # Equal number of observation - - # Confirm predictive distribution dimension - if labels.size(-1) == alpha.size(-1): - dimension = alpha.size(-1) - else: - raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) - - alpha = torch.log(torch.softmax(alpha / temp, -1)) - ids = torch.where(labels[:, 0] != self.ignore_index)[0] - alpha = alpha[ids] - labels = labels[ids] - - labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))) - - kl = kl_div(alpha, labels, reduction='none').sum(-1).mean() - return kl - - -# Pytorch BayesianMatchingLoss nn.Module -class BinaryDistillationKL(Module): - - def __init__(self, lamb=1e-4, ignore_index=-1): - super(BinaryDistillationKL, self).__init__() - - self.lamb = lamb - self.ignore_index = ignore_index - - def forward(self, alpha, labels, temp=0.0): - # Assert input sizes - assert alpha.dim() == 1 # Observations, predictive distribution - assert labels.dim() == 1 # Label for each observation - assert labels.size(0) == alpha.size(0) # Equal number of observation - - # Confirm predictive distribution dimension - # if labels.size(-1) == alpha.size(-1): - # dimension = alpha.size(-1) - # else: - # raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) - - alpha = torch.sigmoid(alpha / temp).unsqueeze(-1) - ids = torch.where(labels != self.ignore_index)[0] - alpha = alpha[ids] - labels = labels[ids] - - alpha = torch.log(torch.cat((1 - alpha, alpha), 1)) - - labels = labels.unsqueeze(-1) - labels = torch.cat((1 - labels, labels), -1) - labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))) - - kl = kl_div(alpha, labels, reduction='none').sum(-1).mean() - return kl - - -# def smart_sort(x, permutation): -# assert x.dim() == permutation.dim() -# if x.dim() == 3: -# d1, d2, d3 = x.size() -# ret = x[torch.arange(d1).unsqueeze(-1).unsqueeze(-1).repeat((1, d2, d3)).flatten(), -# torch.arange(d2).unsqueeze(0).unsqueeze(-1).repeat((d1, 1, d3)).flatten(), -# permutation.flatten()].view(d1, d2, d3) -# return ret -# elif x.dim() == 2: -# d1, d2 = x.size() -# ret = x[torch.arange(d1).unsqueeze(-1).repeat((1, d2)).flatten(), -# permutation.flatten()].view(d1, d2) -# return ret - - -# # Pytorch BayesianMatchingLoss nn.Module -# class DistillationNLL(Module): - -# def __init__(self, lamb=1e-4, ignore_index=-1): -# super(DistillationNLL, self).__init__() - -# self.lamb = lamb -# self.ignore_index = ignore_index -# self.loss_add = BayesianMatchingLoss(lamb=0.001) - -# def forward(self, alpha, labels, temp=1.0): -# # Assert input sizes -# assert alpha.dim() == 2 # Observations, predictive distribution -# assert labels.dim() == 3 # Label for each observation -# assert labels.size(0) == alpha.size(0) # Equal number of observation - -# # Confirm predictive distribution dimension -# if labels.size(-1) == alpha.size(-1): -# dimension = alpha.size(-1) -# else: -# raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) - -# alpha = torch.exp(alpha / temp) -# ids = torch.where(labels[:, 0, 0] != self.ignore_index)[0] -# alpha = alpha[ids] -# labels = labels[ids] - -# best_labels = labels.mean(-2).argmax(-1) -# loss2 = self.loss_add(alpha, best_labels) - -# topn = labels.mean(-2).argsort(-1, descending=True) -# n = 10 -# alpha = smart_sort(alpha, topn)[:, :n] -# labels = smart_sort(labels, topn.unsqueeze(-2).repeat((1, labels.size(-2), 1))) -# labels = labels[:, :, :n] -# labels = labels / labels.sum(-1).unsqueeze(-1).repeat((1, 1, labels.size(-1))) - -# labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))) - -# loss = (alpha - 1) * labels.mean(-2) -# # loss = (alpha - 1) * labels -# loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) -# loss = -1.0 * loss.mean() -# # loss = -1.0 * loss.mean() / alpha.size(-1) - -# return loss - - -# # Pytorch BayesianMatchingLoss nn.Module -# class BinaryDistillationNLL(Module): - -# def __init__(self, lamb=1e-4, ignore_index=-1): -# super(BinaryDistillationNLL, self).__init__() - -# self.lamb = lamb -# self.ignore_index = ignore_index - -# def forward(self, alpha, labels, temp=0.0): -# # Assert input sizes -# assert alpha.dim() == 1 # Observations, predictive distribution -# assert labels.dim() == 2 # Label for each observation -# assert labels.size(0) == alpha.size(0) # Equal number of observation - -# # Confirm predictive distribution dimension -# # if labels.size(-1) == alpha.size(-1): -# # dimension = alpha.size(-1) -# # else: -# # raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) - -# # Remove observations with no labels -# ids = torch.where(labels[:, 0] != self.ignore_index)[0] -# # alpha_sum = 1 + (1 / self.lamb) -# alpha_sum = 10.0 -# alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1) -# alpha = alpha[ids] -# labels = labels[ids] - -# if temp != 1.0: -# alpha = torch.log(alpha + 1e-4) -# alpha = torch.exp(alpha / temp) - -# alpha = torch.cat((alpha_sum - alpha, alpha), 1) - -# labels = labels.unsqueeze(-1) -# labels = torch.cat((1 - labels, labels), -1) -# # labels[labels[:, 0, 0] == self.ignore_index] = 1 -# labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))) - -# loss = (alpha - 1) * labels.mean(-2) -# loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) -# loss = -1.0 * loss.mean() - -# return loss diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/loss/endd_loss.py index d84c3f72..9bd794bf 100644 --- a/convlab/dst/setsumbt/loss/endd_loss.py +++ b/convlab/dst/setsumbt/loss/endd_loss.py @@ -1,30 +1,46 @@ import torch +from torch.nn import Module +from torch.nn.functional import kl_div EPS = torch.finfo(torch.float32).eps + @torch.no_grad() -def compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs): - mkl = torch.nn.functional.kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_probs), - reduction='none').sum(-1).mean(1) - return mkl +def compute_mkl(ensemble_mean_probs: torch.Tensor, ensemble_logprobs: torch.Tensor) -> torch.Tensor: + """ + Computing MKL in ensemble. + + Args: + ensemble_mean_probs (Tensor): Marginal predictive distribution of the ensemble + ensemble_logprobs (Tensor): Log predictive distributions of individual ensemble members + + Returns: + mkl (Tensor): MKL + """ + mkl = kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_logprobs),reduction='none') + return mkl.sum(-1).mean(1) + @torch.no_grad() -def compute_ensemble_stats(ensemble_logits): - # ensemble_probs = torch.softmax(ensemble_logits, dim=-1) - # ensemble_mean_probs = ensemble_probs.mean(dim=1) - # ensemble_logprobs = torch.log_softmax(ensemble_logits, dim=-1) - ensemble_probs = ensemble_logits +def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict: + """ + Compute a range of ensemble uncertainty measures + + Args: + ensemble_probs (Tensor): Predictive distributions of the ensemble members + + Returns: + stats (dict): Dictionary of ensemble uncertainty measures + """ ensemble_mean_probs = ensemble_probs.mean(dim=1) - num_classes = ensemble_logits.size(-1) - ensemble_logprobs = torch.log(ensemble_logits + (1e-4 / num_classes)) + num_classes = ensemble_probs.size(-1) + ensemble_logprobs = torch.log(ensemble_probs + (1e-4 / num_classes)) entropy_of_expected = torch.distributions.Categorical(probs=ensemble_mean_probs).entropy() expected_entropy = torch.distributions.Categorical(probs=ensemble_probs).entropy().mean(dim=1) mutual_info = entropy_of_expected - expected_entropy - mkl = compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs) - - # num_classes = ensemble_logits.size(-1) + mkl = compute_mkl(ensemble_mean_probs, ensemble_logprobs) ensemble_precision = (num_classes - 1) / (2 * mkl.unsqueeze(1) + EPS) @@ -39,108 +55,226 @@ def compute_ensemble_stats(ensemble_logits): } return stats -def entropy(probs, dim: int = -1): + +def entropy(probs: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Compute entropy in a predictive distribution + + Args: + probs (Tensor): Predictive distributions + dim (int): Dimension representing the predictive probabilities for a single prediction + + Returns: + entropy (Tensor): Entropy + """ return -(probs * (probs + EPS).log()).sum(dim=dim) -def compute_dirichlet_uncertainties(dirichlet_params, precisions, expected_dirichlet): +def compute_dirichlet_uncertainties(dirichlet_params: torch.Tensor, + precisions: torch.Tensor, + expected_dirichlet: torch.Tensor) -> tuple: """ Function which computes measures of uncertainty for Dirichlet model. - :param dirichlet_params: Tensor of size [batch_size, n_classes] of Dirichlet concentration parameters. - :param precisions: Tensor of size [batch_size, 1] of Dirichlet Precisions - :param expected_dirichlet: Tensor of size [batch_size, n_classes] of probablities of expected categorical under Dirichlet. - :return: Tensors of token level uncertainties of size [batch_size] + + Args: + dirichlet_params (Tensor): Dirichlet concentration parameters. + precisions (Tensor): Dirichlet Precisions + expected_dirichlet (Tensor): Probabities of expected categorical under Dirichlet. + + Returns: + stats (tuple): Token level uncertainties """ batch_size, n_classes = dirichlet_params.size() entropy_of_expected = entropy(expected_dirichlet) - expected_entropy = ( - -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))).sum(dim=-1) + expected_entropy = -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1)) + expected_entropy = expected_entropy.sum(dim=-1) - mutual_information = -((expected_dirichlet + EPS) * ( - torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) + torch.digamma( - precisions + 1 + EPS))).sum(dim=-1) - # assert torch.allclose(mutual_information, entropy_of_expected - expected_entropy, atol=1e-4, rtol=0) + mutual_information = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) + mutual_information += torch.digamma(precisions + 1 + EPS) + mutual_information *= -(expected_dirichlet + EPS) + mutual_information = mutual_information.sum(dim=-1) epkl = (n_classes - 1) / precisions.squeeze(-1) - mkl = (expected_dirichlet * ( - torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) + torch.digamma( - precisions + EPS))).sum(dim=-1) + mkl = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) + mkl += torch.digamma(precisions + EPS) + mkl *= expected_dirichlet + mkl = mkl.sum(dim=-1) + + stats = (entropy_of_expected.clamp(min=0), expected_entropy.clamp(min=0), mutual_information.clamp(min=0)) + stats += (epkl.clamp(min=0), mkl.clamp(min=0)) + + return stats + + +def get_dirichlet_parameters(logits: torch.Tensor, + parametrization, + add_to_alphas: float = 0, + dtype=torch.double) -> tuple: + """ + Get dirichlet parameters from model logits - return entropy_of_expected.clamp(min=0), \ - expected_entropy.clamp(min=0), \ - mutual_information.clamp(min=0), \ - epkl.clamp(min=0), \ - mkl.clamp(min=0) + Args: + logits (Tensor): Model logits + parametrization (function): Mapping from logits to concentration parameters + add_to_alphas (float): Addition constant for stability + dtype (data type): Data type of the parameters -def get_dirichlet_parameters(logits, parametrization, add_to_alphas=0, dtype=torch.double): + Return: + params (tuple): Concentration and precision parameters of the model Dirichlet + """ max_val = torch.finfo(dtype).max / logits.size(-1) - 1 alphas = torch.clip(parametrization(logits.to(dtype=dtype)) + add_to_alphas, max=max_val) precision = torch.sum(alphas, dim=-1, dtype=dtype) return alphas, precision -def logits_to_mutual_info(logits): - alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0) +def logits_to_mutual_info(logits: torch.Tensor) -> torch.Tensor: + """ + Map modfel logits to mutual information of model Dirichlet - unsqueezed_precision = precision.unsqueeze(1) - normalized_probs = alphas / unsqueezed_precision + Args: + logits (Tensor): Model logits - entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas, - unsqueezed_precision, - normalized_probs) - - # Max entropy is log(K) for K classes. Hence relative MI is calculated as MI/log(K) - # mutual_information /= torch.log(torch.tensor(logits.size(-1))) - - return mutual_information + Returns: + mutual_information (Tensor): Mutual information of the model Dirichlet + """ + alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0) + normalized_probs = alphas / precision.unsqueeze(1) -def rkl_dirichlet_mediator_loss(logits, ensemble_stats, model_offset, target_offset, parametrization=torch.exp): - turns = torch.where(ensemble_stats[:, 0, 0] != -1)[0] - logits = logits[turns] - ensemble_stats = ensemble_stats[turns] + _, _, mutual_information, _, _ = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs) - ensemble_stats = compute_ensemble_stats(ensemble_stats) - - alphas, precision = get_dirichlet_parameters(logits, parametrization, model_offset) - - unsqueezed_precision = precision.unsqueeze(1) - normalized_probs = alphas / unsqueezed_precision - - entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas, - unsqueezed_precision, - normalized_probs) - - stats = { - 'alpha_min': alphas.min(), - 'alpha_mean': alphas.mean(), - 'precision': precision, - 'entropy_of_expected': entropy_of_expected, - 'mutual_info': mutual_information, - 'mkl': mkl, - } - - num_classes = alphas.size(-1) - - ensemble_precision = ensemble_stats['precision'] - - ensemble_precision += target_offset * num_classes - ensemble_probs = ensemble_stats['mean_probs'] - - expected_KL_term = -1.0 * torch.sum(ensemble_probs * (torch.digamma(alphas + EPS) - - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1) - assert torch.isfinite(expected_KL_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype) - - differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) \ - - torch.sum( - (alphas - 1) * (torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1) - assert torch.isfinite(differential_negentropy_term).all() - - cost = expected_KL_term - differential_negentropy_term / ensemble_precision.squeeze(-1) + return mutual_information - assert torch.isfinite(cost).all() - return torch.mean(cost), stats, ensemble_stats +class RKLDirichletMediatorLoss(Module): + """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)""" + + def __init__(self, + model_offset: float = 1.0, + target_offset: float = 1, + ignore_index: int = -1, + parameterization=torch.exp): + """ + Args: + model_offset (float): Offset of model Dirichlet for stability + target_offset (float): Offset of target Dirichlet for stability + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + parameterization (function): Mapping from logits to concentration parameters + """ + super(RKLDirichletMediatorLoss, self).__init__() + + self.model_offset = model_offset + self.target_offset = target_offset + self.ignore_index = ignore_index + self.parameterization = parameterization + + def logits_to_mutual_info(self, logits: torch.Tensor) -> torch.Tensor: + """ + Map modfel logits to mutual information of model Dirichlet + + Args: + logits (Tensor): Model logits + + Returns: + mutual_information (Tensor): Mutual information of the model Dirichlet + """ + return logits_to_mutual_info(logits) + + def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Args: + logits (Tensor): Model logits + targets (Tensor): Ensemble predictive distributions + + Returns: + loss (Tensor): RKL dirichlet mediator loss value + """ + + # Remove padding + turns = torch.where(targets[:, 0, 0] != self.ignore_index)[0] + logits = logits[turns] + targets = targets[turns] + + ensemble_stats = compute_ensemble_stats(targets) + + alphas, precision = get_dirichlet_parameters(logits, self.parameterization, self.model_offset) + + normalized_probs = alphas / precision.unsqueeze(1) + + stats = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs) + entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = stats + + stats = { + 'alpha_min': alphas.min(), + 'alpha_mean': alphas.mean(), + 'precision': precision, + 'entropy_of_expected': entropy_of_expected, + 'mutual_info': mutual_information, + 'mkl': mkl, + } + + num_classes = alphas.size(-1) + + ensemble_precision = ensemble_stats['precision'] + + ensemble_precision += self.target_offset * num_classes + ensemble_probs = ensemble_stats['mean_probs'] + + expected_kl_term = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS) + expected_kl_term = -1.0 * torch.sum(ensemble_probs * expected_kl_term, dim=-1) + assert torch.isfinite(expected_kl_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype) + + differential_negentropy_term_ = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS) + differential_negentropy_term_ *= alphas - 1.0 + differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) + differential_negentropy_term -= torch.sum(differential_negentropy_term_, dim=-1) + assert torch.isfinite(differential_negentropy_term).all() + + loss = expected_kl_term - differential_negentropy_term / ensemble_precision.squeeze(-1) + assert torch.isfinite(loss).all() + + return torch.mean(loss), stats, ensemble_stats + + +class BinaryRKLDirichletMediatorLoss(RKLDirichletMediatorLoss): + """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)""" + + def __init__(self, + model_offset: float = 1.0, + target_offset: float = 1, + ignore_index: int = -1, + parameterization=torch.exp): + """ + Args: + model_offset (float): Offset of model Dirichlet for stability + target_offset (float): Offset of target Dirichlet for stability + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + parameterization (function): Mapping from logits to concentration parameters + """ + super(BinaryRKLDirichletMediatorLoss, self).__init__(model_offset, target_offset, + ignore_index, parameterization) + + def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Args: + logits (Tensor): Model logits + targets (Tensor): Ensemble predictive distributions + + Returns: + loss (Tensor): RKL dirichlet mediator loss value + """ + # Convert single target probability p to distribution [1-p, p] + targets = targets.reshape(-1, targets.size(-1), 1) + targets = torch.cat([1 - targets, targets], -1) + targets[targets[:, 0, 1] == self.ignore_index] = self.ignore_index + + # Convert input logits into predictive distribution [1-z, z] + logits = torch.sigmoid(logits).unsqueeze(1) + logits = torch.cat((1 - logits, logits), 1) + logits = -1.0 * torch.log((1 / (logits + 1e-8)) - 1) # Inverse sigmoid + + return super().forward(logits, targets) diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/loss/kl_distillation.py new file mode 100644 index 00000000..9aee234a --- /dev/null +++ b/convlab/dst/setsumbt/loss/kl_distillation.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""KL Divergence Ensemble Distillation loss""" + +import torch +from torch.nn import Module +from torch.nn.functional import kl_div + + +class KLDistillationLoss(Module): + """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" + + def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: + """ + Args: + lamb (float): Target smoothing parameter + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + """ + super(KLDistillationLoss, self).__init__() + + self.lamb = lamb + self.ignore_index = ignore_index + + def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: + """ + Args: + inputs (Tensor): Predictive distribution + targets (Tensor): Target distribution (ensemble marginal) + temp (float): Temperature scaling coefficient for predictive distribution + + Returns: + loss (Tensor): Loss value + """ + # Assert input sizes + assert inputs.dim() == 2 # Observations, predictive distribution + assert targets.dim() == 2 # Label for each observation + assert targets.size(0) == inputs.size(0) # Equal number of observation + + # Confirm predictive distribution dimension + if targets.size(-1) != inputs.size(-1): + name_error = f'Target dimension {targets.size(-1)} is not the same as the prediction dimension ' + name_error += f'{inputs.size(-1)}.' + raise NameError(name_error) + + # Remove observations to be ignored in loss calculation + inputs = torch.log(torch.softmax(inputs / temp, -1)) + ids = torch.where(targets[:, 0] != self.ignore_index)[0] + inputs = inputs[ids] + targets = targets[ids] + + # Target smoothing + targets = ((1 - self.lamb) * targets) + (self.lamb / targets.size(-1)) + + return kl_div(inputs, targets, reduction='none').sum(-1).mean() + + +# Pytorch BayesianMatchingLoss nn.Module +class BinaryKLDistillationLoss(KLDistillationLoss): + """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" + + def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: + """ + Args: + lamb (float): Target smoothing parameter + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + """ + super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index) + + def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: + """ + Args: + inputs (Tensor): Predictive distribution + targets (Tensor): Target distribution (ensemble marginal) + temp (float): Temperature scaling coefficient for predictive distribution + + Returns: + loss (Tensor): Loss value + """ + # Assert input sizes + assert inputs.dim() == 1 # Observations, predictive distribution + assert targets.dim() == 1 # Label for each observation + assert targets.size(0) == inputs.size(0) # Equal number of observation + + # Convert input and target to 2D binary distribution for KL divergence computation + inputs = torch.sigmoid(inputs / temp).unsqueeze(-1) + inputs = torch.log(torch.cat((1 - inputs, inputs), 1)) + + targets = targets.unsqueeze(-1) + targets = torch.cat((1 - targets, targets), -1) + + return super().forward(input, targets, temp) diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/loss/labelsmoothing.py index 8fcc60af..61d4b353 100644 --- a/convlab/dst/setsumbt/loss/labelsmoothing.py +++ b/convlab/dst/setsumbt/loss/labelsmoothing.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inhibited Softmax Activation and Loss Functions""" +"""Label smoothing loss function""" import torch @@ -23,66 +23,97 @@ from torch.nn.functional import kl_div class LabelSmoothingLoss(Module): """ - With label smoothing, - KL-divergence between q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. + Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w) + and p_{prob. computed by model}(w). """ - def __init__(self, label_smoothing=0.05, ignore_index=-1): + + def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module: + """ + Args: + label_smoothing (float): Label smoothing constant + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + """ super(LabelSmoothingLoss, self).__init__() assert 0.0 < label_smoothing <= 1.0 self.ignore_index = ignore_index self.label_smoothing = float(label_smoothing) - def forward(self, logits, targets): + def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ - output (FloatTensor): batch_size x n_classes - target (LongTensor): batch_size + Args: + input (Tensor): Predictive distribution + labels (Tensor): Label indices + + Returns: + loss (Tensor): Loss value """ - assert logits.dim() == 2 - assert targets.dim() == 1 - assert self.label_smoothing <= ((logits.size(-1) - 1) / logits.size(-1)) + # Assert input sizes + assert inputs.dim() == 2 + assert labels.dim() == 1 + assert self.label_smoothing <= ((inputs.size(-1) - 1) / inputs.size(-1)) - logits = logits[targets != self.ignore_index] - targets = targets[targets != self.ignore_index] + # Confirm predictive distribution dimension + if labels.max() <= inputs.size(-1): + dimension = inputs.size(-1) + else: + raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.') - logits = torch.log(torch.softmax(logits, -1)) - labels = torch.ones(logits.size()).float().to(logits.device) - labels *= self.label_smoothing / (logits.size(-1) - 1) - labels[range(labels.size(0)), targets] = 1.0 - self.label_smoothing + # Remove observations to be ignored in loss calculation + inputs = inputs[labels != self.ignore_index] + labels = labels[labels != self.ignore_index] - kl = kl_div(logits, labels, reduction='none').sum(-1).mean() - del logits, targets, labels - return kl + if labels.size(0) == 0.0: + return torch.zeros(1).float().to(labels.device).mean() + # Create target distribution + inputs = torch.log(torch.softmax(inputs, -1)) + targets = torch.ones(inputs.size()).float().to(inputs.device) + targets *= self.label_smoothing / (dimension - 1) + targets[range(labels.size(0)), labels] = 1.0 - self.label_smoothing -class BinaryLabelSmoothingLoss(Module): + return kl_div(inputs, targets, reduction='none').sum(-1).mean() + + +class BinaryLabelSmoothingLoss(LabelSmoothingLoss): """ - With label smoothing, - KL-divergence between q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. + Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w) + and p_{prob. computed by model}(w). """ - def __init__(self, label_smoothing=0.05): - super(BinaryLabelSmoothingLoss, self).__init__() - assert 0.0 < label_smoothing <= 1.0 - self.label_smoothing = float(label_smoothing) + def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module: + """ + Args: + label_smoothing (float): Label smoothing constant + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + """ + super(BinaryLabelSmoothingLoss, self).__init__(label_smoothing, ignore_index) - def forward(self, logits, targets): + def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ - output (FloatTensor): batch_size x n_classes - target (LongTensor): batch_size + Args: + input (Tensor): Predictive distribution + labels (Tensor): Label indices + + Returns: + loss (Tensor): Loss value """ - assert logits.dim() == 1 - assert targets.dim() == 1 + # Assert input sizes + assert inputs.dim() == 1 + assert labels.dim() == 1 assert self.label_smoothing <= 0.5 - logits = torch.sigmoid(logits).reshape(-1, 1) - logits = torch.log(torch.cat((1 - logits, logits), 1)) - labels = torch.ones(logits.size()).float().to(logits.device) - labels *= self.label_smoothing - labels[range(labels.size(0)), targets.long()] = 1.0 - self.label_smoothing + # Remove observations to be ignored in loss calculation + inputs = inputs[labels != self.ignore_index] + labels = labels[labels != self.ignore_index] + + if labels.size(0) == 0.0: + return torch.zeros(1).float().to(labels.device).mean() + + inputs = torch.sigmoid(inputs).reshape(-1, 1) + inputs = torch.log(torch.cat((1 - inputs, inputs), 1)) + targets = torch.ones(inputs.size()).float().to(inputs.device) + targets *= self.label_smoothing + targets[range(labels.size(0)), labels.long()] = 1.0 - self.label_smoothing - kl = kl_div(logits, labels, reduction='none').sum(-1).mean() - del logits, targets - return kl + return kl_div(inputs, targets, reduction='none').sum(-1).mean() diff --git a/convlab/dst/setsumbt/loss/ece.py b/convlab/dst/setsumbt/loss/uncertainty_measures.py similarity index 50% rename from convlab/dst/setsumbt/loss/ece.py rename to convlab/dst/setsumbt/loss/uncertainty_measures.py index 034b9aa0..87c89dd3 100644 --- a/convlab/dst/setsumbt/loss/ece.py +++ b/convlab/dst/setsumbt/loss/uncertainty_measures.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,14 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Expected calibration error""" +"""Uncertainty evaluation metrics for dialogue belief tracking""" import torch -def fill_bins(n_bins, logits): - assert logits.dim() == 2 - logits = logits.max(-1)[0] +def fill_bins(n_bins: int, probs: torch.Tensor) -> list: + """ + Function to split observations into bins based on predictive probabilities + + Args: + n_bins (int): Number of bins + probs (Tensor): Predictive probabilities for the observations + + Returns: + bins (list): List of observation ids for each bin + """ + assert probs.dim() == 2 + probs = probs.max(-1)[0] step = 1.0 / n_bins bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) @@ -28,29 +38,49 @@ def fill_bins(n_bins, logits): for b in range(n_bins): lower, upper = bin_ranges[b], bin_ranges[b + 1] if b == 0: - ids = torch.where((logits >= lower) * (logits <= upper))[0] + ids = torch.where((probs >= lower) * (probs <= upper))[0] else: - ids = torch.where((logits > lower) * (logits <= upper))[0] + ids = torch.where((probs > lower) * (probs <= upper))[0] bins.append(ids) return bins -def bin_confidence(bins, logits): - logits = logits.max(-1)[0] +def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor: + """ + Compute the confidence score within each bin + + Args: + bins (list): List of observation ids for each bin + probs (Tensor): Predictive probabilities for the observations + + Returns: + scores (Tensor): Average confidence score within each bin + """ + probs = probs.max(-1)[0] scores = [] for b in bins: if b is not None: - l = logits[b] - scores.append(l.mean()) + scores.append(probs[b].mean()) else: scores.append(-1) scores = torch.tensor(scores) return scores -def bin_accuracy(bins, logits, y_true): - y_pred = logits.argmax(-1) +def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """ + Compute the accuracy score for observations in each bin + + Args: + bins (list): List of observation ids for each bin + probs (Tensor): Predictive probabilities for the observations + y_true (Tensor): Labels for the observations + + Returns: + acc (Tensor): Accuracies for the observations in each bin + """ + y_pred = probs.argmax(-1) acc = [] for b in bins: @@ -68,13 +98,24 @@ def bin_accuracy(bins, logits, y_true): return acc -def ece(logits, y_true, n_bins): - bins = fill_bins(n_bins, logits) +def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float: + """ + Expected calibration error calculation - scores = bin_confidence(bins, logits) - acc = bin_accuracy(bins, logits, y_true) + Args: + probs (Tensor): Predictive probabilities for the observations + y_true (Tensor): Labels for the observations + n_bins (int): Number of bins - n = logits.size(0) + Returns: + ece (float): Expected calibration error + """ + bins = fill_bins(n_bins, probs) + + scores = bin_confidence(bins, probs) + acc = bin_accuracy(bins, probs, y_true) + + n = probs.size(0) bk = torch.tensor([b.size(0) for b in bins]) ece = torch.abs(scores - acc) * bk / n @@ -84,34 +125,30 @@ def ece(logits, y_true, n_bins): return ece -def jg_ece(logits, y_true, n_bins): - y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits} +def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float: + """ + Joint goal expected calibration error calculation + + Args: + belief_state (dict): Belief state probabilities for the dialogue turns + y_true (dict): Labels for the state in dialogue turns + n_bins (int): Number of bins + + Returns: + ece (float): Joint goal expected calibration error + """ + y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()} goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred} goal_acc = sum([goal_acc[slot] for slot in goal_acc]) goal_acc = (goal_acc == len(y_true)).int() - scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits] + # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state + scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()] scores = torch.cat(scores, 0).min(0)[0] - step = 1.0 / n_bins - bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) - bins = [] - for b in range(n_bins): - lower, upper = bin_ranges[b], bin_ranges[b + 1] - if b == 0: - ids = torch.where((scores >= lower) * (scores <= upper))[0] - else: - ids = torch.where((scores > lower) * (scores <= upper))[0] - bins.append(ids) + bins = fill_bins(n_bins, scores.unsqueeze(-1)) - conf = [] - for b in bins: - if b is not None: - l = scores[b] - conf.append(l.mean()) - else: - conf.append(-1) - conf = torch.tensor(conf) + conf = bin_confidence(bins, scores.unsqueeze(-1)) slot = [s for s in y_true][0] acc = [] @@ -127,7 +164,7 @@ def jg_ece(logits, y_true, n_bins): acc.append(-1) acc = torch.tensor(acc) - n = logits[slot].reshape(-1, logits[slot].size(-1)).size(0) + n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0) bk = torch.tensor([b.size(0) for b in bins]) ece = torch.abs(conf - acc) * bk / n @@ -137,12 +174,22 @@ def jg_ece(logits, y_true, n_bins): return ece -def l2_acc(belief_state, labels, remove_belief=False): +def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float: + """ + Compute L2 Error of belief state prediction + + Args: + belief_state (dict): Belief state probabilities for the dialogue turns + labels (dict): Labels for the state in dialogue turns + remove_belief (bool): Convert belief state to dialogue state + + Returns: + err (float): L2 Error of belief state prediction + """ # Get ids used for removing padding turns. padding = labels[list(labels.keys())[0]].reshape(-1) padding = torch.where(padding != -1)[0] - # l2 = [] state = [] labs = [] for slot, bs in belief_state.items(): @@ -163,13 +210,8 @@ def l2_acc(belief_state, labels, remove_belief=False): y = torch.zeros(bs.shape).cuda() y[range(y.size(0)), lab] = 1.0 - # err = torch.sqrt(((y - bs) ** 2).sum(-1)) - # l2.append(err.unsqueeze(-1)) - state.append(bs) labs.append(y) - - # err = torch.cat(l2, -1).max(-1)[0] # Concatenate all slots into a single belief state state = torch.cat(state, -1) diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py index 011a1a77..59f14399 100644 --- a/convlab/dst/setsumbt/modeling/__init__.py +++ b/convlab/dst/setsumbt/modeling/__init__.py @@ -1,3 +1,5 @@ from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT -from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT +from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT + +from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py index 8b402b6b..6762fb38 100644 --- a/convlab/dst/setsumbt/modeling/bert_nbt.py +++ b/convlab/dst/setsumbt/modeling/bert_nbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,11 +16,10 @@ """BERT SetSUMBT""" import torch -import transformers from torch.autograd import Variable from transformers import BertModel, BertPreTrainedModel -from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward +from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead class BertSetSUMBT(BertPreTrainedModel): @@ -35,59 +34,37 @@ class BertSetSUMBT(BertPreTrainedModel): for p in self.bert.parameters(): p.requires_grad = False - _initialise(self, config) - - # Add new slot candidates to the model - def add_slot_candidates(self, slot_candidates): - """slot_candidates is a list of tuples for each slot. - - The tuples contains the slot embedding, informable value embeddings and a request indicator. - - If the informable value embeddings is None the slot is not informable - - If the request indicator is false the slot is not requestable""" - if self.slot_embeddings.size(0) != 0: - embeddings = self.slot_embeddings.detach() - else: - embeddings = torch.zeros(0) - - for slot in slot_candidates: - if slot in self.slot_ids: - index = self.slot_ids[slot] - embeddings[index, :] = slot_candidates[slot][0] - else: - index = embeddings.size(0) - emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) - embeddings = torch.cat((embeddings, emb), 0) - self.slot_ids[slot] = index - setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) - # Add slot to relevant requestable and informable slot lists - if slot_candidates[slot][2]: - self.requestable_slot_ids[slot] = index - if slot_candidates[slot][1] is not None: - self.informable_slot_ids[slot] = index - - domain = slot.split('-', 1)[0] - if domain not in self.domain_ids: - self.domain_ids[domain] = [] - self.domain_ids[domain].append(index) - self.domain_ids[domain] = list(set(self.domain_ids[domain])) - - self.slot_embeddings = Variable(embeddings, requires_grad=False) - - - # Add new value candidates to the model - def add_value_candidates(self, slot, value_candidates, replace=False): - embeddings = getattr(self, slot + '_value_embeddings') - - if embeddings.size(0) == 0 or replace: - embeddings = value_candidates - else: - embeddings = torch.cat((embeddings, value_candidates), 0) - - setattr(self, slot + '_value_embeddings', embeddings) - - - def forward(self, input_ids, token_type_ids, attention_mask, hidden_state=None, inform_labels=None, - request_labels=None, domain_labels=None, goodbye_labels=None, - get_turn_pooled_representation=False, calculate_inform_mutual_info=False): + self.setsumbt = SetSUMBTHead(config) + self.add_slot_candidates = self.setsumbt.add_slot_candidates + self.add_value_candidates = self.setsumbt.add_value_candidates + + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + hidden_state: torch.Tensor = None, + state_labels: torch.Tensor = None, + request_labels: torch.Tensor = None, + active_domain_labels: torch.Tensor = None, + general_act_labels: torch.Tensor = None, + get_turn_pooled_representation: bool = False, + calculate_state_mutual_info: bool = False): + """ + Args: + input_ids: Input token ids + attention_mask: Input padding mask + token_type_ids: Token type indicator + hidden_state: Latent internal dialogue belief state + state_labels: Dialogue state labels + request_labels: User request action labels + active_domain_labels: Current active domain labels + general_act_labels: General user action labels + get_turn_pooled_representation: Return pooled representation of the current dialogue turn + calculate_state_mutual_info: Return mutual information in the dialogue state + + Returns: + out: Tuple containing loss, predictive distributions, model statistics and state mutual information + """ # Encode Dialogues batch_size, dialogue_size, turn_size = input_ids.size() @@ -103,9 +80,10 @@ class BertSetSUMBT(BertPreTrainedModel): turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) if get_turn_pooled_representation: - return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, - dialogue_size, turn_size, hidden_state, inform_labels, request_labels, domain_labels, - goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,) - return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size, - turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, - calculate_inform_mutual_info) + return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, + batch_size, dialogue_size, hidden_state, state_labels, + request_labels, active_domain_labels, general_act_labels, + calculate_state_mutual_info) + (bert_output.pooler_output,) + return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, + dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels, + general_act_labels, calculate_state_mutual_info) diff --git a/convlab/dst/setsumbt/modeling/calibration_utils.py b/convlab/dst/setsumbt/modeling/calibration_utils.py deleted file mode 100644 index 8514ac8d..00000000 --- a/convlab/dst/setsumbt/modeling/calibration_utils.py +++ /dev/null @@ -1,134 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Discriminative models calibration""" - -import random - -import torch -import numpy as np -from tqdm import tqdm - - -# Load logger and tensorboard summary writer -def set_logger(logger_, tb_writer_): - global logger, tb_writer - logger = logger_ - tb_writer = tb_writer_ - - -# Set seeds -def set_seed(args): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - logger.info('Seed set to %d.' % args.seed) - - -def get_predictions(args, model, device, dataloader): - logger.info(" Num Batches = %d", len(dataloader)) - - model.eval() - if args.dropout_iterations > 1: - model.train() - - belief_states = {slot: [] for slot in model.informable_slot_ids} - request_belief = {slot: [] for slot in model.requestable_slot_ids} - domain_belief = {dom: [] for dom in model.domain_ids} - greeting_belief = [] - labels = {slot: [] for slot in model.informable_slot_ids} - request_labels = {slot: [] for slot in model.requestable_slot_ids} - domain_labels = {dom: [] for dom in model.domain_ids} - greeting_labels = [] - epoch_iterator = tqdm(dataloader, desc="Iteration") - for step, batch in enumerate(epoch_iterator): - with torch.no_grad(): - input_ids = batch['input_ids'].to(device) - token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None - - if args.dropout_iterations > 1: - p = {slot: [] for slot in model.informable_slot_ids} - for _ in range(args.dropout_iterations): - p_, p_req_, p_dom_, p_bye_, _ = model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask) - for slot in model.informable_slot_ids: - p[slot].append(p_[slot].unsqueeze(0)) - - mu = {slot: torch.cat(p[slot], 0).mean(0) for slot in model.informable_slot_ids} - sig = {slot: torch.cat(p[slot], 0).var(0) for slot in model.informable_slot_ids} - p = {slot: mu[slot] / torch.sqrt(1 + sig[slot]) for slot in model.informable_slot_ids} - p = {slot: normalise(p[slot]) for slot in model.informable_slot_ids} - else: - p, p_req, p_dom, p_bye, _ = model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask) - - for slot in model.informable_slot_ids: - p_ = p[slot] - labs = batch['labels-' + slot].to(device) - - belief_states[slot].append(p_) - labels[slot].append(labs) - - if p_req is not None: - for slot in model.requestable_slot_ids: - p_ = p_req[slot] - labs = batch['request-' + slot].to(device) - - request_belief[slot].append(p_) - request_labels[slot].append(labs) - - for domain in model.domain_ids: - p_ = p_dom[domain] - labs = batch['active-' + domain].to(device) - - domain_belief[domain].append(p_) - domain_labels[domain].append(labs) - - greeting_belief.append(p_bye) - greeting_labels.append(batch['goodbye'].to(device)) - - for slot in belief_states: - belief_states[slot] = torch.cat(belief_states[slot], 0) - labels[slot] = torch.cat(labels[slot], 0) - if p_req is not None: - for slot in request_belief: - request_belief[slot] = torch.cat(request_belief[slot], 0) - request_labels[slot] = torch.cat(request_labels[slot], 0) - for domain in domain_belief: - domain_belief[domain] = torch.cat(domain_belief[domain], 0) - domain_labels[domain] = torch.cat(domain_labels[domain], 0) - greeting_belief = torch.cat(greeting_belief, 0) - greeting_labels = torch.cat(greeting_labels, 0) - else: - request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6 - - return belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels - - -def normalise(p): - p_shape = p.size() - - p = p.reshape(-1, p_shape[-1]) + 1e-10 - p_sum = p.sum(-1).unsqueeze(1).repeat((1, p_shape[-1])) - p /= p_sum - - p = p.reshape(p_shape) - - return p diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py index 9f101d12..6d3d8035 100644 --- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py +++ b/convlab/dst/setsumbt/modeling/ensemble_nbt.py @@ -16,9 +16,9 @@ """Ensemble SetSUMBT""" import os +from shutil import copy2 as copy import torch -import transformers from torch.nn import Module from transformers import RobertaConfig, BertConfig @@ -29,8 +29,13 @@ MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT} class EnsembleSetSUMBT(Module): + """Ensemble SetSUMBT Model for joint ensemble prediction""" def __init__(self, config): + """ + Args: + config (configuration): Model configuration class + """ super(EnsembleSetSUMBT, self).__init__() self.config = config @@ -38,175 +43,138 @@ class EnsembleSetSUMBT(Module): model_cls = MODELS[self.config.model_type] for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: setattr(self, attr, model_cls(config)) - - # Load all ensemble memeber parameters - def load(self, path, config=None): - if config is None: - config = self.config - + def _load(self, path: str): + """ + Load parameters + Args: + path: Location of model parameters + """ for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: idx = attr.split('_', 1)[-1] - state_dict = torch.load(os.path.join(path, f'pytorch_model_{idx}.bin')) + state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin')) getattr(self, attr).load_state_dict(state_dict) - - # Add new slot candidates to the ensemble members - def add_slot_candidates(self, slot_candidates): + def add_slot_candidates(self, slot_candidates: tuple): + """ + Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings + and a request indicator, if the informable value embeddings is None the slot is not informable and if + the request indicator is false the slot is not requestable. + + Args: + slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator + """ for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: getattr(self, attr).add_slot_candidates(slot_candidates) - self.requestable_slot_ids = self.model_0.requestable_slot_ids - self.informable_slot_ids = self.model_0.informable_slot_ids - self.domain_ids = self.model_0.domain_ids - - - # Add new value candidates to the ensemble members - def add_value_candidates(self, slot, value_candidates, replace=False): + self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids + self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids + self.domain_ids = self.model_0.setsumbt.domain_ids + + def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): + """ + Add value candidates for a slot + + Args: + slot: Slot name + value_candidates: Value candidate embeddings + replace: If true existing value candidates are replaced + """ for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: getattr(self, attr).add_value_candidates(slot, value_candidates, replace) - - # Forward pass of full ensemble - def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'): - logits, request_logits, domain_logits, goodbye_scores = [], [], [], [] - logits = {slot: [] for slot in self.model_0.informable_slot_ids} - request_logits = {slot: [] for slot in self.model_0.requestable_slot_ids} - domain_logits = {dom: [] for dom in self.model_0.domain_ids} - goodbye_scores = [] + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + reduction: str = 'mean') -> tuple: + """ + Args: + input_ids: Input token ids + attention_mask: Input padding mask + token_type_ids: Token type indicator + reduction: Reduction of ensemble member predictive distributions (mean, none) + + Returns: + + """ + belief_state_probs = {slot: [] for slot in self.informable_slot_ids} + request_probs = {slot: [] for slot in self.requestable_slot_ids} + active_domain_probs = {dom: [] for dom in self.domain_ids} + general_act_probs = [] for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: # Prediction from each ensemble member - l, r, d, g, _ = getattr(self, attr)(input_ids=input_ids, + b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - for slot in logits: - logits[slot].append(l[slot].unsqueeze(-2)) - if self.config.predict_intents: - for slot in request_logits: - request_logits[slot].append(r[slot].unsqueeze(-1)) - for dom in domain_logits: - domain_logits[dom].append(d[dom].unsqueeze(-1)) - goodbye_scores.append(g.unsqueeze(-2)) + for slot in belief_state_probs: + belief_state_probs[slot].append(b[slot].unsqueeze(-2)) + if self.config.predict_actions: + for slot in request_probs: + request_probs[slot].append(r[slot].unsqueeze(-1)) + for dom in active_domain_probs: + active_domain_probs[dom].append(d[dom].unsqueeze(-1)) + general_act_probs.append(g.unsqueeze(-2)) - logits = {slot: torch.cat(l, -2) for slot, l in logits.items()} - if self.config.predict_intents: - request_logits = {slot: torch.cat(l, -1) for slot, l in request_logits.items()} - domain_logits = {dom: torch.cat(l, -1) for dom, l in domain_logits.items()} - goodbye_scores = torch.cat(goodbye_scores, -2) + belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()} + if self.config.predict_actions: + request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()} + active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()} + general_act_probs = torch.cat(general_act_probs, -2) else: - request_logits = {} - domain_logits = {} - goodbye_scores = torch.tensor(0.0) + request_probs = {} + active_domain_probs = {} + general_act_probs = torch.tensor(0.0) # Apply reduction of ensemble to single posterior if reduction == 'mean': - logits = {slot: l.mean(-2) for slot, l in logits.items()} - request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()} - domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()} - goodbye_scores = goodbye_scores.mean(-2) + belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()} + request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()} + active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()} + general_act_probs = general_act_probs.mean(-2) elif reduction != 'none': raise(NameError('Not Implemented!')) - return logits, request_logits, domain_logits, goodbye_scores, _ + return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _ @classmethod def from_pretrained(cls, path): - if not os.path.exists(os.path.join(path, 'config.json')): + config_path = os.path.join(path, 'ens-0', 'config.json') + if not os.path.exists(config_path): raise(NameError('Could not find config.json in model path.')) - if not os.path.exists(os.path.join(path, 'pytorch_model_0.bin')): - raise(NameError('Could not find a model binary in the model path.')) try: - config = RobertaConfig.from_pretrained(path) + config = RobertaConfig.from_pretrained(config_path) except: - config = BertConfig.from_pretrained(path) + config = BertConfig.from_pretrained(config_path) + + config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir]) model = cls(config) - model.load(path) + model._load(path) return model -class DropoutEnsembleSetSUMBT(Module): - - def __init__(self, config): - super(DropoutEnsembleBeliefTracker, self).__init__() - self.config = config - - model_cls = MODELS[self.config.model_type] - self.model = model_cls(config) - self.model.train() - - - def load(self, path, config=None): - if config is None: - config = self.config - state_dict = torch.load(os.path.join(path, f'pytorch_model.bin')) - self.model.load_state_dict(state_dict) - - - # Add new slot candidates to the model - def add_slot_candidates(self, slot_candidates): - self.model.add_slot_candidates(slot_candidates) - self.requestable_slot_ids = self.model.requestable_slot_ids - self.informable_slot_ids = self.model.informable_slot_ids - self.domain_ids = self.model.domain_ids - - - # Add new value candidates to the model - def add_value_candidates(self, slot, value_candidates, replace=False): - self.model.add_value_candidates(slot, value_candidates, replace) - - - def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'): - - input_ids = input_ids.unsqueeze(0).repeat((self.config.ensemble_size, 1, 1, 1)) - input_ids = input_ids.reshape(-1, input_ids.size(-2), input_ids.size(-1)) - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(0).repeat((10, 1, 1, 1)) - attention_mask = attention_mask.reshape(-1, attention_mask.size(-2), attention_mask.size(-1)) - if token_type_ids is not None: - token_type_ids = token_type_ids.unsqueeze(0).repeat((10, 1, 1, 1)) - token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-2), token_type_ids.size(-1)) - - self.model.train() - logits, request_logits, domain_logits, goodbye_scores, _ = self.model(input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids) - - logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-2), l.size(-1)).transpose(0, 1).transpose(1, 2) - for s, l in logits.items()} - request_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2) - for s, l in request_logits.items()} - domain_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2) - for s, l in domain_logits.items()} - goodbye_scores = goodbye_scores.reshape(self.config.ensemble_size, -1, goodbye_scores.size(-2), goodbye_scores.size(-1)) - goodbye_scores = goodbye_scores.transpose(0, 1).transpose(1, 2) - - if reduction == 'mean': - logits = {slot: l.mean(-2) for slot, l in logits.items()} - request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()} - domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()} - goodbye_scores = goodbye_scores.mean(-2) - elif reduction != 'none': - raise(NameError('Not Implemented!')) - - return logits, request_logits, domain_logits, goodbye_scores, _ - - - @classmethod - def from_pretrained(cls, path): - if not os.path.exists(os.path.join(path, 'config.json')): - raise(NameError('Could not find config.json in model path.')) - if not os.path.exists(os.path.join(path, 'pytorch_model.bin')): - raise(NameError('Could not find a model binary in the model path.')) - - try: - config = RobertaConfig.from_pretrained(path) - except: - config = BertConfig.from_pretrained(path) - - model = cls(config) - model.load(path) - - return model +def setup_ensemble(model_path: str, ensemble_size: int): + """ + Setup ensemble model directory structure. + + Args: + model_path: Path to ensemble model directory + ensemble_size: Number of ensemble members + """ + for i in range(ensemble_size): + path = os.path.join(model_path, f'ens-{i}') + if not os.path.exists(path): + os.mkdir(path) + os.mkdir(os.path.join(path, 'dataloaders')) + os.mkdir(os.path.join(path, 'database')) + # Add development set dataloader to each ensemble member directory + for set_type in ['dev']: + copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'), + os.path.join(path, 'dataloaders', f'{set_type}.dataloader')) + # Add training and development set ontologies to each ensemble member directory + for set_type in ['train', 'dev']: + copy(os.path.join(model_path, 'database', f'{set_type}.db'), + os.path.join(path, 'database', f'{set_type}.db')) diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py new file mode 100644 index 00000000..c73d4b6d --- /dev/null +++ b/convlab/dst/setsumbt/modeling/evaluation_utils.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation Utilities""" + +import random + +import torch +import numpy as np +from tqdm import tqdm + + +def set_seed(args): + """ + Set random seeds + + Args: + args (Arguments class): Arguments class containing seed and number of gpus to use + """ + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple: + """ + Get model predictions + + Args: + args: Runtime arguments + model: SetSUMBT Model + device: Torch device + dataloader: Dataloader containing eval data + """ + model.eval() + + belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids} + request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids} + active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids} + general_act_probs = [] + state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids} + request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids} + active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids} + general_act_labels = [] + epoch_iterator = tqdm(dataloader, desc="Iteration") + for step, batch in enumerate(epoch_iterator): + with torch.no_grad(): + input_ids = batch['input_ids'].to(device) + token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None + attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None + + p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask) + + for slot in belief_states: + p_ = p[slot] + labs = batch['state_labels-' + slot].to(device) + + belief_states[slot].append(p_) + state_labels[slot].append(labs) + + if p_req is not None: + for slot in request_probs: + p_ = p_req[slot] + labs = batch['request_labels-' + slot].to(device) + + request_probs[slot].append(p_) + request_labels[slot].append(labs) + + for domain in active_domain_probs: + p_ = p_dom[domain] + labs = batch['active_domain_labels-' + domain].to(device) + + active_domain_probs[domain].append(p_) + active_domain_labels[domain].append(labs) + + general_act_probs.append(p_gen) + general_act_labels.append(batch['general_act_labels'].to(device)) + + for slot in belief_states: + belief_states[slot] = torch.cat(belief_states[slot], 0) + state_labels[slot] = torch.cat(state_labels[slot], 0) + if p_req is not None: + for slot in request_probs: + request_probs[slot] = torch.cat(request_probs[slot], 0) + request_labels[slot] = torch.cat(request_labels[slot], 0) + for domain in active_domain_probs: + active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0) + active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0) + general_act_probs = torch.cat(general_act_probs, 0) + general_act_labels = torch.cat(general_act_labels, 0) + else: + request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4 + general_act_probs, general_act_labels = [None] * 2 + + out = (belief_states, state_labels, request_probs, request_labels) + out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels) + return out diff --git a/convlab/dst/setsumbt/modeling/functional.py b/convlab/dst/setsumbt/modeling/functional.py deleted file mode 100644 index 0dd083d0..00000000 --- a/convlab/dst/setsumbt/modeling/functional.py +++ /dev/null @@ -1,456 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SetSUMBT functionals""" - -import torch -import transformers -from torch.autograd import Variable -from torch.nn import (MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout, - CosineSimilarity, CrossEntropyLoss, PairwiseDistance, - Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss) -from torch.nn.init import (xavier_normal_, constant_) - -from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss, BinaryBayesianMatchingLoss, dirichlet -from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss -from convlab.dst.setsumbt.loss.distillation import DistillationKL, BinaryDistillationKL -from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info - - -# Default belief tracker model intialisation function -def _initialise(self, config): - # Slot Utterance matching attention - self.slot_attention = MultiheadAttention( - config.hidden_size, config.slot_attention_heads) - - # Latent context tracker - # Initial state prediction - if not config.rnn_zero_init and config.nbt_type in ['gru', 'lstm']: - self.belief_init = Sequential(Linear(config.hidden_size, config.nbt_hidden_size), - ReLU(), Dropout(config.dropout_rate)) - - # Recurrent context tracker setup - if config.nbt_type == 'gru': - self.nbt = GRU(input_size=config.hidden_size, - hidden_size=config.nbt_hidden_size, - num_layers=config.nbt_layers, - dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate, - batch_first=True) - # Initialise Parameters - xavier_normal_(self.nbt.weight_ih_l0) - xavier_normal_(self.nbt.weight_hh_l0) - constant_(self.nbt.bias_ih_l0, 0.0) - constant_(self.nbt.bias_hh_l0, 0.0) - elif config.nbt_type == 'lstm': - self.nbt = LSTM(input_size=config.hidden_size, - hidden_size=config.nbt_hidden_size, - num_layers=config.nbt_layers, - dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate, - batch_first=True) - # Initialise Parameters - xavier_normal_(self.nbt.weight_ih_l0) - xavier_normal_(self.nbt.weight_hh_l0) - constant_(self.nbt.bias_ih_l0, 0.0) - constant_(self.nbt.bias_hh_l0, 0.0) - else: - raise NameError('Not Implemented') - - # Feature decoder and layer norm - self.intermediate = Linear(config.nbt_hidden_size, config.hidden_size) - self.layer_norm = LayerNorm(config.hidden_size) - - # Dropout - self.dropout = Dropout(config.dropout_rate) - - # Set pooler for set similarity model - if self.config.set_similarity: - # 1D convolutional set pooler - if self.config.set_pooling == 'cnn': - self.conv_pooler = Conv1d( - self.config.hidden_size, self.config.hidden_size, 3) - # Deep averaging network set pooler - elif self.config.set_pooling == 'dan': - self.avg_net = Sequential(Linear(self.config.hidden_size, 2 * self.config.hidden_size), GELU(), - Linear(2 * self.config.hidden_size, self.config.hidden_size)) - - # Model ontology placeholders - self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False) - self.slot_ids = dict() - self.requestable_slot_ids = dict() - self.informable_slot_ids = dict() - self.domain_ids = dict() - - # Matching network similarity measure - if config.distance_measure == 'cosine': - self.distance = CosineSimilarity(dim=-1, eps=1e-8) - elif config.distance_measure == 'euclidean': - self.distance = PairwiseDistance(p=2.0, eps=1e-06, keepdim=False) - else: - raise NameError('NotImplemented') - - # Belief state loss function - if config.loss_function == 'crossentropy': - self.loss = CrossEntropyLoss(ignore_index=-1) - elif config.loss_function == 'bayesianmatching': - self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - elif config.loss_function == 'labelsmoothing': - self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) - elif config.loss_function == 'distillation': - self.loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) - self.temp = 1.0 - elif config.loss_function == 'distribution_distillation': - self.loss = rkl_dirichlet_mediator_loss - self.temp = 1.0 - else: - raise NameError('NotImplemented') - - # Intent and domain prediction heads - if config.predict_actions: - self.request_gate = Linear(config.hidden_size, 1) - self.goodbye_gate = Linear(config.hidden_size, 3) - self.domain_gate = Linear(config.hidden_size, 1) - - # Intent and domain loss function - self.request_weight = float(self.config.user_request_loss_weight) - self.goodbye_weight = float(self.config.user_general_act_loss_weight) - self.domain_weight = float(self.config.active_domain_loss_weight) - if config.loss_function == 'crossentropy': - self.request_loss = BCEWithLogitsLoss() - self.goodbye_loss = CrossEntropyLoss(ignore_index=-1) - self.domain_loss = BCEWithLogitsLoss() - elif config.loss_function == 'labelsmoothing': - self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) - self.goodbye_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) - self.domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) - elif config.loss_function == 'bayesianmatching': - self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - self.goodbye_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - self.domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - elif config.loss_function == 'distillation': - self.request_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) - self.goodbye_loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) - self.domain_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) - - -# Default belief tracker forward pass. -def _nbt_forward(self, turn_embeddings, - turn_pooled_representation, - attention_mask, - batch_size, - dialogue_size, - turn_size, - hidden_state, - inform_labels, - request_labels, - domain_labels, - goodbye_labels, - calculate_inform_mutual_info): - hidden_size = turn_embeddings.size(-1) - # Initialise loss - loss = 0.0 - - # Goodbye predictions - goodbye_probs = None - if self.config.predict_actions: - # General action prediction - goodbye_scores = self.goodbye_gate( - turn_pooled_representation.reshape(batch_size * dialogue_size, hidden_size)) - - # Compute loss for general action predictions (weighted loss) - if goodbye_labels is not None: - if self.config.loss_function == 'distillation': - goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-1)) - loss += self.goodbye_loss(goodbye_scores, goodbye_labels, self.temp) * self.goodbye_weight - elif self.config.loss_function == 'distribution_distillation': - goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-2), goodbye_labels.size(-1)) - loss += self.loss(goodbye_scores, goodbye_labels, 1.0, 1.0)[0] * self.goodbye_weight - else: - goodbye_labels = goodbye_labels.reshape(-1) - loss += self.goodbye_loss(goodbye_scores, goodbye_labels) * self.request_weight - - # Compute general action probabilities - if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'distillation', 'distribution_distillation']: - goodbye_probs = torch.softmax(goodbye_scores, -1).reshape(batch_size, dialogue_size, -1) - elif self.config.loss_function in ['bayesianmatching']: - goodbye_probs = dirichlet(goodbye_scores.reshape(batch_size, dialogue_size, -1)) - - # Slot utterance matching - num_slots = self.slot_embeddings.size(0) - slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size) - slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)).to(turn_embeddings.device) - - if self.config.set_similarity: - # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768] - slot_mask = (slot_embeddings != 0.0).float() - - # Turn embeddings shape [turn_size, batch_size * dialogue_size, 768] - turn_embeddings = turn_embeddings.transpose(0, 1) - # Compute key padding mask - key_padding_mask = (attention_mask[:, :, 0] == 0.0) - key_padding_mask[key_padding_mask[:, 0] == True, :] = False - # Multi head attention of slot over tokens - hidden, _ = self.slot_attention(query=slot_embeddings, - key=turn_embeddings, - value=turn_embeddings, - key_padding_mask=key_padding_mask) # [num_slots, batch_size * dialogue_size, 768] - - # Set embeddings for all masked tokens to 0 - attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1)) - hidden = hidden * attention_mask - if self.config.set_similarity: - hidden = hidden * slot_mask - # Hidden layer shape [num_dials, num_slots, num_turns, 768] - hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2) - - # Latent context tracking - # [batch_size * num_slots, dialogue_size, 768] - hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1) - - if self.config.nbt_type == 'gru': - self.nbt.flatten_parameters() - if hidden_state is None: - if self.config.rnn_zero_init: - context = torch.zeros(self.config.nbt_layers, batch_size * slot_embeddings.size(0), - self.config.nbt_hidden_size) - context = context.to(turn_embeddings.device) - else: - context = self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)) - else: - context = hidden_state.to(hidden.device) - - # [batch_size, dialogue_size, nbt_hidden_size] - belief_embedding, context = self.nbt(hidden, context) - elif self.config.nbt_type == 'lstm': - self.nbt.flatten_parameters() - if self.config.rnn_zero_init: - context = (torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size), - torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size)) - context = (context[0].to(turn_embeddings.device), - context[1].to(turn_embeddings.device)) - else: - context = (self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)), - torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size)) - context = (context[0], context[1].to(turn_embeddings.device)) - - # [batch_size, dialogue_size, nbt_hidden_size] - belief_embedding, context = self.nbt(hidden, context) - - # Decode features - belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), dialogue_size, -1).transpose(1, 2) - if self.config.set_similarity: - belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1, - self.config.nbt_hidden_size) - # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768] - # Normalisation and regularisation - belief_embedding = self.layer_norm(self.intermediate(belief_embedding)) - belief_embedding = self.dropout(belief_embedding) - - # Pooling of the set of latent context representation - if self.config.set_similarity: - slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size) - belief_embedding = belief_embedding * slot_mask - - # Apply pooler to latent context sequence - if self.config.set_pooling == 'mean': - belief_embedding = belief_embedding.sum(-2) / slot_mask.sum(-2) - belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1) - elif self.config.set_pooling == 'cnn': - belief_embedding = belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size).transpose(1, 2) - belief_embedding = self.conv_pooler(belief_embedding) - # Mean pooling after CNN - belief_embedding = belief_embedding.mean(-1).reshape(batch_size, dialogue_size, num_slots, -1) - elif self.config.set_pooling == 'dan': - # sqrt N reduction - belief_embedding = belief_embedding.sum(-2) / torch.sqrt(torch.tensor(slot_mask.sum(-2))) - # Deep averaging feature extractor - belief_embedding = self.avg_net(belief_embedding) - belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1) - - # Perform classification - if self.config.predict_actions: - # User request prediction - request_probs = dict() - for slot, slot_id in self.requestable_slot_ids.items(): - request_scores = self.request_gate(belief_embedding[:, :, slot_id, :]) - - # Store output probabilities - request_scores = request_scores.reshape(batch_size, dialogue_size) - mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size) - batches, dialogues = torch.where(mask == 0.0) - # Set request scores to 0.0 for padded turns - request_scores[batches, dialogues] = 0.0 - if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', - 'distillation', 'distribution_distillation']: - request_probs[slot] = torch.sigmoid(request_scores) - - if request_labels is not None: - # Compute request gate loss - request_scores = request_scores.reshape(-1) - if self.config.loss_function == 'distillation': - loss += self.request_loss(request_scores, request_labels[slot].reshape(-1), - self.temp) * self.request_weight - elif self.config.loss_function == 'distribution_distillation': - scores, labs = convert_probs_to_logits(request_scores, request_labels[slot]) - loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight - else: - labs = request_labels[slot].reshape(-1) - request_scores = request_scores[labs != -1] - labs = labs[labs != -1].float() - loss += self.request_loss(request_scores, labs) * self.request_weight - - # Active domain prediction - domain_probs = dict() - for domain, slot_ids in self.domain_ids.items(): - belief = belief_embedding[:, :, slot_ids, :] - if len(slot_ids) > 1: - # SqrtN reduction across all slots within a domain - belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5) - domain_scores = self.domain_gate(belief) - - # Store output probabilities - domain_scores = domain_scores.reshape(batch_size, dialogue_size) - mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size) - batches, dialogues = torch.where(mask == 0.0) - domain_scores[batches, dialogues] = 0.0 - if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', 'distillation', - 'distribution_distillation']: - domain_probs[domain] = torch.sigmoid(domain_scores) - - if domain_labels is not None: - # Compute domain prediction loss - domain_scores = domain_scores.reshape(-1) - if self.config.loss_function == 'distillation': - loss += self.domain_loss(domain_scores, domain_labels[domain].reshape(-1), - self.temp) * self.domain_weight - elif self.config.loss_function == 'distribution_distillation': - scores, labs = convert_probs_to_logits(domain_scores, domain_labels[domain]) - loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight - else: - labs = domain_labels[domain].reshape(-1) - domain_scores = domain_scores[labs != -1] - labs = labs[labs != -1].float() - loss += self.domain_loss(domain_scores, labs) * self.domain_weight - else: - request_probs, domain_probs = None, None - - # Informable slot predictions - inform_probs = dict() - out_dict = dict() - mutual_info = dict() - stats = dict() - for slot, slot_id in self.informable_slot_ids.items(): - # Get slot belief embedding and value candidates - candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device) - belief = belief_embedding[:, :, slot_id, :] - slot_size = candidate_embeddings.size(0) - - # Use similaroty matching to produce belief state - if self.config.distance_measure in ['cosine', 'euclidean']: - belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1)) - belief = belief.reshape(-1, self.config.hidden_size) - - # Pooling of set of value candidate description representation - if self.config.set_similarity and self.config.set_pooling == 'mean': - candidate_mask = (candidate_embeddings != 0.0).float() - candidate_embeddings = candidate_embeddings.sum(1) / candidate_mask.sum(1) - elif self.config.set_similarity and self.config.set_pooling == 'cnn': - candidate_embeddings = candidate_embeddings.transpose(1, 2) - candidate_embeddings = self.conv_pooler(candidate_embeddings).mean(-1) - elif self.config.set_similarity and self.config.set_pooling == 'dan': - candidate_mask = (candidate_embeddings != 0.0).float() - candidate_embeddings = candidate_embeddings.sum(1) / torch.sqrt(torch.tensor(candidate_mask.sum(1))) - candidate_embeddings = self.avg_net(candidate_embeddings) - - candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size, - dialogue_size, 1, 1)) - candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size) - - # Score value candidates - if self.config.distance_measure == 'cosine': - scores = self.distance(belief, candidate_embeddings) - # *27 here rescales the cosine similarity for better learning - scores = scores.reshape(batch_size * dialogue_size, -1) * 27.0 - elif self.config.distance_measure == 'euclidean': - scores = -1.0 * self.distance(belief, candidate_embeddings) - scores = scores.reshape(batch_size * dialogue_size, -1) - - # Calculate belief state - if self.config.loss_function in ['crossentropy', 'inhibitedce', - 'labelsmoothing', 'distillation', 'distribution_distillation']: - probs_ = torch.softmax(scores.reshape(batch_size, dialogue_size, -1), -1) - elif self.config.loss_function in ['bayesianmatching']: - probs_ = dirichlet(scores.reshape(batch_size, dialogue_size, -1)) - - # Compute knowledge uncertainty in the beleif states - if calculate_inform_mutual_info and self.config.loss_function == 'distribution_distillation': - mutual_info[slot] = logits_to_mutual_info(scores).reshape(batch_size, dialogue_size) - - # Set padded turn probabilities to zero - mask = attention_mask[self.slot_ids[slot],:, 0].reshape(batch_size, dialogue_size) - batches, dialogues = torch.where(mask == 0.0) - probs_[batches, dialogues, :] = 0.0 - inform_probs[slot] = probs_ - - # Calculate belief state loss - if inform_labels is not None and slot in inform_labels: - if self.config.loss_function == 'bayesianmatching': - prior = torch.ones(scores.size(-1)).float().to(scores.device) - prior = prior * self.config.prior_constant - prior = prior.unsqueeze(0).repeat((scores.size(0), 1)) - - loss += self.loss(scores, inform_labels[slot].reshape(-1), prior=prior) - elif self.config.loss_function == 'distillation': - labels = inform_labels[slot] - labels = labels.reshape(-1, labels.size(-1)) - loss += self.loss(scores, labels, self.temp) - elif self.config.loss_function == 'distribution_distillation': - labels = inform_labels[slot] - labels = labels.reshape(-1, labels.size(-2), labels.size(-1)) - loss_, model_stats, ensemble_stats = self.loss(scores, labels, 1.0, 1.0) - loss += loss_ - - # Calculate stats regarding model precisions - precision = model_stats['precision'] - ensemble_precision = ensemble_stats['precision'] - stats[slot] = {'model_precision_min': precision.min(), - 'model_precision_max': precision.max(), - 'model_precision_mean': precision.mean(), - 'ensemble_precision_min': ensemble_precision.min(), - 'ensemble_precision_max': ensemble_precision.max(), - 'ensemble_precision_mean': ensemble_precision.mean()} - else: - loss += self.loss(scores, inform_labels[slot].reshape(-1)) - - # Return model outputs - out = inform_probs, request_probs, domain_probs, goodbye_probs, context - if inform_labels is not None or request_labels is not None or domain_labels is not None or goodbye_labels is not None: - out = (loss,) + out + (stats,) - if calculate_inform_mutual_info: - out = out + (mutual_info,) - return out - - -# Convert binary scores and labels to 2 class classification problem for distribution distillation -def convert_probs_to_logits(scores, labels): - # Convert single target probability p to distribution [1-p, p] - labels = labels.reshape(-1, labels.size(-1), 1) - labels = torch.cat([1 - labels, labels], -1) - - # Convert input scores into predictive distribution [1-z, z] - scores = torch.sigmoid(scores).unsqueeze(1) - scores = torch.cat((1 - scores, scores), 1) - scores = -1.0 * torch.log((1 / (scores + 1e-8)) - 1) # Inverse sigmoid - - return scores, labels diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py index 36920c5c..f72d17fa 100644 --- a/convlab/dst/setsumbt/modeling/roberta_nbt.py +++ b/convlab/dst/setsumbt/modeling/roberta_nbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,16 +16,19 @@ """RoBERTa SetSUMBT""" import torch -import transformers -from torch.autograd import Variable from transformers import RobertaModel, RobertaPreTrainedModel -from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward +from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead class RobertaSetSUMBT(RobertaPreTrainedModel): + """Roberta based SetSUMBT model""" def __init__(self, config): + """ + Args: + config (configuration): Model configuration class + """ super(RobertaSetSUMBT, self).__init__(config) self.config = config @@ -35,60 +38,37 @@ class RobertaSetSUMBT(RobertaPreTrainedModel): for p in self.roberta.parameters(): p.requires_grad = False - _initialise(self, config) + self.setsumbt = SetSUMBTHead(config) + self.add_slot_candidates = self.setsumbt.add_slot_candidates + self.add_value_candidates = self.setsumbt.add_value_candidates + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + hidden_state: torch.Tensor = None, + state_labels: torch.Tensor = None, + request_labels: torch.Tensor = None, + active_domain_labels: torch.Tensor = None, + general_act_labels: torch.Tensor = None, + get_turn_pooled_representation: bool = False, + calculate_state_mutual_info: bool = False): + """ + Args: + input_ids: Input token ids + attention_mask: Input padding mask + token_type_ids: Token type indicator + hidden_state: Latent internal dialogue belief state + state_labels: Dialogue state labels + request_labels: User request action labels + active_domain_labels: Current active domain labels + general_act_labels: General user action labels + get_turn_pooled_representation: Return pooled representation of the current dialogue turn + calculate_state_mutual_info: Return mutual information in the dialogue state - # Add new slot candidates to the model - def add_slot_candidates(self, slot_candidates): - """slot_candidates is a list of tuples for each slot. - - The tuples contains the slot embedding, informable value embeddings and a request indicator. - - If the informable value embeddings is None the slot is not informable - - If the request indicator is false the slot is not requestable""" - if self.slot_embeddings.size(0) != 0: - embeddings = self.slot_embeddings.detach() - else: - embeddings = torch.zeros(0) - - for slot in slot_candidates: - if slot in self.slot_ids: - index = self.slot_ids[slot] - embeddings[index, :] = slot_candidates[slot][0] - else: - index = embeddings.size(0) - emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) - embeddings = torch.cat((embeddings, emb), 0) - self.slot_ids[slot] = index - setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) - # Add slot to relevant requestable and informable slot lists - if slot_candidates[slot][2]: - self.requestable_slot_ids[slot] = index - if slot_candidates[slot][1] is not None: - self.informable_slot_ids[slot] = index - - domain = slot.split('-', 1)[0] - if domain not in self.domain_ids: - self.domain_ids[domain] = [] - self.domain_ids[domain].append(index) - self.domain_ids[domain] = list(set(self.domain_ids[domain])) - - self.slot_embeddings = Variable(embeddings, requires_grad=False) - - - # Add new value candidates to the model - def add_value_candidates(self, slot, value_candidates, replace=False): - embeddings = getattr(self, slot + '_value_embeddings') - - if embeddings.size(0) == 0 or replace: - embeddings = value_candidates - else: - embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0) - - setattr(self, slot + '_value_embeddings', embeddings) - - - def forward(self, input_ids, attention_mask, token_type_ids=None, hidden_state=None, inform_labels=None, - request_labels=None, domain_labels=None, goodbye_labels=None, - get_turn_pooled_representation=False, calculate_inform_mutual_info=False): + Returns: + out: Tuple containing loss, predictive distributions, model statistics and state mutual information + """ if token_type_ids is not None: token_type_ids = None @@ -106,9 +86,10 @@ class RobertaSetSUMBT(RobertaPreTrainedModel): turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) if get_turn_pooled_representation: - return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size, - turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, - calculate_inform_mutual_info) + (roberta_output.pooler_output,) - return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size, - turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, - calculate_inform_mutual_info) + return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, + batch_size, dialogue_size, hidden_state, state_labels, + request_labels, active_domain_labels, general_act_labels, + calculate_state_mutual_info) + (roberta_output.pooler_output,) + return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, + dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels, + general_act_labels, calculate_state_mutual_info) diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py new file mode 100644 index 00000000..0249649f --- /dev/null +++ b/convlab/dst/setsumbt/modeling/setsumbt.py @@ -0,0 +1,564 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SetSUMBT Prediction Head""" + +import torch +from torch.autograd import Variable +from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout, + CosineSimilarity, CrossEntropyLoss, PairwiseDistance, + Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss) +from torch.nn.init import (xavier_normal_, constant_) + +from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss, + KLDistillationLoss, BinaryKLDistillationLoss, + LabelSmoothingLoss, BinaryLabelSmoothingLoss, + RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss) + + +class SlotUtteranceMatching(Module): + """Slot Utterance matching attention based information extractor""" + + def __init__(self, hidden_size: int = 768, attention_heads: int = 12): + """ + Args: + hidden_size (int): Dimension of token embeddings + attention_heads (int): Number of attention heads to use in attention module + """ + super(SlotUtteranceMatching, self).__init__() + + self.attention = MultiheadAttention(hidden_size, attention_heads) + + def forward(self, + turn_embeddings: torch.Tensor, + attention_mask: torch.Tensor, + slot_embeddings: torch.Tensor) -> torch.Tensor: + """ + Args: + turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size] + attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size] + slot_embeddings: Embeddings for each token in the slot descriptions + + Returns: + hidden: Information extracted from turn related to slot descriptions + """ + turn_embeddings = turn_embeddings.transpose(0, 1) + + key_padding_mask = (attention_mask[:, :, 0] == 0.0) + key_padding_mask[key_padding_mask[:, 0], :] = False + + hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings, + key_padding_mask=key_padding_mask) + + attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1)) + hidden = hidden * attention_mask + + return hidden + + +class RecurrentNeuralBeliefTracker(Module): + """Recurrent latent neural belief tracking module""" + + def __init__(self, + nbt_type: str = 'gru', + rnn_zero_init: bool = False, + input_size: int = 768, + hidden_size: int = 300, + hidden_layers: int = 1, + dropout_rate: float = 0.3): + """ + Args: + nbt_type: Type of recurrent neural network (gru/lstm) + rnn_zero_init: Use zero initialised state for the RNN + input_size: Embedding size of the inputs + hidden_size: Hidden size of the RNN + hidden_layers: Number of RNN Layers + dropout_rate: Dropout rate + """ + super(RecurrentNeuralBeliefTracker, self).__init__() + + if rnn_zero_init: + self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate)) + else: + self.belief_init = None + + self.nbt_type = nbt_type + self.hidden_layers = hidden_layers + self.hidden_size = hidden_size + if nbt_type == 'gru': + self.nbt = GRU(input_size=input_size, + hidden_size=hidden_size, + num_layers=hidden_layers, + dropout=0.0 if hidden_layers == 1 else dropout_rate, + batch_first=True) + elif nbt_type == 'lstm': + self.nbt = LSTM(input_size=input_size, + hidden_size=hidden_size, + num_layers=hidden_layers, + dropout=0.0 if hidden_layers == 1 else dropout_rate, + batch_first=True) + else: + raise NameError('Not Implemented') + + # Initialise Parameters + xavier_normal_(self.nbt.weight_ih_l0) + xavier_normal_(self.nbt.weight_hh_l0) + constant_(self.nbt.bias_ih_l0, 0.0) + constant_(self.nbt.bias_hh_l0, 0.0) + + # Intermediate feature mapping and layer normalisation + self.intermediate = Linear(hidden_size, input_size) + self.layer_norm = LayerNorm(input_size) + self.dropout = Dropout(dropout_rate) + + def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor: + """ + Args: + inputs: Latent turn level information + hidden_state: Latent internal belief state + + Returns: + belief_embedding: Belief state embeddings + context: Latent internal belief state + """ + self.nbt.flatten_parameters() + if hidden_state is None: + if self.belief_init is None: + context = torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device) + else: + context = self.belief_init(inputs[:, 0, :]).unsqueeze(0).repeat((self.hidden_layers, 1, 1)) + if self.nbt_type == "lstm": + context = (context, torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device)) + else: + context = hidden_state.to(inputs.device) + + # [batch_size, dialogue_size, nbt_hidden_size] + belief_embedding, context = self.nbt(inputs, context) + + # Normalisation and regularisation + belief_embedding = self.layer_norm(self.intermediate(belief_embedding)) + belief_embedding = self.dropout(belief_embedding) + + return belief_embedding, context + + +class SetPooler(Module): + """Token set pooler""" + + def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768): + """ + Args: + pooling_strategy: Type of set pooler (cnn/dan/mean) + hidden_size: Token embedding size + """ + super(SetPooler, self).__init__() + + self.pooling_strategy = pooling_strategy + if pooling_strategy == 'cnn': + self.cnn_filter_size = 3 + self.pooler = Conv1d(hidden_size, hidden_size, self.cnn_filter_size) + elif pooling_strategy == 'dan': + self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size)) + + def forward(self, inputs, attention_mask): + """ + Args: + inputs: Token set embeddings + attention_mask: Padding mask for the set of tokens + + Returns: + + """ + if self.pooling_strategy == "mean": + hidden = inputs.sum(1) / attention_mask.sum(1) + elif self.pooling_strategy == "cnn": + hidden = self.pooler(inputs.transpose(1, 2)).mean(-1) + elif self.pooling_strategy == 'dan': + hidden = inputs.sum(1) / torch.sqrt(torch.tensor(attention_mask.sum(1))) + hidden = self.pooler(hidden) + + return hidden + + +class SetSUMBTHead(Module): + """SetSUMBT Prediction Head for Language Models""" + + def __init__(self, config): + """ + Args: + config (configuration): Model configuration class + """ + super(SetSUMBTHead, self).__init__() + self.config = config + # Slot Utterance matching attention + self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads) + + # Latent context tracker + self.nbt = RecurrentNeuralBeliefTracker(config.nbt_type, config.rnn_zero_init, config.hidden_size, + config.nbt_hidden_size, config.nbt_layers, config.dropout_rate) + + # Set pooler for set similarity model + if self.config.set_similarity: + self.set_pooler = SetPooler(config.set_pooling, config.hidden_size) + + # Model ontology placeholders + self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False) + self.slot_ids = dict() + self.requestable_slot_ids = dict() + self.informable_slot_ids = dict() + self.domain_ids = dict() + + # Matching network similarity measure + if config.distance_measure == 'cosine': + self.distance = CosineSimilarity(dim=-1, eps=1e-8) + elif config.distance_measure == 'euclidean': + self.distance = PairwiseDistance(p=2.0, eps=1e-6, keepdim=False) + else: + raise NameError('NotImplemented') + + # User goal prediction loss function + if config.loss_function == 'crossentropy': + self.loss = CrossEntropyLoss(ignore_index=-1) + elif config.loss_function == 'bayesianmatching': + self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + elif config.loss_function == 'labelsmoothing': + self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) + elif config.loss_function == 'distillation': + self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) + self.temp = 1.0 + elif config.loss_function == 'distribution_distillation': + self.loss = RKLDirichletMediatorLoss(ignore_index=-1) + else: + raise NameError('NotImplemented') + + # Intent and domain prediction heads + if config.predict_actions: + self.request_gate = Linear(config.hidden_size, 1) + self.general_act_gate = Linear(config.hidden_size, 3) + self.active_domain_gate = Linear(config.hidden_size, 1) + + # Intent and domain loss function + self.request_weight = float(self.config.user_request_loss_weight) + self.general_act_weight = float(self.config.user_general_act_loss_weight) + self.active_domain_weight = float(self.config.active_domain_loss_weight) + if config.loss_function == 'crossentropy': + self.request_loss = BCEWithLogitsLoss() + self.general_act_loss = CrossEntropyLoss(ignore_index=-1) + self.active_domain_loss = BCEWithLogitsLoss() + elif config.loss_function == 'labelsmoothing': + self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) + self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) + self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) + elif config.loss_function == 'bayesianmatching': + self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + elif config.loss_function == 'distillation': + self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) + self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) + self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) + elif config.loss_function == 'distribution_distillation': + self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1) + self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1) + self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1) + + def add_slot_candidates(self, slot_candidates: tuple): + """ + Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings + and a request indicator, if the informable value embeddings is None the slot is not informable and if + the request indicator is false the slot is not requestable. + + Args: + slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator + """ + if self.slot_embeddings.size(0) != 0: + embeddings = self.slot_embeddings.detach() + else: + embeddings = torch.zeros(0) + + for slot in slot_candidates: + if slot in self.slot_ids: + index = self.slot_ids[slot] + embeddings[index, :] = slot_candidates[slot][0] + else: + index = embeddings.size(0) + emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) + embeddings = torch.cat((embeddings, emb), 0) + self.slot_ids[slot] = index + setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) + # Add slot to relevant requestable and informable slot lists + if slot_candidates[slot][2]: + self.requestable_slot_ids[slot] = index + if slot_candidates[slot][1] is not None: + self.informable_slot_ids[slot] = index + + domain = slot.split('-', 1)[0] + if domain not in self.domain_ids: + self.domain_ids[domain] = [] + self.domain_ids[domain].append(index) + self.domain_ids[domain] = list(set(self.domain_ids[domain])) + + self.slot_embeddings = Variable(embeddings, requires_grad=False) + + def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): + """ + Add value candidates for a slot + + Args: + slot: Slot name + value_candidates: Value candidate embeddings + replace: If true existing value candidates are replaced + """ + embeddings = getattr(self, slot + '_value_embeddings') + + if embeddings.size(0) == 0 or replace: + embeddings = value_candidates + else: + embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0) + + setattr(self, slot + '_value_embeddings', embeddings) + + def forward(self, + turn_embeddings: torch.Tensor, + turn_pooled_representation: torch.Tensor, + attention_mask: torch.Tensor, + batch_size: int, + dialogue_size: int, + hidden_state: torch.Tensor = None, + state_labels: torch.Tensor = None, + request_labels: torch.Tensor = None, + active_domain_labels: torch.Tensor = None, + general_act_labels: torch.Tensor = None, + calculate_state_mutual_info: bool = False): + """ + Args: + turn_embeddings: Token embeddings in the current turn + turn_pooled_representation: Pooled representation of the current dialogue turn + attention_mask: Padding mask for the current dialogue turn + batch_size: Number of dialogues in the batch + dialogue_size: Number of turns in each dialogue + hidden_state: Latent internal dialogue belief state + state_labels: Dialogue state labels + request_labels: User request action labels + active_domain_labels: Current active domain labels + general_act_labels: General user action labels + calculate_state_mutual_info: Return mutual information in the dialogue state + + Returns: + out: Tuple containing loss, predictive distributions, model statistics and state mutual information + """ + hidden_size = turn_embeddings.size(-1) + # Initialise loss + loss = 0.0 + + # General Action predictions + general_act_probs = None + if self.config.predict_actions: + # General action prediction + general_act_logits = self.general_act_gate(turn_pooled_representation.reshape(batch_size * dialogue_size, + hidden_size)) + + # Compute loss for general action predictions (weighted loss) + if general_act_labels is not None: + if self.config.loss_function == 'distillation': + general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-1)) + loss += self.general_act_loss(general_act_logits, general_act_labels, + self.temp) * self.general_act_weight + elif self.config.loss_function == 'distribution_distillation': + general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-2), + general_act_labels.size(-1)) + loss += self.general_act_loss(general_act_logits, general_act_labels)[0] * self.general_act_weight + else: + general_act_labels = general_act_labels.reshape(-1) + loss += self.general_act_loss(general_act_logits, general_act_labels) * self.general_act_weight + + # Compute general action probabilities + general_act_probs = torch.softmax(general_act_logits, -1).reshape(batch_size, dialogue_size, -1) + + # Slot utterance matching + num_slots = self.slot_embeddings.size(0) + slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size) + slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)) + slot_embeddings = slot_embeddings.to(turn_embeddings.device) + + if self.config.set_similarity: + # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768] + slot_mask = (slot_embeddings != 0.0).float() + + hidden = self.slot_utterance_matching(turn_embeddings, attention_mask, slot_embeddings) + + if self.config.set_similarity: + hidden = hidden * slot_mask + # Hidden layer shape [num_dials, num_slots, num_turns, 768] + hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2) + + # Latent context tracking + # [batch_size * num_slots, dialogue_size, 768] + hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1) + belief_embedding, hidden_state = self.nbt(hidden, hidden_state) + + belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), + dialogue_size, -1).transpose(1, 2) + if self.config.set_similarity: + belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1, + self.config.hidden_size) + # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768] + + # Pooling of the set of latent context representation + if self.config.set_similarity: + slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size) + belief_embedding = belief_embedding * slot_mask + + belief_embedding = self.set_pooler(belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size), + slot_mask.reshape(-1, slot_mask.size(-2), hidden_size)) + belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1) + + # Perform classification + # Get padded batch, dialogue idx pairs + batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0) + + if self.config.predict_actions: + # User request prediction + request_probs = dict() + for slot, slot_id in self.requestable_slot_ids.items(): + request_logits = self.request_gate(belief_embedding[:, :, slot_id, :]) + + # Store output probabilities + request_logits = request_logits.reshape(batch_size, dialogue_size) + # Set request scores to 0.0 for padded turns + request_logits[batches, dialogues] = 0.0 + request_probs[slot] = torch.sigmoid(request_logits) + + if request_labels is not None: + # Compute request gate loss + request_logits = request_logits.reshape(-1) + if self.config.loss_function == 'distillation': + loss += self.request_loss(request_logits, request_labels[slot].reshape(-1), + self.temp) * self.request_weight + elif self.config.loss_function == 'distribution_distillation': + loss += self.request_loss(request_logits, request_labels[slot])[0] * self.request_weight + else: + labs = request_labels[slot].reshape(-1) + request_logits = request_logits[labs != -1] + labs = labs[labs != -1].float() + loss += self.request_loss(request_logits, labs) * self.request_weight + + # Active domain prediction + active_domain_probs = dict() + for domain, slot_ids in self.domain_ids.items(): + belief = belief_embedding[:, :, slot_ids, :] + if len(slot_ids) > 1: + # SqrtN reduction across all slots within a domain + belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5) + active_domain_logits = self.active_domain_gate(belief) + + # Store output probabilities + active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size) + active_domain_logits[batches, dialogues] = 0.0 + active_domain_probs[domain] = torch.sigmoid(active_domain_logits) + + if active_domain_labels is not None and domain in active_domain_labels: + # Compute domain prediction loss + active_domain_logits = active_domain_logits.reshape(-1) + if self.config.loss_function == 'distillation': + loss += self.active_domain_loss(active_domain_logits, active_domain_labels[domain].reshape(-1), + self.temp) * self.active_domain_weight + elif self.config.loss_function == 'distribution_distillation': + loss += self.active_domain_loss(active_domain_logits, + active_domain_labels[domain])[0] * self.active_domain_weight + else: + labs = active_domain_labels[domain].reshape(-1) + active_domain_logits = active_domain_logits[labs != -1] + labs = labs[labs != -1].float() + loss += self.active_domain_loss(active_domain_logits, labs) * self.active_domain_weight + else: + request_probs, active_domain_probs = None, None + + # Dialogue state predictions + belief_state_probs = dict() + belief_state_mutual_info = dict() + belief_state_stats = dict() + for slot, slot_id in self.informable_slot_ids.items(): + # Get slot belief embedding and value candidates + candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device) + belief = belief_embedding[:, :, slot_id, :] + slot_size = candidate_embeddings.size(0) + + belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1)) + belief = belief.reshape(-1, self.config.hidden_size) + + if self.config.set_similarity: + candidate_embeddings = self.set_pooler(candidate_embeddings, (candidate_embeddings != 0.0).float()) + candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size, + dialogue_size, 1, 1)) + candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size) + + # Score value candidates + if self.config.distance_measure == 'cosine': + logits = self.distance(belief, candidate_embeddings) + # *27 here rescales the cosine similarity for better learning + logits = logits.reshape(batch_size * dialogue_size, -1) * 27.0 + elif self.config.distance_measure == 'euclidean': + logits = -1.0 * self.distance(belief, candidate_embeddings) + logits = logits.reshape(batch_size * dialogue_size, -1) + + # Calculate belief state + probs_ = torch.softmax(logits.reshape(batch_size, dialogue_size, -1), -1) + + # Compute knowledge uncertainty in the beleif states + if calculate_state_mutual_info and self.config.loss_function == 'distribution_distillation': + belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size) + + # Set padded turn probabilities to zero + probs_[batches, dialogues, :] = 0.0 + belief_state_probs[slot] = probs_ + + # Calculate belief state loss + if state_labels is not None and slot in state_labels: + if self.config.loss_function == 'bayesianmatching': + prior = torch.ones(logits.size(-1)).float().to(logits.device) + prior = prior * self.config.prior_constant + prior = prior.unsqueeze(0).repeat((logits.size(0), 1)) + + loss += self.loss(logits, state_labels[slot].reshape(-1), prior=prior) + elif self.config.loss_function == 'distillation': + labels = state_labels[slot] + labels = labels.reshape(-1, labels.size(-1)) + loss += self.loss(logits, labels, self.temp) + elif self.config.loss_function == 'distribution_distillation': + labels = state_labels[slot] + labels = labels.reshape(-1, labels.size(-2), labels.size(-1)) + loss_, model_stats, ensemble_stats = self.loss(logits, labels) + loss += loss_ + + # Calculate stats regarding model precisions + precision = model_stats['precision'] + ensemble_precision = ensemble_stats['precision'] + belief_state_stats[slot] = {'model_precision_min': precision.min(), + 'model_precision_max': precision.max(), + 'model_precision_mean': precision.mean(), + 'ensemble_precision_min': ensemble_precision.min(), + 'ensemble_precision_max': ensemble_precision.max(), + 'ensemble_precision_mean': ensemble_precision.mean()} + else: + loss += self.loss(logits, state_labels[slot].reshape(-1)) + + # Return model outputs + out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state + if state_labels is not None or request_labels is not None: + out = (loss,) + out + (belief_state_stats,) + if calculate_state_mutual_info: + out = out + (belief_state_mutual_info,) + return out diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py index fab205be..654e83c5 100644 --- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py +++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,50 +13,70 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Temperature Scheduler Class""" -import torch +"""Linear Temperature Scheduler Class""" + # Temp scheduler class for ensemble distillation -class TemperatureScheduler: +class LinearTemperatureScheduler: + """ + Temperature scheduler object used for distribution temperature scheduling in distillation - def __init__(self, total_steps, base_temp=2.5, cycle_len=0.1): - self.state = {} + Attributes: + state (dict): Internal state of scheduler + """ + def __init__(self, + total_steps: int, + base_temp: float = 2.5, + cycle_len: float = 0.1): + """ + Args: + total_steps (int): Total number of training steps + base_temp (float): Starting temperature + cycle_len (float): Fraction of total steps used for scheduling cycle + """ + self.state = dict() self.state['total_steps'] = total_steps self.state['current_step'] = 0 self.state['base_temp'] = base_temp self.state['current_temp'] = base_temp self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)] + self.state['rate'] = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0]) def step(self): + """ + Update temperature based on the schedule + """ self.state['current_step'] += 1 assert self.state['current_step'] <= self.state['total_steps'] if self.state['current_step'] > self.state['cycles'][0]: if self.state['current_step'] < self.state['cycles'][1]: - rate = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0]) - self.state['current_temp'] -= rate + self.state['current_temp'] -= self.state['rate'] else: self.state['current_temp'] = 1.0 def temp(self): + """ + Get current temperature + + Returns: + temp (float): Current temperature for distribution scaling + """ return float(self.state['current_temp']) def state_dict(self): - return self.state - - def load_state_dict(self, sd): - self.state = sd + """ + Return scheduler state - -# if __name__ == "__main__": -# temp_scheduler = TemperatureScheduler(100) -# print(temp_scheduler.state_dict()) - -# temp = [] -# for i in range(100): -# temp.append(temp_scheduler.temp()) -# temp_scheduler.step() + Returns: + state (dict): Dictionary format state of the scheduler + """ + return self.state -# temp_scheduler.load_state_dict(temp_scheduler.state_dict()) -# print(temp_scheduler.state_dict()) + def load_state_dict(self, state_dict: dict): + """ + Load scheduler state from dictionary -# print(temp) + Args: + state_dict (dict): Dictionary format state of the scheduler + """ + self.state = state_dict diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py index 259c6e1d..77f41dc3 100644 --- a/convlab/dst/setsumbt/modeling/training.py +++ b/convlab/dst/setsumbt/modeling/training.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,17 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Training utils""" +"""Training and evaluation utils""" import random import os import logging +from copy import deepcopy import torch from torch.nn import DataParallel from torch.distributions import Categorical import numpy as np -from transformers import AdamW, get_linear_schedule_with_warmup +from transformers import get_linear_schedule_with_warmup +from torch.optim import AdamW from tqdm import tqdm, trange try: from apex import amp @@ -31,7 +33,7 @@ except: print('Apex not used') from convlab.dst.setsumbt.utils import clear_checkpoints -from convlab.dst.setsumbt.modeling.temperature_scheduler import TemperatureScheduler +from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler # Load logger and tensorboard summary writer @@ -59,18 +61,131 @@ def set_ontology_embeddings(model, slots, load_slots=True): if load_slots: slots = {slot: embs for slot, embs in slots.items()} model.add_slot_candidates(slots) - for slot in model.informable_slot_ids: + try: + informable_slot_ids = model.setsumbt.informable_slot_ids + except: + informable_slot_ids = model.informable_slot_ids + for slot in informable_slot_ids: model.add_value_candidates(slot, values[slot], replace=True) -def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev, embeddings=None, tokenizer=None): - """Train model!""" +def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None): + """ + Log training statistics. + + Args: + global_step: Number of global training steps completed + loss: Training loss + jg_acc: Joint goal accuracy + sl_acc: Slot accuracy + req_f1: Request prediction F1 score + dom_f1: Active domain prediction F1 score + gen_f1: General action prediction F1 score + stats: Uncertainty measure statistics of model + """ + if type(global_step) == int: + info = f"{global_step} steps complete, " + info += f"Loss since last update: {loss}. Validation set stats: " + elif global_step == 'training_complete': + info = f"Training Complete, " + info += f"Validation set stats: " + elif global_step == 'dev': + info = f"Validation set stats: Loss: {loss}, " + elif global_step == 'test': + info = f"Test set stats: Loss: {loss}, " + info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, " + if req_f1 is not None: + info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, " + info += f"General Action F1 Score: {gen_f1}" + logger.info(info) + + if type(global_step) == int: + tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) + tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) + if req_f1 is not None: + tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step) + tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step) + tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step) + tb_writer.add_scalar('Loss/Dev', loss, global_step) + + if stats: + for slot, stats_slot in stats.items(): + for key, item in stats_slot.items(): + tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step) + + +def get_input_dict(batch: dict, + predict_actions: bool, + model_informable_slot_ids: list, + model_requestable_slot_ids: list = None, + model_domain_ids: list = None, + device = 'cpu') -> dict: + """ + Produce model input arguments + + Args: + batch: Batch of data from the dataloader + predict_actions: Model should predict user actions if set true + model_informable_slot_ids: List of model dialogue state slots + model_requestable_slot_ids: List of model requestable slots + model_domain_ids: List of model domains + device: Current torch device in use + + Returns: + input_dict: Dictrionary containing model inputs for the batch + """ + input_dict = dict() + + input_dict['input_ids'] = batch['input_ids'].to(device) + input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None + input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None + + if any('belief_state' in key for key in batch): + input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device) + for slot in model_informable_slot_ids + if ('belief_state-' + slot) in batch} + if predict_actions: + input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device) + for slot in model_requestable_slot_ids + if ('request_probs-' + slot) in batch} + input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device) + for domain in model_domain_ids + if ('active_domain_probs-' + domain) in batch} + input_dict['general_act_labels'] = batch['general_act_probs'].to(device) + else: + input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device) + for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch} + if predict_actions: + input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device) + for slot in model_requestable_slot_ids + if ('request_labels-' + slot) in batch} + input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device) + for domain in model_domain_ids + if ('active_domain_labels-' + domain) in batch} + input_dict['general_act_labels'] = batch['general_act_labels'].to(device) + + return input_dict + + +def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict): + """ + Train the SetSUMBT model. + + Args: + args: Runtime arguments + model: SetSUMBT Model instance to train + device: Torch device to use during training + train_dataloader: Dataloader containing the training data + dev_dataloader: Dataloader containing the validation set data + slots: Model ontology used for training + slots_dev: Model ontology used for evaluating on the validation set + """ # Calculate the total number of training steps to be performed if args.max_training_steps > 0: t_total = args.max_training_steps - args.num_train_epochs = args.max_training_steps // ( - (len(train_dataloader) // args.gradient_accumulation_steps) + 1) + args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + args.num_train_epochs = args.max_training_steps // args.num_train_epochs else: t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs @@ -88,12 +203,12 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, - "lr":args.learning_rate + "lr": args.learning_rate }, ] # Initialise the optimizer - optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) # Initialise linear lr scheduler num_warmup_steps = int(t_total * args.warmup_proportion) @@ -109,8 +224,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de # Set up fp16 and multi gpu usage if args.fp16: - model, optimizer = amp.initialize( - model, optimizer, opt_level=args.fp16_opt_level) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) if args.n_gpu > 1: model = DataParallel(model) @@ -118,7 +232,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de best_model = {'joint goal accuracy': 0.0, 'request f1 score': 0.0, 'active domain f1 score': 0.0, - 'goodbye act f1 score': 0.0, + 'general act f1 score': 0.0, 'train loss': np.inf} if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')): logger.info("Optimizer loaded from previous run.") @@ -136,27 +250,27 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de model.eval() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader) + jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True) # Set model back to training mode model.train() model.zero_grad() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False) else: - jg_acc, req_f1, dom_f1, bye_f1 = 0.0, 0.0, 0.0, 0.0 + jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0 best_model['joint goal accuracy'] = jg_acc best_model['request f1 score'] = req_f1 best_model['active domain f1 score'] = dom_f1 - best_model['goodbye act f1 score'] = bye_f1 + best_model['general act f1 score'] = gen_f1 # Log training set up - logger.info("Device: %s, Number of GPUs: %s, FP16 training: %s" % (device, args.n_gpu, args.fp16)) + logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}") logger.info("***** Running training *****") - logger.info(" Num Batches = %d" % len(train_dataloader)) - logger.info(" Num Epochs = %d" % args.num_train_epochs) - logger.info(" Gradient Accumulation steps = %d" % args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d" % t_total) + logger.info(f" Num Batches = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {t_total}") # Initialise training parameters global_step = 0 @@ -173,11 +287,11 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(" Continuing training from epoch %d" % epochs_trained) - logger.info(" Continuing training from global step %d" % global_step) - logger.info(" Will skip the first %d steps in the first epoch" % steps_trained_in_current_epoch) + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {global_step}") + logger.info(f" Will skip the first {steps_trained_in_current_epoch} steps in the first epoch") except ValueError: - logger.info(" Starting fine-tuning.") + logger.info(f" Starting fine-tuning.") # Prepare model for training tr_loss, logging_loss = 0.0, 0.0 @@ -196,43 +310,15 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de continue # Extract all label dictionaries from the batch - if 'goodbye_belief' in batch: - labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids - if ('belief-' + slot) in batch} - request_labels = {slot: batch['request_belief-' + slot].to(device) - for slot in model.requestable_slot_ids - if ('request_belief-' + slot) in batch} if args.predict_actions else None - domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids - if ('domain_belief-' + domain) in batch} if args.predict_actions else None - goodbye_labels = batch['goodbye_belief'].to( - device) if args.predict_actions else None - else: - labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids - if ('labels-' + slot) in batch} - request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids - if ('request-' + slot) in batch} if args.predict_actions else None - domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids - if ('active-' + domain) in batch} if args.predict_actions else None - goodbye_labels = batch['goodbye'].to( - device) if args.predict_actions else None - - # Extract all model inputs from batch - input_ids = batch['input_ids'].to(device) - token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None + input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids, + model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device) # Set up temperature scaling for the model if temp_scheduler is not None: - model.temp = temp_scheduler.temp() + model.setsumbt.temp = temp_scheduler.temp() # Forward pass to obtain loss - loss, _, _, _, _, _, stats = model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - inform_labels=labels, - request_labels=request_labels, - domain_labels=domain_labels, - goodbye_labels=goodbye_labels) + loss, _, _, _, _, _, stats = model(**input_dict) if args.n_gpu > 1: loss = loss.mean() @@ -258,7 +344,6 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de tb_writer.add_scalar('LearningRate', lr, global_step) if stats: - # print(stats.keys()) for slot, stats_slot in stats.items(): for key, item in stats_slot.items(): tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step) @@ -273,7 +358,6 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de tr_loss += loss.float().item() epoch_iterator.set_postfix(loss=loss.float().item()) - loss = 0.0 global_step += 1 # Save model checkpoint @@ -286,52 +370,34 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de model.eval() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader) + jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader, + is_train=True) # Log model eval information - if req_f1 is not None: - logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f' - % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) - tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) - tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) - tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step) - tb_writer.add_scalar('DomainF1Score/Dev', dom_f1, global_step) - tb_writer.add_scalar('GoodbyeF1Score/Dev', bye_f1, global_step) - else: - logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' - % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc)) - tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) - tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) - tb_writer.add_scalar('Loss/Dev', loss, global_step) - if stats: - for slot, stats_slot in stats.items(): - for key, item in stats_slot.items(): - tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step) + log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats) # Set model back to training mode model.train() model.zero_grad() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False) else: - jg_acc, req_f1 = 0.0, None - logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps)) + log_info(global_step, logging_loss / args.save_steps) logging_loss = tr_loss # Compute the score of the best model try: - best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \ - (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \ - (best_model['goodbye act f1 score'] * - model.config.user_general_act_loss_weight) + best_score = best_model['request f1 score'] * model.config.user_request_loss_weight + best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight + best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight except AttributeError: best_score = 0.0 best_score += best_model['joint goal accuracy'] # Compute the score of the current model try: - current_score = (req_f1 * model.config.user_request_loss_weight) + \ - (dom_f1 * model.config.active_domain_loss_weight) + \ - (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0 + current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0 + current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0 + current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0 except AttributeError: current_score = 0.0 current_score += jg_acc @@ -353,10 +419,10 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de if req_f1: best_model['request f1 score'] = req_f1 best_model['active domain f1 score'] = dom_f1 - best_model['goodbye act f1 score'] = bye_f1 + best_model['general act f1 score'] = gen_f1 best_model['train loss'] = tr_loss / global_step - output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) + output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}") if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) @@ -386,14 +452,15 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de epoch_iterator.close() break - logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step)) + steps_trained_in_current_epoch = 0 + logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}') if args.max_training_steps > 0 and global_step > args.max_training_steps: train_iterator.close() break if args.patience > 0 and steps_since_last_update >= args.patience: train_iterator.close() - logger.info('Model has not improved for at least %i steps. Training stopped!' % args.patience) + logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!') break # Evaluate final model @@ -401,30 +468,25 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de model.eval() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader) - if req_f1 is not None: - logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f' - % (tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) - else: - logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' - % (tr_loss / global_step, jg_acc, sl_acc)) + jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader, + is_train=True) + + log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) else: - jg_acc = 0.0 logger.info('Training complete!') # Store final model try: - best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \ - (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \ - (best_model['goodbye act f1 score'] * - model.config.user_general_act_loss_weight) + best_score = best_model['request f1 score'] * model.config.user_request_loss_weight + best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight + best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight except AttributeError: best_score = 0.0 best_score += best_model['joint goal accuracy'] try: - current_score = (req_f1 * model.config.user_request_loss_weight) + \ - (dom_f1 * model.config.active_domain_loss_weight) + \ - (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0 + current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0 + current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0 + current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0 except AttributeError: current_score = 0.0 current_score += jg_acc @@ -456,225 +518,89 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_de torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt")) clear_checkpoints(args.output_dir) else: - logger.info( - 'Final model not saved, since it is not the best performing model.') + logger.info('Final model not saved, as it is not the best performing model.') -# Function for validation -def train_eval(args, model, device, dev_dataloader): - """Evaluate Model during training!""" - accuracy_jg = [] - accuracy_sl = [] - accuracy_req = [] - truepos_req, falsepos_req, falseneg_req = [], [], [] - truepos_dom, falsepos_dom, falseneg_dom = [], [], [] - truepos_bye, falsepos_bye, falseneg_bye = [], [], [] - accuracy_dom = [] - accuracy_bye = [] - turns = [] - for batch in dev_dataloader: - # Perform with no gradients stored - with torch.no_grad(): - if 'goodbye_belief' in batch: - labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids - if ('belief-' + slot) in batch} - request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids - if ('request_belief-' + slot) in batch} if args.predict_actions else None - domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids - if ('domain_belief-' + domain) in batch} if args.predict_actions else None - goodbye_labels = batch['goodbye_belief'].to( - device) if args.predict_actions else None - else: - labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids - if ('labels-' + slot) in batch} - request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids - if ('request-' + slot) in batch} if args.predict_actions else None - domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids - if ('active-' + domain) in batch} if args.predict_actions else None - goodbye_labels = batch['goodbye'].to( - device) if args.predict_actions else None - - input_ids = batch['input_ids'].to(device) - token_type_ids = batch['token_type_ids'].to( - device) if 'token_type_ids' in batch else None - attention_mask = batch['attention_mask'].to( - device) if 'attention_mask' in batch else None - - loss, p, p_req, p_dom, p_bye, _, stats = model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - inform_labels=labels, - request_labels=request_labels, - domain_labels=domain_labels, - goodbye_labels=goodbye_labels) +def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False): + """ + Evaluate model - jg_acc = 0.0 - req_acc = 0.0 - req_tp, req_fp, req_fn = 0.0, 0.0, 0.0 - dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0 - dom_acc = 0.0 - for slot in model.informable_slot_ids: - labels = batch['labels-' + slot].to(device) - p_ = p[slot] - - acc = (p_.argmax(-1) == labels).reshape(-1).float() - jg_acc += acc - - if model.config.predict_actions: - for slot in model.requestable_slot_ids: - p_req_ = p_req[slot] - request_labels = batch['request-' + slot].to(device) + Args: + args: Runtime arguments + model: SetSUMBT model instance + device: Torch device in use + dataloader: Dataloader of data to evaluate on + return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format + is_train: If true model is training and no logging is performed - acc = (p_req_.round().int() == request_labels).reshape(-1).float() - tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float() - fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float() - fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float() - req_acc += acc - req_tp += tp - req_fp += fp - req_fn += fn - - for domain in model.domain_ids: - p_dom_ = p_dom[domain] - domain_labels = batch['active-' + domain].to(device) - - acc = (p_dom_.round().int() == domain_labels).reshape(-1).float() - tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float() - fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float() - fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float() - dom_acc += acc - dom_tp += tp - dom_fp += fp - dom_fn += fn - - goodbye_labels = batch['goodbye'].to(device) - bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum() - bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum() - bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum() - bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum() - else: - req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0) - req_tp, req_fp, req_fn = None, None, None - dom_tp, dom_fp, dom_fn = None, None, None - bye_tp, bye_fp, bye_fn = torch.tensor( - 0.0), torch.tensor(0.0), torch.tensor(0.0) - - sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float() - jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float() - req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0) - req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0) - req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0) - req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0) - dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) - dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) - dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) - dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0) - n_turns = (labels >= 0).reshape(-1).sum().float().item() - - accuracy_jg.append(jg_acc.item()) - accuracy_sl.append(sl_acc.item()) - accuracy_req.append(req_acc.item()) - truepos_req.append(req_tp.item()) - falsepos_req.append(req_fp.item()) - falseneg_req.append(req_fn.item()) - accuracy_dom.append(dom_acc.item()) - truepos_dom.append(dom_tp.item()) - falsepos_dom.append(dom_fp.item()) - falseneg_dom.append(dom_fn.item()) - accuracy_bye.append(bye_acc.item()) - truepos_bye.append(bye_tp.item()) - falsepos_bye.append(bye_fp.item()) - falseneg_bye.append(bye_fn.item()) - turns.append(n_turns) - - # Global accuracy reduction across batches - turns = sum(turns) - jg_acc = sum(accuracy_jg) / turns - sl_acc = sum(accuracy_sl) / turns - if model.config.predict_actions: - req_acc = sum(accuracy_req) / turns - req_tp = sum(truepos_req) - req_fp = sum(falsepos_req) - req_fn = sum(falseneg_req) - req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn)) - dom_acc = sum(accuracy_dom) / turns - dom_tp = sum(truepos_dom) - dom_fp = sum(falsepos_dom) - dom_fn = sum(falseneg_dom) - dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn)) - bye_tp = sum(truepos_bye) - bye_fp = sum(falsepos_bye) - bye_fn = sum(falseneg_bye) - bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn)) - bye_acc = sum(accuracy_bye) / turns - else: - req_acc, dom_acc, bye_acc = None, None, None - req_f1, dom_f1, bye_f1 = None, None, None - - return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats - - -def evaluate(args, model, device, dataloader): - """Evaluate Model!""" - # Evaluate! - logger.info("***** Running evaluation *****") - logger.info(" Num Batches = %d", len(dataloader)) + Returns: + out: Evaluated model statistics + """ + return_eval_output = False if is_train else return_eval_output + if not is_train: + logger.info("***** Running evaluation *****") + logger.info(" Num Batches = %d", len(dataloader)) tr_loss = 0.0 model.eval() + if return_eval_output: + ontology = dataloader.dataset.ontology - # logits = {slot: [] for slot in model.informable_slot_ids} accuracy_jg = [] accuracy_sl = [] - accuracy_req = [] truepos_req, falsepos_req, falseneg_req = [], [], [] truepos_dom, falsepos_dom, falseneg_dom = [], [], [] - truepos_bye, falsepos_bye, falseneg_bye = [], [], [] - accuracy_dom = [] - accuracy_bye = [] + truepos_gen, falsepos_gen, falseneg_gen = [], [], [] turns = [] - epoch_iterator = tqdm(dataloader, desc="Iteration") + if return_eval_output: + evaluation_output = [] + epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader for batch in epoch_iterator: with torch.no_grad(): - if 'goodbye_belief' in batch: - labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids - if ('belief-' + slot) in batch} - request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids - if ('request_belief-' + slot) in batch} if args.predict_actions else None - domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids - if ('domain_belief-' + domain) in batch} if args.predict_actions else None - goodbye_labels = batch['goodbye_belief'].to( - device) if args.predict_actions else None - else: - labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids - if ('labels-' + slot) in batch} - request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids - if ('request-' + slot) in batch} if args.predict_actions else None - domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids - if ('active-' + domain) in batch} if args.predict_actions else None - goodbye_labels = batch['goodbye'].to( - device) if args.predict_actions else None - - input_ids = batch['input_ids'].to(device) - token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None - - loss, p, p_req, p_dom, p_bye, _, _ = model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - inform_labels=labels, - request_labels=request_labels, - domain_labels=domain_labels, - goodbye_labels=goodbye_labels) + input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids, + model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device) + + loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict) jg_acc = 0.0 + num_inform_slots = 0.0 req_acc = 0.0 req_tp, req_fp, req_fn = 0.0, 0.0, 0.0 dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0 dom_acc = 0.0 - for slot in model.informable_slot_ids: + + if return_eval_output: + eval_output_batch = [] + for dial_id, dial in enumerate(input_dict['input_ids']): + for turn_id, turn in enumerate(dial): + if turn.sum() != 0: + eval_output_batch.append({'dial_idx': dial_id, + 'utt_idx': turn_id, + 'state': dict(), + 'predictions': {'state': dict()} + }) + + for slot in model.setsumbt.informable_slot_ids: p_ = p[slot] - labels = batch['labels-' + slot].to(device) + state_labels = batch['state_labels-' + slot].to(device) + + if return_eval_output: + prediction = p_.argmax(-1) + + for sample in eval_output_batch: + dom, slt = slot.split('-', 1) + lab = state_labels[sample['dial_idx']][sample['utt_idx']].item() + if lab != -1: + lab = ontology[dom][slt]['possible_values'][lab] + pred = prediction[sample['dial_idx']][sample['utt_idx']].item() + pred = ontology[dom][slt]['possible_values'][pred] + + if dom not in sample['state']: + sample['state'][dom] = dict() + sample['predictions']['state'][dom] = dict() + + sample['state'][dom][slt] = lab if lab != 'none' else '' + sample['predictions']['state'][dom][slt] = pred if pred != 'none' else '' if args.temp_scaling > 0.0: p_ = torch.log(p_ + 1e-10) / args.temp_scaling @@ -683,28 +609,19 @@ def evaluate(args, model, device, dataloader): p_ = torch.log(p_ + 1e-10) / 1.0 p_ = torch.softmax(p_, -1) - # logits[slot].append(p_) - - if args.accuracy_samples > 0: - dist = Categorical(probs=p_.reshape(-1, p_.size(-1))) - lab_sample = dist.sample((args.accuracy_samples,)) - lab_sample = lab_sample.transpose(0, 1) - acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)] - acc = torch.tensor(acc).float() - elif args.accuracy_topn > 0: - labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True) - labs = labs[:, :args.accuracy_topn] - acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)] - acc = torch.tensor(acc).float() - else: - acc = (p_.argmax(-1) == labels).reshape(-1).float() + acc = (p_.argmax(-1) == state_labels).reshape(-1).float() jg_acc += acc + num_inform_slots += (state_labels != -1).float().reshape(-1) + + if return_eval_output: + evaluation_output += deepcopy(eval_output_batch) + eval_output_batch = [] if model.config.predict_actions: - for slot in model.requestable_slot_ids: + for slot in model.setsumbt.requestable_slot_ids: p_req_ = p_req[slot] - request_labels = batch['request-' + slot].to(device) + request_labels = batch['request_labels-' + slot].to(device) acc = (p_req_.round().int() == request_labels).reshape(-1).float() tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float() @@ -715,85 +632,93 @@ def evaluate(args, model, device, dataloader): req_fp += fp req_fn += fn - for domain in model.domain_ids: + domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch] + for domain in domains: p_dom_ = p_dom[domain] - domain_labels = batch['active-' + domain].to(device) + active_domain_labels = batch['active_domain_labels-' + domain].to(device) - acc = (p_dom_.round().int() == domain_labels).reshape(-1).float() - tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float() - fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float() - fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float() + acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float() + tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float() + fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float() + fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float() dom_acc += acc dom_tp += tp dom_fp += fp dom_fn += fn - goodbye_labels = batch['goodbye'].to(device) - bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum() - bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum() - bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum() - bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum() + general_act_labels = batch['general_act_labels'].to(device) + gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum() + gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum() + gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum() else: - req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0) req_tp, req_fp, req_fn = None, None, None dom_tp, dom_fp, dom_fn = None, None, None - bye_tp, bye_fp, bye_fn = torch.tensor( - 0.0), torch.tensor(0.0), torch.tensor(0.0) - - sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float() - jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float() - req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0) - req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0) - req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0) - req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0) - dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) - dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) - dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) - dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0) - n_turns = (labels >= 0).reshape(-1).sum().float().item() + gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) + + jg_acc = jg_acc[num_inform_slots > 0] + num_inform_slots = num_inform_slots[num_inform_slots > 0] + sl_acc = sum(jg_acc / num_inform_slots).float() + jg_acc = sum((jg_acc == num_inform_slots).int()).float() + if req_tp is not None and model.setsumbt.requestable_slot_ids: + req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float() + req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float() + req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float() + else: + req_tp, req_fp, req_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) + dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) + dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) + dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) + n_turns = num_inform_slots.size(0) accuracy_jg.append(jg_acc.item()) accuracy_sl.append(sl_acc.item()) - accuracy_req.append(req_acc.item()) truepos_req.append(req_tp.item()) falsepos_req.append(req_fp.item()) falseneg_req.append(req_fn.item()) - accuracy_dom.append(dom_acc.item()) truepos_dom.append(dom_tp.item()) falsepos_dom.append(dom_fp.item()) falseneg_dom.append(dom_fn.item()) - accuracy_bye.append(bye_acc.item()) - truepos_bye.append(bye_tp.item()) - falsepos_bye.append(bye_fp.item()) - falseneg_bye.append(bye_fn.item()) + truepos_gen.append(gen_tp.item()) + falsepos_gen.append(gen_fp.item()) + falseneg_gen.append(gen_fn.item()) turns.append(n_turns) tr_loss += loss.item() - # for slot in logits: - # logits[slot] = torch.cat(logits[slot], 0) - # Global accuracy reduction across batches turns = sum(turns) jg_acc = sum(accuracy_jg) / turns sl_acc = sum(accuracy_sl) / turns if model.config.predict_actions: - req_acc = sum(accuracy_req) / turns req_tp = sum(truepos_req) req_fp = sum(falsepos_req) req_fn = sum(falseneg_req) - req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn)) - dom_acc = sum(accuracy_dom) / turns + req_f1 = req_tp + 0.5 * (req_fp + req_fn) + req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0 dom_tp = sum(truepos_dom) dom_fp = sum(falsepos_dom) dom_fn = sum(falseneg_dom) - dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn)) - bye_tp = sum(truepos_bye) - bye_fp = sum(falsepos_bye) - bye_fn = sum(falseneg_bye) - bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn)) - bye_acc = sum(accuracy_bye) / turns + dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn) + dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0 + gen_tp = sum(truepos_gen) + gen_fp = sum(falsepos_gen) + gen_fn = sum(falseneg_gen) + gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn) + gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0 else: - req_acc, dom_acc, bye_acc = None, None, None - req_f1, dom_f1, bye_f1 = None, None, None - - return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, tr_loss / len(dataloader) + req_f1, dom_f1, gen_f1 = None, None, None + + if return_eval_output: + dial_idx = 0 + for sample in evaluation_output: + if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0: + dial_idx = 0 + elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0: + dial_idx += 1 + elif sample['utt_idx'] == 0: + dial_idx += 1 + sample['dial_idx'] = dial_idx + + return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output + if is_train: + return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats + return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader) diff --git a/convlab/dst/setsumbt/multiwoz/Tracker.py b/convlab/dst/setsumbt/multiwoz/Tracker.py deleted file mode 100644 index fed1a1a6..00000000 --- a/convlab/dst/setsumbt/multiwoz/Tracker.py +++ /dev/null @@ -1,455 +0,0 @@ -import os -import json -import copy -import logging - -import torch -import transformers -from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer) -from convlab.dst.setsumbt.modeling import (RobertaSetSUMBT, - BertSetSUMBT) - -from convlab.dst.dst import DST -from convlab.util.multiwoz.state import default_state -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from convlab.dst.rule.multiwoz import normalize_value -from convlab.util.custom_util import model_downloader - -USE_CUDA = torch.cuda.is_available() - -# Map from SetSUMBT slot names to Convlab slot names -SLOT_MAP = {'arrive by': 'arriveBy', - 'leave at': 'leaveAt', - 'price range': 'pricerange', - 'trainid': 'trainID', - 'reference': 'Ref', - 'taxi types': 'car type'} - - -class SetSUMBTTracker(DST): - - def __init__(self, model_path="", model_type="roberta", - get_turn_pooled_representation=False, - get_confidence_scores=False, - threshold='auto', - return_entropy=False, - return_mutual_info=False, - store_full_belief_state=False): - super(SetSUMBTTracker, self).__init__() - - self.model_type = model_type - self.model_path = model_path - self.get_turn_pooled_representation = get_turn_pooled_representation - self.get_confidence_scores = get_confidence_scores - self.threshold = threshold - self.return_entropy = return_entropy - self.return_mutual_info = return_mutual_info - self.store_full_belief_state = store_full_belief_state - if self.store_full_belief_state: - self.full_belief_state = {} - self.info_dict = {} - - # Download model if needed - if not os.path.exists(self.model_path): - # Get path /.../convlab/dst/setsumbt/multiwoz/models - download_path = os.path.dirname(os.path.abspath(__file__)) - download_path = os.path.join(download_path, 'models') - if not os.path.exists(download_path): - os.mkdir(download_path) - model_downloader(download_path, self.model_path) - # Downloadable model path format http://.../setsumbt_model_name.zip - self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') - self.model_path = os.path.join(download_path, self.model_path) - - # Select model type based on the encoder - if model_type == "roberta": - self.config = RobertaConfig.from_pretrained(self.model_path) - self.tokenizer = RobertaTokenizer - self.model = RobertaSetSUMBT - elif model_type == "bert": - self.config = BertConfig.from_pretrained(self.model_path) - self.tokenizer = BertTokenizer - self.model = BertSetSUMBT - else: - logging.debug("Name Error: Not Implemented") - - self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu') - - # Value dict for value normalisation - path = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - path = os.path.join(path, 'data/multiwoz/value_dict.json') - self.value_dict = json.load(open(path)) - - self.load_weights() - - def load_weights(self): - # Load tokenizer and model checkpoints - logging.info('Loading SetSUMBT pretrained model.') - self.tokenizer = self.tokenizer.from_pretrained( - self.config.tokenizer_name) - logging.info( - f'Model tokenizer loaded from {self.config.tokenizer_name}.') - self.model = self.model.from_pretrained( - self.model_path, config=self.config) - logging.info(f'Model loaded from {self.model_path}.') - - # Transfer model to compute device and setup eval environment - self.model = self.model.to(self.device) - self.model.eval() - logging.info(f'Model transferred to device: {self.device}') - - logging.info('Loading model ontology') - f = open(os.path.join(self.model_path, 'ontology.json'), 'r') - self.ontology = json.load(f) - f.close() - - db = torch.load(os.path.join(self.model_path, 'ontology.db')) - # Get slot and value embeddings - slots = {slot: db[slot] for slot in db} - values = {slot: db[slot][1] for slot in db} - del db - - # Load model ontology - self.model.add_slot_candidates(slots) - for slot in values: - self.model.add_value_candidates(slot, values[slot], replace=True) - - if self.get_confidence_scores: - logging.info('Model will output action and state confidence scores.') - if self.get_confidence_scores: - self.get_thresholds(self.threshold) - logging.info('Uncertain Querying set up and thresholds set up at:') - logging.info(self.thresholds) - if self.return_entropy: - logging.info('Model will output state distribution entropy.') - if self.return_mutual_info: - logging.info('Model will output state distribution mutual information.') - logging.info('Ontology loaded successfully.') - - self.det_dic = {} - for domain, dic in REF_USR_DA.items(): - for key, value in dic.items(): - assert '-' not in key - self.det_dic[key.lower()] = key + '-' + domain - self.det_dic[value.lower()] = key + '-' + domain - - def get_thresholds(self, threshold='auto'): - self.thresholds = {} - for slot, value_candidates in self.ontology.items(): - domain, slot = slot.split('-', 1) - slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) - slot = slot.strip().split()[1] if 'book ' in slot else slot - slot = SLOT_MAP.get(slot, slot) - - # Auto thresholds are set based on the number of value candidates per slot - if domain not in self.thresholds: - self.thresholds[domain] = {} - if threshold == 'auto': - thres = 1.0 / (float(len(value_candidates)) - 2.1) - self.thresholds[domain][slot] = max(0.05, thres) - else: - self.thresholds[domain][slot] = max(0.05, threshold) - - return self.thresholds - - def init_session(self): - self.state = default_state() - self.active_domains = {} - self.hidden_states = None - self.info_dict = {} - - def update(self, user_act=''): - prev_state = self.state - - # Convert dialogs into transformer input features (token_ids, masks, etc) - features = self.get_features(user_act) - # Model forward pass - pred_states, active_domains, user_acts, turn_pooled_representation, belief_state, entropy_, mutual_info_ = self.predict( - features) - - if entropy_ is not None: - entropy = {} - for slot, e in entropy_.items(): - domain, slot = slot.split('-', 1) - if domain not in entropy: - entropy[domain] = {} - if 'book' in slot: - assert slot.startswith('book ') - slot = slot.strip().split()[1] - slot = SLOT_MAP.get(slot, slot) - entropy[domain][slot] = e - del entropy_ - else: - entropy = None - - if mutual_info_ is not None: - mutual_info = {} - for slot, mi in mutual_info_.items(): - domain, slot = slot.split('-', 1) - if domain not in mutual_info: - mutual_info[domain] = {} - if 'book' in slot: - assert slot.startswith('book ') - slot = slot.strip().split()[1] - slot = SLOT_MAP.get(slot, slot) - mutual_info[domain][slot] = mi[0, 0] - else: - mutual_info = None - - if belief_state is not None: - bs_probs = {} - belief_state, request_dist, domain_dist, greeting_dist = belief_state - for slot, p in belief_state.items(): - domain, slot = slot.split('-', 1) - if domain not in bs_probs: - bs_probs[domain] = {} - if 'book' in slot: - assert slot.startswith('book ') - slot = slot.strip().split()[1] - slot = SLOT_MAP.get(slot, slot) - if slot not in bs_probs[domain]: - bs_probs[domain][slot] = {} - bs_probs[domain][slot]['inform'] = p - - for slot, p in request_dist.items(): - domain, slot = slot.split('-', 1) - if domain not in bs_probs: - bs_probs[domain] = {} - slot = SLOT_MAP.get(slot, slot) - if slot not in bs_probs[domain]: - bs_probs[domain][slot] = {} - bs_probs[domain][slot]['request'] = p - - for domain, p in domain_dist.items(): - if domain not in bs_probs: - bs_probs[domain] = {} - bs_probs[domain]['none'] = {'inform': p} - - if 'general' not in bs_probs: - bs_probs['general'] = {} - bs_probs['general']['none'] = greeting_dist - - new_domains = [d for d, active in active_domains.items() if active] - new_domains = [ - d for d in new_domains if not self.active_domains.get(d, False)] - self.active_domains = active_domains - - for domain in new_domains: - user_acts.append(['Inform', domain.capitalize(), 'none', 'none']) - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - # user_acts = [] - for state, value in pred_states.items(): - domain, slot = state.split('-', 1) - value = '' if value == 'none' else value - value = 'dontcare' if value == 'do not care' else value - value = 'guesthouse' if value == 'guest house' else value - if slot not in ['name', 'book']: - if domain not in new_belief_state: - if domain == 'bus': - continue - else: - logging.debug( - 'Error: domain <{}> not in belief state'.format(domain)) - slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) - assert 'semi' in new_belief_state[domain] - assert 'book' in new_belief_state[domain] - if 'book' in slot: - assert slot.startswith('book ') - slot = slot.strip().split()[1] - slot = SLOT_MAP.get(slot, slot) - - # Uncertainty clipping of state - if belief_state is not None: - if bs_probs[domain][slot].get('inform', 1.0) < self.thresholds[domain][slot]: - value = '' - - domain_dic = new_belief_state[domain] - value = normalize_value(self.value_dict, domain, slot, value) - if slot in domain_dic['semi']: - new_belief_state[domain]['semi'][slot] = value - if prev_state['belief_state'][domain]['semi'][slot] != value: - user_acts.append(['Inform', domain.capitalize( - ), REF_USR_DA[domain.capitalize()].get(slot, slot), value]) - elif slot in domain_dic['book']: - new_belief_state[domain]['book'][slot] = value - if prev_state['belief_state'][domain]['book'][slot] != value: - user_acts.append(['Inform', domain.capitalize( - ), REF_USR_DA[domain.capitalize()].get(slot, slot), value]) - elif slot.lower() in domain_dic['book']: - new_belief_state[domain]['book'][slot.lower()] = value - if prev_state['belief_state'][domain]['book'][slot.lower()] != value: - user_acts.append(['Inform', domain.capitalize( - ), REF_USR_DA[domain.capitalize()].get(slot.lower(), slot.lower()), value]) - else: - logging.debug( - 'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format( - slot, value, domain, state) - ) - - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state - new_state['active_domains'] = self.active_domains - if belief_state is not None: - new_state['belief_state_probs'] = bs_probs - if entropy is not None: - new_state['entropy'] = entropy - if mutual_info is not None: - new_state['mutual_information'] = mutual_info - - new_state['user_action'] = user_acts - - user_requests = [[a, d, s, v] - for a, d, s, v in user_acts if a == 'Request'] - for act, domain, slot, value in user_requests: - k = REF_SYS_DA[domain].get(slot, slot) - domain = domain.lower() - if domain not in new_state['request_state']: - new_state['request_state'][domain] = {} - if k not in new_state['request_state'][domain]: - new_state['request_state'][domain][k] = 0 - - if turn_pooled_representation is not None: - new_state['turn_pooled_representation'] = turn_pooled_representation - - self.state = new_state - self.info_dict = copy.deepcopy(dict(new_state)) - - return self.state - - # Model prediction function - - def predict(self, features): - # Forward Pass - mutual_info = None - with torch.no_grad(): - turn_pooled_representation = None - if self.get_turn_pooled_representation: - belief_state, request, domain, goodbye, self.hidden_states, turn_pooled_representation = self.model(input_ids=features['input_ids'], - token_type_ids=features[ - 'token_type_ids'], - attention_mask=features[ - 'attention_mask'], - hidden_state=self.hidden_states, - get_turn_pooled_representation=True) - elif self.return_mutual_info: - belief_state, request, domain, goodbye, self.hidden_states, mutual_info = self.model(input_ids=features['input_ids'], - token_type_ids=features[ - 'token_type_ids'], - attention_mask=features[ - 'attention_mask'], - hidden_state=self.hidden_states, - get_turn_pooled_representation=False, - calculate_inform_mutual_info=True) - else: - belief_state, request, domain, goodbye, self.hidden_states = self.model(input_ids=features['input_ids'], - token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], - hidden_state=self.hidden_states, - get_turn_pooled_representation=False) - - # Convert belief state into dialog state - predictions = {slot: state[0, 0, :].argmax().item() - for slot, state in belief_state.items()} - predictions = {slot: self.ontology[slot][idx] - for slot, idx in predictions.items()} - predictions = {s: v for s, v in predictions.items() if v != 'none'} - - if self.store_full_belief_state: - self.full_belief_state = belief_state - - # Obtain model output probabilities - if self.get_confidence_scores: - entropy = None - if self.return_entropy: - entropy = {slot: state[0, 0, :] - for slot, state in belief_state.items()} - entropy = {slot: self.relative_entropy( - p).item() for slot, p in entropy.items()} - - # Confidence score is the max probability across all not "none" values candidates. - belief_state = {slot: state[0, 0, 1:].max().item() - for slot, state in belief_state.items()} - request_dist = {SLOT_MAP.get( - slot, slot): p[0, 0].item() for slot, p in request.items()} - domain_dist = {domain: p[0, 0].item() - for domain, p in domain.items()} - greeting_dist = {'bye': goodbye[0, 0, 1].item( - ), 'thank': goodbye[0, 0, 2].item()} - belief_state = (belief_state, request_dist, - domain_dist, greeting_dist) - else: - belief_state = None - entropy = None - - # Construct request action prediction - request = [slot for slot, p in request.items() if p[0, 0].item() > 0.5] - request = [slot.split('-', 1) for slot in request] - request = [[domain, SLOT_MAP.get(slot, slot)] - for domain, slot in request] - request = [['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get( - slot, slot), '?'] for domain, slot in request] - - # Construct active domain set - domain = {domain: p[0, 0].item() > 0.5 for domain, p in domain.items()} - - # Construct general domain action - goodbye = goodbye[0, 0, :].argmax(-1).item() - goodbye = [[], ['bye'], ['thank']][goodbye] - goodbye = [[act, 'general', 'none', 'none'] for act in goodbye] - - user_acts = request + goodbye - - return predictions, domain, user_acts, turn_pooled_representation, belief_state, entropy, mutual_info - - def relative_entropy(self, probs): - entropy = probs * torch.log(probs + 1e-8) - entropy = -entropy.sum() - # Maximum entropy of a K dimentional distribution is ln(K) - entropy /= torch.log(torch.tensor(probs.size(-1)).float()) - - return entropy - - # Convert dialog turns into model features - def get_features(self, user_act): - # Extract system utterance from dialog history - context = self.state['history'] - if context: - if context[-1][0] != 'sys': - system_act = '' - else: - system_act = context[-1][-1] - else: - system_act = '' - - # Tokenize dialog - features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, max_length=self.config.max_turn_len, - padding='max_length', truncation='longest_first') - - input_ids = torch.tensor(features['input_ids']).reshape( - 1, 1, -1).to(self.device) if 'input_ids' in features else None - token_type_ids = torch.tensor(features['token_type_ids']).reshape( - 1, 1, -1).to(self.device) if 'token_type_ids' in features else None - attention_mask = torch.tensor(features['attention_mask']).reshape( - 1, 1, -1).to(self.device) if 'attention_mask' in features else None - features = {'input_ids': input_ids, - 'token_type_ids': token_type_ids, 'attention_mask': attention_mask} - - return features - - -# if __name__ == "__main__": -# tracker = SetSUMBTTracker(model_type='roberta', model_path='/gpfs/project/niekerk/results/nbt/convlab_setsumbt_acts') -# # nlu_path='/gpfs/project/niekerk/data/bert_multiwoz_all_context.zip') -# tracker.init_session() -# state = tracker.update('hey. I need a cheap restaurant.') -# # tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# # tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# # state = tracker.update('If you have something Asian that would be great.') -# # tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# # tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# # state = tracker.update('Great. Where are they located?') -# # tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# print(tracker.state) diff --git a/convlab/dst/setsumbt/multiwoz/__init__.py b/convlab/dst/setsumbt/multiwoz/__init__.py deleted file mode 100644 index a1f1fb89..00000000 --- a/convlab/dst/setsumbt/multiwoz/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from convlab.dst.setsumbt.multiwoz.dataset import multiwoz21, ontology -from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair deleted file mode 100644 index 34df41d0..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair +++ /dev/null @@ -1,83 +0,0 @@ -it's it is -don't do not -doesn't does not -didn't did not -you'd you would -you're you are -you'll you will -i'm i am -they're they are -that's that is -what's what is -couldn't could not -i've i have -we've we have -can't cannot -i'd i would -i'd i would -aren't are not -isn't is not -wasn't was not -weren't were not -won't will not -there's there is -there're there are -. . . -restaurants restaurant -s -hotels hotel -s -laptops laptop -s -cheaper cheap -er -dinners dinner -s -lunches lunch -s -breakfasts breakfast -s -expensively expensive -ly -moderately moderate -ly -cheaply cheap -ly -prices price -s -places place -s -venues venue -s -ranges range -s -meals meal -s -locations location -s -areas area -s -policies policy -s -children child -s -kids kid -s -kidfriendly kid friendly -cards card -s -upmarket expensive -inpricey cheap -inches inch -s -uses use -s -dimensions dimension -s -driverange drive range -includes include -s -computers computer -s -machines machine -s -families family -s -ratings rating -s -constraints constraint -s -pricerange price range -batteryrating battery rating -requirements requirement -s -drives drive -s -specifications specification -s -weightrange weight range -harddrive hard drive -batterylife battery life -businesses business -s -hours hour -s -one 1 -two 2 -three 3 -four 4 -five 5 -six 6 -seven 7 -eight 8 -nine 9 -ten 10 -eleven 11 -twelve 12 -anywhere any where -good bye goodbye diff --git a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py deleted file mode 100644 index 2c8e98f3..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py +++ /dev/null @@ -1,502 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""MultiWOZ 2.1/2.3 Dialogue Dataset""" - -import os -import json -import requests -import zipfile -import io -from shutil import copy2 as copy - -import torch -from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler -from tqdm import tqdm - -from convlab.dst.setsumbt.multiwoz.dataset.utils import (clean_text, ACTIVE_DOMAINS, get_domains, set_util_domains, - fix_delexicalisation, extract_dialogue, PRICERANGE, - BOOLEAN, DAYS, QUANTITIES, TIME, VALUE_MAP, map_values) - - -# Set up global data_directory -def set_datadir(dir): - global DATA_DIR - DATA_DIR = dir - - -def set_active_domains(domains): - global ACTIVE_DOMAINS - ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS] - set_util_domains(ACTIVE_DOMAINS) - - -# MultiWOZ2.1 download link -URL = 'https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip' -def set_url(url): - global URL - URL = url - - -# Create Dialogue examples from the dataset -def create_examples(max_utt_len, get_requestable_slots=False, force_processing=False): - - # Load or download Raw Data - if not os.path.exists(DATA_DIR): - os.mkdir(DATA_DIR) - if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')): - # Download data archive and extract - archive = _download() - data = _extract(archive) - - writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w') - json.dump(data, writer, indent = 2) - del archive, writer - else: - reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r') - data = json.load(reader) - - if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')): - # Preprocess all dialogues - data_processed = _process(data['data'], data['system_acts']) - # Format data and split train, test and devlopment sets - train, dev, test = _split_data(data_processed, data['testListFile'], - data['valListFile'], max_utt_len) - - # Write data - writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w') - json.dump(train, writer, indent = 2) - writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w') - json.dump(test, writer, indent = 2) - writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w') - json.dump(dev, writer, indent = 2) - writer.flush() - writer.close() - del writer - - # Extract slots and slot value candidates from the dataset - for set_type in ['train', 'dev', 'test']: - _get_ontology(set_type, get_requestable_slots) - - script_path = os.path.abspath(__file__).replace('/multiwoz21.py', '') - file_name = 'mwoz21_ont_request.json' if get_requestable_slots else 'mwoz21_ont.json' - copy(os.path.join(script_path, file_name), os.path.join(DATA_DIR, 'ontology_test.json')) - copy(os.path.join(script_path, 'mwoz21_slot_descriptions.json'), os.path.join(DATA_DIR, 'slot_descriptions.json')) - - -# Extract slots and slot value candidates from the dataset -def _get_ontology(set_type, get_requestable_slots=False): - - datasets = ['train'] - if set_type in ['test', 'dev']: - datasets.append('dev') - datasets.append('test') - - # Load examples - data = [] - for dataset in datasets: - reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r') - data += json.load(reader) - - ontology = dict() - for dial in data: - for turn in dial['dialogue']: - for state in turn['dialogue_state']: - slot, value = state - value = map_values(value) - if slot not in ontology: - ontology[slot] = [value] - else: - ontology[slot].append(value) - - requestable_slots = [] - if get_requestable_slots: - for dial in data: - for turn in dial['dialogue']: - for act, dom, slot, val in turn['user_acts']: - if act == 'request': - requestable_slots.append(f'{dom}-{slot}') - requestable_slots = list(set(requestable_slots)) - - for slot in ontology: - if 'price' in slot: - ontology[slot] = PRICERANGE - if 'parking' in slot or 'internet' in slot: - ontology[slot] = BOOLEAN - if 'day' in slot: - ontology[slot] = DAYS - if 'people' in slot or 'duration' in slot or 'stay' in slot: - ontology[slot] = QUANTITIES - if 'time' in slot or 'leave' in slot or 'arrive' in slot: - ontology[slot] = TIME - if 'stars' in slot: - ontology[slot] += [str(i) for i in range(5)] - - # Sort slot values and add none and dontcare values - for slot in ontology: - ontology[slot] = list(set(ontology[slot])) - ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']]) - for slot in requestable_slots: - if slot in ontology: - ontology[slot].append('request') - else: - ontology[slot] = ['request'] - - writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w') - json.dump(ontology, writer, indent=2) - writer.close() - - -# Convert dialogue examples to model input features and labels -def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64): - - features = dict() - - # Load examples - reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r') - data = json.load(reader) - - # Get encoder input for system, user utterance pairs - input_feats = [] - for dial in data: - dial_feats = [] - for turn in dial['dialogue']: - if len(turn['system_transcript']) == 0: - usr = turn['transcript'] - dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens = True, - max_length = max_seq_len, padding='max_length', - truncation = 'longest_first')) - else: - usr = turn['transcript'] - sys = turn['system_transcript'] - dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens = True, - max_length = max_seq_len, padding='max_length', - truncation = 'longest_first')) - if len(dial_feats) >= max_turns: - break - input_feats.append(dial_feats) - del dial_feats - - # Perform turn level padding - input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] - if 'token_type_ids' in input_feats[0][0]: - token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] - else: - token_type_ids = None - if 'attention_mask' in input_feats[0][0]: - attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] - else: - attention_mask = None - del input_feats - - # Create torch data tensors - features['input_ids'] = torch.tensor(input_ids) - features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None - features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None - del input_ids, token_type_ids, attention_mask - - # Load ontology - reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r') - ontology = json.load(reader) - reader.close() - - informable_slots = [slot for slot, values in ontology.items() if values != ['request']] - requestable_slots = [slot for slot, values in ontology.items() if 'request' in values] - for slot in requestable_slots: - ontology[slot].remove('request') - - domains = list(set(informable_slots + requestable_slots)) - domains = list(set([slot.split('-', 1)[0] for slot in domains])) - - # Create slot labels - for slot in informable_slots: - labels = [] - for dial in data: - labs = [] - for turn in dial['dialogue']: - slots_active = [s for s, v in turn['dialogue_state']] - if slot in slots_active: - value = [v for s, v in turn['dialogue_state'] if s == slot][0] - else: - value = 'none' - if value in ontology[slot]: - value = ontology[slot].index(value) - else: - value = map_values(value) - if value in ontology[slot]: - value = ontology[slot].index(value) - else: - value = -1 # If value is not in ontology then we do not penalise the model - labs.append(value) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['labels-' + slot] = labels - - for slot in requestable_slots: - labels = [] - for dial in data: - labs = [] - for turn in dial['dialogue']: - slots_active = [[d, s] for i, d, s, v in turn['user_acts']] - if slot.split('-', 1) in slots_active: - act_ = [i for i, d, s, v in turn['user_acts'] if f"{d}-{s}" == slot][0] - if act_ == 'request': - labs.append(1) - else: - labs.append(0) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['request-' + slot] = labels - - # Greeting act labels (0-no greeting, 1-goodbye, 2-thank you) - labels = [] - for dial in data: - labs = [] - for turn in dial['dialogue']: - greeting_active = [i for i, d, s, v in turn['user_acts'] if i in ['bye', 'thank']] - if greeting_active: - if 'bye' in greeting_active: - labs.append(1) - else : - labs.append(2) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['goodbye'] = labels - - for domain in domains: - labels = [] - for dial in data: - labs = [] - for turn in dial['dialogue']: - if domain == turn['domain']: - labs.append(1) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['active-' + domain] = labels - - del labels - - return features - - -# MultiWOZ2.1 Dataset object -class MultiWoz21(Dataset): - - def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64): - self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len) - - def __getitem__(self, index): - return {label: self.features[label][index] for label in self.features - if self.features[label] is not None} - - def __len__(self): - return self.features['input_ids'].size(0) - - def resample(self, size=None): - n_dialogues = self.__len__() - if not size: - size = n_dialogues - - dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) - self.features = {label: self.features[label][dialogues] for label in self.features - if self.features[label] is not None} - - return self - - def to(self, device): - self.device = device - self.features = {label: self.features[label].to(device) for label in self.features - if self.features[label] is not None} - - -# MultiWOZ2.1 Dataset object -class EnsembleMultiWoz21(Dataset): - def __init__(self, data): - self.features = data - - def __getitem__(self, index): - return {label: self.features[label][index] for label in self.features - if self.features[label] is not None} - - def __len__(self): - return self.features['input_ids'].size(0) - - def resample(self, size=None): - n_dialogues = self.__len__() - if not size: - size = n_dialogues - - dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) - self.features = {label: self.features[label][dialogues] for label in self.features - if self.features[label] is not None} - - def to(self, device): - self.device = device - self.features = {label: self.features[label].to(device) for label in self.features - if self.features[label] is not None} - - -# Module to create torch dataloaders -def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None): - data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len) - data.to('cpu') - - if resampled_size: - data.resample(resampled_size) - - if set_type in ['test', 'dev']: - sampler = SequentialSampler(data) - else: - sampler = RandomSampler(data) - loader = DataLoader(data, sampler=sampler, batch_size=batch_size) - - return loader - - -def _download(chunk_size=1048576): - """Download data archive. - - Parameters: - chunk_size (int): Download chunk size. (default=1048576) - Returns: - archive: ZipFile archive object. - """ - # Download the archive byte string - req = requests.get(URL, stream=True) - archive = b'' - for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'): - if chunk: - archive += chunk - - # Convert the bytestring into a zipfile object - archive = io.BytesIO(archive) - archive = zipfile.ZipFile(archive) - - return archive - - -def _extract(archive): - """Extract the json dictionaries from the archive. - - Parameters: - archive: ZipFile archive object. - Returns: - data: Data dictionary. - """ - files = [file for file in archive.filelist if ('.json' in file.filename or '.txt' in file.filename) - and 'MACOSX' not in file.filename] - objects = [] - for file in tqdm(files, desc='File'): - data = archive.open(file).read() - # Get data objects from the files - try: - data = json.loads(data) - except json.decoder.JSONDecodeError: - data = data.decode().split('\n') - objects.append(data) - - files = [file.filename.split('/')[-1].split('.')[0] for file in files] - - data = {file: data for file, data in zip(files, objects)} - return data - - -# Process files -def _process(dialogue_data, acts_data): - print('Processing Dialogues') - out = {} - for dial_name in tqdm(dialogue_data): - dialogue = dialogue_data[dial_name] - - prev_dom = '' - for turn_id, turn in enumerate(dialogue['log']): - dialogue['log'][turn_id]['text'] = clean_text(turn['text']) - if len(turn['metadata']) != 0: - crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom) - prev_dom = crnt_dom - dialogue['log'][turn_id - 1]['domain'] = crnt_dom - - dialogue['log'][turn_id] = fix_delexicalisation(turn) - - out[dial_name] = dialogue - - return out - - -# Split data (train, dev, test) -def _split_data(dial_data, test, dev, max_utt_len): - train_dials, test_dials, dev_dials = [], [], [] - print('Formatting and Splitting Data') - for name in tqdm(dial_data): - dialogue = dial_data[name] - domains = [] - - dial = extract_dialogue(dialogue, max_utt_len) - if dial: - dialogue = dict() - dialogue['dialogue_idx'] = name - dialogue['domains'] = [] - dialogue['dialogue'] = [] - - for turn_id, turn in enumerate(dial): - turn_dialog = dict() - turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else '' - turn_dialog['turn_idx'] = turn_id - turn_dialog['dialogue_state'] = turn['ds'] - turn_dialog['transcript'] = turn['usr'] - # turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else [] - turn_dialog['user_acts'] = turn['usr_a'] - turn_dialog['domain'] = turn['domain'] - dialogue['domains'].append(turn['domain']) - dialogue['dialogue'].append(turn_dialog) - - dialogue['domains'] = [d for d in list(set(dialogue['domains'])) if d != ''] - if True in [dom not in ACTIVE_DOMAINS for dom in dialogue['domains']]: - dialogue['domains'] = [] - dialogue['domains'] = [dom for dom in dialogue['domains'] if dom in ACTIVE_DOMAINS] - - if dialogue['domains']: - if name in test: - test_dials.append(dialogue) - elif name in dev: - dev_dials.append(dialogue) - else: - train_dials.append(dialogue) - - print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials))) - - return train_dials, dev_dials, test_dials diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json deleted file mode 100644 index b703793d..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json +++ /dev/null @@ -1,2990 +0,0 @@ -{ - "hotel-price range": [ - "none", - "do not care", - "cheap", - "expensive", - "moderate" - ], - "hotel-type": [ - "none", - "do not care", - "bed and breakfast", - "guest house", - "hotel" - ], - "hotel-parking": [ - "none", - "do not care", - "no", - "yes" - ], - "hotel-book day": [ - "none", - "do not care", - "friday", - "monday", - "saterday", - "sunday", - "thursday", - "tuesday", - "wednesday" - ], - "hotel-book people": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "hotel-book stay": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "train-destination": [ - "none", - "do not care", - "bishops stortford", - "kings lynn", - "london liverpool street", - "centre", - "bishop stortford", - "liverpool", - "leicester", - "broxbourne", - "gourmet burger kitchen", - "copper kettle", - "bournemouth", - "stevenage", - "liverpool street", - "norwich", - "huntingdon marriott hotel", - "city centre north", - "taj tandoori", - "the copper kettle", - "peterborough", - "ely", - "lecester", - "london", - "willi", - "stansted airport", - "huntington marriott", - "cambridge", - "gonv", - "glastonbury", - "hol", - "north", - "birmingham new street", - "norway", - "petersborough", - "london kings cross", - "curry prince", - "bishops storford" - ], - "train-arrive by": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55" - ], - "train-departure": [ - "none", - "do not care", - "bishops stortford", - "kings lynn", - "brookshite", - "london liverpool street", - "cam", - "liverpool", - "bro", - "leicester", - "broxbourne", - "norwhich", - "saint johns", - "stevenage", - "stansted", - "london liverpool", - "cambrid", - "city hall", - "rosas bed and breakfast", - "alpha-milton", - "wandlebury country park", - "norwich", - "liecester", - "stratford", - "peterborough", - "duxford", - "ely", - "london", - "stansted airport", - "lon", - "cambridge", - "panahar", - "cineworld", - "leicaster", - "birmingham", - "cafe uno", - "camboats", - "huntingdon", - "birmingham new street", - "arbu", - "alpha milton", - "east london", - "london kings cross", - "hamilton lodge", - "aylesbray lodge guest", - "el shaddai" - ], - "train-day": [ - "none", - "do not care", - "friday", - "monday", - "saterday", - "sunday", - "thursday", - "tuesday", - "wednesday" - ], - "train-book people": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "hotel-stars": [ - "none", - "do not care", - "0", - "1", - "2", - "3", - "4", - "5" - ], - "hotel-internet": [ - "none", - "do not care", - "no", - "yes" - ], - "hotel-name": [ - "a and b guest house", - "city roomz", - "carolina bed and breakfast", - "limehouse", - "anatolia", - "hamilton lodge", - "the lensfield hotel", - "rosa's bed and breakfast", - "gall", - "aylesbray lodge", - "kirkwood", - "cambridge belfry", - "warkworth house", - "gonville", - "belfy hotel", - "nus", - "alexander", - "super 5", - "aylesbray lodge guest house", - "the gonvile hotel", - "allenbell", - "nothamilton lodge", - "ashley hotel", - "autumn house", - "hobsons house", - "hotel", - "ashely hotel", - "caridge belfrey", - "el shaddia guest house", - "avalon", - "cote", - "city centre north bed and breakfast", - "the cambridge belfry", - "home from home", - "wandlebury coutn", - "wankworth house", - "city stop rest", - "the worth house", - "cityroomz", - "huntingdon marriottt hotel", - "none", - "lensfield", - "rosas bed and breakfast", - "leverton house", - "gonville hotel", - "holiday inn cambridge", - "do not care", - "archway house", - "lan hon", - "levert", - "acorn guest house", - "cambridge", - "the ashley hotel", - "el shaddai", - "sleeperz", - "alpha milton guest house", - "doubletree by hilton cambridge", - "tandoori palace", - "express by", - "express by holiday inn cambridge", - "north bed and breakfast", - "holiday inn", - "arbury lodge guest house", - "alexander bed and breakfast", - "huntingdon marriott hotel", - "royal spice", - "sou", - "finches bed and breakfast", - "the alpha milton", - "bridge guest house", - "the acorn guest house", - "kirkwood house", - "eraina", - "la margherit", - "lensfield hotel", - "marriott hotel", - "nusha", - "city centre bed and breakfast", - "the allenbell", - "university arms hotel", - "clare", - "cherr", - "wartworth", - "acorn place", - "lovell lodge", - "whale" - ], - "train-leave at": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55" - ], - "restaurant-price range": [ - "none", - "do not care", - "cheap", - "expensive", - "moderate" - ], - "restaurant-food": [ - "british food", - "steakhouse", - "turkish", - "sushi", - "north american", - "scottish", - "french", - "austrian", - "korean", - "eastern european", - "swedish", - "gastro pub", - "modern eclectic", - "afternoon tea", - "welsh", - "christmas", - "tuscan", - "gastropub", - "sri lankan", - "molecular gastronomy", - "traditional american", - "italian", - "pizza", - "thai", - "south african", - "creative", - "english", - "asian", - "lebanese", - "hungarian", - "halal", - "portugese", - "modern english", - "african", - "light bites", - "malaysian", - "venetian", - "traditional", - "chinese", - "vegetarian", - "persian", - "thai and chinese", - "scandinavian", - "catalan", - "polynesian", - "crossover", - "canapes", - "cantonese", - "north african", - "seafood", - "brazilian", - "south indian", - "australasian", - "belgian", - "barbeque", - "the americas", - "indonesian", - "singaporean", - "irish", - "middle eastern", - "dojo noodle bar", - "caribbean", - "vietnamese", - "modern european", - "russian", - "none", - "german", - "world", - "japanese", - "moroccan", - "modern global", - "do not care", - "indian", - "british", - "american", - "danish", - "panasian", - "swiss", - "basque", - "north indian", - "modern american", - "australian", - "european", - "corsica", - "greek", - "northern european", - "mediterranean", - "portuguese", - "romanian", - "jamaican", - "polish", - "international", - "unusual", - "latin american", - "asian oriental", - "mexican", - "bistro", - "cuban", - "fusion", - "new zealand", - "spanish", - "eritrean", - "afghan", - "kosher" - ], - "attraction-name": [ - "downing college", - "fitzwilliam", - "clare college", - "ruskin gallery", - "sidney sussex college", - "great saint mary's church", - "cherry hinton water play park", - "wandlebury country park", - "cafe uno", - "place", - "broughton", - "cineworld cinema", - "jesus college", - "vue cinema", - "history of science museum", - "mumford theatre", - "whale of time", - "fitzbillies", - "christs church", - "churchill college", - "museum of classical archaeology", - "gonville and caius college", - "pizza", - "kirkwood", - "saint catharines college", - "kings college", - "parkside", - "by", - "st catharines college", - "saint john's college", - "cherry hinton water park", - "st christs college", - "christ's college", - "bangkok city", - "scudamores punti co", - "free", - "great saint marys church", - "milton country park", - "the fez club", - "soultree", - "autu", - "whipple museum of the history of science", - "aylesbray lodge guest house", - "broughton house gallery", - "peoples portraits exhibition", - "primavera", - "kettles yard", - "all saint's church", - "cinema cinema", - "regency gallery", - "corpus christi", - "corn cambridge exchange", - "da vinci pizzeria", - "school", - "hobsons house", - "cambride and country folk museum", - "north", - "da v", - "cambridge corn exchange", - "soul tree nightclub", - "cambridge arts theater", - "saint catharine's college", - "byard art", - "cambridge punter", - "cambridge university botanic gardens", - "castle galleries", - "museum of archaelogy and anthropogy", - "no specific location", - "cherry hinton hall", - "gallery at 12 a high street", - "parkside pools", - "queen's college", - "little saint mary's church", - "gallery", - "home from home", - "tenpin", - "the wandlebury", - "county folk museum", - "swimming pool", - "christs college", - "cafe jello museum", - "scott polar", - "christ college", - "cambridge museum of technology", - "abbey pool and astroturf pitch", - "king hedges learner pool", - "the cambridge arts theatre", - "the castle galleries", - "cambridge and country folk museum", - "kohinoor", - "scudamores punting co", - "sidney sussex", - "the man on the moon", - "little saint marys church", - "queens", - "the place", - "old school", - "churchill", - "churchills college", - "hughes hall", - "churchhill college", - "riverboat georgina", - "none", - "belf", - "cambridge temporary art", - "abc theatre", - "cambridge contemporary art museum", - "man on the moon", - "the junction", - "cherry hinton water play", - "adc theatre", - "gonville hotel", - "magdalene college", - "peoples portraits exhibition at girton college", - "boat", - "centre", - "sheep's green and lammas land park fen causeway", - "do not care", - "the mumford theatre", - "archway house", - "queens' college", - "williams art and antiques", - "funky fun house", - "cherry hinton village centre", - "camboats", - "cambridge", - "old schools", - "kettle's yard", - "whale of a time", - "the churchill college", - "cafe jello gallery", - "aut", - "salsa", - "city", - "clare hall", - "boating", - "pembroke college", - "kings hedges learner pool", - "caffe uno", - "lammas land park", - "museum", - "the fitzwilliam museum", - "the cherry hinton village centre", - "the cambridge corn exchange", - "fitzwilliam museum", - "museum of archaelogy and anthropology", - "fez club", - "the cambridge punter", - "saint johns college", - "emmanuel college", - "cambridge belf", - "scudamore", - "lynne strover gallery", - "king's college", - "whippple museum", - "trinity college", - "college in the north", - "sheep's green", - "kambar", - "museum of archaelogy", - "adc", - "garde", - "club salsa", - "people's portraits exhibition at girton college", - "botanic gardens", - "carol", - "college", - "gallery at twelve a high street", - "abbey pool and astroturf", - "cambridge book and print gallery", - "jesus green outdoor pool", - "scott polar museum", - "saint barnabas press gallery", - "cambridge artworks", - "older churches", - "cambridge contemporary art", - "cherry hinton hall and grounds", - "univ", - "jesus green", - "ballare", - "abbey pool", - "cambridge botanic gardens", - "nusha", - "worth house", - "thanh", - "university arms hotel", - "cambridge arts theatre", - "cafe jello", - "cambridge and county folk museum", - "the cambridge artworks", - "all saints church", - "holy trinity church", - "contemporary art museum", - "architectural churches", - "queens college", - "trinity street college" - ], - "restaurant-name": [ - "none", - "do not care", - "hotel du vin and bistro", - "ask", - "gourmet formal kitchen", - "the meze bar", - "lan hong house", - "cow pizza", - "one seven", - "prezzo", - "maharajah tandoori restaurant", - "alex", - "shanghai", - "golden wok", - "restaurant", - "fitzbillies", - "nil", - "copper kettle", - "meghna", - "hk fusion", - "bangkok city", - "hobsons house", - "tang chinese", - "anatolia", - "ugly duckling", - "anatolia and efes restaurant", - "sitar tandoori", - "city stop", - "ashley", - "pizza express fen ditton", - "molecular gastronomy", - "autumn house", - "el shaddia guesthouse", - "the grafton hotel", - "limehouse", - "gardenia", - "not metioned", - "hakka", - "michaelhouse cafe", - "pipasha", - "meze bar", - "archway", - "molecular gastonomy", - "yipee noodle bar", - "the peking", - "curry prince", - "midsummer house restaurant", - "pizza hut cherry hinton", - "the lucky star", - "stazione restaurant and coffee bar", - "shanghi family restaurant", - "good luck", - "j restaurant", - "bedouin", - "cott", - "little seoul", - "south", - "thanh binh", - "el", - "efes restaurant", - "kohinoor", - "clowns", - "india", - "the slug and lettuce", - "shiraz", - "barbakan", - "zizzi cambridge", - "restaurant one seven", - "slug and lettuce", - "travellers rest", - "binh", - "worth house", - "broughton house gallery", - "chiquito", - "the river bar steakhouse and grill", - "ros", - "golden house", - "india west", - "cam", - "panahar", - "restaurant 22", - "adden", - "indian", - "hu", - "jinling noodle bar", - "darrys cookhouse and wine shop", - "hobson house", - "cambridge be", - "el shaddai", - "ac", - "nandos", - "cambridge lodge", - "the cow pizza kitchen and bar", - "charlie", - "rajmahal", - "kymmoy", - "cambri", - "backstreet bistro", - "galleria", - "restaurant 2 two", - "chiquito restaurant bar", - "royal standard", - "lucky star", - "curry king", - "grafton hotel restaurant", - "mahal of cambridge", - "the bedouin", - "nus", - "the kohinoor", - "pizza hut fenditton", - "camboats", - "the gardenia", - "de luca cucina and bar", - "nusha", - "european", - "taj tandoori", - "tandoori palace", - "golden curry", - "efes", - "loch fyne", - "the maharajah tandoor", - "lovel", - "restaurant 17", - "clowns cafe", - "cambridge punter", - "bloomsbury restaurant", - "la mimosa", - "the cambridge chop house", - "funky", - "cotto", - "oak bistro", - "restaurant two two", - "pipasha restaurant", - "river bar steakhouse and grill", - "royal spice", - "the copper kettle", - "graffiti", - "nandos city centre", - "saffron brasserie", - "cambridge chop house", - "sitar", - "kitchen and bar", - "the good luck chinese food takeaway", - "clu", - "la tasca", - "cafe uno", - "cote", - "the varsity restaurant", - "bri", - "eraina", - "bridge", - "fin", - "cambridge lodge restaurant", - "grafton", - "hotpot", - "sala thong", - "margherita", - "wise buddha", - "the missing sock", - "seasame restaurant and bar", - "the dojo noodle bar", - "restaurant alimentum", - "gastropub", - "saigon city", - "la margherita", - "pizza hut", - "curry garden", - "ashley hotel", - "eraina and michaelhouse cafe", - "the golden curry", - "curry queen", - "cow pizza kitchen and bar", - "the peking restaurant:", - "hamilton lodge", - "alimentum", - "yippee noodle bar", - "2 two and cote", - "shanghai family restaurant", - "grafton hotel", - "yes", - "ali baba", - "dif", - "fitzbillies restaurant", - "peking restaurant", - "lev", - "nirala", - "the alex", - "tandoori", - "city stop restaurant", - "rice house", - "cityr", - "yu garden", - "meze bar restaurant", - "the", - "don pasquale pizzeria", - "rice boat", - "the hotpot", - "old school", - "the oak bistro", - "sesame restaurant and bar", - "pizza express", - "the gandhi", - "pizza hut fen ditton", - "charlie chan", - "da vinci pizzeria", - "dojo noodle bar", - "gourmet burger kitchen", - "the golden house", - "india house", - "hobso", - "missing sock", - "pizza hut city centre", - "parkside pools", - "riverside brasserie", - "caffe uno", - "primavera", - "the nirala", - "wagamama", - "au", - "ian hong house", - "frankie and bennys", - "4 kings parade city centre", - "shiraz restaurant", - "scudamores punt", - "mahal", - "saint johns chop house", - "de luca cucina and bar riverside brasserie", - "cocum", - "la raza" - ], - "attraction-type": [ - "none", - "do not care", - "architecture", - "boat", - "boating", - "camboats", - "church", - "churchills college", - "cinema", - "college", - "concert", - "concerthall", - "entertainment", - "gallery", - "gastropub", - "hiking", - "hotel", - "multiple sports", - "museum", - "museum kettles yard", - "night club", - "outdoor", - "park", - "pool", - "special", - "sports", - "swimming pool", - "theater", - "theatre", - "concert hall", - "local site", - "nightclub", - "hotspot" - ], - "taxi-leave at": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55" - ], - "taxi-destination": [ - "none", - "do not care", - "a and b guest house", - "abbey pool and astroturf pitch", - "acorn guest house", - "adc theatre", - "addenbrookes hospital", - "alexander bed and breakfast", - "ali baba", - "all saints church", - "allenbell", - "alpha milton guest house", - "anatolia", - "arbury lodge guesthouse", - "archway house", - "ashley hotel", - "ask", - "attraction", - "autumn house", - "avalon", - "aylesbray lodge guest house", - "backstreet bistro", - "ballare", - "bangkok city", - "bedouin", - "birmingham new street train station", - "bishops stortford train station", - "bloomsbury restaurant", - "bridge guest house", - "broughton house gallery", - "broxbourne train station", - "byard art", - "cafe jello gallery", - "cafe uno", - "camboats", - "cambridge", - "cambridge and county folk museum", - "cambridge arts theatre", - "cambridge artworks", - "cambridge belfry", - "cambridge book and print gallery", - "cambridge chop house", - "cambridge contemporary art", - "cambridge county fair next to the city tourist museum", - "cambridge lodge restaurant", - "cambridge museum of technology", - "cambridge punter", - "cambridge road church of christ", - "cambridge train station", - "cambridge university botanic gardens", - "carolina bed and breakfast", - "castle galleries", - "charlie chan", - "cherry hinton hall and grounds", - "cherry hinton village centre", - "cherry hinton water park", - "cherry hinton water play", - "chiquito restaurant bar", - "christ college", - "churchills college", - "cineworld cinema", - "city centre north bed and breakfast", - "city stop restaurant", - "cityroomz", - "clare college", - "clare hall", - "clowns cafe", - "club salsa", - "cocum", - "copper kettle", - "corpus christi", - "cote", - "cotto", - "cow pizza kitchen and bar", - "curry garden", - "curry king", - "curry prince", - "da vinci pizzeria", - "darrys cookhouse and wine shop", - "de luca cucina and bar", - "dojo noodle bar", - "don pasquale pizzeria", - "downing college", - "efes restaurant", - "el shaddia guesthouse", - "ely train station", - "emmanuel college", - "eraina", - "express by holiday inn cambridge", - "finches bed and breakfast", - "finders corner newmarket road", - "fitzbillies restaurant", - "fitzwilliam museum", - "frankie and bennys", - "funky fun house", - "galleria", - "gallery at 12 a high street", - "gastropub", - "golden curry", - "golden house", - "golden wok", - "gonville and caius college", - "gonville hotel", - "good luck", - "gourmet burger kitchen", - "graffiti", - "grafton hotel restaurant", - "great saint marys church", - "hakka", - "hamilton lodge", - "hk fusion", - "hobsons house", - "holy trinity church", - "home from home", - "hotel du vin and bistro", - "hughes hall", - "huntingdon marriott hotel", - "ian hong", - "india house", - "j restaurant", - "jesus college", - "jesus green outdoor pool", - "jinling noodle bar", - "kambar", - "kettles yard", - "kings college", - "kings hedges learner pool", - "kirkwood house", - "kohinoor", - "kymmoy", - "la margherita", - "la mimosa", - "la raza", - "la tasca", - "lan hong house", - "leicester train station", - "lensfield hotel", - "limehouse", - "little saint marys church", - "little seoul", - "loch fyne", - "london kings cross train station", - "london liverpool street train station", - "lovell lodge", - "lynne strover gallery", - "magdalene college", - "mahal of cambridge", - "maharajah tandoori restaurant", - "meghna", - "meze bar", - "michaelhouse cafe", - "midsummer house restaurant", - "milton country park", - "mumford theatre", - "museum of archaelogy and anthropology", - "museum of classical archaeology", - "nandos", - "nandos city centre", - "nil", - "nirala", - "norwich train station", - "nusha", - "old schools", - "panahar", - "parkside police station", - "parkside pools", - "peking restaurant", - "pembroke college", - "peoples portraits exhibition at girton college", - "peterborough train station", - "pipasha restaurant", - "pizza express", - "pizza hut cherry hinton", - "pizza hut city centre", - "pizza hut fenditton", - "prezzo", - "primavera", - "queens college", - "rajmahal", - "regency gallery", - "restaurant 17", - "restaurant 2 two", - "restaurant alimentum", - "rice boat", - "rice house", - "riverboat georgina", - "riverside brasserie", - "rosas bed and breakfast", - "royal spice", - "royal standard", - "ruskin gallery", - "saffron brasserie", - "saigon city", - "saint barnabas", - "saint barnabas press gallery", - "saint catharines college", - "saint johns chop house", - "saint johns college", - "sala thong", - "scott polar museum", - "scudamores punting co", - "sesame restaurant and bar", - "shanghai family restaurant", - "sheeps green and lammas land park fen causeway", - "shiraz", - "sidney sussex college", - "sitar tandoori", - "sleeperz hotel", - "soul tree nightclub", - "st johns chop house", - "stansted airport train station", - "station road", - "stazione restaurant and coffee bar", - "stevenage train station", - "taj tandoori", - "tall monument", - "tandoori palace", - "tang chinese", - "tenpin", - "thanh binh", - "the anatolia", - "the cambridge corn exchange", - "the cambridge shop", - "the fez club", - "the gandhi", - "the gardenia", - "the hotpot", - "the junction", - "the lucky star", - "the man on the moon", - "the missing sock", - "the oak bistro", - "the place", - "the regent street city center", - "the river bar steakhouse and grill", - "the slug and lettuce", - "the varsity restaurant", - "travellers rest", - "trinity college", - "ugly duckling", - "university arms hotel", - "vue cinema", - "wagamama", - "wandlebury country park", - "wankworth hotel", - "warkworth house", - "whale of a time", - "whipple museum of the history of science", - "williams art and antiques", - "worth house", - "yippee noodle bar", - "yu garden", - "zizzi cambridge", - "leverton house", - "the cambridge chop house", - "saint john's college", - "churchill college", - "the nirala", - "the cow pizza kitchen and bar", - "christ's college", - "el shaddai", - "saint catharine's college", - "camb", - "the golden curry", - "little saint mary's church", - "country folk museum", - "meze bar restaurant", - "the cambridge belfry", - "the fitzwilliam museum", - "the lensfield hotel", - "pizza express fen ditton", - "the cambridge punter", - "king's college", - "the cherry hinton village centre", - "shiraz restaurant", - "sheep's green and lammas land park fen causeway", - "caffe uno", - "the ghandi", - "the copper kettle", - "man on the moon concert hall", - "alpha-milton guest house", - "queen's college", - "restaurant one seven", - "restaurant two two", - "city centre north b and b", - "rosa's bed and breakfast", - "the good luck chinese food takeaway", - "not museum of archaeology and anthropologymentioned", - "tandori in cambridge", - "kettle's yard", - "megna", - "grou", - "gallery at twelve a high street", - "maharajah tandoori restaurant", - "pizza hut fen ditton", - "gandhi", - "tranh binh", - "kambur", - "people's portraits exhibition at girton college", - "hotel", - "restaurant", - "the galleria", - "queens' college", - "great saint mary's church", - "theathre", - "cambridge artworks", - "acorn house", - "shiraz", - "riverboat georginawd", - "mic", - "the gallery at twelve", - "the soul tree", - "finches" - ], - "taxi-departure": [ - "none", - "do not care", - "172 chestertown road", - "4455 woodbridge road", - "a and b guest house", - "abbey pool and astroturf pitch", - "acorn guest house", - "adc theatre", - "addenbrookes hospital", - "alexander bed and breakfast", - "ali baba", - "all saints church", - "allenbell", - "alpha milton guest house", - "alyesbray lodge hotel", - "ambridge", - "anatolia", - "arbury lodge guesthouse", - "archway house", - "ashley hotel", - "ask", - "autumn house", - "avalon", - "aylesbray lodge guest house", - "backstreet bistro", - "ballare", - "bangkok city", - "bedouin", - "birmingham new street train station", - "bishops stortford train station", - "bloomsbury restaurant", - "bridge guest house", - "broughton house gallery", - "broxbourne train station", - "byard art", - "cafe jello gallery", - "cafe uno", - "caffee uno", - "camboats", - "cambridge", - "cambridge and county folk museum", - "cambridge arts theatre", - "cambridge artworks", - "cambridge belfry", - "cambridge book and print gallery", - "cambridge chop house", - "cambridge contemporary art", - "cambridge lodge restaurant", - "cambridge museum of technology", - "cambridge punter", - "cambridge towninfo centre", - "cambridge train station", - "cambridge university botanic gardens", - "carolina bed and breakfast", - "castle galleries", - "centre of town at my hotel", - "charlie chan", - "cherry hinton hall and grounds", - "cherry hinton village center", - "cherry hinton village centre", - "cherry hinton water play", - "chiquito restaurant bar", - "christ college", - "churchills college", - "cineworld cinema", - "citiroomz", - "city centre north bed and breakfast", - "city stop restaurant", - "cityroomz", - "clair hall", - "clare college", - "clare hall", - "clowns cafe", - "club salsa", - "cocum", - "copper kettle", - "corpus christi", - "cote", - "cotto", - "cow pizza kitchen and bar", - "curry garden", - "curry king", - "curry prince", - "curry queen", - "da vinci pizzeria", - "darrys cookhouse and wine shop", - "de luca cucina and bar", - "dojo noodle bar", - "don pasquale pizzeria", - "downing college", - "downing street", - "el shaddia guesthouse", - "ely", - "ely train station", - "emmanuel college", - "eraina", - "express by holiday inn cambridge", - "finches bed and breakfast", - "fitzbillies restaurant", - "fitzwilliam museum", - "frankie and bennys", - "funky fun house", - "galleria", - "gallery at 12 a high street", - "girton college", - "golden curry", - "golden house", - "golden wok", - "gonville and caius college", - "gonville hotel", - "good luck", - "gourmet burger kitchen", - "graffiti", - "grafton hotel restaurant", - "great saint marys church", - "hakka", - "hamilton lodge", - "hobsons house", - "holy trinity church", - "home", - "home from home", - "hotel", - "hotel du vin and bistro", - "hughes hall", - "huntingdon marriott hotel", - "india house", - "j restaurant", - "jesus college", - "jesus green outdoor pool", - "jinling noodle bar", - "junction theatre", - "kambar", - "kettles yard", - "kings college", - "kings hedges learner pool", - "kings lynn train station", - "kirkwood house", - "kohinoor", - "kymmoy", - "la margherita", - "la mimosa", - "la raza", - "la tasca", - "lan hong house", - "lensfield hotel", - "leverton house", - "limehouse", - "little saint marys church", - "little seoul", - "loch fyne", - "london kings cross train station", - "london liverpool street", - "london liverpool street train station", - "lovell lodge", - "lynne strover gallery", - "magdalene college", - "mahal of cambridge", - "maharajah tandoori restaurant", - "meghna", - "meze bar", - "michaelhouse cafe", - "milton country park", - "mumford theatre", - "museum", - "museum of archaelogy and anthropology", - "museum of classical archaeology", - "nandos", - "nandos city centre", - "new england", - "nirala", - "norwich train station", - "nstaot mentioned", - "nusha", - "old schools", - "panahar", - "parkside police station", - "parkside pools", - "peking restaurant", - "pembroke college", - "peoples portraits exhibition at girton college", - "peterborough train station", - "pizza express", - "pizza hut cherry hinton", - "pizza hut city centre", - "pizza hut fenditton", - "prezzo", - "primavera", - "queens college", - "rajmahal", - "regency gallery", - "restaurant 17", - "restaurant 2 two", - "restaurant alimentum", - "rice boat", - "rice house", - "riverboat georgina", - "riverside brasserie", - "rosas bed and breakfast", - "royal spice", - "royal standard", - "ruskin gallery", - "saffron brasserie", - "saigon city", - "saint barnabas press gallery", - "saint catharines college", - "saint johns chop house", - "saint johns college", - "sala thong", - "scott polar museum", - "scudamores punting co", - "sesame restaurant and bar", - "sheeps green and lammas land park", - "sheeps green and lammas land park fen causeway", - "shiraz", - "sidney sussex college", - "sitar tandoori", - "soul tree nightclub", - "st johns college", - "stazione restaurant and coffee bar", - "stevenage train station", - "taj tandoori", - "tandoori palace", - "tang chinese", - "tenpin", - "thanh binh", - "the cambridge corn exchange", - "the fez club", - "the gallery at 12", - "the gandhi", - "the gardenia", - "the hotpot", - "the junction", - "the lucky star", - "the man on the moon", - "the missing sock", - "the oak bistro", - "the place", - "the river bar steakhouse and grill", - "the slug and lettuce", - "the varsity restaurant", - "travellers rest", - "trinity college", - "ugly duckling", - "university arms hotel", - "vue cinema", - "wagamama", - "wandlebury country park", - "warkworth house", - "whale of a time", - "whipple museum of the history of science", - "williams art and antiques", - "worth house", - "yippee noodle bar", - "yu garden", - "zizzi cambridge", - "christ's college", - "city centre north b and b", - "the lensfield hotel", - "alpha-milton guest house", - "el shaddai", - "churchill college", - "the cambridge belfry", - "king's college", - "great saint mary's church", - "restaurant two two", - "queens' college", - "little saint mary's church", - "chinese city centre", - "kettle's yard", - "pizza hut", - "the golden curry", - "rosa's bed and breakfast", - "the cambridge punter", - "the byard art museum", - "saint catharine's college", - "meze bar restaurant", - "the good luck chinese food takeaway", - "restaurant one seven", - "pizza hut fen ditton", - "the nirala", - "the fitzwilliam museum", - "st. john's college", - "gallery at twelve a high street", - "sheep's green and lammas land park fen causeway", - "the cherry hinton village centre", - "pizza express fen ditton", - "corpus cristi", - "cas", - "acorn house", - "lens", - "the cambridge chop house", - "the copper kettle", - "the avalon", - "saint john's college", - "aylesbray lodge", - "the alexander bed and breakfast", - "cambridge belfy", - "people's portraits exhibition at girton college", - "gonville", - "caffe uno", - "the cow pizza kitchen and bar", - "lovell ldoge", - "cinema", - "shiraz restaurant", - "park", - "the allenbell" - ], - "restaurant-book day": [ - "none", - "do not care", - "friday", - "monday", - "saterday", - "sunday", - "thursday", - "tuesday", - "wednesday" - ], - "restaurant-book people": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "restaurant-book time": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55" - ], - "taxi-arrive by": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55" - ], - "restaurant-area": [ - "none", - "do not care", - "centre", - "east", - "north", - "south", - "west" - ], - "hotel-area": [ - "none", - "do not care", - "centre", - "east", - "north", - "south", - "west" - ], - "attraction-area": [ - "none", - "do not care", - "centre", - "east", - "north", - "south", - "west" - ] -} \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json deleted file mode 100644 index b0dd00fd..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json +++ /dev/null @@ -1,3128 +0,0 @@ -{ - "hotel-price range": [ - "none", - "do not care", - "cheap", - "expensive", - "moderate", - "request" - ], - "hotel-type": [ - "none", - "do not care", - "bed and breakfast", - "guest house", - "hotel", - "request" - ], - "hotel-parking": [ - "none", - "do not care", - "no", - "yes", - "request" - ], - "hotel-book day": [ - "none", - "do not care", - "friday", - "monday", - "saterday", - "sunday", - "thursday", - "tuesday", - "wednesday" - ], - "hotel-book people": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "hotel-book stay": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "train-destination": [ - "none", - "do not care", - "bishops stortford", - "kings lynn", - "london liverpool street", - "centre", - "bishop stortford", - "liverpool", - "leicester", - "broxbourne", - "gourmet burger kitchen", - "copper kettle", - "bournemouth", - "stevenage", - "liverpool street", - "norwich", - "huntingdon marriott hotel", - "city centre north", - "taj tandoori", - "the copper kettle", - "peterborough", - "ely", - "lecester", - "london", - "willi", - "stansted airport", - "huntington marriott", - "cambridge", - "gonv", - "glastonbury", - "hol", - "north", - "birmingham new street", - "norway", - "petersborough", - "london kings cross", - "curry prince", - "bishops storford" - ], - "train-arrive by": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55", - "request" - ], - "train-departure": [ - "none", - "do not care", - "bishops stortford", - "kings lynn", - "brookshite", - "london liverpool street", - "cam", - "liverpool", - "bro", - "leicester", - "broxbourne", - "norwhich", - "saint johns", - "stevenage", - "stansted", - "london liverpool", - "cambrid", - "city hall", - "rosas bed and breakfast", - "alpha-milton", - "wandlebury country park", - "norwich", - "liecester", - "stratford", - "peterborough", - "duxford", - "ely", - "london", - "stansted airport", - "lon", - "cambridge", - "panahar", - "cineworld", - "leicaster", - "birmingham", - "cafe uno", - "camboats", - "huntingdon", - "birmingham new street", - "arbu", - "alpha milton", - "east london", - "london kings cross", - "hamilton lodge", - "aylesbray lodge guest", - "el shaddai" - ], - "train-day": [ - "none", - "do not care", - "friday", - "monday", - "saterday", - "sunday", - "thursday", - "tuesday", - "wednesday" - ], - "train-book people": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "hotel-stars": [ - "none", - "do not care", - "0", - "1", - "2", - "3", - "4", - "5", - "request" - ], - "hotel-internet": [ - "none", - "do not care", - "no", - "yes", - "request" - ], - "hotel-name": [ - "none", - "do not care", - "a and b guest house", - "city roomz", - "carolina bed and breakfast", - "limehouse", - "anatolia", - "hamilton lodge", - "the lensfield hotel", - "rosa's bed and breakfast", - "gall", - "aylesbray lodge", - "kirkwood", - "cambridge belfry", - "warkworth house", - "gonville", - "belfy hotel", - "nus", - "alexander", - "super 5", - "aylesbray lodge guest house", - "the gonvile hotel", - "allenbell", - "nothamilton lodge", - "ashley hotel", - "autumn house", - "hobsons house", - "hotel", - "ashely hotel", - "caridge belfrey", - "el shaddia guest house", - "avalon", - "cote", - "city centre north bed and breakfast", - "the cambridge belfry", - "home from home", - "wandlebury coutn", - "wankworth house", - "city stop rest", - "the worth house", - "cityroomz", - "huntingdon marriottt hotel", - "lensfield", - "rosas bed and breakfast", - "leverton house", - "gonville hotel", - "holiday inn cambridge", - "archway house", - "lan hon", - "levert", - "acorn guest house", - "cambridge", - "the ashley hotel", - "el shaddai", - "sleeperz", - "alpha milton guest house", - "doubletree by hilton cambridge", - "tandoori palace", - "express by", - "express by holiday inn cambridge", - "north bed and breakfast", - "holiday inn", - "arbury lodge guest house", - "alexander bed and breakfast", - "huntingdon marriott hotel", - "royal spice", - "sou", - "finches bed and breakfast", - "the alpha milton", - "bridge guest house", - "the acorn guest house", - "kirkwood house", - "eraina", - "la margherit", - "lensfield hotel", - "marriott hotel", - "nusha", - "city centre bed and breakfast", - "the allenbell", - "university arms hotel", - "clare", - "cherr", - "wartworth", - "acorn place", - "lovell lodge", - "whale" - ], - "train-leave at": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55", - "request" - ], - "restaurant-price range": [ - "none", - "do not care", - "cheap", - "expensive", - "moderate", - "request" - ], - "restaurant-food": [ - "none", - "do not care", - "british food", - "steakhouse", - "turkish", - "sushi", - "north american", - "scottish", - "french", - "austrian", - "korean", - "eastern european", - "swedish", - "gastro pub", - "modern eclectic", - "afternoon tea", - "welsh", - "christmas", - "tuscan", - "gastropub", - "sri lankan", - "molecular gastronomy", - "traditional american", - "italian", - "pizza", - "thai", - "south african", - "creative", - "english", - "asian", - "lebanese", - "hungarian", - "halal", - "portugese", - "modern english", - "african", - "light bites", - "malaysian", - "venetian", - "traditional", - "chinese", - "vegetarian", - "persian", - "thai and chinese", - "scandinavian", - "catalan", - "polynesian", - "crossover", - "canapes", - "cantonese", - "north african", - "seafood", - "brazilian", - "south indian", - "australasian", - "belgian", - "barbeque", - "the americas", - "indonesian", - "singaporean", - "irish", - "middle eastern", - "dojo noodle bar", - "caribbean", - "vietnamese", - "modern european", - "russian", - "german", - "world", - "japanese", - "moroccan", - "modern global", - "indian", - "british", - "american", - "danish", - "panasian", - "swiss", - "basque", - "north indian", - "modern american", - "australian", - "european", - "corsica", - "greek", - "northern european", - "mediterranean", - "portuguese", - "romanian", - "jamaican", - "polish", - "international", - "unusual", - "latin american", - "asian oriental", - "mexican", - "bistro", - "cuban", - "fusion", - "new zealand", - "spanish", - "eritrean", - "afghan", - "kosher", - "request" - ], - "attraction-name": [ - "none", - "do not care", - "downing college", - "fitzwilliam", - "clare college", - "ruskin gallery", - "sidney sussex college", - "great saint mary's church", - "cherry hinton water play park", - "wandlebury country park", - "cafe uno", - "place", - "broughton", - "cineworld cinema", - "jesus college", - "vue cinema", - "history of science museum", - "mumford theatre", - "whale of time", - "fitzbillies", - "christs church", - "churchill college", - "museum of classical archaeology", - "gonville and caius college", - "pizza", - "kirkwood", - "saint catharines college", - "kings college", - "parkside", - "by", - "st catharines college", - "saint john's college", - "cherry hinton water park", - "st christs college", - "christ's college", - "bangkok city", - "scudamores punti co", - "free", - "great saint marys church", - "milton country park", - "the fez club", - "soultree", - "autu", - "whipple museum of the history of science", - "aylesbray lodge guest house", - "broughton house gallery", - "peoples portraits exhibition", - "primavera", - "kettles yard", - "all saint's church", - "cinema cinema", - "regency gallery", - "corpus christi", - "corn cambridge exchange", - "da vinci pizzeria", - "school", - "hobsons house", - "cambride and country folk museum", - "north", - "da v", - "cambridge corn exchange", - "soul tree nightclub", - "cambridge arts theater", - "saint catharine's college", - "byard art", - "cambridge punter", - "cambridge university botanic gardens", - "castle galleries", - "museum of archaelogy and anthropogy", - "no specific location", - "cherry hinton hall", - "gallery at 12 a high street", - "parkside pools", - "queen's college", - "little saint mary's church", - "gallery", - "home from home", - "tenpin", - "the wandlebury", - "county folk museum", - "swimming pool", - "christs college", - "cafe jello museum", - "scott polar", - "christ college", - "cambridge museum of technology", - "abbey pool and astroturf pitch", - "king hedges learner pool", - "the cambridge arts theatre", - "the castle galleries", - "cambridge and country folk museum", - "kohinoor", - "scudamores punting co", - "sidney sussex", - "the man on the moon", - "little saint marys church", - "queens", - "the place", - "old school", - "churchill", - "churchills college", - "hughes hall", - "churchhill college", - "riverboat georgina", - "belf", - "cambridge temporary art", - "abc theatre", - "cambridge contemporary art museum", - "man on the moon", - "the junction", - "cherry hinton water play", - "adc theatre", - "gonville hotel", - "magdalene college", - "peoples portraits exhibition at girton college", - "boat", - "centre", - "sheep's green and lammas land park fen causeway", - "the mumford theatre", - "archway house", - "queens' college", - "williams art and antiques", - "funky fun house", - "cherry hinton village centre", - "camboats", - "cambridge", - "old schools", - "kettle's yard", - "whale of a time", - "the churchill college", - "cafe jello gallery", - "aut", - "salsa", - "city", - "clare hall", - "boating", - "pembroke college", - "kings hedges learner pool", - "caffe uno", - "lammas land park", - "museum", - "the fitzwilliam museum", - "the cherry hinton village centre", - "the cambridge corn exchange", - "fitzwilliam museum", - "museum of archaelogy and anthropology", - "fez club", - "the cambridge punter", - "saint johns college", - "emmanuel college", - "cambridge belf", - "scudamore", - "lynne strover gallery", - "king's college", - "whippple museum", - "trinity college", - "college in the north", - "sheep's green", - "kambar", - "museum of archaelogy", - "adc", - "garde", - "club salsa", - "people's portraits exhibition at girton college", - "botanic gardens", - "carol", - "college", - "gallery at twelve a high street", - "abbey pool and astroturf", - "cambridge book and print gallery", - "jesus green outdoor pool", - "scott polar museum", - "saint barnabas press gallery", - "cambridge artworks", - "older churches", - "cambridge contemporary art", - "cherry hinton hall and grounds", - "univ", - "jesus green", - "ballare", - "abbey pool", - "cambridge botanic gardens", - "nusha", - "worth house", - "thanh", - "university arms hotel", - "cambridge arts theatre", - "cafe jello", - "cambridge and county folk museum", - "the cambridge artworks", - "all saints church", - "holy trinity church", - "contemporary art museum", - "architectural churches", - "queens college", - "trinity street college" - ], - "restaurant-name": [ - "none", - "do not care", - "hotel du vin and bistro", - "ask", - "gourmet formal kitchen", - "the meze bar", - "lan hong house", - "cow pizza", - "one seven", - "prezzo", - "maharajah tandoori restaurant", - "alex", - "shanghai", - "golden wok", - "restaurant", - "fitzbillies", - "nil", - "copper kettle", - "meghna", - "hk fusion", - "bangkok city", - "hobsons house", - "tang chinese", - "anatolia", - "ugly duckling", - "anatolia and efes restaurant", - "sitar tandoori", - "city stop", - "ashley", - "pizza express fen ditton", - "molecular gastronomy", - "autumn house", - "el shaddia guesthouse", - "the grafton hotel", - "limehouse", - "gardenia", - "not metioned", - "hakka", - "michaelhouse cafe", - "pipasha", - "meze bar", - "archway", - "molecular gastonomy", - "yipee noodle bar", - "the peking", - "curry prince", - "midsummer house restaurant", - "pizza hut cherry hinton", - "the lucky star", - "stazione restaurant and coffee bar", - "shanghi family restaurant", - "good luck", - "j restaurant", - "bedouin", - "cott", - "little seoul", - "south", - "thanh binh", - "el", - "efes restaurant", - "kohinoor", - "clowns", - "india", - "the slug and lettuce", - "shiraz", - "barbakan", - "zizzi cambridge", - "restaurant one seven", - "slug and lettuce", - "travellers rest", - "binh", - "worth house", - "broughton house gallery", - "chiquito", - "the river bar steakhouse and grill", - "ros", - "golden house", - "india west", - "cam", - "panahar", - "restaurant 22", - "adden", - "indian", - "hu", - "jinling noodle bar", - "darrys cookhouse and wine shop", - "hobson house", - "cambridge be", - "el shaddai", - "ac", - "nandos", - "cambridge lodge", - "the cow pizza kitchen and bar", - "charlie", - "rajmahal", - "kymmoy", - "cambri", - "backstreet bistro", - "galleria", - "restaurant 2 two", - "chiquito restaurant bar", - "royal standard", - "lucky star", - "curry king", - "grafton hotel restaurant", - "mahal of cambridge", - "the bedouin", - "nus", - "the kohinoor", - "pizza hut fenditton", - "camboats", - "the gardenia", - "de luca cucina and bar", - "nusha", - "european", - "taj tandoori", - "tandoori palace", - "golden curry", - "efes", - "loch fyne", - "the maharajah tandoor", - "lovel", - "restaurant 17", - "clowns cafe", - "cambridge punter", - "bloomsbury restaurant", - "la mimosa", - "the cambridge chop house", - "funky", - "cotto", - "oak bistro", - "restaurant two two", - "pipasha restaurant", - "river bar steakhouse and grill", - "royal spice", - "the copper kettle", - "graffiti", - "nandos city centre", - "saffron brasserie", - "cambridge chop house", - "sitar", - "kitchen and bar", - "the good luck chinese food takeaway", - "clu", - "la tasca", - "cafe uno", - "cote", - "the varsity restaurant", - "bri", - "eraina", - "bridge", - "fin", - "cambridge lodge restaurant", - "grafton", - "hotpot", - "sala thong", - "margherita", - "wise buddha", - "the missing sock", - "seasame restaurant and bar", - "the dojo noodle bar", - "restaurant alimentum", - "gastropub", - "saigon city", - "la margherita", - "pizza hut", - "curry garden", - "ashley hotel", - "eraina and michaelhouse cafe", - "the golden curry", - "curry queen", - "cow pizza kitchen and bar", - "the peking restaurant:", - "hamilton lodge", - "alimentum", - "yippee noodle bar", - "2 two and cote", - "shanghai family restaurant", - "grafton hotel", - "yes", - "ali baba", - "dif", - "fitzbillies restaurant", - "peking restaurant", - "lev", - "nirala", - "the alex", - "tandoori", - "city stop restaurant", - "rice house", - "cityr", - "yu garden", - "meze bar restaurant", - "the", - "don pasquale pizzeria", - "rice boat", - "the hotpot", - "old school", - "the oak bistro", - "sesame restaurant and bar", - "pizza express", - "the gandhi", - "pizza hut fen ditton", - "charlie chan", - "da vinci pizzeria", - "dojo noodle bar", - "gourmet burger kitchen", - "the golden house", - "india house", - "hobso", - "missing sock", - "pizza hut city centre", - "parkside pools", - "riverside brasserie", - "caffe uno", - "primavera", - "the nirala", - "wagamama", - "au", - "ian hong house", - "frankie and bennys", - "4 kings parade city centre", - "shiraz restaurant", - "scudamores punt", - "mahal", - "saint johns chop house", - "de luca cucina and bar riverside brasserie", - "cocum", - "la raza" - ], - "attraction-type": [ - "none", - "do not care", - "architecture", - "boat", - "boating", - "camboats", - "church", - "churchills college", - "cinema", - "college", - "concert", - "concerthall", - "entertainment", - "gallery", - "gastropub", - "hiking", - "hotel", - "multiple sports", - "museum", - "museum kettles yard", - "night club", - "outdoor", - "park", - "pool", - "special", - "sports", - "swimming pool", - "theater", - "theatre", - "concert hall", - "local site", - "nightclub", - "hotspot", - "request" - ], - "taxi-leave at": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55", - "request" - ], - "taxi-destination": [ - "none", - "do not care", - "a and b guest house", - "abbey pool and astroturf pitch", - "acorn guest house", - "adc theatre", - "addenbrookes hospital", - "alexander bed and breakfast", - "ali baba", - "all saints church", - "allenbell", - "alpha milton guest house", - "anatolia", - "arbury lodge guesthouse", - "archway house", - "ashley hotel", - "ask", - "attraction", - "autumn house", - "avalon", - "aylesbray lodge guest house", - "backstreet bistro", - "ballare", - "bangkok city", - "bedouin", - "birmingham new street train station", - "bishops stortford train station", - "bloomsbury restaurant", - "bridge guest house", - "broughton house gallery", - "broxbourne train station", - "byard art", - "cafe jello gallery", - "cafe uno", - "camboats", - "cambridge", - "cambridge and county folk museum", - "cambridge arts theatre", - "cambridge artworks", - "cambridge belfry", - "cambridge book and print gallery", - "cambridge chop house", - "cambridge contemporary art", - "cambridge county fair next to the city tourist museum", - "cambridge lodge restaurant", - "cambridge museum of technology", - "cambridge punter", - "cambridge road church of christ", - "cambridge train station", - "cambridge university botanic gardens", - "carolina bed and breakfast", - "castle galleries", - "charlie chan", - "cherry hinton hall and grounds", - "cherry hinton village centre", - "cherry hinton water park", - "cherry hinton water play", - "chiquito restaurant bar", - "christ college", - "churchills college", - "cineworld cinema", - "city centre north bed and breakfast", - "city stop restaurant", - "cityroomz", - "clare college", - "clare hall", - "clowns cafe", - "club salsa", - "cocum", - "copper kettle", - "corpus christi", - "cote", - "cotto", - "cow pizza kitchen and bar", - "curry garden", - "curry king", - "curry prince", - "da vinci pizzeria", - "darrys cookhouse and wine shop", - "de luca cucina and bar", - "dojo noodle bar", - "don pasquale pizzeria", - "downing college", - "efes restaurant", - "el shaddia guesthouse", - "ely train station", - "emmanuel college", - "eraina", - "express by holiday inn cambridge", - "finches bed and breakfast", - "finders corner newmarket road", - "fitzbillies restaurant", - "fitzwilliam museum", - "frankie and bennys", - "funky fun house", - "galleria", - "gallery at 12 a high street", - "gastropub", - "golden curry", - "golden house", - "golden wok", - "gonville and caius college", - "gonville hotel", - "good luck", - "gourmet burger kitchen", - "graffiti", - "grafton hotel restaurant", - "great saint marys church", - "hakka", - "hamilton lodge", - "hk fusion", - "hobsons house", - "holy trinity church", - "home from home", - "hotel du vin and bistro", - "hughes hall", - "huntingdon marriott hotel", - "ian hong", - "india house", - "j restaurant", - "jesus college", - "jesus green outdoor pool", - "jinling noodle bar", - "kambar", - "kettles yard", - "kings college", - "kings hedges learner pool", - "kirkwood house", - "kohinoor", - "kymmoy", - "la margherita", - "la mimosa", - "la raza", - "la tasca", - "lan hong house", - "leicester train station", - "lensfield hotel", - "limehouse", - "little saint marys church", - "little seoul", - "loch fyne", - "london kings cross train station", - "london liverpool street train station", - "lovell lodge", - "lynne strover gallery", - "magdalene college", - "mahal of cambridge", - "maharajah tandoori restaurant", - "meghna", - "meze bar", - "michaelhouse cafe", - "midsummer house restaurant", - "milton country park", - "mumford theatre", - "museum of archaelogy and anthropology", - "museum of classical archaeology", - "nandos", - "nandos city centre", - "nil", - "nirala", - "norwich train station", - "nusha", - "old schools", - "panahar", - "parkside police station", - "parkside pools", - "peking restaurant", - "pembroke college", - "peoples portraits exhibition at girton college", - "peterborough train station", - "pipasha restaurant", - "pizza express", - "pizza hut cherry hinton", - "pizza hut city centre", - "pizza hut fenditton", - "prezzo", - "primavera", - "queens college", - "rajmahal", - "regency gallery", - "restaurant 17", - "restaurant 2 two", - "restaurant alimentum", - "rice boat", - "rice house", - "riverboat georgina", - "riverside brasserie", - "rosas bed and breakfast", - "royal spice", - "royal standard", - "ruskin gallery", - "saffron brasserie", - "saigon city", - "saint barnabas", - "saint barnabas press gallery", - "saint catharines college", - "saint johns chop house", - "saint johns college", - "sala thong", - "scott polar museum", - "scudamores punting co", - "sesame restaurant and bar", - "shanghai family restaurant", - "sheeps green and lammas land park fen causeway", - "shiraz", - "sidney sussex college", - "sitar tandoori", - "sleeperz hotel", - "soul tree nightclub", - "st johns chop house", - "stansted airport train station", - "station road", - "stazione restaurant and coffee bar", - "stevenage train station", - "taj tandoori", - "tall monument", - "tandoori palace", - "tang chinese", - "tenpin", - "thanh binh", - "the anatolia", - "the cambridge corn exchange", - "the cambridge shop", - "the fez club", - "the gandhi", - "the gardenia", - "the hotpot", - "the junction", - "the lucky star", - "the man on the moon", - "the missing sock", - "the oak bistro", - "the place", - "the regent street city center", - "the river bar steakhouse and grill", - "the slug and lettuce", - "the varsity restaurant", - "travellers rest", - "trinity college", - "ugly duckling", - "university arms hotel", - "vue cinema", - "wagamama", - "wandlebury country park", - "wankworth hotel", - "warkworth house", - "whale of a time", - "whipple museum of the history of science", - "williams art and antiques", - "worth house", - "yippee noodle bar", - "yu garden", - "zizzi cambridge", - "leverton house", - "the cambridge chop house", - "saint john's college", - "churchill college", - "the nirala", - "the cow pizza kitchen and bar", - "christ's college", - "el shaddai", - "saint catharine's college", - "camb", - "the golden curry", - "little saint mary's church", - "country folk museum", - "meze bar restaurant", - "the cambridge belfry", - "the fitzwilliam museum", - "the lensfield hotel", - "pizza express fen ditton", - "the cambridge punter", - "king's college", - "the cherry hinton village centre", - "shiraz restaurant", - "sheep's green and lammas land park fen causeway", - "caffe uno", - "the ghandi", - "the copper kettle", - "man on the moon concert hall", - "alpha-milton guest house", - "queen's college", - "restaurant one seven", - "restaurant two two", - "city centre north b and b", - "rosa's bed and breakfast", - "the good luck chinese food takeaway", - "not museum of archaeology and anthropologymentioned", - "tandori in cambridge", - "kettle's yard", - "megna", - "grou", - "gallery at twelve a high street", - "maharajah tandoori restaurant", - "pizza hut fen ditton", - "gandhi", - "tranh binh", - "kambur", - "people's portraits exhibition at girton college", - "hotel", - "restaurant", - "the galleria", - "queens' college", - "great saint mary's church", - "theathre", - "cambridge artworks", - "acorn house", - "shiraz", - "riverboat georginawd", - "mic", - "the gallery at twelve", - "the soul tree", - "finches" - ], - "taxi-departure": [ - "none", - "do not care", - "172 chestertown road", - "4455 woodbridge road", - "a and b guest house", - "abbey pool and astroturf pitch", - "acorn guest house", - "adc theatre", - "addenbrookes hospital", - "alexander bed and breakfast", - "ali baba", - "all saints church", - "allenbell", - "alpha milton guest house", - "alyesbray lodge hotel", - "ambridge", - "anatolia", - "arbury lodge guesthouse", - "archway house", - "ashley hotel", - "ask", - "autumn house", - "avalon", - "aylesbray lodge guest house", - "backstreet bistro", - "ballare", - "bangkok city", - "bedouin", - "birmingham new street train station", - "bishops stortford train station", - "bloomsbury restaurant", - "bridge guest house", - "broughton house gallery", - "broxbourne train station", - "byard art", - "cafe jello gallery", - "cafe uno", - "caffee uno", - "camboats", - "cambridge", - "cambridge and county folk museum", - "cambridge arts theatre", - "cambridge artworks", - "cambridge belfry", - "cambridge book and print gallery", - "cambridge chop house", - "cambridge contemporary art", - "cambridge lodge restaurant", - "cambridge museum of technology", - "cambridge punter", - "cambridge towninfo centre", - "cambridge train station", - "cambridge university botanic gardens", - "carolina bed and breakfast", - "castle galleries", - "centre of town at my hotel", - "charlie chan", - "cherry hinton hall and grounds", - "cherry hinton village center", - "cherry hinton village centre", - "cherry hinton water play", - "chiquito restaurant bar", - "christ college", - "churchills college", - "cineworld cinema", - "citiroomz", - "city centre north bed and breakfast", - "city stop restaurant", - "cityroomz", - "clair hall", - "clare college", - "clare hall", - "clowns cafe", - "club salsa", - "cocum", - "copper kettle", - "corpus christi", - "cote", - "cotto", - "cow pizza kitchen and bar", - "curry garden", - "curry king", - "curry prince", - "curry queen", - "da vinci pizzeria", - "darrys cookhouse and wine shop", - "de luca cucina and bar", - "dojo noodle bar", - "don pasquale pizzeria", - "downing college", - "downing street", - "el shaddia guesthouse", - "ely", - "ely train station", - "emmanuel college", - "eraina", - "express by holiday inn cambridge", - "finches bed and breakfast", - "fitzbillies restaurant", - "fitzwilliam museum", - "frankie and bennys", - "funky fun house", - "galleria", - "gallery at 12 a high street", - "girton college", - "golden curry", - "golden house", - "golden wok", - "gonville and caius college", - "gonville hotel", - "good luck", - "gourmet burger kitchen", - "graffiti", - "grafton hotel restaurant", - "great saint marys church", - "hakka", - "hamilton lodge", - "hobsons house", - "holy trinity church", - "home", - "home from home", - "hotel", - "hotel du vin and bistro", - "hughes hall", - "huntingdon marriott hotel", - "india house", - "j restaurant", - "jesus college", - "jesus green outdoor pool", - "jinling noodle bar", - "junction theatre", - "kambar", - "kettles yard", - "kings college", - "kings hedges learner pool", - "kings lynn train station", - "kirkwood house", - "kohinoor", - "kymmoy", - "la margherita", - "la mimosa", - "la raza", - "la tasca", - "lan hong house", - "lensfield hotel", - "leverton house", - "limehouse", - "little saint marys church", - "little seoul", - "loch fyne", - "london kings cross train station", - "london liverpool street", - "london liverpool street train station", - "lovell lodge", - "lynne strover gallery", - "magdalene college", - "mahal of cambridge", - "maharajah tandoori restaurant", - "meghna", - "meze bar", - "michaelhouse cafe", - "milton country park", - "mumford theatre", - "museum", - "museum of archaelogy and anthropology", - "museum of classical archaeology", - "nandos", - "nandos city centre", - "new england", - "nirala", - "norwich train station", - "nstaot mentioned", - "nusha", - "old schools", - "panahar", - "parkside police station", - "parkside pools", - "peking restaurant", - "pembroke college", - "peoples portraits exhibition at girton college", - "peterborough train station", - "pizza express", - "pizza hut cherry hinton", - "pizza hut city centre", - "pizza hut fenditton", - "prezzo", - "primavera", - "queens college", - "rajmahal", - "regency gallery", - "restaurant 17", - "restaurant 2 two", - "restaurant alimentum", - "rice boat", - "rice house", - "riverboat georgina", - "riverside brasserie", - "rosas bed and breakfast", - "royal spice", - "royal standard", - "ruskin gallery", - "saffron brasserie", - "saigon city", - "saint barnabas press gallery", - "saint catharines college", - "saint johns chop house", - "saint johns college", - "sala thong", - "scott polar museum", - "scudamores punting co", - "sesame restaurant and bar", - "sheeps green and lammas land park", - "sheeps green and lammas land park fen causeway", - "shiraz", - "sidney sussex college", - "sitar tandoori", - "soul tree nightclub", - "st johns college", - "stazione restaurant and coffee bar", - "stevenage train station", - "taj tandoori", - "tandoori palace", - "tang chinese", - "tenpin", - "thanh binh", - "the cambridge corn exchange", - "the fez club", - "the gallery at 12", - "the gandhi", - "the gardenia", - "the hotpot", - "the junction", - "the lucky star", - "the man on the moon", - "the missing sock", - "the oak bistro", - "the place", - "the river bar steakhouse and grill", - "the slug and lettuce", - "the varsity restaurant", - "travellers rest", - "trinity college", - "ugly duckling", - "university arms hotel", - "vue cinema", - "wagamama", - "wandlebury country park", - "warkworth house", - "whale of a time", - "whipple museum of the history of science", - "williams art and antiques", - "worth house", - "yippee noodle bar", - "yu garden", - "zizzi cambridge", - "christ's college", - "city centre north b and b", - "the lensfield hotel", - "alpha-milton guest house", - "el shaddai", - "churchill college", - "the cambridge belfry", - "king's college", - "great saint mary's church", - "restaurant two two", - "queens' college", - "little saint mary's church", - "chinese city centre", - "kettle's yard", - "pizza hut", - "the golden curry", - "rosa's bed and breakfast", - "the cambridge punter", - "the byard art museum", - "saint catharine's college", - "meze bar restaurant", - "the good luck chinese food takeaway", - "restaurant one seven", - "pizza hut fen ditton", - "the nirala", - "the fitzwilliam museum", - "st. john's college", - "gallery at twelve a high street", - "sheep's green and lammas land park fen causeway", - "the cherry hinton village centre", - "pizza express fen ditton", - "corpus cristi", - "cas", - "acorn house", - "lens", - "the cambridge chop house", - "the copper kettle", - "the avalon", - "saint john's college", - "aylesbray lodge", - "the alexander bed and breakfast", - "cambridge belfy", - "people's portraits exhibition at girton college", - "gonville", - "caffe uno", - "the cow pizza kitchen and bar", - "lovell ldoge", - "cinema", - "shiraz restaurant", - "park", - "the allenbell" - ], - "restaurant-book day": [ - "none", - "do not care", - "friday", - "monday", - "saterday", - "sunday", - "thursday", - "tuesday", - "wednesday" - ], - "restaurant-book people": [ - "none", - "do not care", - "1", - "10 or more", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9" - ], - "restaurant-book time": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55" - ], - "taxi-arrive by": [ - "none", - "do not care", - "00:00", - "00:05", - "00:10", - "00:15", - "00:20", - "00:25", - "00:30", - "00:35", - "00:40", - "00:45", - "00:50", - "00:55", - "01:00", - "01:05", - "01:10", - "01:15", - "01:20", - "01:25", - "01:30", - "01:35", - "01:40", - "01:45", - "01:50", - "01:55", - "02:00", - "02:05", - "02:10", - "02:15", - "02:20", - "02:25", - "02:30", - "02:35", - "02:40", - "02:45", - "02:50", - "02:55", - "03:00", - "03:05", - "03:10", - "03:15", - "03:20", - "03:25", - "03:30", - "03:35", - "03:40", - "03:45", - "03:50", - "03:55", - "04:00", - "04:05", - "04:10", - "04:15", - "04:20", - "04:25", - "04:30", - "04:35", - "04:40", - "04:45", - "04:50", - "04:55", - "05:00", - "05:05", - "05:10", - "05:15", - "05:20", - "05:25", - "05:30", - "05:35", - "05:40", - "05:45", - "05:50", - "05:55", - "06:00", - "06:05", - "06:10", - "06:15", - "06:20", - "06:25", - "06:30", - "06:35", - "06:40", - "06:45", - "06:50", - "06:55", - "07:00", - "07:05", - "07:10", - "07:15", - "07:20", - "07:25", - "07:30", - "07:35", - "07:40", - "07:45", - "07:50", - "07:55", - "08:00", - "08:05", - "08:10", - "08:15", - "08:20", - "08:25", - "08:30", - "08:35", - "08:40", - "08:45", - "08:50", - "08:55", - "09:00", - "09:05", - "09:10", - "09:15", - "09:20", - "09:25", - "09:30", - "09:35", - "09:40", - "09:45", - "09:50", - "09:55", - "10:00", - "10:05", - "10:10", - "10:15", - "10:20", - "10:25", - "10:30", - "10:35", - "10:40", - "10:45", - "10:50", - "10:55", - "11:00", - "11:05", - "11:10", - "11:15", - "11:20", - "11:25", - "11:30", - "11:35", - "11:40", - "11:45", - "11:50", - "11:55", - "12:00", - "12:05", - "12:10", - "12:15", - "12:20", - "12:25", - "12:30", - "12:35", - "12:40", - "12:45", - "12:50", - "12:55", - "13:00", - "13:05", - "13:10", - "13:15", - "13:20", - "13:25", - "13:30", - "13:35", - "13:40", - "13:45", - "13:50", - "13:55", - "14:00", - "14:05", - "14:10", - "14:15", - "14:20", - "14:25", - "14:30", - "14:35", - "14:40", - "14:45", - "14:50", - "14:55", - "15:00", - "15:05", - "15:10", - "15:15", - "15:20", - "15:25", - "15:30", - "15:35", - "15:40", - "15:45", - "15:50", - "15:55", - "16:00", - "16:05", - "16:10", - "16:15", - "16:20", - "16:25", - "16:30", - "16:35", - "16:40", - "16:45", - "16:50", - "16:55", - "17:00", - "17:05", - "17:10", - "17:15", - "17:20", - "17:25", - "17:30", - "17:35", - "17:40", - "17:45", - "17:50", - "17:55", - "18:00", - "18:05", - "18:10", - "18:15", - "18:20", - "18:25", - "18:30", - "18:35", - "18:40", - "18:45", - "18:50", - "18:55", - "19:00", - "19:05", - "19:10", - "19:15", - "19:20", - "19:25", - "19:30", - "19:35", - "19:40", - "19:45", - "19:50", - "19:55", - "20:00", - "20:05", - "20:10", - "20:15", - "20:20", - "20:25", - "20:30", - "20:35", - "20:40", - "20:45", - "20:50", - "20:55", - "21:00", - "21:05", - "21:10", - "21:15", - "21:20", - "21:25", - "21:30", - "21:35", - "21:40", - "21:45", - "21:50", - "21:55", - "22:00", - "22:05", - "22:10", - "22:15", - "22:20", - "22:25", - "22:30", - "22:35", - "22:40", - "22:45", - "22:50", - "22:55", - "23:00", - "23:05", - "23:10", - "23:15", - "23:20", - "23:25", - "23:30", - "23:35", - "23:40", - "23:45", - "23:50", - "23:55", - "request" - ], - "restaurant-area": [ - "none", - "do not care", - "centre", - "east", - "north", - "south", - "west", - "request" - ], - "hotel-area": [ - "none", - "do not care", - "centre", - "east", - "north", - "south", - "west", - "request" - ], - "attraction-area": [ - "none", - "do not care", - "centre", - "east", - "north", - "south", - "west", - "request" - ], - "hospital-department": [ - "none", - "do not care", - "acute medical assessment unit", - "acute medicine for the elderly", - "antenatal", - "cambridge eye unit", - "cardiology", - "cardiology and coronary care unit", - "childrens oncology and haematology", - "childrens surgical and medicine", - "clinical decisions unit", - "clinical research facility", - "coronary care unit", - "diabetes and endocrinology", - "emergency department", - "gastroenterology", - "gynaecology", - "haematology", - "haematology and haematological oncology", - "haematology day unit", - "hepatobillary and gastrointestinal surgery regional referral centre", - "hepatology", - "infectious diseases", - "infusion services", - "inpatient occupational therapy", - "intermediate dependancy area", - "john farman intensive care unit", - "medical decisions unit", - "medicine for the elderly", - "neonatal unit", - "neurology", - "neurology neurosurgery", - "neurosciences", - "neurosciences critical care unit", - "oncology", - "oral and maxillofacial surgery and ent", - "paediatric clinic", - "paediatric day unit", - "paediatric intensive care unit", - "plastic and vascular surgery plastics", - "psychiatry", - "respiratory medicine", - "surgery", - "teenage cancer trust unit", - "transitional care", - "transplant high dependency unit", - "trauma and orthopaedics", - "trauma high dependency unit", - "urology" - ], - "police-postcode": [ - "request" - ], - "restaurant-postcode": [ - "request" - ], - "train-duration": [ - "request" - ], - "train-trainid": [ - "request" - ], - "hospital-address": [ - "request" - ], - "restaurant-phone": [ - "request" - ], - "hotel-phone": [ - "request" - ], - "restaurant-address": [ - "request" - ], - "hotel-postcode": [ - "request" - ], - "attraction-phone": [ - "request" - ], - "attraction-entrance fee": [ - "request" - ], - "hotel-reference": [ - "request" - ], - "taxi-taxi types": [ - "request" - ], - "attraction-address": [ - "request" - ], - "hospital-phone": [ - "request" - ], - "attraction-postcode": [ - "request" - ], - "police-address": [ - "request" - ], - "taxi-taxi phone": [ - "request" - ], - "train-price": [ - "request" - ], - "hospital-postcode": [ - "request" - ], - "police-phone": [ - "request" - ], - "hotel-address": [ - "request" - ], - "restaurant-reference": [ - "request" - ], - "train-reference": [ - "request" - ] -} \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json deleted file mode 100644 index 87e31536..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "hotel-price range": "preferred cost or price of the hotel", - "hotel-type": "what is the type of the hotel", - "hotel-parking": "does the hotel have parking", - "hotel-book stay": "number of nights for the hotel reservation", - "hotel-book day": "starting day of the hotel booking", - "hotel-book people": "number of people for the hotel booking", - "hotel-area": "area or place of the hotel", - "hotel-stars": "star rating of the hotel", - "hotel-internet": "does the hotel have internet or wifi", - "hotel-name": "name of the hotel", - "hotel-phone": "phone number of the hotel", - "hotel-postcode": "postcode of the hotel", - "hotel-reference": "booking reference of the hotel booking", - "hotel-address": "street address of the hotel", - "train-destination": "train station you want to travel to", - "train-day": "day of the train booking", - "train-departure": "train station you want to leave from", - "train-arrive by": "arrival time of the train", - "train-book people": "number of people for the train booking", - "train-leave at": "departure time for the train", - "train-duration": "duration of the train journey", - "train-trainid": "train identifier or number", - "train-price": "how much does the train trip cost", - "train-reference": "booking reference of the train booking", - "attraction-type": "type of attraction or point of interest", - "attraction-area": "area or place of the attraction", - "attraction-name": "name of the attraction", - "attraction-phone": "phone number of the attraction", - "attraction-entrance fee": "entrace fee at the attraction", - "attraction-address": "street address of the attraction", - "attraction-postcode": "postcode of the attraction", - "restaurant-book people": "number of people for the restaurant booking", - "restaurant-book day": "weekday for the restaurant booking", - "restaurant-book time": "time of the restaurant booking", - "restaurant-food": "type of food served at the restaurant", - "restaurant-price range": "preferred cost or price of the restaurant", - "restaurant-name": "name of the restaurant", - "restaurant-area": "area or place of the restaurant", - "restaurant-postcode": "postcode of the restaurant", - "restaurant-phone": "phone number of the restaurant", - "restaurant-address": "street address of the restaurant", - "restaurant-reference": "booking reference of the hotel booking", - "taxi-leave at": "what time you want the taxi to leave by", - "taxi-destination": "where you want the taxi to drop you off", - "taxi-departure": "where you want the taxi to pick you up", - "taxi-arrive by": "what time you to arrive at your destination", - "taxi-taxi types": "vehicle type of the taxi", - "taxi-taxi phone": "phone number of the taxi", - "hospital-department": "name of hospital department", - "hospital-address": "street address of the hospital", - "hospital-phone": "phone number of the hospital", - "hospital-postcode": "postcode of the hospital", - "police-postcode": "postcode of the police station", - "police-address": "street address of the police station", - "police-phone": "phone number of the police station" -} \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py deleted file mode 100644 index c6b9c336..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Create Ontology Embeddings""" - -import json -import os -import random - -import torch -import numpy as np - - -# Slot mapping table for description extractions -# SLOT_NAME_MAPPINGS = { -# 'arrive at': 'arriveAt', -# 'arrive by': 'arriveBy', -# 'leave at': 'leaveAt', -# 'leave by': 'leaveBy', -# 'arriveby': 'arriveBy', -# 'arriveat': 'arriveAt', -# 'leaveat': 'leaveAt', -# 'leaveby': 'leaveBy', -# 'price range': 'pricerange' -# } - -# Set up global data directory -def set_datadir(dir): - global DATA_DIR - DATA_DIR = dir - - -# Set seeds -def set_seed(args): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - - -# Get embeddings for slots and candidates -def get_slot_candidate_embeddings(set_type, args, tokenizer, embedding_model, save_to_file=True): - # Get set alots and candidates - reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r') - ontology = json.load(reader) - reader.close() - - reader = open(os.path.join(DATA_DIR, 'slot_descriptions.json'), 'r') - slot_descriptions = json.load(reader) - reader.close() - - embedding_model.eval() - - slots = dict() - for slot in ontology: - if args.use_descriptions: - # d, s = slot.split('-', 1) - # s = SLOT_NAME_MAPPINGS[s] if s in SLOT_NAME_MAPPINGS else s - # s = d + '-' + s - # if slot in slot_descriptions: - desc = slot_descriptions[slot] - # elif slot.lower() in slot_descriptions: - # desc = slot_descriptions[s.lower()] - # else: - # desc = slot.replace('-', ' ') - else: - desc = slot - - # Tokenize slot and get embeddings - feats = tokenizer.encode_plus(desc, add_special_tokens = True, - max_length = args.max_slot_len, padding='max_length', - truncation = 'longest_first') - - with torch.no_grad(): - input_ids = torch.tensor([feats['input_ids']]).to(embedding_model.device) # [1, max_slot_len] - if 'token_type_ids' in feats: - token_type_ids = torch.tensor([feats['token_type_ids']]).to(embedding_model.device) # [1, max_slot_len] - if 'attention_mask' in feats: - attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) # [1, max_slot_len] - feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask) - attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) - feats = feats * attention_mask # [1, max_slot_len, hidden_dim] - else: - feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) - else: - if 'attention_mask' in feats: - attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) - feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask) - attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) - feats = feats * attention_mask # [1, max_slot_len, hidden_dim] - else: - feats, pooled_feats = embedding_model(input_ids=input_ids) # [1, max_slot_len, hidden_dim] - - if args.set_similarity: - slot_emb = feats[0, :, :].detach().cpu() # [seq_len, hidden_dim] - else: - if args.candidate_pooling == 'cls' and pooled_feats is not None: - slot_emb = pooled_feats[0, :].detach().cpu() # [hidden_dim] - elif args.candidate_pooling == 'mean': - feats = feats.sum(1) - feats = torch.nn.functional.layer_norm(feats, feats.size()) - slot_emb = feats[0, :].detach().cpu() # [hidden_dim] - - # Tokenize value candidates and get embeddings - values = ontology[slot] - is_requestable = False - if 'request' in values: - is_requestable = True - values.remove('request') - if values: - feats = [tokenizer.encode_plus(val, add_special_tokens = True, - max_length = args.max_candidate_len, padding='max_length', - truncation = 'longest_first') - for val in values] - with torch.no_grad(): - input_ids = torch.tensor([f['input_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len] - if 'token_type_ids' in feats[0]: - token_type_ids = torch.tensor([f['token_type_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len] - if 'attention_mask' in feats[0]: - attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len] - feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask) - attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) - feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim] - else: - feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) # [num_candidates, max_candidate_len, hidden_dim] - else: - if 'attention_mask' in feats[0]: - attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) - feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask) - attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) - feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim] - else: - feats, pooled_feats = embedding_model(input_ids=input_ids) # [num_candidates, max_candidate_len, hidden_dim] - - if args.set_similarity: - feats = feats.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim] - else: - if args.candidate_pooling == 'cls' and pooled_feats is not None: - feats = pooled_feats.detach().cpu() - elif args.candidate_pooling == "mean": - feats = feats.sum(1) - feats = torch.nn.functional.layer_norm(feats, feats.size()) - feats = feats.detach().cpu() - else: - feats = None - slots[slot] = (slot_emb, feats, is_requestable) - - # Dump tensors for use in training - if save_to_file: - writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type) - torch.save(slots, writer) - - return slots diff --git a/convlab/dst/setsumbt/multiwoz/dataset/utils.py b/convlab/dst/setsumbt/multiwoz/dataset/utils.py deleted file mode 100644 index 485dee64..00000000 --- a/convlab/dst/setsumbt/multiwoz/dataset/utils.py +++ /dev/null @@ -1,446 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf -# Code adapted from the TRADE preprocessing code (https://github.com/jasonwu0731/trade-dst) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""MultiWOZ2.1/3 data processing utilities""" - -import re -import os - -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -from convlab.dst.rule.multiwoz import normalize_value - -# ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train'] -ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train', 'hospital', 'police'] -def set_util_domains(domains): - global ACTIVE_DOMAINS - ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS] - -MAPPING_PATH = os.path.abspath(__file__).replace('utils.py', 'mapping.pair') -# Read replacement pairs from the mapping.pair file -REPLACEMENTS = [] -for line in open(MAPPING_PATH).readlines(): - tok_from, tok_to = line.replace('\n', '').split('\t') - REPLACEMENTS.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) - -# Extract belief state from mturk annotations -def build_dialoguestate(metadata, get_domains=False): - domains_list = [dom for dom in ACTIVE_DOMAINS if dom in metadata] - dialogue_state, domains = [], [] - for domain in domains_list: - active = False - # Extract booking information - booking = [] - for slot in sorted(metadata[domain]['book'].keys()): - if slot != 'booked': - if metadata[domain]['book'][slot] == 'not mentioned': - continue - if metadata[domain]['book'][slot] != '': - val = ['%s-book %s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['book'][slot])] - dialogue_state.append(val) - active = True - - for slot in metadata[domain]['semi']: - if metadata[domain]['semi'][slot] == 'not mentioned': - continue - elif metadata[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", 'don not care', - 'do not care', 'does not care']: - dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), 'do not care']) - active = True - elif metadata[domain]['semi'][slot]: - dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['semi'][slot])]) - active = True - - if active: - domains.append(domain) - - if get_domains: - return domains - return clean_dialoguestate(dialogue_state) - - -PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive'] -BOOLEAN = ['do not care', 'yes', 'no'] -DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday', - 'friday', 'saterday', 'sunday'] -QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more'] -TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)] -TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l] - -VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast', - 'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott', - 'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge', - 'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel', - 'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european', - 'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"} - -def map_values(value): - for old, new in VALUE_MAP.items(): - value = value.replace(old, new) - return value - -def clean_dialoguestate(states, is_acts=False): - # path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) - # path = os.path.join(path, 'data/multiwoz/value_dict.json') - # value_dict = json.load(open(path)) - clean_state = [] - for slot, value in states: - if 'pricerange' in slot: - d, s = slot.split('-', 1) - s = 'price range' - slot = f'{d}-{s}' - if value in PRICERANGE: - clean_state.append([slot, value]) - elif True in [v in value for v in PRICERANGE]: - value = [v for v in PRICERANGE if v in value][0] - clean_state.append([slot, value]) - elif value == '?' and is_acts: - clean_state.append([slot, value]) - else: - continue - elif 'parking' in slot or 'internet' in slot: - if value in BOOLEAN: - clean_state.append([slot, value]) - if value == 'free': - value = 'yes' - clean_state.append([slot, value]) - elif True in [v in value for v in BOOLEAN]: - value = [v for v in BOOLEAN if v in value][0] - clean_state.append([slot, value]) - elif value == '?' and is_acts: - clean_state.append([slot, value]) - else: - continue - elif 'day' in slot: - if value in DAYS: - clean_state.append([slot, value]) - elif True in [v in value for v in DAYS]: - value = [v for v in DAYS if v in value][0] - clean_state.append([slot, value]) - else: - continue - elif 'people' in slot or 'duration' in slot or 'stay' in slot: - if value in QUANTITIES: - clean_state.append([slot, value]) - elif True in [v in value for v in QUANTITIES]: - value = [v for v in QUANTITIES if v in value][0] - clean_state.append([slot, value]) - elif value == '?' and is_acts: - clean_state.append([slot, value]) - else: - try: - value = int(value) - if value >= 10: - value = '10 or more' - clean_state.append([slot, value]) - else: - continue - except: - continue - elif 'time' in slot or 'leaveat' in slot or 'arriveby' in slot: - if 'leaveat' in slot: - d, s = slot.split('-', 1) - s = 'leave at' - slot = f'{d}-{s}' - if 'arriveby' in slot: - d, s = slot.split('-', 1) - s = 'arrive by' - slot = f'{d}-{s}' - if value in TIME: - if value == 'do not care': - clean_state.append([slot, value]) - else: - h, m = value.split(':') - if int(m) % 5 == 0: - clean_state.append([slot, value]) - else: - m = round(int(m) / 5) * 5 - h = int(h) - if m == 60: - m = 0 - h += 1 - if h >= 24: - h -= 24 - value = '%02i:%02i' % (h, m) - clean_state.append([slot, value]) - elif True in [v in value for v in TIME]: - value = [v for v in TIME if v in value][0] - h, m = value.split(':') - if int(m) % 5 == 0: - clean_state.append([slot, value]) - else: - m = round(int(m) / 5) * 5 - h = int(h) - if m == 60: - m = 0 - h += 1 - if h >= 24: - h -= 24 - value = '%02i:%02i' % (h, m) - clean_state.append([slot, value]) - elif value == '?' and is_acts: - clean_state.append([slot, value]) - else: - continue - elif 'stars' in slot: - if len(value) == 1 or value == 'do not care': - clean_state.append([slot, value]) - elif value == '?' and is_acts: - clean_state.append([slot, value]) - elif len(value) > 1: - try: - value = int(value[0]) - value = str(value) - clean_state.append([slot, value]) - except: - continue - elif 'area' in slot: - if '|' in value: - value = value.split('|', 1)[0] - clean_state.append([slot, value]) - else: - if '|' in value: - value = value.split('|', 1)[0] - value = map_values(value) - # d, s = slot.split('-', 1) - # value = normalize_value(value_dict, d, s, value) - clean_state.append([slot, value]) - - return clean_state - - -# Module to process a dialogue and check its validity -def process_dialogue(dialogue, max_utt_len=128): - if len(dialogue['log']) % 2 != 0: - return None - - # Extract user and system utterances - usr_utts, sys_utts = [], [] - avg_len = sum(len(utt['text'].split(' ')) for utt in dialogue['log']) - avg_len = avg_len / len(dialogue['log']) - if avg_len > max_utt_len: - return None - - # If the first term is a system turn then ignore dialogue - if dialogue['log'][0]['metadata']: - return None - - usr, sys = None, None - for turn in dialogue['log']: - if not is_ascii(turn['text']): - return None - - if not usr or not sys: - if len(turn['metadata']) == 0: - usr = turn - else: - sys = turn - - if usr and sys: - states = build_dialoguestate(sys['metadata'], get_domains = False) - sys['dialogue_states'] = states - - usr_utts.append(usr) - sys_utts.append(sys) - usr, sys = None, None - - dial_clean = dict() - dial_clean['usr_log'] = usr_utts - dial_clean['sys_log'] = sys_utts - return dial_clean - - -# Get new domains -def get_act_domains(prev, crnt): - diff = {} - if not prev or not crnt: - return diff - - for ((prev_dom, prev_val), (crnt_dom, crnt_val)) in zip(prev.items(), crnt.items()): - assert prev_dom == crnt_dom - if prev_val != crnt_val: - diff[crnt_dom] = crnt_val - return diff - - -# Get current domains -def get_domains(dial_log, turn_id, prev_domain): - if turn_id == 1: - active = build_dialoguestate(dial_log[turn_id]['metadata'], get_domains=True) - acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else [] - acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']] - active += acts - crnt = active[0] if active else '' - else: - active = get_act_domains(dial_log[turn_id - 2]['metadata'], dial_log[turn_id]['metadata']) - active = list(active.keys()) - acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else [] - acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']] - active += acts - crnt = [prev_domain] if not active else active - crnt = crnt[0] - - return crnt - - -# Function to extract dialogue info from data -def extract_dialogue(dialogue, max_utt_len=50): - dialogue = process_dialogue(dialogue, max_utt_len) - if not dialogue: - return None - - usr_utts = [turn['text'] for turn in dialogue['usr_log']] - sys_utts = [turn['text'] for turn in dialogue['sys_log']] - # sys_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['sys_log']] - usr_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['usr_log']] - dialogue_states = [turn['dialogue_states'] for turn in dialogue['sys_log']] - domains = [turn['domain'] for turn in dialogue['usr_log']] - - # dial = [{'usr': u,'sys': s, 'usr_a': ua, 'sys_a': a, 'domain': d, 'ds': v} - # for u, s, ua, a, d, v in zip(usr_utts, sys_utts, usr_acts, sys_acts, domains, dialogue_states)] - dial = [{'usr': u,'sys': s, 'usr_a': ua, 'domain': d, 'ds': v} - for u, s, ua, d, v in zip(usr_utts, sys_utts, usr_acts, domains, dialogue_states)] - return dial - - -def format_acts(acts): - new_acts = [] - for key, item in acts.items(): - domain, intent = key.split('-', 1) - if domain.lower() in ACTIVE_DOMAINS + ['general']: - state = [] - for slot, value in item: - slot = str(REF_SYS_DA[domain].get(slot, slot)).lower() if domain in REF_SYS_DA else slot - value = clean_text(value) - slot = slot.replace('_', ' ').replace('ref', 'reference') - state.append([f'{domain.lower()}-{slot}', value]) - state = clean_dialoguestate(state, is_acts=True) - if domain == 'general': - if intent in ['thank', 'bye']: - state = [['general-none', 'none']] - else: - state = [] - for slot, value in state: - if slot not in ['train-people']: - slot = slot.split('-', 1)[-1] - new_acts.append([intent.lower(), domain.lower(), slot, value]) - - return new_acts - - -# Fix act labels -def fix_delexicalisation(turn): - if 'dialog_act' in turn: - for dom, act in turn['dialog_act'].items(): - if 'Attraction' in dom: - if 'restaurant_' in turn['text']: - turn['text'] = turn['text'].replace("restaurant", "attraction") - if 'hotel_' in turn['text']: - turn['text'] = turn['text'].replace("hotel", "attraction") - if 'Hotel' in dom: - if 'attraction_' in turn['text']: - turn['text'] = turn['text'].replace("attraction", "hotel") - if 'restaurant_' in turn['text']: - turn['text'] = turn['text'].replace("restaurant", "hotel") - if 'Restaurant' in dom: - if 'attraction_' in turn['text']: - turn['text'] = turn['text'].replace("attraction", "restaurant") - if 'hotel_' in turn['text']: - turn['text'] = turn['text'].replace("hotel", "restaurant") - - return turn - - -# Check if a character is an ascii character -def is_ascii(s): - return all(ord(c) < 128 for c in s) - - -# Insert white space -def separate_token(token, text): - sidx = 0 - while True: - # Find next instance of token - sidx = text.find(token, sidx) - if sidx == -1: - break - # If the token is already seperated continue to next - if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ - re.match('[0-9]', text[sidx + 1]): - sidx += 1 - continue - # Create white space separation around token - if text[sidx - 1] != ' ': - text = text[:sidx] + ' ' + text[sidx:] - sidx += 1 - if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': - text = text[:sidx + 1] + ' ' + text[sidx + 1:] - sidx += 1 - return text - - -def clean_text(text): - # Replace white spaces in front and end - text = re.sub(r'^\s*|\s*$', '', text.strip().lower()) - - # Replace b&v or 'b and b' with 'bed and breakfast' - text = re.sub(r"b&b", "bed and breakfast", text) - text = re.sub(r"b and b", "bed and breakfast", text) - - # Fix apostrophies - text = re.sub(u"(\u2018|\u2019)", "'", text) - - # Correct punctuation - text = text.replace(';', ',') - text = re.sub('$\/', '', text) - text = text.replace('/', ' and ') - - # Replace special characters - text = text.replace('-', ' ') - text = re.sub('[\"\<>@\(\)]', '', text) - - # Insert white space around special tokens: - for token in ['?', '.', ',', '!']: - text = separate_token(token, text) - - # insert white space for 's - text = separate_token('\'s', text) - - # replace it's, does't, you'd ... etc - text = re.sub('^\'', '', text) - text = re.sub('\'$', '', text) - text = re.sub('\'\s', ' ', text) - text = re.sub('\s\'', ' ', text) - - # Perform pair replacements listed in the mapping.pair file - for fromx, tox in REPLACEMENTS: - text = ' ' + text + ' ' - text = text.replace(fromx, tox)[1:-1] - - # Remove multiple spaces - text = re.sub(' +', ' ', text) - - # Concatenate numbers eg '1 3' -> '13' - tokens = text.split() - i = 1 - while i < len(tokens): - if re.match(u'^\d+$', tokens[i]) and \ - re.match(u'\d+$', tokens[i - 1]): - tokens[i - 1] += tokens[i] - del tokens[i] - else: - i += 1 - text = ' '.join(tokens) - - return text diff --git a/convlab/dst/setsumbt/predict_user_actions.py b/convlab/dst/setsumbt/predict_user_actions.py new file mode 100644 index 00000000..2c304a56 --- /dev/null +++ b/convlab/dst/setsumbt/predict_user_actions.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Predict dataset user action using SetSUMBT Model""" + +from copy import deepcopy +import os +import json +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +from convlab.util.custom_util import flatten_acts as flatten +from convlab.util import load_dataset, load_policy_data +from convlab.dst.setsumbt import SetSUMBTTracker + + +def flatten_acts(acts: dict) -> list: + """ + Flatten dictionary actions. + + Args: + acts: Dictionary acts + + Returns: + flat_acts: Flattened actions + """ + acts = flatten(acts) + flat_acts = [] + for intent, domain, slot, value in acts: + flat_acts.append([intent, + domain, + slot if slot != 'none' else '', + value.lower() if value != 'none' else '']) + + return flat_acts + + +def get_user_actions(context: list, system_acts: list) -> list: + """ + Extract user actions from the data. + + Args: + context: Previous dialogue turns. + system_acts: List of flattened system actions. + + Returns: + user_acts: List of flattened user actions. + """ + user_acts = context[-1]['dialogue_acts'] + user_acts = flatten_acts(user_acts) + if len(context) == 3: + prev_state = context[-3]['state'] + cur_state = context[-1]['state'] + for domain, substate in cur_state.items(): + for slot, value in substate.items(): + if prev_state[domain][slot] != value: + act = ['inform', domain, slot, value] + if act not in user_acts and act not in system_acts: + user_acts.append(act) + + return user_acts + + +def extract_dataset(dataset: str = 'multiwoz21') -> list: + """ + Extract acts and utterances from the dataset. + + Args: + dataset: Dataset name + + Returns: + data: Extracted data + """ + data = load_dataset(dataset_name=dataset) + raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test'] + + dialogue = list() + data = list() + for turn in raw_data: + state = dict() + state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else '' + state['utterance'] = turn['context'][-1]['utterance'] + state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {} + state['system_actions'] = flatten_acts(state['system_actions']) + state['user_actions'] = get_user_actions(turn['context'], state['system_actions']) + dialogue.append(state) + if turn['terminated']: + data.append(dialogue) + dialogue = list() + + return data + + +def unflatten_acts(acts: list) -> dict: + """ + Convert acts from flat list format to dict format. + + Args: + acts: List of flat actions. + + Returns: + unflat_acts: Dictionary of acts. + """ + binary_acts = [] + cat_acts = [] + for intent, domain, slot, value in acts: + include = True if (domain == 'general') or (slot != 'none') else False + if include and (value == '' or value == 'none' or intent == 'request'): + binary_acts.append({'intent': intent, + 'domain': domain, + 'slot': slot if slot != 'none' else ''}) + elif include: + cat_acts.append({'intent': intent, + 'domain': domain, + 'slot': slot if slot != 'none' else '', + 'value': value}) + + unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()} + + return unflat_acts + + +def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list: + """ + Predict the user actions using the SetSUMBT Tracker. + + Args: + data: List of dialogues. + tracker: SetSUMBT Tracker + + Returns: + predict_result: List of turns containing predictions and true user actions. + """ + tracker.init_session() + predict_result = [] + for dial_idx, dialogue in enumerate(data): + for turn_idx, state in enumerate(dialogue): + sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx} + + tracker.state['history'].append(['sys', state['system_utterance']]) + predicted_state = deepcopy(tracker.update(state['utterance'])) + tracker.state['history'].append(['usr', state['utterance']]) + tracker.state['system_action'] = state['system_actions'] + + sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])} + sample['dialogue_acts'] = unflatten_acts(state['user_actions']) + + predict_result.append(sample) + + tracker.init_session() + + return predict_result + + +if __name__ =="__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21") + parser.add_argument('--model_path', type=str, help='Path to model dir') + args = parser.parse_args() + + dataset = extract_dataset(args.dataset_name) + tracker = SetSUMBTTracker(args.model_path) + predict_results = predict_user_acts(dataset, tracker) + + with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer: + json.dump(predict_results, writer, indent=2) + writer.close() diff --git a/convlab/dst/setsumbt/process_mwoz_data.py b/convlab/dst/setsumbt/process_mwoz_data.py deleted file mode 100755 index 701a5236..00000000 --- a/convlab/dst/setsumbt/process_mwoz_data.py +++ /dev/null @@ -1,99 +0,0 @@ -import os -import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - -import torch -from tqdm import tqdm - -from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker -from convlab.util.multiwoz.lexicalize import deflat_da, flat_da - - -def load_data(path): - with open(path, 'r') as reader: - data = json.load(reader) - reader.close() - - return data - - -def load_tracker(model_checkpoint): - model = SetSUMBTTracker(model_path=model_checkpoint) - model.init_session() - - return model - - -def process_dialogue(dial, model, get_full_belief_state): - model.store_full_belief_state = get_full_belief_state - model.init_session() - - model.state['history'].append(['sys', '']) - processed_dial = [] - belief_state = {} - for turn in dial: - if not turn['metadata']: - state = model.update(turn['text']) - model.state['history'].append(['usr', turn['text']]) - - acts = model.state['user_action'] - acts = [[val.replace('-', ' ') for val in act] for act in acts] - acts = flat_da(acts) - acts = deflat_da(acts) - turn['dialog_act'] = acts - else: - model.state['history'].append(['sys', turn['text']]) - turn['metadata'] = model.state['belief_state'] - - if get_full_belief_state: - for slot, probs in model.full_belief_state.items(): - if slot not in belief_state: - belief_state[slot] = [probs[0]] - else: - belief_state[slot].append(probs[0]) - - processed_dial.append(turn) - - if get_full_belief_state: - belief_state = {slot: torch.cat(probs, 0).cpu() for slot, probs in belief_state.items()} - - return processed_dial, belief_state - - -def process_dialogues(data, model, get_full_belief_state=False): - processed_data = {} - belief_states = {} - for dial_id, dial in tqdm(data.items()): - dial['log'], bs = process_dialogue(dial['log'], model, get_full_belief_state) - processed_data[dial_id] = dial - if get_full_belief_state: - belief_states[dial_id] = bs - - return processed_data, belief_states - - -def get_arguments(): - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - - parser.add_argument('--model_path') - parser.add_argument('--data_path') - parser.add_argument('--get_full_belief_state', action='store_true') - - return parser.parse_args() - - -if __name__ == "__main__": - args = get_arguments() - - print('Loading data and model...') - data = load_data(os.path.join(args.data_path, 'data.json')) - model = load_tracker(args.model_path) - - print('Processing data...\n') - data, belief_states = process_dialogues(data, model, get_full_belief_state=args.get_full_belief_state) - - print('Saving results...\n') - torch.save(belief_states, os.path.join(args.data_path, 'setsumbt_belief_states.bin')) - with open(os.path.join(args.data_path, 'setsumbt_data.json'), 'w') as writer: - json.dump(data, writer, indent=2) - writer.close() diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py index b9c9a75b..e45bf129 100644 --- a/convlab/dst/setsumbt/run.py +++ b/convlab/dst/setsumbt/run.py @@ -33,8 +33,8 @@ def main(): if args.run_nbt: from convlab.dst.setsumbt.do.nbt import main main(args, config) - if args.run_calibration: - from convlab.dst.setsumbt.do.calibration import main + if args.run_evaluation: + from convlab.dst.setsumbt.do.evaluate import main main(args, config) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py new file mode 100644 index 00000000..6b620247 --- /dev/null +++ b/convlab/dst/setsumbt/tracker.py @@ -0,0 +1,446 @@ +import os +import json +import copy +import logging + +import torch +import transformers +from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer + +from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT +from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings +from convlab.dst.dst import DST +from convlab.util.custom_util import model_downloader + +USE_CUDA = torch.cuda.is_available() +transformers.logging.set_verbosity_error() + + +class SetSUMBTTracker(DST): + """SetSUMBT Tracker object for Convlab dialogue system""" + + def __init__(self, + model_path: str = "", + model_type: str = "roberta", + return_turn_pooled_representation: bool = False, + return_confidence_scores: bool = False, + confidence_threshold='auto', + return_belief_state_entropy: bool = False, + return_belief_state_mutual_info: bool = False, + store_full_belief_state: bool = False): + """ + Args: + model_path: Model path or download URL + model_type: Transformer type (roberta/bert) + return_turn_pooled_representation: If true a turn level pooled representation is returned + return_confidence_scores: If true act confidence scores are included in the state + confidence_threshold: Confidence threshold value for constraints or option auto + return_belief_state_entropy: If true belief state distribution entropies are included in the state + return_belief_state_mutual_info: If true belief state distribution mutual infos are included in the state + store_full_belief_state: If true full belief state is stored within tracker object + """ + super(SetSUMBTTracker, self).__init__() + + self.model_type = model_type + self.model_path = model_path + self.return_turn_pooled_representation = return_turn_pooled_representation + self.return_confidence_scores = return_confidence_scores + self.confidence_threshold = confidence_threshold + self.return_belief_state_entropy = return_belief_state_entropy + self.return_belief_state_mutual_info = return_belief_state_mutual_info + self.store_full_belief_state = store_full_belief_state + if self.store_full_belief_state: + self.full_belief_state = {} + self.info_dict = {} + + # Download model if needed + if not os.path.exists(self.model_path): + # Get path /.../convlab/dst/setsumbt/multiwoz/models + download_path = os.path.dirname(os.path.abspath(__file__)) + download_path = os.path.join(download_path, 'models') + if not os.path.exists(download_path): + os.mkdir(download_path) + model_downloader(download_path, self.model_path) + # Downloadable model path format http://.../setsumbt_model_name.zip + self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') + self.model_path = os.path.join(download_path, self.model_path) + + # Select model type based on the encoder + if model_type == "roberta": + self.config = RobertaConfig.from_pretrained(self.model_path) + self.tokenizer = RobertaTokenizer + self.model = RobertaSetSUMBT + elif model_type == "bert": + self.config = BertConfig.from_pretrained(self.model_path) + self.tokenizer = BertTokenizer + self.model = BertSetSUMBT + else: + logging.debug("Name Error: Not Implemented") + + self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu') + + self.load_weights() + + def load_weights(self): + """Load model weights and model ontology""" + logging.info('Loading SetSUMBT pretrained model.') + self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name) + logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.') + self.model = self.model.from_pretrained(self.model_path, config=self.config) + logging.info(f'Model loaded from {self.model_path}.') + + # Transfer model to compute device and setup eval environment + self.model = self.model.to(self.device) + self.model.eval() + logging.info(f'Model transferred to device: {self.device}') + + logging.info('Loading model ontology') + f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r') + self.ontology = json.load(f) + f.close() + + db = torch.load(os.path.join(self.model_path, 'database', 'test.db')) + set_ontology_embeddings(self.model, db) + + if self.return_confidence_scores: + logging.info('Model returns user action and belief state confidence scores.') + self.get_thresholds(self.confidence_threshold) + logging.info('Uncertain Querying set up and thresholds set up at:') + logging.info(self.confidence_thresholds) + if self.return_belief_state_entropy: + logging.info('Model returns belief state distribution entropy scores (Total uncertainty).') + if self.return_belief_state_mutual_info: + logging.info('Model returns belief state distribution mutual information scores (Knowledge uncertainty).') + logging.info('Ontology loaded successfully.') + + def get_thresholds(self, threshold='auto') -> dict: + """ + Setup dictionary of domain specific confidence thresholds + + Args: + threshold: Threshold value or option auto + + Returns: + confidence_thresholds: Domain specific confidence thresholds + """ + self.confidence_thresholds = dict() + for domain, substate in self.ontology.items(): + for slot, slot_info in substate.items(): + # Auto thresholds are set based on the number of value candidates per slot + if domain not in self.confidence_thresholds: + self.confidence_thresholds[domain] = dict() + if threshold == 'auto': + thres = 1.0 / (float(len(slot_info['possible_values'])) - 2.1) + self.confidence_thresholds[domain][slot] = max(0.05, thres) + else: + self.confidence_thresholds[domain][slot] = max(0.05, threshold) + + return self.confidence_thresholds + + def init_session(self): + self.state = dict() + self.state['belief_state'] = dict() + self.state['booked'] = dict() + for domain, substate in self.ontology.items(): + self.state['belief_state'][domain] = dict() + for slot, slot_info in substate.items(): + if slot_info['possible_values'] and slot_info['possible_values'] != ['?']: + self.state['belief_state'][domain][slot] = '' + self.state['booked'][domain] = list() + self.state['history'] = [] + self.state['system_action'] = [] + self.state['user_action'] = [] + self.state['terminated'] = False + self.active_domains = {} + self.hidden_states = None + self.info_dict = {} + + def update(self, user_act: str = '') -> dict: + """ + Update user actions and dialogue and belief states. + + Args: + user_act: + + Returns: + + """ + prev_state = self.state + _output = self.predict(self.get_features(user_act)) + + # Format state entropy + if _output[5] is not None: + state_entropy = dict() + for slot, e in _output[5].items(): + domain, slot = slot.split('-', 1) + if domain not in state_entropy: + state_entropy[domain] = dict() + state_entropy[domain][slot] = e + else: + state_entropy = None + + # Format state mutual information + if _output[6] is not None: + state_mutual_info = dict() + for slot, mi in _output[6].items(): + domain, slot = slot.split('-', 1) + if domain not in state_mutual_info: + state_mutual_info[domain] = dict() + state_mutual_info[domain][slot] = mi[0, 0] + else: + state_mutual_info = None + + # Format all confidence scores + belief_state_confidence = None + if _output[4] is not None: + belief_state_confidence = dict() + belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4] + for slot, p in belief_state_conf.items(): + domain, slot = slot.split('-', 1) + if domain not in belief_state_confidence: + belief_state_confidence[domain] = dict() + if slot not in belief_state_confidence[domain]: + belief_state_confidence[domain][slot] = dict() + belief_state_confidence[domain][slot]['inform'] = p + + for slot, p in request_probs.items(): + domain, slot = slot.split('-', 1) + if domain not in belief_state_confidence: + belief_state_confidence[domain] = dict() + if slot not in belief_state_confidence[domain]: + belief_state_confidence[domain][slot] = dict() + belief_state_confidence[domain][slot]['request'] = p + + for domain, p in active_domain_probs.items(): + if domain not in belief_state_confidence: + belief_state_confidence[domain] = dict() + belief_state_confidence[domain]['none'] = {'inform': p} + + if 'general' not in belief_state_confidence: + belief_state_confidence['general'] = dict() + belief_state_confidence['general']['none'] = general_act_probs + + # Get new domain activation actions + new_domains = [d for d, active in _output[1].items() if active] + new_domains = [d for d in new_domains if not self.active_domains.get(d, False)] + self.active_domains = _output[1] + + user_acts = _output[2] + for domain in new_domains: + user_acts.append(['inform', domain, 'none', 'none']) + + new_belief_state = copy.deepcopy(prev_state['belief_state']) + for domain, substate in _output[0].items(): + for slot, value in substate.items(): + value = '' if value == 'none' else value + value = 'dontcare' if value == 'do not care' else value + value = 'guesthouse' if value == 'guest house' else value + + if domain not in new_belief_state: + if domain == 'bus': + continue + else: + logging.debug('Error: domain <{}> not in belief state'.format(domain)) + + # Uncertainty clipping of state + if belief_state_confidence is not None: + threshold = self.confidence_thresholds[domain][slot] + if belief_state_confidence[domain][slot].get('inform', 1.0) < threshold: + value = '' + + new_belief_state[domain][slot] = value + if prev_state['belief_state'][domain][slot] != value: + user_acts.append(['inform', domain, slot, value]) + else: + bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>' + logging.debug(bug) + + new_state = copy.deepcopy(dict(prev_state)) + new_state['belief_state'] = new_belief_state + new_state['active_domains'] = self.active_domains + if belief_state_confidence is not None: + new_state['belief_state_probs'] = belief_state_confidence + if state_entropy is not None: + new_state['entropy'] = state_entropy + if state_mutual_info is not None: + new_state['mutual_information'] = state_mutual_info + + user_acts = [act for act in user_acts if act not in new_state['system_action']] + new_state['user_action'] = user_acts + + if _output[3] is not None: + new_state['turn_pooled_representation'] = _output[3] + + self.state = new_state + self.info_dict = copy.deepcopy(dict(new_state)) + + return self.state + + def predict(self, features: dict) -> tuple: + """ + Model forward pass and prediction post processing. + + Args: + features: Dictionary of model input features + + Returns: + out: Model predictions and uncertainty features + """ + state_mutual_info = None + with torch.no_grad(): + turn_pooled_representation = None + if self.return_turn_pooled_representation: + _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], + attention_mask=features['attention_mask'], hidden_state=self.hidden_states, + get_turn_pooled_representation=True) + belief_state = _outputs[0] + request_probs = _outputs[1] + active_domain_probs = _outputs[2] + general_act_probs = _outputs[3] + self.hidden_states = _outputs[4] + turn_pooled_representation = _outputs[5] + elif self.return_belief_state_mutual_info: + _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], + attention_mask=features['attention_mask'], hidden_state=self.hidden_states, + get_turn_pooled_representation=True, calculate_state_mutual_info=True) + belief_state = _outputs[0] + request_probs = _outputs[1] + active_domain_probs = _outputs[2] + general_act_probs = _outputs[3] + self.hidden_states = _outputs[4] + state_mutual_info = _outputs[5] + else: + _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], + attention_mask=features['attention_mask'], hidden_state=self.hidden_states, + get_turn_pooled_representation=False) + belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs + + # Convert belief state into dialog state + dialogue_state = dict() + for slot, probs in belief_state.items(): + dom, slot = slot.split('-', 1) + if dom not in dialogue_state: + dialogue_state[dom] = dict() + val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()] + if val != 'none': + dialogue_state[dom][slot] = val + + if self.store_full_belief_state: + self.full_belief_state = belief_state + + # Obtain model output probabilities + if self.return_confidence_scores: + state_entropy = None + if self.return_belief_state_entropy: + state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()} + state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()} + + # Confidence score is the max probability across all not "none" values candidates. + belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()} + _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()} + _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()} + _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()} + confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs) + else: + confidence_scores = None + state_entropy = None + + # Construct request action prediction + request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] + request_acts = [slot.split('-', 1) for slot in request_acts] + request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + + # Construct active domain set + active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + + # Construct general domain action + general_acts = general_act_probs[0, 0, :].argmax(-1).item() + general_acts = [[], ['bye'], ['thank']][general_acts] + general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + + user_acts = request_acts + general_acts + + out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores) + out += (state_entropy, state_mutual_info) + return out + + def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor: + """ + Compute relative entrop for a probability distribution + + Args: + probs: Probability distributions + + Returns: + entropy: Relative entropy + """ + entropy = probs * torch.log(probs + 1e-8) + entropy = -entropy.sum() + # Maximum entropy of a K dimentional distribution is ln(K) + entropy /= torch.log(torch.tensor(probs.size(-1)).float()) + + return entropy + + def get_features(self, user_act: str) -> dict: + """ + Tokenize utterances and construct model input features + + Args: + user_act: User action string + + Returns: + features: Model input features + """ + # Extract system utterance from dialog history + context = self.state['history'] + if context: + if context[-1][0] != 'sys': + system_act = '' + else: + system_act = context[-1][-1] + else: + system_act = '' + + # Tokenize dialog + features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, + max_length=self.config.max_turn_len, padding='max_length', + truncation='longest_first') + + input_ids = torch.tensor(features['input_ids']).reshape( + 1, 1, -1).to(self.device) if 'input_ids' in features else None + token_type_ids = torch.tensor(features['token_type_ids']).reshape( + 1, 1, -1).to(self.device) if 'token_type_ids' in features else None + attention_mask = torch.tensor(features['attention_mask']).reshape( + 1, 1, -1).to(self.device) if 'attention_mask' in features else None + features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask} + + return features + + +# if __name__ == "__main__": +# from convlab.policy.vector.vector_uncertainty import VectorUncertainty +# # from convlab.policy.vector.vector_binary import VectorBinary +# tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42', +# return_confidence_scores=True, confidence_threshold='auto', +# return_belief_state_entropy=True) +# vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds, +# use_masking=True) +# # vector = VectorBinary() +# tracker.init_session() +# +# state = tracker.update('hey. I need a cheap restaurant.') +# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) +# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) +# state = tracker.update('If you have something Asian that would be great.') +# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) +# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) +# tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'], +# ['inform', 'restaurant', 'name', 'the golden wok']] +# state = tracker.update('Great. Where are they located?') +# tracker.state['history'].append(['usr', 'Great. Where are they located?']) +# state = tracker.state +# state['terminated'] = False +# state['booked'] = {} +# +# print(state) +# print(vector.state_vectorize(state)) diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils.py index 75a6a1fe..51839552 100644 --- a/convlab/dst/setsumbt/utils.py +++ b/convlab/dst/setsumbt/utils.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,57 +15,43 @@ # limitations under the License. """SetSUMBT utils""" -import re import os import shutil from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -from glob import glob from datetime import datetime -from google.cloud import storage +from git import Repo -def get_args(MODELS): +def get_args(base_models: dict): # Get arguments parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) # Optional - parser.add_argument('--tensorboard_path', - help='Path to tensorboard', default='') + parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='') parser.add_argument('--logging_path', help='Path for log file', default='') - parser.add_argument( - '--seed', help='Seed value for reproducability', default=0, type=int) + parser.add_argument('--seed', help='Seed value for reproducibility', default=0, type=int) # DATASET (Optional) - parser.add_argument( - '--dataset', help='Dataset Name: multiwoz21/simr', default='multiwoz21') - parser.add_argument('--shrink_active_domains', help='Shrink active domains to only well represented test set domains', - action='store_true') - parser.add_argument( - '--data_dir', help='Data storage directory', default=None) - parser.add_argument( - '--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int) - parser.add_argument( - '--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int) - parser.add_argument( - '--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int) - parser.add_argument('--max_candidate_len', - help='Maximum number of tokens per value candidate', default=12, type=int) - parser.add_argument('--force_processing', action='store_true', - help='Force preprocessing of data.') - parser.add_argument('--data_sampling_size', - help='Resampled dataset size', default=-1, type=int) - parser.add_argument('--use_descriptions', help='Use slot descriptions rather than slot names for embeddings', + parser.add_argument('--dataset', help='Dataset Name (See Convlab 3 unified format for possible datasets', + default='multiwoz21') + parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0, + type=float) + parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int) + parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int) + parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int) + parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12, + type=int) + parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.') + parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int) + parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings', action='store_true') # MODEL # Environment - parser.add_argument( - '--output_dir', help='Output storage directory', default=None) - parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', - default='roberta') - parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', - default=None) + parser.add_argument('--output_dir', help='Output storage directory', default=None) + parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta') + parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None) parser.add_argument('--candidate_embedding_model_name', default=None, help='Name of the pretrained candidate embedding model.') @@ -74,92 +60,73 @@ def get_args(MODELS): action='store_true') parser.add_argument('--slot_attention_heads', help='Number of attention heads for slot conditioning', default=12, type=int) - parser.add_argument('--dropout_rate', help='Dropout Rate', - default=0.3, type=float) - parser.add_argument( - '--nbt_type', help='Belief Tracker type: gru/lstm', default='gru') + parser.add_argument('--dropout_rate', help='Dropout Rate', default=0.3, type=float) + parser.add_argument('--nbt_type', help='Belief Tracker type: gru/lstm', default='gru') parser.add_argument('--nbt_hidden_size', help='Hidden embedding size for the Neural Belief Tracker', default=300, type=int) - parser.add_argument( - '--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int) - parser.add_argument( - '--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true') + parser.add_argument('--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int) + parser.add_argument('--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true') parser.add_argument('--distance_measure', default='cosine', help='Similarity measure for candidate scoring: cosine/euclidean') - parser.add_argument( - '--ensemble_size', help='Number of models in ensemble', default=-1, type=int) - parser.add_argument('--set_similarity', action='store_true', - help='Set True to not use set similarity (Model tracks latent belief state as sequence and performs semantic similarity of sets)') - parser.add_argument('--set_pooling', help='Set pooling method for set similarity model using single embedding distances', + parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int) + parser.add_argument('--no_set_similarity', action='store_true', help='Set True to not use set similarity') + parser.add_argument('--set_pooling', + help='Set pooling method for set similarity model using single embedding distances', default='cnn') - parser.add_argument('--candidate_pooling', help='Pooling approach for non set based candidate representations: cls/mean', + parser.add_argument('--candidate_pooling', + help='Pooling approach for non set based candidate representations: cls/mean', default='mean') - parser.add_argument('--predict_actions', help='Model predicts user actions and active domain', + parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain', action='store_true') # Loss - parser.add_argument('--loss_function', help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/distillation/distribution_distillation', + parser.add_argument('--loss_function', + help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/...', default='labelsmoothing') parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss', type=float) parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss', type=float) - parser.add_argument('--ensemble_smoothing', - help='Ensemble distribution smoothing constant', type=float) - parser.add_argument('--annealing_base_temp', help='Ensemble Distribution destillation temp annealing base temp', + parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float) + parser.add_argument('--annealing_base_temp', help='Ensemble Distribution distillation temp annealing base temp', type=float) - parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution destillation temp annealing cycle length', + parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution distillation temp annealing cycle length', type=float) - parser.add_argument('--inhibiting_factor', - help='Inhibiting factor for Inhibited Softmax CE', type=float) - parser.add_argument('--label_smoothing', - help='Label smoothing coefficient.', type=float) - parser.add_argument( - '--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', type=float) - parser.add_argument( - '--user_request_loss_weight', help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float) - parser.add_argument( - '--user_general_act_loss_weight', help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float) - parser.add_argument( - '--active_domain_loss_weight', help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float) + parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', + type=float) + parser.add_argument('--user_request_loss_weight', + help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument('--user_general_act_loss_weight', + help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument('--active_domain_loss_weight', + help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float) # TRAINING - parser.add_argument('--train_batch_size', - help='Training Set Batch Size', default=4, type=int) - parser.add_argument('--max_training_steps', help='Maximum number of training update steps', - default=-1, type=int) + parser.add_argument('--train_batch_size', help='Training Set Batch Size', default=8, type=int) + parser.add_argument('--max_training_steps', help='Maximum number of training update steps', default=-1, type=int) parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='Number of batches accumulated for one update step') - parser.add_argument('--num_train_epochs', - help='Number of training epochs', default=50, type=int) + parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int) parser.add_argument('--patience', help='Number of training steps without improving model before stopping.', - default=25, type=int) - parser.add_argument( - '--weight_decay', help='Weight decay rate', default=0.01, type=float) - parser.add_argument('--learning_rate', - help='Initial Learning Rate', default=5e-5, type=float) - parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', - default=0.2, type=float) - parser.add_argument( - '--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float) - parser.add_argument( - '--save_steps', help='Number of update steps between saving model', default=-1, type=int) - parser.add_argument( - '--keep_models', help='How many model checkpoints should be kept during training', default=1, type=int) + default=20, type=int) + parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float) + parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float) + parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float) + parser.add_argument('--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float) + parser.add_argument('--save_steps', help='Number of update steps between saving model', default=-1, type=int) + parser.add_argument('--keep_models', help='How many model checkpoints should be kept during training', + default=1, type=int) # CALIBRATION - parser.add_argument( - '--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float) + parser.add_argument('--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float) # EVALUATION - parser.add_argument('--dev_batch_size', - help='Dev Set Batch Size', default=16, type=int) - parser.add_argument('--test_batch_size', - help='Test Set Batch Size', default=16, type=int) + parser.add_argument('--dev_batch_size', help='Dev Set Batch Size', default=16, type=int) + parser.add_argument('--test_batch_size', help='Test Set Batch Size', default=16, type=int) # COMPUTING - parser.add_argument( - '--n_gpu', help='Number of GPUs to use', default=1, type=int) + parser.add_argument('--n_gpu', help='Number of GPUs to use', default=1, type=int) parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O1', @@ -167,32 +134,29 @@ def get_args(MODELS): "See details at https://nvidia.github.io/apex/amp.html") # ACTIONS - parser.add_argument('--run_nbt', help='Run NBT script', - action='store_true') - parser.add_argument('--run_calibration', - help='Run calibration', action='store_true') + parser.add_argument('--run_nbt', help='Run NBT script', action='store_true') + parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true') # RUN_NBT ACTIONS - parser.add_argument( - '--do_train', help='Perform training', action='store_true') - parser.add_argument( - '--do_eval', help='Perform model evaluation during training', action='store_true') - parser.add_argument( - '--do_test', help='Evaulate model on test data', action='store_true') + parser.add_argument('--do_train', help='Perform training', action='store_true') + parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true') + parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true') args = parser.parse_args() - # Setup default directories - if not args.data_dir: - args.data_dir = os.path.dirname(os.path.abspath(__file__)) - args.data_dir = os.path.join(args.data_dir, 'data') - os.makedirs(args.data_dir, exist_ok=True) + # Simplify args + args.set_similarity = not args.no_set_similarity + args.use_descriptions = not args.no_descriptions + args.predict_actions = not args.no_action_prediction + # Setup default directories if not args.output_dir: args.output_dir = os.path.dirname(os.path.abspath(__file__)) args.output_dir = os.path.join(args.output_dir, 'models') - name = 'SetSUMBT' - name += '-Acts' if args.predict_actions else '' + name = 'SetSUMBT' if args.set_similarity else 'SUMBT' + name += '+ActPrediction' if args.predict_actions else '' + name += '-' + args.dataset + name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else '' name += '-' + args.model_type name += '-' + args.nbt_type name += '-' + args.distance_measure @@ -208,9 +172,6 @@ def get_args(MODELS): args.kl_scaling_factor = 0.001 if not args.prior_constant: args.prior_constant = 1.0 - if args.loss_function == 'inhibitedce': - if not args.inhibiting_factor: - args.inhibiting_factor = 1.0 if args.loss_function == 'labelsmoothing': if not args.label_smoothing: args.label_smoothing = 0.05 @@ -233,10 +194,8 @@ def get_args(MODELS): if not args.active_domain_loss_weight: args.active_domain_loss_weight = 0.2 - args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join( - args.output_dir, 'tb_logs') - args.logging_path = args.logging_path if args.logging_path else os.path.join( - args.output_dir, 'run.log') + args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(args.output_dir, 'tb_logs') + args.logging_path = args.logging_path if args.logging_path else os.path.join(args.output_dir, 'run.log') # Default model_name's if not args.model_name_or_path: @@ -250,30 +209,37 @@ def get_args(MODELS): if not args.candidate_embedding_model_name: args.candidate_embedding_model_name = args.model_name_or_path - if args.model_type in MODELS: - configClass = MODELS[args.model_type][-2] + if args.model_type in base_models: + config_class = base_models[args.model_type][-2] else: raise NameError('NotImplemented') - config = build_config(configClass, args) + config = build_config(config_class, args) return args, config -def build_config(configClass, args): - if args.model_type == 'fasttext': - config = configClass.from_pretrained('bert-base-uncased') - config.model_type == 'fasttext' - config.fasttext_path = args.model_name_or_path - config.vocab_size = None - elif not os.path.exists(args.model_name_or_path): - config = configClass.from_pretrained(args.model_name_or_path) +def get_git_info(): + repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) + branch_name = repo.active_branch.name + commit_hex = repo.head.object.hexsha + + info = f"{branch_name}/{commit_hex}" + return info + + +def build_config(config_class, args): + config = config_class.from_pretrained(args.model_name_or_path) + config.code_version = get_git_info() + if not os.path.exists(args.model_name_or_path): config.tokenizer_name = args.model_name_or_path - elif 'tod-bert' in args.model_name_or_path.lower(): - config = configClass.from_pretrained(args.model_name_or_path) + try: + config.tokenizer_name = config.tokenizer_name + except AttributeError: config.tokenizer_name = args.model_name_or_path - else: - config = configClass.from_pretrained(args.model_name_or_path) - if args.candidate_embedding_model_name: - config.candidate_embedding_model_name = args.candidate_embedding_model_name + try: + config.candidate_embedding_model_name = config.candidate_embedding_model_name + except: + if args.candidate_embedding_model_name: + config.candidate_embedding_model_name = args.candidate_embedding_model_name config.max_dialogue_len = args.max_dialogue_len config.max_turn_len = args.max_turn_len config.max_slot_len = args.max_slot_len diff --git a/convlab/policy/mle/loader.py b/convlab/policy/mle/loader.py index ebc01a01..bb898ab4 100755 --- a/convlab/policy/mle/loader.py +++ b/convlab/policy/mle/loader.py @@ -2,6 +2,9 @@ import os import pickle import torch import torch.utils.data as data +from copy import deepcopy + +from tqdm import tqdm from convlab.policy.vector.vector_binary import VectorBinary from convlab.util import load_policy_data, load_dataset @@ -12,18 +15,20 @@ from convlab.policy.vector.dataset import ActDataset class PolicyDataVectorizer: - def __init__(self, dataset_name='multiwoz21', vector=None): + def __init__(self, dataset_name='multiwoz21', vector=None, dst=None): self.dataset_name = dataset_name if vector is None: self.vector = VectorBinary(dataset_name) else: self.vector = vector + self.dst = dst self.process_data() def process_data(self): - - processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), - f'processed_data/{self.dataset_name}_{type(self.vector).__name__}') + name = f"{self.dataset_name}_" + name += f"{type(self.dst).__name__}_" if self.dst is not None else "" + name += f"{type(self.vector).__name__}" + processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), name) if os.path.exists(processed_dir): print('Load processed data file') self._load_data(processed_dir) @@ -42,15 +47,27 @@ class PolicyDataVectorizer: self.data[split] = [] raw_data = data_split[split] - for data_point in raw_data: - state = default_state() + if self.dst is not None: + self.dst.init_session() + + for data_point in tqdm(raw_data): + if self.dst is None: + state = default_state() + + state['belief_state'] = data_point['context'][-1]['state'] + state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts']) + else: + last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else '' + self.dst.state['history'].append(['sys', last_system_utt]) - state['belief_state'] = data_point['context'][-1]['state'] - state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts']) - last_system_act = data_point['context'][-2]['dialogue_acts'] \ - if len(data_point['context']) > 1 else {} + usr_utt = data_point['context'][-1]['utterance'] + state = deepcopy(self.dst.update(usr_utt)) + self.dst.state['history'].append(['usr', usr_utt]) + last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {} state['system_action'] = flatten_acts(last_system_act) state['terminated'] = data_point['terminated'] + if self.dst is not None and state['terminated']: + self.dst.init_session() state['booked'] = data_point['booked'] dialogue_act = flatten_acts(data_point['dialogue_acts']) diff --git a/convlab/policy/mle/train.py b/convlab/policy/mle/train.py index 2b82a476..c2477760 100755 --- a/convlab/policy/mle/train.py +++ b/convlab/policy/mle/train.py @@ -137,15 +137,6 @@ class MLE_Trainer(MLE_Trainer_Abstract): def __init__(self, manager, vector, cfg): self._init_data(manager, cfg) - try: - self.use_entropy = manager.use_entropy - self.use_mutual_info = manager.use_mutual_info - self.use_confidence_scores = manager.use_confidence_scores - except: - self.use_entropy = False - self.use_mutual_info = False - self.use_confidence_scores = False - # override the loss defined in the MLE_Trainer_Abstract to support pos_weight pos_weight = cfg['pos_weight'] * torch.ones(vector.da_dim).to(device=DEVICE) self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight) @@ -161,6 +152,10 @@ def arg_parser(): parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--dataset_name", type=str, default="multiwoz21") + parser.add_argument("--use_masking", action='store_true') + + parser.add_argument("--dst", type=str, default=None) + parser.add_argument("--dst_args", type=str, default=None) args = parser.parse_args() return args @@ -181,8 +176,28 @@ if __name__ == '__main__': set_seed(args.seed) logging.info(f"Seed used: {args.seed}") - vector = VectorBinary(dataset_name=args.dataset_name, use_masking=False) - manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector) + if args.dst is None: + vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking) + dst = None + elif args.dst == "setsumbt": + dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ') + if '=' in arg] if args.dst_args is not None else [] + dst_args = {key: eval(value) for key, value in dst_args} + from convlab.dst.setsumbt import SetSUMBTTracker + dst = SetSUMBTTracker(**dst_args) + if dst.return_confidence_scores: + from convlab.policy.vector.vector_uncertainty import VectorUncertainty + vector = VectorUncertainty(dataset_name=args.dataset_name, use_masking=args.use_masking, + manually_add_entity_names=False, + use_confidence_scores=dst.return_confidence_scores, + confidence_thresholds=dst.confidence_thresholds, + use_state_total_uncertainty=dst.return_belief_state_entropy, + use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info) + else: + vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking) + else: + raise NameError(f"Tracker: {args.tracker} not implemented.") + manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst) agent = MLE_Trainer(manager, vector, cfg) logging.info('Start training') diff --git a/convlab/policy/ppo/setsumbt_end_baseline_config.json b/convlab/policy/ppo/setsumbt_config.json similarity index 53% rename from convlab/policy/ppo/setsumbt_end_baseline_config.json rename to convlab/policy/ppo/setsumbt_config.json index ea84dd76..5a13ee82 100644 --- a/convlab/policy/ppo/setsumbt_end_baseline_config.json +++ b/convlab/policy/ppo/setsumbt_config.json @@ -1,22 +1,22 @@ { "model": { - "load_path": "supervised", + "load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-11-13-12-56-34/save/supervised", "pretrained_load_path": "", "use_pretrained_initialisation": false, "batchsz": 1000, "seed": 0, "epoch": 50, "eval_frequency": 5, - "process_num": 4, + "process_num": 2, "num_eval_dialogues": 500, - "sys_semantic_to_usr": false + "sys_semantic_to_usr": true }, "vectorizer_sys": { "uncertainty_vector_mul": { - "class_path": "convlab.policy.vector.vector_multiwoz_uncertainty.MultiWozVector", + "class_path": "convlab.policy.vector.vector_binary.VectorBinary", "ini_params": { "use_masking": false, - "manually_add_entity_names": false, + "manually_add_entity_names": true, "seed": 0 } } @@ -24,12 +24,9 @@ "nlu_sys": {}, "dst_sys": { "setsumbt-mul": { - "class_path": "convlab.dst.setsumbt.multiwoz.Tracker.SetSUMBTTracker", + "class_path": "convlab.dst.setsumbt.SetSUMBTTracker", "ini_params": { - "model_path": "https://zenodo.org/record/5497808/files/setsumbt_end.zip", - "get_confidence_scores": true, - "return_mutual_info": false, - "return_entropy": true + "model_path": "/gpfs/project/niekerk/models/setsumbt_models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0-30-08-22-15-00" } } }, @@ -41,16 +38,7 @@ } } }, - "nlu_usr": { - "BERTNLU": { - "class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU", - "ini_params": { - "mode": "sys", - "config_file": "multiwoz_sys_context.json", - "model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip" - } - } - }, + "nlu_usr": {}, "dst_usr": {}, "policy_usr": { "RulePolicy": { @@ -65,7 +53,7 @@ "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", "ini_params": { "is_user": true, - "label_noise": 0.0, + "label_noise": 0.05, "text_noise": 0.0 } } diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/setsumbt_unc_config.json new file mode 100644 index 00000000..6b7d115a --- /dev/null +++ b/convlab/policy/ppo/setsumbt_unc_config.json @@ -0,0 +1,65 @@ +{ + "model": { + "load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-11-10-10-37-30/save/supervised", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 1000, + "seed": 0, + "epoch": 50, + "eval_frequency": 5, + "process_num": 2, + "num_eval_dialogues": 500, + "sys_semantic_to_usr": true + }, + "vectorizer_sys": { + "uncertainty_vector_mul": { + "class_path": "convlab.policy.vector.vector_uncertainty.VectorUncertainty", + "ini_params": { + "use_masking": false, + "manually_add_entity_names": true, + "seed": 0, + "use_confidence_scores": true, + "use_state_knowledge_uncertainty": true + } + } + }, + "nlu_sys": {}, + "dst_sys": { + "setsumbt-mul": { + "class_path": "convlab.dst.setsumbt.SetSUMBTTracker", + "ini_params": { + "model_path": "/gpfs/project/niekerk/models/setsumbt_models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-distribution_distillation-Seed0-30-08-22-15-00", + "return_confidence_scores": true, + "return_belief_state_mutual_info": true + } + } + }, + "sys_nlg": { + "TemplateNLG": { + "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", + "ini_params": { + "is_user": false + } + } + }, + "nlu_usr": {}, + "dst_usr": {}, + "policy_usr": { + "RulePolicy": { + "class_path": "convlab.policy.rule.multiwoz.RulePolicy", + "ini_params": { + "character": "usr" + } + } + }, + "usr_nlg": { + "TemplateNLG": { + "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", + "ini_params": { + "is_user": true, + "label_noise": 0.05, + "text_noise": 0.0 + } + } + } +} \ No newline at end of file diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index 50b06aab..45681169 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -199,7 +199,7 @@ if __name__ == '__main__': logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ init_logging(os.path.dirname(os.path.abspath(__file__)), mode) - args = [('model', 'seed', seed)] if seed is not None else list() + args = [('model', 'seed', seed)] if seed else list() environment_config = load_config_file(path) save_config(vars(parser.parse_args()), environment_config, config_save_path) @@ -228,14 +228,6 @@ if __name__ == '__main__': env, sess = env_config(conf, policy_sys) - # Setup uncertainty thresholding - if env.sys_dst: - try: - if env.sys_dst.use_confidence_scores: - policy_sys.vector.setup_uncertain_query(env.sys_dst.thresholds) - except: - logging.info('Uncertainty threshold not set.') - policy_sys.current_time = current_time policy_sys.log_dir = config_save_path.replace('configs', 'logs') policy_sys.save_dir = save_path @@ -261,7 +253,7 @@ if __name__ == '__main__': if idx % conf['model']['eval_frequency'] == 0 and idx != 0: time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60) + logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60) eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path) diff --git a/convlab/policy/vector/dataset.py b/convlab/policy/vector/dataset.py index b1481854..0aa1b7ad 100755 --- a/convlab/policy/vector/dataset.py +++ b/convlab/policy/vector/dataset.py @@ -18,26 +18,6 @@ class ActDataset(data.Dataset): return self.num_total -class ActDatasetKG(data.Dataset): - def __init__(self, action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch): - self.action_batch = action_batch - self.a_masks = a_masks - self.current_domain_mask_batch = current_domain_mask_batch - self.non_current_domain_mask_batch = non_current_domain_mask_batch - self.num_total = len(action_batch) - - def __getitem__(self, index): - action = self.action_batch[index] - action_mask = self.a_masks[index] - current_domain_mask = self.current_domain_mask_batch[index] - non_current_domain_mask = self.non_current_domain_mask_batch[index] - - return action, action_mask, current_domain_mask, non_current_domain_mask, index - - def __len__(self): - return self.num_total - - class ActStateDataset(data.Dataset): def __init__(self, s_s, a_s, next_s): self.s_s = s_s diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 89f22203..8b7d8ff0 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -2,11 +2,10 @@ import os import sys import numpy as np -import logging from copy import deepcopy from convlab.policy.vec import Vector -from convlab.util.custom_util import flatten_acts, timeout +from convlab.util.custom_util import flatten_acts from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab.util import load_ontology, load_database, load_dataset @@ -23,20 +22,18 @@ class VectorBase(Vector): super().__init__() - logging.info(f"Vectorizer: Data set used is {dataset_name}") self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: # execute to make sure that the database exists or is downloaded otherwise - if dataset_name == "multiwoz21": - load_database(dataset_name) + load_database(dataset_name) # the following two lines are needed for pickling correctly during multi-processing exec(f'from data.unified_datasets.{dataset_name}.database import Database') self.db = eval('Database()') self.db_domains = self.db.domains except Exception as e: self.db = None - self.db_domains = [] + self.db_domains = None print(f"VectorBase: {e}") self.dataset_name = dataset_name @@ -275,10 +272,6 @@ class VectorBase(Vector): 2. If there is an entity available, can not say NoOffer or NoBook ''' mask_list = np.zeros(self.da_dim) - - if number_entities_dict is None: - return mask_list - for i in range(self.da_dim): action = self.vec2act[i] domain, intent, slot, value = action.split('-') diff --git a/convlab/policy/vector/vector_binary.py b/convlab/policy/vector/vector_binary.py index 3671178b..e780dc64 100755 --- a/convlab/policy/vector/vector_binary.py +++ b/convlab/policy/vector/vector_binary.py @@ -8,7 +8,7 @@ from .vector_base import VectorBase class VectorBinary(VectorBase): def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True, - seed=0): + seed=0, **kwargs): super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) diff --git a/convlab/policy/vector/vector_multiwoz_uncertainty.py b/convlab/policy/vector/vector_multiwoz_uncertainty.py deleted file mode 100644 index 6a0850f4..00000000 --- a/convlab/policy/vector/vector_multiwoz_uncertainty.py +++ /dev/null @@ -1,238 +0,0 @@ -# -*- coding: utf-8 -*- -import sys -import os -import numpy as np -import logging -from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da -from convlab.util.multiwoz.state import default_state -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -from .vector_binary import VectorBinary as VectorBase - -DEFAULT_INTENT_FILEPATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))), - 'data/multiwoz/trackable_intent.json' -) - - -SLOT_MAP = {'taxi_types': 'car type'} - - -class MultiWozVector(VectorBase): - - def __init__(self, voc_file=None, voc_opp_file=None, character='sys', - intent_file=DEFAULT_INTENT_FILEPATH, - use_confidence_scores=False, - use_entropy=False, - use_mutual_info=False, - use_masking=False, - manually_add_entity_names=False, - seed=0, - shrink=False): - - self.use_confidence_scores = use_confidence_scores - self.use_entropy = use_entropy - self.use_mutual_info = use_mutual_info - self.thresholds = None - - super().__init__(voc_file, voc_opp_file, character, intent_file, use_masking, manually_add_entity_names, seed) - - def get_state_dim(self): - self.belief_state_dim = 0 - for domain in self.belief_domains: - for slot in default_state()['belief_state'][domain.lower()]['semi']: - # Dim 1 - indicator/confidence score - # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc) - slot_dim = 1 if not self.use_entropy else 2 - slot_dim += 1 if self.use_mutual_info else 0 - self.belief_state_dim += slot_dim - - self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \ - len(self.db_domains) + 6 * len(self.db_domains) + 1 - - def dbquery_domain(self, domain): - """ - query entities of specified domain - Args: - domain string: - domain to query - Returns: - entities list: - list of entities of the specified domain - """ - # Get all user constraints - constraint = self.state[domain.lower()]['semi'] - constraint = {k: i for k, i in constraint.items() if i and i not in ['dontcare', "do n't care", "do not care"]} - - # Remove constraints for which the uncertainty is high - if self.confidence_scores is not None and self.use_confidence_scores and self.thresholds != None: - # Collect threshold values for each domain-slot pair - thres = self.thresholds.get(domain.lower(), {}) - thres = {k: thres.get(k, 0.05) for k in constraint} - # Get confidence scores for each constraint - probs = self.confidence_scores.get(domain.lower(), {}) - probs = {k: probs.get(k, {}).get('inform', 1.0) - for k in constraint} - - # Filter out constraints for which confidence is lower than threshold - constraint = {k: i for k, i in constraint.items() - if probs[k] >= thres[k]} - - return self.db.query(domain.lower(), constraint.items()) - - # Add thresholds for db_queries - def setup_uncertain_query(self, thresholds): - self.use_confidence_scores = True - self.thresholds = thresholds - logging.info('DB Search uncertainty activated.') - - def vectorize_user_act_confidence_scores(self, state, opp_action): - """Return confidence scores for the user actions""" - opp_act_vec = np.zeros(self.da_opp_dim) - for da in self.opp2vec: - domain, intent, slot, value = da.split('-') - if domain.lower() in state['belief_state_probs']: - # Map slot name to match user actions - slot = REF_SYS_DA[domain].get( - slot, slot) if domain in REF_SYS_DA else slot - slot = slot if slot else 'none' - slot = SLOT_MAP.get(slot, slot) - domain = domain.lower() - - if slot in state['belief_state_probs'][domain]: - prob = state['belief_state_probs'][domain][slot] - elif slot.lower() in state['belief_state_probs'][domain]: - prob = state['belief_state_probs'][domain][slot.lower()] - else: - prob = {} - - intent = intent.lower() - if intent in prob: - prob = float(prob[intent]) - elif da in opp_action: - prob = 1.0 - else: - prob = 0.0 - elif da in opp_action: - prob = 1.0 - else: - prob = 0.0 - opp_act_vec[self.opp2vec[da]] = prob - - return opp_act_vec - - def state_vectorize(self, state): - """vectorize a state - - Args: - state (dict): - Dialog state - action (tuple): - Dialog act - Returns: - state_vec (np.array): - Dialog state vector - """ - self.state = state['belief_state'] - self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None - domain_active_dict = {} - for domain in self.belief_domains: - domain_active_dict[domain] = False - - # when character is sys, to help query database when da is booking-book - # update current domain according to user action - if self.character == 'sys': - action = state['user_action'] - for intent, domain, slot, value in action: - domain_active_dict[domain] = True - - action = state['user_action'] if self.character == 'sys' else state['system_action'] - opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) - if 'belief_state_probs' in state and self.use_confidence_scores: - opp_act_vec = self.vectorize_user_act_confidence_scores( - state, opp_action) - else: - opp_act_vec = np.zeros(self.da_opp_dim) - for da in opp_action: - if da in self.opp2vec: - prob = 1.0 - opp_act_vec[self.opp2vec[da]] = prob - - action = state['system_action'] if self.character == 'sys' else state['user_action'] - action = delexicalize_da(action, self.requestable) - action = flat_da(action) - last_act_vec = np.zeros(self.da_dim) - for da in action: - if da in self.act2vec: - last_act_vec[self.act2vec[da]] = 1. - - belief_state = np.zeros(self.belief_state_dim) - i = 0 - for domain in self.belief_domains: - if self.use_confidence_scores and 'belief_state_probs' in state: - for slot in state['belief_state'][domain.lower()]['semi']: - if slot in state['belief_state_probs'][domain.lower()]: - prob = state['belief_state_probs'][domain.lower() - ][slot] - prob = prob['inform'] if 'inform' in prob else None - if prob: - belief_state[i] = float(prob) - i += 1 - else: - for slot, value in state['belief_state'][domain.lower()]['semi'].items(): - if value and value != 'not mentioned': - belief_state[i] = 1. - i += 1 - if 'active_domains' in state: - domain_active = state['active_domains'][domain.lower()] - domain_active_dict[domain] = domain_active - else: - if [slot for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value]: - domain_active_dict[domain] = True - - # Add knowledge and/or total uncertainty to the belief state - if self.use_entropy and 'entropy' in state: - for domain in self.belief_domains: - for slot in state['belief_state'][domain.lower()]['semi']: - if slot in state['entropy'][domain.lower()]: - belief_state[i] = float( - state['entropy'][domain.lower()][slot]) - i += 1 - - if self.use_mutual_info and 'mutual_information' in state: - for domain in self.belief_domains: - for slot in state['belief_state'][domain.lower()]['semi']: - if slot in state['mutual_information'][domain.lower()]: - belief_state[i] = float( - state['mutual_information'][domain.lower()][slot]) - i += 1 - - book = np.zeros(len(self.db_domains)) - for i, domain in enumerate(self.db_domains): - if state['belief_state'][domain.lower()]['book']['booked']: - book[i] = 1. - - degree, number_entities_dict = self.pointer() - - final = 1. if state['terminated'] else 0. - - state_vec = np.r_[opp_act_vec, last_act_vec, - belief_state, book, degree, final] - assert len(state_vec) == self.state_dim - - if self.use_mask is not None: - # None covers the case for policies that don't use masking at all, so do not expect an output "state_vec, mask" - if self.use_mask: - domain_mask = self.compute_domain_mask(domain_active_dict) - entity_mask = self.compute_entity_mask(number_entities_dict) - general_mask = self.compute_general_mask() - mask = domain_mask + entity_mask + general_mask - for i in range(self.da_dim): - mask[i] = -int(bool(mask[i])) * sys.maxsize - else: - mask = np.zeros(self.da_dim) - - return state_vec, mask - else: - return state_vec diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py index 2e073669..c2f6258f 100644 --- a/convlab/policy/vector/vector_nodes.py +++ b/convlab/policy/vector/vector_nodes.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- import sys import numpy as np -import logging - from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da from .vector_base import VectorBase @@ -10,11 +8,9 @@ from .vector_base import VectorBase class VectorNodes(VectorBase): def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True, - seed=0, filter_state=True): + seed=0): super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) - self.filter_state = filter_state - logging.info(f"We filter state by active domains: {self.filter_state}") def get_state_dim(self): self.belief_state_dim = 0 @@ -60,16 +56,9 @@ class VectorNodes(VectorBase): self.get_user_act_feature(state) self.get_sys_act_feature(state) domain_active_dict = self.get_user_goal_feature(state, domain_active_dict) + number_entities_dict = self.get_db_features() self.get_general_features(state, domain_active_dict) - if self.db is not None: - number_entities_dict = self.get_db_features() - else: - number_entities_dict = None - - if self.filter_state: - self.kg_info = self.filter_inactive_domains(domain_active_dict) - if self.use_mask: mask = self.get_mask(domain_active_dict, number_entities_dict) for i in range(self.da_dim): @@ -100,15 +89,13 @@ class VectorNodes(VectorBase): feature_type = 'user goal' for domain in self.belief_domains: - # the if case is needed because SGD only saves the dialogue state info for active domains - if domain in state['belief_state']: - for slot, value in state['belief_state'][domain].items(): - description = f"user goal-{domain}-{slot}".lower() - value = 1.0 if (value and value != "not mentioned") else 0.0 - self.add_graph_node(domain, feature_type, description, value) - - if [slot for slot, value in state['belief_state'][domain].items() if value]: - domain_active_dict[domain] = True + for slot, value in state['belief_state'][domain].items(): + description = f"user goal-{domain}-{slot}".lower() + value = 1.0 if (value and value != "not mentioned") else 0.0 + self.add_graph_node(domain, feature_type, description, value) + + if [slot for slot, value in state['belief_state'][domain].items() if value]: + domain_active_dict[domain] = True return domain_active_dict def get_sys_act_feature(self, state): @@ -141,12 +128,11 @@ class VectorNodes(VectorBase): def get_general_features(self, state, domain_active_dict): feature_type = 'general' - if 'booked' in state: - for i, domain in enumerate(self.db_domains): - if domain in state['booked']: - description = f"general-{domain}-booked".lower() - value = 1.0 if state['booked'][domain] else 0.0 - self.add_graph_node(domain, feature_type, description, value) + for i, domain in enumerate(self.db_domains): + if domain in state['booked']: + description = f"general-{domain}-booked".lower() + value = 1.0 if state['booked'][domain] else 0.0 + self.add_graph_node(domain, feature_type, description, value) for domain in self.domains: if domain == 'general': @@ -154,17 +140,3 @@ class VectorNodes(VectorBase): value = 1.0 if domain_active_dict[domain] else 0 description = f"general-{domain}".lower() self.add_graph_node(domain, feature_type, description, value) - - def filter_inactive_domains(self, domain_active_dict): - - kg_filtered = [] - for node in self.kg_info: - domain = node['domain'] - if domain in domain_active_dict: - if domain_active_dict[domain]: - kg_filtered.append(node) - else: - kg_filtered.append(node) - - return kg_filtered - diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py new file mode 100644 index 00000000..7da05449 --- /dev/null +++ b/convlab/policy/vector/vector_uncertainty.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +import sys +import numpy as np +import logging + +from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da +from convlab.policy.vector.vector_binary import VectorBinary + + +class VectorUncertainty(VectorBinary): + """Vectorise state and state uncertainty predictions""" + + def __init__(self, + dataset_name: str = 'multiwoz21', + character: str = 'sys', + use_masking: bool = False, + manually_add_entity_names: bool = True, + seed: str = 0, + use_confidence_scores: bool = True, + confidence_thresholds: dict = None, + use_state_total_uncertainty: bool = False, + use_state_knowledge_uncertainty: bool = False): + """ + Args: + dataset_name: Name of environment dataset + character: Character of the agent (sys/usr) + use_masking: If true certain actions are masked during devectorisation + manually_add_entity_names: If true inform entity name actions are manually added + seed: Seed + use_confidence_scores: If true confidence scores are used in state vectorisation + confidence_thresholds: If true confidence thresholds are used in database querying + use_state_total_uncertainty: If true state entropy is added to the state vector + use_state_knowledge_uncertainty: If true state mutual information is added to the state vector + """ + + self.use_confidence_scores = use_confidence_scores + self.use_state_total_uncertainty = use_state_total_uncertainty + self.use_state_knowledge_uncertainty = use_state_knowledge_uncertainty + if confidence_thresholds is not None: + self.setup_uncertain_query(confidence_thresholds) + + super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) + + def get_state_dim(self): + self.belief_state_dim = 0 + + for domain in self.ontology['state']: + for slot in self.ontology['state'][domain]: + # Dim 1 - indicator/confidence score + # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc) + slot_dim = 1 if not self.use_state_total_uncertainty else 2 + slot_dim += 1 if self.use_state_knowledge_uncertainty else 0 + self.belief_state_dim += slot_dim + + self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \ + len(self.db_domains) + 6 * len(self.db_domains) + 1 + + # Add thresholds for db_queries + def setup_uncertain_query(self, confidence_thresholds): + self.use_confidence_scores = True + self.confidence_thresholds = confidence_thresholds + logging.info('DB Search uncertainty activated.') + + def dbquery_domain(self, domain): + """ + query entities of specified domain + Args: + domain string: + domain to query + Returns: + entities list: + list of entities of the specified domain + """ + # Get all user constraints + constraints = {slot: value for slot, value in self.state[domain].items() + if slot and value not in ['dontcare', + "do n't care", "do not care"]} if domain in self.state else dict() + + # Remove constraints for which the uncertainty is high + if self.confidence_scores is not None and self.use_confidence_scores and self.confidence_thresholds is not None: + # Collect threshold values for each domain-slot pair + threshold = self.confidence_thresholds.get(domain, dict()) + threshold = {slot: threshold.get(slot, 0.05) for slot in constraints} + # Get confidence scores for each constraint + probs = self.confidence_scores.get(domain, dict()) + probs = {slot: probs.get(slot, {}).get('inform', 1.0) for slot in constraints} + + # Filter out constraints for which confidence is lower than threshold + constraints = {slot: value for slot, value in constraints.items() if probs[slot] >= threshold[slot]} + + return self.db.query(domain, constraints.items(), topk=10) + + def vectorize_user_act(self, state): + """Return confidence scores for the user actions""" + self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None + action = state['user_action'] if self.character == 'sys' else state['system_action'] + opp_action = delexicalize_da(action, self.requestable) + opp_action = flat_da(opp_action) + opp_act_vec = np.zeros(self.da_opp_dim) + for da in opp_action: + if da in self.opp2vec: + if 'belief_state_probs' in state and self.use_confidence_scores: + domain, intent, slot, value = da.split('-') + if domain in state['belief_state_probs']: + slot = slot if slot else 'none' + if slot in state['belief_state_probs'][domain]: + prob = state['belief_state_probs'][domain][slot] + elif slot.lower() in state['belief_state_probs'][domain]: + prob = state['belief_state_probs'][domain][slot.lower()] + else: + prob = dict() + + if intent in prob: + prob = float(prob[intent]) + else: + prob = 1.0 + else: + prob = 1.0 + else: + prob = 1.0 + opp_act_vec[self.opp2vec[da]] = prob + + return opp_act_vec + + def vectorize_belief_state(self, state, domain_active_dict): + belief_state = np.zeros(self.belief_state_dim) + i = 0 + for domain in self.belief_domains: + if self.use_confidence_scores and 'belief_state_probs' in state: + for slot in state['belief_state'][domain]: + prob = None + if slot in state['belief_state_probs'][domain]: + prob = state['belief_state_probs'][domain][slot] + prob = prob['inform'] if 'inform' in prob else None + if prob: + belief_state[i] = float(prob) + i += 1 + else: + for slot, value in state['belief_state'][domain].items(): + if value and value != 'not mentioned': + belief_state[i] = 1. + i += 1 + + if 'active_domains' in state: + domain_active = state['active_domains'][domain] + domain_active_dict[domain] = domain_active + else: + if [slot for slot, value in state['belief_state'][domain].items() if value]: + domain_active_dict[domain] = True + + # Add knowledge and/or total uncertainty to the belief state + if self.use_state_total_uncertainty and 'entropy' in state: + for domain in self.belief_domains: + for slot in state['belief_state'][domain]: + if slot in state['entropy'][domain]: + belief_state[i] = float(state['entropy'][domain][slot]) + i += 1 + + if self.use_state_knowledge_uncertainty and 'mutual_information' in state: + for domain in self.belief_domains: + for slot in state['belief_state'][domain]: + if slot in state['mutual_information'][domain]: + belief_state[i] = float(state['mutual_information'][domain][slot]) + i += 1 + + return belief_state, domain_active_dict diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py index aad6c4cd..c79c6f0d 100644 --- a/convlab/util/custom_util.py +++ b/convlab/util/custom_util.py @@ -21,7 +21,6 @@ from convlab.evaluator.multiwoz_eval import MultiWozEvaluator from convlab.util import load_dataset import shutil -import signal slot_mapping = {"pricerange": "price range", "post": "postcode", "arriveBy": "arrive by", "leaveAt": "leave at", @@ -35,22 +34,6 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = DEVICE -class timeout: - def __init__(self, seconds=10, error_message='Timeout'): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - - class NumpyEncoder(json.JSONEncoder): """ Special json encoder for numpy types """ @@ -171,20 +154,20 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do if conf['model']['process_num'] == 1: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ - select_acts, offer_acts, recommend_acts = evaluate(sess, + select_acts, offer_acts = evaluate(sess, num_dialogues=conf['model']['num_eval_dialogues'], sys_semantic_to_usr=conf['model'][ 'sys_semantic_to_usr'], save_flag=save_eval, save_path=log_save_path, goals=goals) - total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts + total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts else: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ - select_acts, offer_acts, recommend_acts = \ + select_acts, offer_acts = \ evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])), conf['model']['process_num'], goals) - total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts + total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts task_success_gathered = {} for task_dict in task_success: @@ -195,40 +178,22 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do task_success = task_success_gathered policy_sys.is_train = True - - mean_complete, err_complete = np.average(complete_rate), np.std(complete_rate) / np.sqrt(len(complete_rate)) - mean_success, err_success = np.average(success_rate), np.std(success_rate) / np.sqrt(len(success_rate)) - mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std(success_rate_strict) / np.sqrt(len(success_rate_strict)) - mean_return, err_return = np.average(avg_return), np.std(avg_return) / np.sqrt(len(avg_return)) - mean_turns, err_turns = np.average(turns), np.std(turns) / np.sqrt(len(turns)) - mean_actions, err_actions = np.average(avg_actions), np.std(avg_actions) / np.sqrt(len(avg_actions)) - - logging.info(f"Complete: {mean_complete}+-{round(err_complete, 2)}, " - f"Success: {mean_success}+-{round(err_success, 2)}, " - f"Success strict: {mean_success_strict}+-{round(err_success_strict, 2)}, " - f"Average Return: {mean_return}+-{round(err_return, 2)}, " - f"Turns: {mean_turns}+-{round(err_turns, 2)}, " - f"Average Actions: {mean_actions}+-{round(err_actions, 2)}, " + logging.info(f"Complete: {complete_rate}, Success: {success_rate}, Success strict: {success_rate_strict}, " + f"Average Return: {avg_return}, Turns: {turns}, Average Actions: {avg_actions}, " f"Book Actions: {book_acts/total_acts}, Inform Actions: {inform_acts/total_acts}, " f"Request Actions: {request_acts/total_acts}, Select Actions: {select_acts/total_acts}, " - f"Offer Actions: {offer_acts/total_acts}, Recommend Actions: {recommend_acts/total_acts}") + f"Offer Actions: {offer_acts/total_acts}") for key in task_success: logging.info( f"{key}: Num: {len(task_success[key])} Success: {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}") - return {"complete_rate": mean_complete, - "success_rate": mean_success, - "success_rate_strict": mean_success_strict, - "avg_return": mean_return, - "turns": mean_turns, - "avg_actions": mean_actions, - "book_acts": book_acts/total_acts, - "inform_acts": inform_acts/total_acts, - "request_acts": request_acts/total_acts, - "select_acts": select_acts/total_acts, - "offer_acts": offer_acts/total_acts, - "recommend_acts": recommend_acts/total_acts} + return {"complete_rate": complete_rate, + "success_rate": success_rate, + "success_rate_strict": success_rate_strict, + "avg_return": avg_return, + "turns": turns, + "avg_actions": avg_actions} def env_config(conf, policy_sys, check_book_constraints=True): @@ -240,6 +205,14 @@ def env_config(conf, policy_sys, check_book_constraints=True): policy_usr = conf['policy_usr_activated'] usr_nlg = conf['usr_nlg_activated'] + # Setup uncertainty thresholding + if dst_sys: + try: + if dst_sys.return_confidence_scores: + policy_sys.vector.setup_uncertain_query(dst_sys.confidence_thresholds) + except: + logging.info('Uncertainty threshold not set.') + simulator = PipelineAgent(nlu_usr, dst_usr, policy_usr, usr_nlg, 'user') system_pipeline = PipelineAgent(nlu_sys, dst_sys, policy_sys, sys_nlg, 'sys', return_semantic_acts=conf['model']['sys_semantic_to_usr']) @@ -321,7 +294,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False task_success = {'All_user_sim': [], 'All_evaluator': [], "All_evaluator_strict": [], 'total_return': [], 'turns': [], 'avg_actions': [], 'total_booking_acts': [], 'total_inform_acts': [], 'total_request_acts': [], - 'total_select_acts': [], 'total_offer_acts': [], 'total_recommend_acts': []} + 'total_select_acts': [], 'total_offer_acts': []} dial_count = 0 for seed in range(1000, 1000 + num_dialogues): set_seed(seed) @@ -337,7 +310,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False request = 0 select = 0 offer = 0 - recommend = 0 # this 40 represents the max turn of dialogue for i in range(40): sys_response, user_response, session_over, reward = sess.next_turn( @@ -360,8 +332,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False select += 1 if intent.lower() == 'offerbook': offer += 1 - if intent.lower() == 'recommend': - recommend += 1 avg_actions += len(acts) turn_counter += 1 turns += 1 @@ -398,8 +368,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False task_success['total_request_acts'].append(request) task_success['total_select_acts'].append(select) task_success['total_offer_acts'].append(offer) - task_success['total_offer_acts'].append(offer) - task_success['total_recommend_acts'].append(recommend) # print(agent_sys.agent_saves) eval_save['Conversation {}'.format(str(dial_count))] = [ @@ -415,11 +383,12 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False save_file.close() # save dialogue_info and clear mem - return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \ - task_success['total_return'], task_success['turns'], task_success['avg_actions'], task_success, \ + return np.average(task_success['All_user_sim']), np.average(task_success['All_evaluator']), \ + np.average(task_success['All_evaluator_strict']), np.average(task_success['total_return']), \ + np.average(task_success['turns']), np.average(task_success['avg_actions']), task_success, \ np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \ np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \ - np.average(task_success['total_offer_acts']), np.average(task_success['total_recommend_acts']) + np.average(task_success['total_offer_acts']) def model_downloader(download_dir, model_path): @@ -570,18 +539,21 @@ def get_config(filepath, args) -> dict: vec_name = [model for model in conf['vectorizer_sys']] vec_name = vec_name[0] if vec_name else None if dst_name and 'setsumbt' in dst_name.lower(): - if 'get_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']: - conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = conf['dst_sys'][dst_name]['ini_params']['get_confidence_scores'] + if 'return_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']: + param = conf['dst_sys'][dst_name]['ini_params']['return_confidence_scores'] + conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = param else: conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = False - if 'return_mutual_info' in conf['dst_sys'][dst_name]['ini_params']: - conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = conf['dst_sys'][dst_name]['ini_params']['return_mutual_info'] + if 'return_belief_state_mutual_info' in conf['dst_sys'][dst_name]['ini_params']: + param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_mutual_info'] + conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = param else: - conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = False - if 'return_entropy' in conf['dst_sys'][dst_name]['ini_params']: - conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = conf['dst_sys'][dst_name]['ini_params']['return_entropy'] + conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = False + if 'return_belief_state_entropy' in conf['dst_sys'][dst_name]['ini_params']: + param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_entropy'] + conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = param else: - conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = False + conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = False from convlab.nlu import NLU from convlab.dst import DST @@ -610,8 +582,7 @@ def get_config(filepath, args) -> dict: cls_path = infos.get('class_path', '') cls = map_class(cls_path) conf[unit + '_class'] = cls - conf[unit + '_activated'] = conf[unit + - '_class'](**conf[unit][model]['ini_params']) + conf[unit + '_activated'] = conf[unit + '_class'](**conf[unit][model]['ini_params']) print("Loaded " + model + " for " + unit) return conf -- GitLab