diff --git a/.gitignore b/.gitignore index 136bbce75fa69edcf426d6a0ef9aaf0e85e70662..c879a524e95127b0e91aeb95734c4bbfade592fc 100644 --- a/.gitignore +++ b/.gitignore @@ -102,8 +102,6 @@ convlab/dst/trade/multiwoz_config/ convlab/deploy/bert_multiwoz_all.zip convlab/deploy/templates/dialog_eg.html test.py -convlab/dst/setsumbt/models/* - *budget*.pdf *convlab/policy/vector/action_dicts diff --git a/convlab/__init__.py b/convlab/__init__.py index 4b14a707ee43021fc97936c5f64d95c2cb7aa659..43cc4bc75aab6c4576a4097a549079e521d20e20 100755 --- a/convlab/__init__.py +++ b/convlab/__init__.py @@ -1,3 +1,4 @@ + import os from convlab.nlu import NLU from convlab.dst import DST diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index 6216eaaac9fb81615f903f048d6d85766ce663c5..508d06b56080f72146d65a98286a0fd099f9b4c2 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -50,10 +50,10 @@ class Environment(): observation) if self.sys_nlu else observation self.sys_dst.state['user_action'] = dialog_act state = self.sys_dst.update(dialog_act) - state = deepcopy(state) + self.sys_dst.state['history'].append(["sys", model_response]) + self.sys_dst.state['history'].append(["usr", observation]) - state['history'].append(["sys", model_response]) - state['history'].append(["usr", observation]) + state = deepcopy(state) terminated = self.usr.is_terminated() diff --git a/convlab/dst/evaluate_unified_datasets.py b/convlab/dst/evaluate_unified_datasets.py index 4b4accf1bb78b1580c8235bebde1160ed771bd59..d4e0720dc02bf90efcfebb1780630211f0722f7f 100644 --- a/convlab/dst/evaluate_unified_datasets.py +++ b/convlab/dst/evaluate_unified_datasets.py @@ -1,73 +1,42 @@ import json from pprint import pprint -import numpy as np - def evaluate(predict_result): predict_result = json.load(open(predict_result)) - metrics = {'TP': 0, 'FP': 0, 'FN': 0} - jga = [] - aga = [] - fga = [] - l2_err = [] - lamb = [0.25, 0.5, 0.75, 1.0] + metrics = {'TP':0, 'FP':0, 'FN':0} + acc = [] for sample in predict_result: pred_state = sample['predictions']['state'] gold_state = sample['state'] - utt_idx = sample['utt_idx'] - - predicts = {(domain, slot, ''.join(value.split()).lower()) for domain in pred_state - for slot, value in pred_state[domain].items() if value} - labels = {(domain, slot, ''.join(value.split()).lower()) for domain in gold_state - for slot, value in gold_state[domain].items() if value} - predicts, labels = sorted(list(predicts)), sorted(list(labels)) - - # Flexible goal accuracy (see https://arxiv.org/pdf/2204.03375.pdf) - weighted_err = [1] * len(lamb) - if utt_idx == 0: - err_idx = -999999 - predicts_prev = [] - labels_prev = [] - - if predicts != labels: - err_idx = utt_idx - weighted_err = [0] * len(lamb) - else: - if predicts != labels: - predicts_changes = [ele for ele in predicts if ele not in predicts_prev] - labels_changes = [ele for ele in labels if ele not in labels_prev] - - new_predict_err = [ele for ele in predicts_changes if ele not in labels] - new_predict_miss = [ele for ele in labels_changes if ele not in predicts] - - if new_predict_err or new_predict_miss: - weighted_err = [0] * len(lamb) - err_idx = utt_idx + flag = True + for domain in gold_state: + for slot, values in gold_state[domain].items(): + if domain not in pred_state or slot not in pred_state[domain]: + predict_values = '' else: - err_age = utt_idx - err_idx - weighted_err = [1 - np.exp(-l * err_age) for l in lamb] - predicts_prev = predicts - labels_prev = labels - fga.append(weighted_err) - - _l2 = 2.0 * len([ele for ele in labels if ele not in predicts]) - _l2 += 2.0 * len([ele for ele in predicts if ele not in labels]) - l2_err.append(_l2) + predict_values = ''.join(pred_state[domain][slot].split()).lower() + if len(values) > 0: + if len(predict_values) > 0: + values = [''.join(value.split()).lower() for value in values.split('|')] + predict_values = [''.join(value.split()).lower() for value in predict_values.split('|')] + if any([value in values for value in predict_values]): + metrics['TP'] += 1 + else: + metrics['FP'] += 1 + metrics['FN'] += 1 + flag = False + else: + metrics['FN'] += 1 + flag = False + else: + if len(predict_values) > 0: + metrics['FP'] += 1 + flag = False - flag = True - for ele in predicts: - if ele in labels: - metrics['TP'] += 1 - else: - metrics['FP'] += 1 - for ele in labels: - if ele not in predicts: - metrics['FN'] += 1 - flag &= (predicts == labels) - jga.append(flag) + acc.append(flag) TP = metrics.pop('TP') FP = metrics.pop('FP') @@ -78,10 +47,7 @@ def evaluate(predict_result): metrics[f'slot_f1'] = f1 metrics[f'slot_precision'] = precision metrics[f'slot_recall'] = recall - metrics['joint_goal_accuracy'] = sum(jga) / len(jga) - for i, l in enumerate(lamb): - metrics[f'flexible_goal_accuracy_{l}'] = sum(weighted_err[i] for weighted_err in fga)/len(fga) - metrics['l2_norm_error'] = sum(l2_err) / len(l2_err) + metrics['accuracy'] = sum(acc)/len(acc) return metrics diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py index 9492faa9c9a20d1c476819bb995900ca71d56607..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/convlab/dst/setsumbt/__init__.py +++ b/convlab/dst/setsumbt/__init__.py @@ -1 +0,0 @@ -from convlab.dst.setsumbt.tracker import SetSUMBTTracker \ No newline at end of file diff --git a/convlab/dst/setsumbt/calibration_plots.py b/convlab/dst/setsumbt/calibration_plots.py index a41f280d3349164a2a67333d0ab176a37cbe50ea..379057e6411082d466b10a81027bb57e4131bb9b 100644 --- a/convlab/dst/setsumbt/calibration_plots.py +++ b/convlab/dst/setsumbt/calibration_plots.py @@ -35,7 +35,7 @@ def main(): path = args.data_dir models = os.listdir(path) - models = [os.path.join(path, model, 'test.predictions') for model in models] + models = [os.path.join(path, model, 'test.belief') for model in models] fig = plt.figure(figsize=(14,8)) font=20 @@ -56,16 +56,16 @@ def main(): def get_calibration(path, device, n_bins=10, temperature=1.00): - probs = torch.load(path, map_location=device) - y_true = probs['state_labels'] - probs = probs['belief_states'] + logits = torch.load(path, map_location=device) + y_true = logits['labels'] + logits = logits['belief_states'] - y_pred = {slot: probs[slot].reshape(-1, probs[slot].size(-1)).argmax(-1) for slot in probs} + y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits} goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred} goal_acc = sum([goal_acc[slot] for slot in goal_acc]) goal_acc = (goal_acc == len(y_true)).int() - scores = [probs[slot].reshape(-1, probs[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in probs] + scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits] scores = torch.cat(scores, 0).min(0)[0] step = 1.0 / float(n_bins) diff --git a/convlab/dst/setsumbt/dataset/__init__.py b/convlab/dst/setsumbt/dataset/__init__.py deleted file mode 100644 index 17b1f93b3b39f95827cf6c09e8826383cd00b805..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from convlab.dst.setsumbt.dataset.unified_format import get_dataloader, change_batch_size -from convlab.dst.setsumbt.dataset.ontology import get_slot_candidate_embeddings diff --git a/convlab/dst/setsumbt/dataset/ontology.py b/convlab/dst/setsumbt/dataset/ontology.py deleted file mode 100644 index 81e207805c47dcde8b7194ff5f7fac8ff10b1c2f..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/ontology.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Create Ontology Embeddings""" - -import json -import os -import random -from copy import deepcopy - -import torch -import numpy as np - - -def set_seed(args): - """ - Set random seeds - - Args: - args (Arguments class): Arguments class containing seed and number of gpus to use - """ - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - - -def encode_candidates(candidates: list, args, tokenizer, embedding_model) -> torch.tensor: - """ - Embed candidates - - Args: - candidates (list): List of candidate descriptions - args (argument class): Runtime arguments - tokenizer (transformers Tokenizer): Tokenizer for the embedding_model - embedding_model (transformer Model): Transformer model for embedding candidate descriptions - - Returns: - feats (torch.tensor): Embeddings of the candidate descriptions - """ - # Tokenize candidate descriptions - feats = [tokenizer.encode_plus(val, add_special_tokens=True,max_length=args.max_candidate_len, - padding='max_length', truncation='longest_first') - for val in candidates] - - # Encode tokenized descriptions - with torch.no_grad(): - feats = {key: torch.tensor([f[key] for f in feats]).to(embedding_model.device) for key in feats[0]} - embedded_feats = embedding_model(**feats) # [num_candidates, max_candidate_len, hidden_dim] - - # Reduce/pool descriptions embeddings if required - if args.set_similarity: - feats = embedded_feats.last_hidden_state.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim] - elif args.candidate_pooling == 'cls': - feats = embedded_feats.pooler_output.detach().cpu() # [num_candidates, hidden_dim] - elif args.candidate_pooling == "mean": - feats = embedded_feats.last_hidden_state.detach().cpu() - feats = feats.sum(1) - feats = torch.nn.functional.layer_norm(feats, feats.size()) - feats = feats.detach().cpu() # [num_candidates, hidden_dim] - - return feats - - -def get_slot_candidate_embeddings(ontology: dict, set_type: str, args, tokenizer, embedding_model, save_to_file=True): - """ - Get embeddings for slots and candidates - - Args: - ontology (dict): Dictionary of domain-slot pair descriptions and possible value sets - set_type (str): Subset of the dataset being used (train/validation/test) - args (argument class): Runtime arguments - tokenizer (transformers Tokenizer): Tokenizer for the embedding_model - embedding_model (transformer Model): Transormer model for embedding candidate descriptions - save_to_file (bool): Indication of whether to save information to file - - Returns: - slots (dict): domain-slot description embeddings, candidate embeddings and requestable flag for each domain-slot - """ - # Set model to eval mode - embedding_model.eval() - - slots = dict() - for domain, subset in ontology.items(): - for slot, slot_info in subset.items(): - # Get description or use "domain-slot" - if args.use_descriptions: - desc = slot_info['description'] - else: - desc = f"{domain}-{slot}" - - # Encode domain-slot pair description - slot_emb = encode_candidates([desc], args, tokenizer, embedding_model)[0] - - # Obtain possible value set and discard requestable value - values = deepcopy(slot_info['possible_values']) - is_requestable = False - if '?' in values: - is_requestable = True - values.remove('?') - - # Encode value candidates - if values: - feats = encode_candidates(values, args, tokenizer, embedding_model) - else: - feats = None - - # Store domain-slot description embeddings, candidate embeddings and requestabke flag for each domain-slot - slots[f"{domain}-{slot}"] = (slot_emb, feats, is_requestable) - - # Dump tensors and ontology for use in training and evaluation - if save_to_file: - writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type) - torch.save(slots, writer) - - writer = open(os.path.join(args.output_dir, 'database', '%s.json' % set_type), 'w') - json.dump(ontology, writer, indent=2) - writer.close() - - return slots diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py deleted file mode 100644 index 83d8f776c135ac596fed322cd39c1c31c456a481..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/unified_format.py +++ /dev/null @@ -1,385 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convlab3 Unified Format Dialogue Datasets""" - -from copy import deepcopy - -import torch -import transformers -from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler -from transformers.tokenization_utils import PreTrainedTokenizer -from tqdm import tqdm - -from convlab.util import load_dataset -from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, - get_values_from_data, ontology_add_requestable_slots, - get_requestable_slots, load_dst_data, extract_dialogues, - combine_value_sets) - -transformers.logging.set_verbosity_error() - - -def convert_examples_to_features(data: list, - ontology: dict, - tokenizer: PreTrainedTokenizer, - max_turns: int = 12, - max_seq_len: int = 64) -> dict: - """ - Convert dialogue examples to model input features and labels - - Args: - data (list): List of all extracted dialogues - ontology (dict): Ontology dictionary containing slots, slot descriptions and - possible value sets including requests - tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used - max_turns (int): Maximum numbers of turns in a dialogue - max_seq_len (int): Maximum number of tokens in a dialogue turn - - Returns: - features (dict): All inputs and labels required to train the model - """ - features = dict() - ontology = deepcopy(ontology) - - # Get encoder input for system, user utterance pairs - input_feats = [] - for dial in tqdm(data): - dial_feats = [] - for turn in dial: - if len(turn['system_utterance']) == 0: - usr = turn['user_utterance'] - dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens=True, - max_length=max_seq_len, padding='max_length', - truncation='longest_first')) - else: - usr = turn['user_utterance'] - sys = turn['system_utterance'] - dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True, - max_length=max_seq_len, padding='max_length', - truncation='longest_first')) - # Truncate - if len(dial_feats) >= max_turns: - break - input_feats.append(dial_feats) - del dial_feats - - # Perform turn level padding - input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) - for dial in input_feats] - if 'token_type_ids' in input_feats[0][0]: - token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) - for dial in input_feats] - else: - token_type_ids = None - if 'attention_mask' in input_feats[0][0]: - attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) - for dial in input_feats] - else: - attention_mask = None - del input_feats - - # Create torch data tensors - features['input_ids'] = torch.tensor(input_ids) - features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None - features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None - del input_ids, token_type_ids, attention_mask - - # Extract all informable and requestable slots from the ontology - informable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain] - if ontology[domain][slot]['possible_values'] - and ontology[domain][slot]['possible_values'] != ['?']] - requestable_slots = [f"{domain}-{slot}" for domain in ontology for slot in ontology[domain] - if '?' in ontology[domain][slot]['possible_values']] - for slot in requestable_slots: - domain, slot = slot.split('-', 1) - ontology[domain][slot]['possible_values'].remove('?') - - # Extract a list of domains from the ontology slots - domains = list(set(informable_slots + requestable_slots)) - domains = list(set([slot.split('-', 1)[0] for slot in domains])) - - # Create slot labels - for domslot in tqdm(informable_slots): - labels = [] - for dial in data: - labs = [] - for turn in dial: - value = [v for d, substate in turn['state'].items() for s, v in substate.items() - if f'{d}-{s}' == domslot] - value = value[0] if value else 'none' - domain, slot = domslot.split('-', 1) - if value in ontology[domain][slot]['possible_values']: - value = ontology[domain][slot]['possible_values'].index(value) - else: - value = -1 # If value is not in ontology then we do not penalise the model - labs.append(value) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['state_labels-' + domslot] = labels - - # Create requestable slot labels - for domslot in tqdm(requestable_slots): - labels = [] - for dial in data: - labs = [] - for turn in dial: - domain, slot = domslot.split('-', 1) - acts = [act['intent'] for act in turn['dialogue_acts'] - if act['domain'] == domain and act['slot'] == slot] - if acts: - act_ = acts[0] - if act_ == 'request': - labs.append(1) - else: - labs.append(0) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['request_labels-' + domslot] = labels - - # General act labels (1-goodbye, 2-thank you) - labels = [] - for dial in tqdm(data): - labs = [] - for turn in dial: - acts = [act['intent'] for act in turn['dialogue_acts'] if act['intent'] in ['bye', 'thank']] - if acts: - if 'bye' in acts: - labs.append(1) - else: - labs.append(2) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['general_act_labels'] = labels - - # Create active domain labels - for domain in tqdm(domains): - labels = [] - for dial in data: - labs = [] - for turn in dial: - if domain in turn['active_domains']: - labs.append(1) - else: - labs.append(0) - if len(labs) >= max_turns: - break - labs = labs + [-1] * (max_turns - len(labs)) - labels.append(labs) - - labels = torch.tensor(labels) - features['active_domain_labels-' + domain] = labels - - del labels - - return features - - -class UnifiedFormatDataset(Dataset): - """ - Class for preprocessing, and storing data easily from the Convlab3 unified format. - - Attributes: - dataset_dict (dict): Dictionary containing all the data in dataset - ontology (dict): Set of all domain-slot-value triplets in the ontology of the model - features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model - """ - def __init__(self, - dataset_name: str, - set_type: str, - tokenizer: PreTrainedTokenizer, - max_turns: int = 12, - max_seq_len: int = 64, - train_ratio: float = 1.0, - seed: int = 0, - data: dict = None, - ontology: dict = None): - """ - Args: - dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +) - set_type (str): Subset of the dataset to load (train, validation or test) - tokenizer (transformers tokenizer): Tokenizer for the encoder model used - max_turns (int): Maximum numbers of turns in a dialogue - max_seq_len (int): Maximum number of tokens in a dialogue turn - train_ratio (float): Fraction of training data to use during training - seed (int): Seed governing random order of ids for subsampling - data (dict): Dataset features for loading from dict - ontology (dict): Ontology dict for loading from dict - """ - if data is not None: - self.ontology = ontology - self.features = data - else: - if '+' in dataset_name: - dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')] - else: - dataset_args = [{"dataset_name": dataset_name}] - if train_ratio != 1.0: - for dataset_args_ in dataset_args: - dataset_args_['dial_ids_order'] = seed - dataset_args_['split2ratio'] = {'train': train_ratio, 'validation': train_ratio} - self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args] - self.ontology = get_ontology_slots(dataset_name) - values = [get_values_from_data(dataset) for dataset in self.dataset_dicts] - self.ontology = ontology_add_values(self.ontology, combine_value_sets(values)) - self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts)) - - data = [load_dst_data(dataset_dict, data_split=set_type, speaker='all', - dialogue_acts=True, split_to_turn=False) - for dataset_dict in self.dataset_dicts] - data_list = [data_[set_type] for data_ in data] - - data = [] - for data_ in data_list: - data += extract_dialogues(data_) - self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len) - - def __getitem__(self, index: int) -> dict: - """ - Obtain dialogues with specific ids from dataset - - Args: - index (int/list/tensor): Index/indices of dialogues to get - - Returns: - features (dict): All inputs and labels required to train the model - """ - return {label: self.features[label][index] for label in self.features - if self.features[label] is not None} - - def __len__(self): - """ - Get number of dialogues in the dataset - - Returns: - len (int): Number of dialogues in the dataset object - """ - return self.features['input_ids'].size(0) - - def resample(self, size: int = None) -> Dataset: - """ - Resample subset of the dataset - - Args: - size (int): Number of dialogues to sample - - Returns: - self (Dataset): Dataset object - """ - # If no subset size is specified we resample a set with the same size as the full dataset - n_dialogues = self.__len__() - if not size: - size = n_dialogues - - dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) - self.features = self.__getitem__(dialogues) - - return self - - def to(self, device): - """ - Map all data to a device - - Args: - device (torch device): Device to map data to - """ - self.device = device - self.features = {label: self.features[label].to(device) for label in self.features - if self.features[label] is not None} - - @classmethod - def from_datadict(cls, data: dict, ontology: dict): - return cls(None, None, None, data=data, ontology=ontology) - - -def get_dataloader(dataset_name: str, - set_type: str, - batch_size: int, - tokenizer: PreTrainedTokenizer, - max_turns: int = 12, - max_seq_len: int = 64, - device='cpu', - resampled_size: int = None, - train_ratio: float = 1.0, - seed: int = 0) -> DataLoader: - ''' - Module to create torch dataloaders - - Args: - dataset_name (str): Name of the dataset to load - set_type (str): Subset of the dataset to load (train, validation or test) - batch_size (int): Batch size for the dataloader - tokenizer (transformers tokenizer): Tokenizer for the encoder model used - max_turns (int): Maximum numbers of turns in a dialogue - max_seq_len (int): Maximum number of tokens in a dialogue turn - device (torch device): Device to map data to - resampled_size (int): Number of dialogues to sample - train_ratio (float): Ratio of training data to use for training - seed (int): Seed governing random order of ids for subsampling - - Returns: - loader (torch dataloader): Dataloader to train and evaluate the setsumbt model - ''' - data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio, - seed=seed) - data.to(device) - - if resampled_size: - data = data.resample(resampled_size) - - if set_type in ['test', 'validation']: - sampler = SequentialSampler(data) - else: - sampler = RandomSampler(data) - loader = DataLoader(data, sampler=sampler, batch_size=batch_size) - - return loader - - -def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader: - """ - Change the batch size of a preloaded loader - - Args: - loader (DataLoader): Dataloader to train and evaluate the setsumbt model - batch_size (int): Batch size for the dataloader - - Returns: - loader (DataLoader): Dataloader to train and evaluate the setsumbt model - """ - - if 'SequentialSampler' in str(loader.sampler): - sampler = SequentialSampler(loader.dataset) - else: - sampler = RandomSampler(loader.dataset) - loader = DataLoader(loader.dataset, sampler=sampler, batch_size=batch_size) - - return loader diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/dataset/utils.py deleted file mode 100644 index dded04aac31102d9e81f5de288b9399b0c190d0c..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/utils.py +++ /dev/null @@ -1,404 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convlab3 Unified dataset data processing utilities""" - -from convlab.util import load_ontology, load_dst_data, load_nlu_data -from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME - - -def get_ontology_slots(dataset_name: str) -> dict: - """ - Function to extract slots, slot descriptions and categorical slot values from the dataset ontology. - - Args: - dataset_name (str): Dataset name - - Returns: - ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values - """ - dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name] - ontology_slots = dict() - for dataset_name in dataset_names: - ontology = load_ontology(dataset_name) - domains = [domain for domain in ontology['domains'] if domain not in ['booking', 'general']] - for domain in domains: - domain_name = DOMAINS_MAP.get(domain, domain.lower()) - if domain_name not in ontology_slots: - ontology_slots[domain_name] = dict() - for slot, slot_info in ontology['domains'][domain]['slots'].items(): - if slot not in ontology_slots[domain_name]: - ontology_slots[domain_name][slot] = {'description': slot_info['description'], - 'possible_values': list()} - if slot_info['is_categorical']: - ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values'] - - ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values'])) - - return ontology_slots - - -def get_values_from_data(dataset: dict) -> dict: - """ - Function to extract slots, slot descriptions and categorical slot values from the dataset ontology. - - Args: - dataset (dict): Dataset dictionary obtained using the load_dataset function - - Returns: - value_sets (dict): Dictionary containing possible values obtained from dataset - """ - data = load_dst_data(dataset, data_split='all', speaker='user') - value_sets = {} - for set_type, dataset in data.items(): - for turn in dataset: - for domain, substate in turn['state'].items(): - domain_name = DOMAINS_MAP.get(domain, domain.lower()) - if domain not in value_sets: - value_sets[domain_name] = {} - for slot, value in substate.items(): - if slot not in value_sets[domain_name]: - value_sets[domain_name][slot] = [] - if value and value not in value_sets[domain_name][slot]: - value_sets[domain_name][slot].append(value) - - return clean_values(value_sets) - - -def combine_value_sets(value_sets: list) -> dict: - """ - Function to combine value sets extracted from different datasets - - Args: - value_sets (list): List of value sets extracted using the get_values_from_data function - - Returns: - value_set (dict): Dictionary containing possible values obtained from datasets - """ - value_set = value_sets[0] - for _value_set in value_sets[1:]: - for domain, domain_info in _value_set.items(): - for slot, possible_values in domain_info.items(): - if domain not in value_set: - value_set[domain] = dict() - if slot not in value_set[domain]: - value_set[domain][slot] = list() - value_set[domain][slot] += _value_set[domain][slot] - value_set[domain][slot] = list(set(value_set[domain][slot])) - - return value_set - - -def clean_values(value_sets: dict, value_map: dict = VALUE_MAP) -> dict: - """ - Function to clean up the possible value sets extracted from the states in the dataset - - Args: - value_sets (dict): Dictionary containing possible values obtained from dataset - value_map (dict): Label map to avoid duplication and typos in values - - Returns: - clean_vals (dict): Cleaned Dictionary containing possible values obtained from dataset - """ - clean_vals = {} - for domain, subset in value_sets.items(): - clean_vals[domain] = {} - for slot, values in subset.items(): - # Remove pipe separated values - values = list(set([val.split('|', 1)[0] for val in values])) - - # Map values using value_map - for old, new in value_map.items(): - values = list(set([val.replace(old, new) for val in values])) - - # Remove empty and dontcare from possible value sets - values = [val for val in values if val not in ['', 'dontcare']] - - # MultiWOZ specific value sets for quantity, time and boolean slots - if 'people' in slot or 'duration' in slot or 'stay' in slot: - values = QUANTITIES - elif 'time' in slot or 'leave' in slot or 'arrive' in slot: - values = TIME - elif 'parking' in slot or 'internet' in slot: - values = ['yes', 'no'] - - clean_vals[domain][slot] = values - - return clean_vals - - -def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict: - """ - Add value sets obtained from the dataset to the ontology - Args: - ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values - value_sets (dict): Cleaned Dictionary containing possible values obtained from dataset - - Returns: - ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and possible value sets - """ - ontology = {} - for domain in sorted(ontology_slots): - ontology[domain] = {} - for slot in sorted(ontology_slots[domain]): - if not ontology_slots[domain][slot]['possible_values']: - if domain in value_sets: - if slot in value_sets[domain]: - ontology_slots[domain][slot]['possible_values'] = value_sets[domain][slot] - if ontology_slots[domain][slot]['possible_values']: - values = sorted(ontology_slots[domain][slot]['possible_values']) - ontology_slots[domain][slot]['possible_values'] = ['none', 'do not care'] + values - - ontology[domain][slot] = ontology_slots[domain][slot] - - return ontology - - -def get_requestable_slots(datasets: list) -> dict: - """ - Function to get set of requestable slots from the dataset action labels. - Args: - dataset (dict): Dataset dictionary obtained using the load_dataset function - - Returns: - slots (dict): Dictionary containing requestable domain-slot pairs - """ - datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets] - - slots = {} - for data in datasets: - for set_type, subset in data.items(): - for turn in subset: - requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request'] - requests += [act for act in turn['dialogue_acts']['non-categorical'] if act['intent'] == 'request'] - requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request'] - requests = [(act['domain'], act['slot']) for act in requests] - for domain, slot in requests: - domain_name = DOMAINS_MAP.get(domain, domain.lower()) - if domain_name not in slots: - slots[domain_name] = [] - slots[domain_name].append(slot) - - slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()} - - return slots - - -def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict) -> dict: - """ - Add requestable slots obtained from the dataset to the ontology - Args: - ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values - requestable_slots (dict): Dictionary containing requestable domain-slot pairs - - Returns: - ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and - possible value sets including requests - """ - for domain in ontology_slots: - for slot in ontology_slots[domain]: - if domain in requestable_slots: - if slot in requestable_slots[domain]: - ontology_slots[domain][slot]['possible_values'].append('?') - - return ontology_slots - - -def extract_turns(dialogue: list) -> list: - """ - Extract the required information from the data provided by unified loader - Args: - dialogue (list): List of turns within a dialogue - - Returns: - turns (list): List of turns within a dialogue - """ - turns = [] - turn_info = {} - for turn in dialogue: - if turn['speaker'] == 'system': - turn_info['system_utterance'] = turn['utterance'] - - # System utterance in the first turn is always empty as conversation is initiated by the user - if turn['utt_idx'] == 1: - turn_info['system_utterance'] = '' - - if turn['speaker'] == 'user': - turn_info['user_utterance'] = turn['utterance'] - - # Inform acts not required by model - turn_info['dialogue_acts'] = [act for act in turn['dialogue_acts']['categorical'] - if act['intent'] not in ['inform']] - turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['non-categorical'] - if act['intent'] not in ['inform']] - turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['binary'] - if act['intent'] not in ['inform']] - - turn_info['state'] = turn['state'] - - if 'system_utterance' in turn_info and 'user_utterance' in turn_info: - turns.append(turn_info) - turn_info = {} - - return turns - - -def clean_states(turns: list) -> list: - """ - Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology) - Args: - turns (list): List of turns within a dialogue - - Returns: - clean_turns (list): List of turns within a dialogue - """ - clean_turns = [] - for turn in turns: - clean_state = {} - clean_acts = [] - for act in turn['dialogue_acts']: - domain = act['domain'] - act['domain'] = DOMAINS_MAP.get(domain, domain.lower()) - clean_acts.append(act) - for domain, subset in turn['state'].items(): - domain_name = DOMAINS_MAP.get(domain, domain.lower()) - clean_state[domain_name] = {} - for slot, value in subset.items(): - # Remove pipe separated values - value = value.split('|', 1)[0] - - # Map values using value_map - for old, new in VALUE_MAP.items(): - value = value.replace(old, new) - - # Map dontcare to "do not care" and empty to 'none' - value = value.replace('dontcare', 'do not care') - value = value if value else 'none' - - # Map quantity values to the integer quantity value - if 'people' in slot or 'duration' in slot or 'stay' in slot: - try: - if value not in ['do not care', 'none']: - value = int(value) - value = str(value) if value < 10 else QUANTITIES[-1] - except: - value = value - # Map time values to the most appropriate value in the standard time set - elif 'time' in slot or 'leave' in slot or 'arrive' in slot: - try: - if value not in ['do not care', 'none']: - # Strip after/before from time value - value = value.replace('after ', '').replace('before ', '') - # Extract hours and minutes from different possible formats - if ':' not in value and len(value) == 4: - h, m = value[:2], value[2:] - elif len(value) == 1: - h = int(value) - m = 0 - elif 'pm' in value: - h = int(value.replace('pm', '')) + 12 - m = 0 - elif 'am' in value: - h = int(value.replace('pm', '')) - m = 0 - elif ':' in value: - h, m = value.split(':') - elif ';' in value: - h, m = value.split(';') - # Map to closest 5 minutes - if int(m) % 5 != 0: - m = round(int(m) / 5) * 5 - h = int(h) - if m == 60: - m = 0 - h += 1 - if h >= 24: - h -= 24 - # Set in standard 24 hour format - h, m = int(h), int(m) - value = '%02i:%02i' % (h, m) - except: - value = value - # Map boolean slots to yes/no value - elif 'parking' in slot or 'internet' in slot: - if value not in ['do not care', 'none']: - if value == 'free': - value = 'yes' - elif True in [v in value.lower() for v in ['yes', 'no']]: - value = [v for v in ['yes', 'no'] if v in value][0] - - clean_state[domain_name][slot] = value - turn['state'] = clean_state - turn['dialogue_acts'] = clean_acts - clean_turns.append(turn) - - return clean_turns - - -def get_active_domains(turns: list) -> list: - """ - Get active domains at each turn in a dialogue - Args: - turns (list): List of turns within a dialogue - - Returns: - turns (list): List of turns within a dialogue - """ - for turn_id in range(len(turns)): - # At first turn all domains with not none values in the state are active - if turn_id == 0: - domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none'] - domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']] - domains = [DOMAINS_MAP.get(domain, domain.lower()) for domain in domains] - turns[turn_id]['active_domains'] = list(set(domains)) - else: - # Use changes in domains to identify active domains - domains = [] - for domain, substate in turns[turn_id]['state'].items(): - domain_name = DOMAINS_MAP.get(domain, domain.lower()) - for slot, value in substate.items(): - if value != turns[turn_id - 1]['state'][domain][slot]: - val = value - else: - val = 'none' - if value == 'none': - val = 'none' - if val != 'none': - domains.append(domain_name) - # Add all domains activated by a user action - domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] - if act['domain'] in turns[turn_id]['state']] - turns[turn_id]['active_domains'] = list(set(domains)) - - return turns - - -def extract_dialogues(data: list) -> list: - """ - Extract all dialogues from dataset - Args: - data (list): List of all dialogues in a subset of the data - - Returns: - dialogues (list): List of all extracted dialogues - """ - dialogues = [] - for dial in data: - turns = extract_turns(dial['turns']) - turns = clean_states(turns) - turns = get_active_domains(turns) - dialogues.append(turns) - - return dialogues diff --git a/convlab/dst/setsumbt/dataset/value_maps.py b/convlab/dst/setsumbt/dataset/value_maps.py deleted file mode 100644 index 619600a7b0a57096918058ff117aa2ca5aac864a..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/dataset/value_maps.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convlab3 Unified dataset value maps""" - - -# MultiWOZ specific label map to avoid duplication and typos in values -VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast', - 'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott', - 'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge', - 'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel', - 'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european', - 'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"} - - -# Domain map for SGD and TM Data -DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus', - 'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events', - 'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights', - 'Flights_4': 'flights', 'Homes_1': 'homes', 'Homes_2': 'homes', 'Hotels_1': 'hotel', - 'Hotels_2': 'hotel', 'Hotels_3': 'hotel', 'Hotels_4': 'hotel', 'Media_1': 'media', - 'Media_2': 'media', 'Media_3': 'media', 'Messaging_1': 'messaging', 'Movies_1': 'movies', - 'Movies_2': 'movies', 'Movies_3': 'movies', 'Music_1': 'music', 'Music_2': 'music', 'Music_3': 'music', - 'Payment_1': 'payment', 'RentalCars_1': 'rentalcars', 'RentalCars_2': 'rentalcars', - 'RentalCars_3': 'rentalcars', 'Restaurants_1': 'restaurant', 'Restaurants_2': 'restaurant', - 'RideSharing_1': 'ridesharing', 'RideSharing_2': 'ridesharing', 'Services_1': 'services', - 'Services_2': 'services', 'Services_3': 'services', 'Services_4': 'services', 'Trains_1': 'train', - 'Travel_1': 'travel', 'Weather_1': 'weather', 'movie_ticket': 'movies', - 'restaurant_reservation': 'restaurant', 'coffee_ordering': 'coffee', 'pizza_ordering': 'takeout', - 'auto_repair': 'car_repairs', 'flights': 'flights', 'food-ordering': 'takeout', 'hotels': 'hotel', - 'movies': 'movies', 'music': 'music', 'restaurant-search': 'restaurant', 'sports': 'sports', - 'movie': 'movies'} - - -# Generic value sets for quantity and time slots -QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more'] -TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)] -TIME = ['%02i:%02i' % t for l in TIME for t in l] \ No newline at end of file diff --git a/convlab/dst/setsumbt/distillation_setup.py b/convlab/dst/setsumbt/distillation_setup.py index e54ccaa224d80aa7a009f7cf538900e289d5dd4f..e0d87bb964041cae19d58351cd6b31f6d836f125 100644 --- a/convlab/dst/setsumbt/distillation_setup.py +++ b/convlab/dst/setsumbt/distillation_setup.py @@ -1,57 +1,42 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Get ensemble predictions and build distillation dataloaders""" - from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser import os -import json import torch -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +import transformers +from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler +from transformers import RobertaConfig, BertConfig from tqdm import tqdm -from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset +import convlab +from convlab.dst.setsumbt.multiwoz.dataset.multiwoz21 import EnsembleMultiWoz21 from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT -from convlab.dst.setsumbt.modeling import training -DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +DEVICE = 'cuda' -def main(): +def args_parser(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument('--model_path', type=str) parser.add_argument('--model_type', type=str) parser.add_argument('--set_type', type=str) parser.add_argument('--batch_size', type=int) + parser.add_argument('--ensemble_size', type=int) parser.add_argument('--reduction', type=str, default='mean') parser.add_argument('--get_ensemble_distributions', action='store_true') parser.add_argument('--build_dataloaders', action='store_true') - args = parser.parse_args() + + return parser.parse_args() + + +def main(): + args = args_parser() if args.get_ensemble_distributions: get_ensemble_distributions(args) elif args.build_dataloaders: path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') data = torch.load(path) - - reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r') - ontology = json.load(reader) - reader.close() - - loader = get_loader(data, ontology, args.set_type, args.batch_size) + loader = get_loader(data, args.set_type, args.batch_size) path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader') torch.save(loader, path) @@ -59,22 +44,10 @@ def main(): raise NameError("NotImplemented") -def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader: - """ - Build dataloader from ensemble prediction data - - Args: - data: Dictionary of ensemble predictions - ontology: Data ontology - set_type: Data subset (train/validation/test) - batch_size: Number of dialogues per batch - - Returns: - loader: Data loader object - """ +def get_loader(data, set_type='train', batch_size=3): data = flatten_data(data) data = do_label_padding(data) - data = UnifiedFormatDataset.from_datadict(data, ontology) + data = EnsembleMultiWoz21(data) if set_type == 'train': sampler = RandomSampler(data) else: @@ -84,16 +57,7 @@ def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: return loader -def do_label_padding(data: dict) -> dict: - """ - Add padding to the ensemble predictions (used as labels in distillation) - - Args: - data: Dictionary of ensemble predictions - - Returns: - data: Padded ensemble predictions - """ +def do_label_padding(data): if 'attention_mask' in data: dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0) else: @@ -106,17 +70,13 @@ def do_label_padding(data: dict) -> dict: return data -def flatten_data(data: dict) -> dict: - """ - Map data to flattened feature format used in training - Args: - data: Ensemble prediction data - - Returns: - data: Flattened ensemble prediction data - """ +map_dict = {'belief_state': 'belief', 'greeting_act_belief': 'goodbye_belief', + 'state_labels': 'labels', 'request_labels': 'request', + 'domain_labels': 'active', 'greeting_labels': 'goodbye'} +def flatten_data(data): data_new = dict() for label, feats in data.items(): + label = map_dict.get(label, label) if type(feats) == dict: for label_, feats_ in feats.items(): data_new[label + '-' + label_] = feats_ @@ -127,11 +87,13 @@ def flatten_data(data: dict) -> dict: def get_ensemble_distributions(args): - """ - Load data and get ensemble predictions - Args: - args: Runtime arguments - """ + if args.model_type == 'roberta': + config = RobertaConfig + elif args.model_type == 'bert': + config = BertConfig + config = config.from_pretrained(args.model_path) + config.ensemble_size = args.ensemble_size + device = DEVICE model = EnsembleSetSUMBT.from_pretrained(args.model_path) @@ -145,7 +107,16 @@ def get_ensemble_distributions(args): dataloader = torch.load(dataloader) database = torch.load(database) - training.set_ontology_embeddings(model, database) + # Get slot and value embeddings + slots = {slot: val for slot, val in database.items()} + values = {slot: val[1] for slot, val in database.items()} + del database + + # Load model ontology + model.add_slot_candidates(slots) + for slot in model.informable_slot_ids: + model.add_value_candidates(slot, values[slot], replace=True) + del slots, values print('Environment set up.') @@ -154,24 +125,18 @@ def get_ensemble_distributions(args): attention_mask = [] state_labels = {slot: [] for slot in model.informable_slot_ids} request_labels = {slot: [] for slot in model.requestable_slot_ids} - active_domain_labels = {domain: [] for domain in model.domain_ids} - general_act_labels = [] - - is_noisy = [] if 'is_noisy' in dataloader.dataset.features else None - + domain_labels = {domain: [] for domain in model.domain_ids} + greeting_labels = [] belief_state = {slot: [] for slot in model.informable_slot_ids} - request_probs = {slot: [] for slot in model.requestable_slot_ids} - active_domain_probs = {domain: [] for domain in model.domain_ids} - general_act_probs = [] + request_belief = {slot: [] for slot in model.requestable_slot_ids} + domain_belief = {domain: [] for domain in model.domain_ids} + greeting_act_belief = [] model.eval() for batch in tqdm(dataloader, desc='Batch:'): ids = batch['input_ids'] tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None mask = batch['attention_mask'] if 'attention_mask' in batch else None - if 'is_noisy' in batch: - is_noisy.append(batch['is_noisy']) - input_ids.append(ids) token_type_ids.append(tt_ids) attention_mask.append(mask) @@ -181,59 +146,57 @@ def get_ensemble_distributions(args): mask = mask.to(device) if mask is not None else None for slot in state_labels: - state_labels[slot].append(batch['state_labels-' + slot]) - if model.config.predict_actions: + state_labels[slot].append(batch['labels-' + slot]) + if model.config.predict_intents: for slot in request_labels: - request_labels[slot].append(batch['request_labels-' + slot]) + request_labels[slot].append(batch['request-' + slot]) for domain in domain_labels: - domain_labels[domain].append(batch['active_domain_labels-' + domain]) - greeting_labels.append(batch['general_act_labels']) + domain_labels[domain].append(batch['active-' + domain]) + greeting_labels.append(batch['goodbye']) with torch.no_grad(): - p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction) + p, p_req, p_dom, p_bye, _ = model(ids, mask, tt_ids, + reduction=args.reduction) for slot in belief_state: belief_state[slot].append(p[slot].cpu()) - if model.config.predict_actions: - for slot in request_probs: - request_probs[slot].append(p_req[slot].cpu()) - for domain in active_domain_probs: - active_domain_probs[domain].append(p_dom[domain].cpu()) - general_act_probs.append(p_gen.cpu()) + if model.config.predict_intents: + for slot in request_belief: + request_belief[slot].append(p_req[slot].cpu()) + for domain in domain_belief: + domain_belief[domain].append(p_dom[domain].cpu()) + greeting_act_belief.append(p_bye.cpu()) input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None - is_noisy = torch.cat(is_noisy, 0) if is_noisy is not None else None state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()} - if model.config.predict_actions: + if model.config.predict_intents: request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()} - active_domain_labels = {domain: torch.cat(l, 0) for domain, l in active_domain_labels.items()} - general_act_labels = torch.cat(general_act_labels, 0) + domain_labels = {domain: torch.cat(l, 0) for domain, l in domain_labels.items()} + greeting_labels = torch.cat(greeting_labels, 0) belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()} - if model.config.predict_actions: - request_probs = {slot: torch.cat(p, 0) for slot, p in request_probs.items()} - active_domain_probs = {domain: torch.cat(p, 0) for domain, p in active_domain_probs.items()} - general_act_probs = torch.cat(general_act_probs, 0) + if model.config.predict_intents: + request_belief = {slot: torch.cat(p, 0) for slot, p in request_belief.items()} + domain_belief = {domain: torch.cat(p, 0) for domain, p in domain_belief.items()} + greeting_act_belief = torch.cat(greeting_act_belief, 0) data = {'input_ids': input_ids} if token_type_ids is not None: data['token_type_ids'] = token_type_ids if attention_mask is not None: data['attention_mask'] = attention_mask - if is_noisy is not None: - data['is_noisy'] = is_noisy data['state_labels'] = state_labels data['belief_state'] = belief_state - if model.config.predict_actions: + if model.config.predict_intents: data['request_labels'] = request_labels - data['active_domain_labels'] = active_domain_labels - data['general_act_labels'] = general_act_labels - data['request_probs'] = request_probs - data['active_domain_probs'] = active_domain_probs - data['general_act_probs'] = general_act_probs + data['domain_labels'] = domain_labels + data['greeting_labels'] = greeting_labels + data['request_belief'] = request_belief + data['domain_belief'] = domain_belief + data['greeting_act_belief'] = greeting_act_belief file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data') torch.save(data, file) diff --git a/convlab/dst/setsumbt/do/calibration.py b/convlab/dst/setsumbt/do/calibration.py new file mode 100644 index 0000000000000000000000000000000000000000..27ee058eca882ce7e10937f9640d143b88e57f5e --- /dev/null +++ b/convlab/dst/setsumbt/do/calibration.py @@ -0,0 +1,481 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run SetSUMBT Calibration""" + +import logging +import random +import os +from shutil import copy2 as copy + +import torch +from transformers import (BertModel, BertConfig, BertTokenizer, + RobertaModel, RobertaConfig, RobertaTokenizer, + AdamW, get_linear_schedule_with_warmup) +from tqdm import tqdm, trange +from tensorboardX import SummaryWriter +from torch.distributions import Categorical + +from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT +from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT +from convlab.dst.setsumbt.multiwoz import multiwoz21 +from convlab.dst.setsumbt.multiwoz import ontology as embeddings +from convlab.dst.setsumbt.utils import get_args, upload_local_directory_to_gcs, update_args +from convlab.dst.setsumbt.modeling import calibration_utils +from convlab.dst.setsumbt.modeling import ensemble_utils +from convlab.dst.setsumbt.loss.ece import ece, jg_ece, l2_acc + + +# Datasets +DATASETS = { + 'multiwoz21': multiwoz21 +} + +MODELS = { + 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), + 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) +} + + +def main(args=None, config=None): + # Get arguments + if args is None: + args, config = get_args(MODELS) + + # Select Dataset object + if args.dataset in DATASETS: + Dataset = DATASETS[args.dataset] + else: + raise NameError('NotImplemented') + + if args.model_type in MODELS: + SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] + else: + raise NameError('NotImplemented') + + # Set up output directory + OUTPUT_DIR = args.output_dir + if not os.path.exists(OUTPUT_DIR): + os.mkdir(OUTPUT_DIR) + args.output_dir = OUTPUT_DIR + if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): + os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) + + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + if 'pytorch_model.bin' in paths and 'config.json' in paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + else: + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + paths = [os.path.join(args.output_dir, p) + for p in paths if 'checkpoint-' in p] + if paths: + paths = paths[0] + args.model_name_or_path = paths + config = ConfigClass.from_pretrained(args.model_name_or_path) + + if args.ensemble_size > 0: + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + paths = [os.path.join(args.output_dir, p) + for p in paths if 'ensemble_' in p] + if paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + + args = update_args(args, config) + + # Set up data directory + DATA_DIR = args.data_dir + Dataset.set_datadir(DATA_DIR) + embeddings.set_datadir(DATA_DIR) + + if args.shrink_active_domains and args.dataset == 'multiwoz21': + Dataset.set_active_domains( + ['attraction', 'hotel', 'restaurant', 'taxi', 'train']) + + # Download and preprocess + Dataset.create_examples( + args.max_turn_len, args.predict_intents, args.force_processing) + + # Create logger + global logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + if 'stream' not in args.logging_path: + fh = logging.FileHandler(args.logging_path) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) + else: + ch = logging.StreamHandler() + ch.setLevel(level=logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + # Get device + if torch.cuda.is_available() and args.n_gpu > 0: + device = torch.device('cuda') + else: + device = torch.device('cpu') + args.n_gpu = 0 + + if args.n_gpu == 0: + args.fp16 = False + + # Set up model training/evaluation + calibration.set_logger(logger, None) + calibration.set_seed(args) + + if args.ensemble_size > 0: + ensemble.set_logger(logger, tb_writer) + ensemble_utils.set_seed(args) + + # Perform tasks + + if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')): + pred = torch.load(os.path.join( + OUTPUT_DIR, 'predictions', 'test.predictions')) + labels = pred['labels'] + belief_states = pred['belief_states'] + if 'request_labels' in pred: + request_labels = pred['request_labels'] + request_belief = pred['request_belief'] + domain_labels = pred['domain_labels'] + domain_belief = pred['domain_belief'] + greeting_labels = pred['greeting_labels'] + greeting_belief = pred['greeting_belief'] + else: + request_belief = None + del pred + elif args.ensemble_size > 0: + # Get training batch loaders and ontology embeddings + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): + test_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'test.db')) + else: + # Create Tokenizer and embedding model for Data Loaders and ontology + encoder = CandidateEncoderModel.from_pretrained( + config.candidate_embedding_model_name) + tokenizer = Tokenizer(config.candidate_embedding_model_name) + embeddings.get_slot_candidate_embeddings( + 'test', args, tokenizer, encoder) + test_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'test.db')) + + exists = False + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): + test_dataloader = torch.load(os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if test_dataloader.batch_size == args.test_batch_size: + exists = True + if not exists: + tokenizer = Tokenizer(config.candidate_embedding_model_name) + test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(test_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + + config, models = ensemble.get_models( + args.model_name_or_path, device, ConfigClass, SetSumbtModel) + + belief_states, labels = ensemble_utils.get_predictions( + args, models, device, test_dataloader, test_slots) + torch.save({'belief_states': belief_states, 'labels': labels}, + os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) + else: + # Get training batch loaders and ontology embeddings + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): + test_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'test.db')) + else: + # Create Tokenizer and embedding model for Data Loaders and ontology + encoder = CandidateEncoderModel.from_pretrained( + config.candidate_embedding_model_name) + tokenizer = Tokenizer(config.candidate_embedding_model_name) + embeddings.get_slot_candidate_embeddings( + 'test', args, tokenizer, encoder) + test_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'test.db')) + + exists = False + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): + test_dataloader = torch.load(os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if test_dataloader.batch_size == args.test_batch_size: + exists = True + if not exists: + tokenizer = Tokenizer(config.candidate_embedding_model_name) + test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(test_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + + # Initialise Model + model = SetSumbtModel.from_pretrained( + args.model_name_or_path, config=config) + model = model.to(device) + + # Get slot and value embeddings + slots = {slot: test_slots[slot] for slot in test_slots} + values = {slot: test_slots[slot][1] for slot in test_slots} + + # Load model ontology + model.add_slot_candidates(slots) + for slot in model.informable_slot_ids: + model.add_value_candidates(slot, values[slot], replace=True) + + belief_states = calibration.get_predictions( + args, model, device, test_dataloader) + belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = belief_states + out = {'belief_states': belief_states, 'labels': labels, + 'request_belief': request_belief, 'request_labels': request_labels, + 'domain_belief': domain_belief, 'domain_labels': domain_labels, + 'greeting_belief': greeting_belief, 'greeting_labels': greeting_labels} + torch.save(out, os.path.join( + OUTPUT_DIR, 'predictions', 'test.predictions')) + + # err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10) + # for slot in belief_states] + # err = max(err) + # logger.info('ECE: %f' % err) + + # Calculate calibration metrics + + jg = jg_ece(belief_states, labels, 10) + logger.info('Joint Goal ECE: %f' % jg) + + binary_states = {} + for slot, p in belief_states.items(): + shp = p.shape + p = p.reshape(-1, p.size(-1)) + p_ = torch.ones(p.shape).to(p.device) * 1e-8 + p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 + binary_states[slot] = p_.reshape(shp) + jg = jg_ece(binary_states, labels, 10) + logger.info('Joint Goal Binary ECE: %f' % jg) + + bs = {slot: torch.cat((p[:, :, 0].unsqueeze(-1), p[:, :, 1:].max(-1) + [0].unsqueeze(-1)), -1) for slot, p in belief_states.items()} + ls = {} + for slot, l in labels.items(): + y = torch.zeros((l.size(0), l.size(1))).to(l.device) + dials, turns = torch.where(l > 0) + y[dials, turns] = 1.0 + dials, turns = torch.where(l < 0) + y[dials, turns] = -1.0 + ls[slot] = y + + jg = jg_ece(bs, ls, 10) + logger.info('Slot presence ECE: %f' % jg) + + binary_states = {} + for slot, p in bs.items(): + shp = p.shape + p = p.reshape(-1, p.size(-1)) + p_ = torch.ones(p.shape).to(p.device) * 1e-8 + p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 + binary_states[slot] = p_.reshape(shp) + jg = jg_ece(binary_states, ls, 10) + logger.info('Slot presence Binary ECE: %f' % jg) + + jg_acc = 0.0 + padding = torch.cat([item.unsqueeze(-1) + for _, item in labels.items()], -1).sum(-1) * -1.0 + padding = (padding == len(labels)) + padding = padding.reshape(-1) + for slot in belief_states: + topn = args.accuracy_topn + p_ = belief_states[slot] + gold = labels[slot] + + if p_.size(-1) <= topn: + topn = p_.size(-1) - 1 + if topn <= 0: + topn = 1 + + if topn > 1: + labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True) + labs = labs[:, :topn] + else: + labs = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1) + acc = [lab in s for lab, s, pad in zip( + gold.reshape(-1), labs, padding) if not pad] + acc = torch.tensor(acc).float() + + jg_acc += acc + + n_turns = jg_acc.size(0) + sl_acc = sum(jg_acc / len(belief_states)).float() + jg_acc = sum((jg_acc / len(belief_states)).int()).float() + + sl_acc /= n_turns + jg_acc /= n_turns + + logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc)) + + l2 = l2_acc(belief_states, labels, remove_belief=False) + logger.info(f'Model L2 Norm Goal Accuracy: {l2}') + l2 = l2_acc(belief_states, labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}') + + for slot in belief_states: + p = belief_states[slot] + p = p.reshape(-1, p.size(-1)) + p = torch.cat( + (p[:, 0].unsqueeze(-1), p[:, 1:].max(-1)[0].unsqueeze(-1)), -1) + belief_states[slot] = p + + l = labels[slot].reshape(-1) + l[l > 0] = 1 + labels[slot] = l + + f1 = 0.0 + for slot in belief_states: + prd = belief_states[slot].argmax(-1) + tp = ((prd == 1) * (labels[slot] == 1)).sum() + fp = ((prd == 1) * (labels[slot] == 0)).sum() + fn = ((prd == 0) * (labels[slot] == 1)).sum() + if tp > 0: + f1 += tp / (tp + 0.5 * (fp + fn)) + f1 /= len(belief_states) + logger.info(f'Trucated Goal F1 Score: {f1}') + + l2 = l2_acc(belief_states, labels, remove_belief=False) + logger.info(f'Model L2 Norm Trucated Goal Accuracy: {l2}') + l2 = l2_acc(belief_states, labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Trucated Goal Accuracy: {l2}') + + if request_belief is not None: + tp, fp, fn = 0.0, 0.0, 0.0 + for slot in request_belief: + p = request_belief[slot] + l = request_labels[slot] + + tp += (p.round().int() * (l == 1)).reshape(-1).float() + fp += (p.round().int() * (l == 0)).reshape(-1).float() + fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() + tp /= len(request_belief) + fp /= len(request_belief) + fn /= len(request_belief) + f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) + logger.info('Request F1 Score: %f' % f1.item()) + + for slot in request_belief: + p = request_belief[slot] + p = p.unsqueeze(-1) + p = torch.cat((1 - p, p), -1) + request_belief[slot] = p + jg = jg_ece(request_belief, request_labels, 10) + logger.info('Request Joint Goal ECE: %f' % jg) + + binary_states = {} + for slot, p in request_belief.items(): + shp = p.shape + p = p.reshape(-1, p.size(-1)) + p_ = torch.ones(p.shape).to(p.device) * 1e-8 + p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 + binary_states[slot] = p_.reshape(shp) + jg = jg_ece(binary_states, request_labels, 10) + logger.info('Request Joint Goal Binary ECE: %f' % jg) + + tp, fp, fn = 0.0, 0.0, 0.0 + for dom in domain_belief: + p = domain_belief[dom] + l = domain_labels[dom] + + tp += (p.round().int() * (l == 1)).reshape(-1).float() + fp += (p.round().int() * (l == 0)).reshape(-1).float() + fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() + tp /= len(domain_belief) + fp /= len(domain_belief) + fn /= len(domain_belief) + f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) + logger.info('Domain F1 Score: %f' % f1.item()) + + for dom in domain_belief: + p = domain_belief[dom] + p = p.unsqueeze(-1) + p = torch.cat((1 - p, p), -1) + domain_belief[dom] = p + jg = jg_ece(domain_belief, domain_labels, 10) + logger.info('Domain Joint Goal ECE: %f' % jg) + + binary_states = {} + for slot, p in domain_belief.items(): + shp = p.shape + p = p.reshape(-1, p.size(-1)) + p_ = torch.ones(p.shape).to(p.device) * 1e-8 + p_[range(p.size(0)), p.argmax(-1)] = 1.0 - 1e-8 + binary_states[slot] = p_.reshape(shp) + jg = jg_ece(binary_states, domain_labels, 10) + logger.info('Domain Joint Goal Binary ECE: %f' % jg) + + tp = ((greeting_belief.argmax(-1) > 0) * + (greeting_labels > 0)).reshape(-1).float().sum() + fp = ((greeting_belief.argmax(-1) > 0) * + (greeting_labels == 0)).reshape(-1).float().sum() + fn = ((greeting_belief.argmax(-1) == 0) * + (greeting_labels > 0)).reshape(-1).float().sum() + f1 = tp / (tp + 0.5 * (fp + fn)) + logger.info('Greeting F1 Score: %f' % f1.item()) + + err = ece(greeting_belief.reshape(-1, greeting_belief.size(-1)), + greeting_labels.reshape(-1), 10) + logger.info('Greetings ECE: %f' % err) + + greeting_belief = greeting_belief.reshape(-1, greeting_belief.size(-1)) + binary_states = torch.ones(greeting_belief.shape).to( + greeting_belief.device) * 1e-8 + binary_states[range(greeting_belief.size(0)), + greeting_belief.argmax(-1)] = 1.0 - 1e-8 + err = ece(binary_states, greeting_labels.reshape(-1), 10) + logger.info('Greetings Binary ECE: %f' % err) + + for slot in request_belief: + p = request_belief[slot].unsqueeze(-1) + request_belief[slot] = torch.cat((1 - p, p), -1) + + l2 = l2_acc(request_belief, request_labels, remove_belief=False) + logger.info(f'Model L2 Norm Request Accuracy: {l2}') + l2 = l2_acc(request_belief, request_labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}') + + for slot in domain_belief: + p = domain_belief[slot].unsqueeze(-1) + domain_belief[slot] = torch.cat((1 - p, p), -1) + + l2 = l2_acc(domain_belief, domain_labels, remove_belief=False) + logger.info(f'Model L2 Norm Domain Accuracy: {l2}') + l2 = l2_acc(domain_belief, domain_labels, remove_belief=True) + logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}') + + greeting_labels = {'bye': greeting_labels} + greeting_belief = {'bye': greeting_belief} + + l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False) + logger.info(f'Model L2 Norm Greeting Accuracy: {l2}') + l2 = l2_acc(greeting_belief, greeting_labels, remove_belief=False) + logger.info(f'Binary Model L2 Norm Greeting Accuracy: {l2}') + + +if __name__ == "__main__": + main() diff --git a/convlab/dst/setsumbt/do/evaluate.py b/convlab/dst/setsumbt/do/evaluate.py deleted file mode 100644 index 2fe351b3d5c2af187da58ffcc46e8184013bbcdb..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/do/evaluate.py +++ /dev/null @@ -1,296 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Run SetSUMBT Calibration""" - -import logging -import os - -import torch -from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer) - -from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT -from convlab.dst.setsumbt.dataset import unified_format -from convlab.dst.setsumbt.dataset import ontology as embeddings -from convlab.dst.setsumbt.utils import get_args, update_args -from convlab.dst.setsumbt.modeling import evaluation_utils -from convlab.dst.setsumbt.loss.uncertainty_measures import ece, jg_ece, l2_acc -from convlab.dst.setsumbt.modeling import training - - -# Available model -MODELS = { - 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), - 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) -} - - -def main(args=None, config=None): - # Get arguments - if args is None: - args, config = get_args(MODELS) - - if args.model_type in MODELS: - SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] - else: - raise NameError('NotImplemented') - - # Set up output directory - OUTPUT_DIR = args.output_dir - args.output_dir = OUTPUT_DIR - if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): - os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) - - # Set pretrained model path to the trained checkpoint - paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] - if paths: - paths = paths[0] - args.model_name_or_path = paths - config = ConfigClass.from_pretrained(args.model_name_or_path) - - args = update_args(args, config) - - # Create logger - global logger - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - - formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') - - fh = logging.FileHandler(args.logging_path) - fh.setLevel(logging.INFO) - fh.setFormatter(formatter) - logger.addHandler(fh) - - # Get device - if torch.cuda.is_available() and args.n_gpu > 0: - device = torch.device('cuda') - else: - device = torch.device('cpu') - args.n_gpu = 0 - - if args.n_gpu == 0: - args.fp16 = False - - # Set up model training/evaluation - evaluation_utils.set_seed(args) - - # Perform tasks - if os.path.exists(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')): - pred = torch.load(os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) - state_labels = pred['state_labels'] - belief_states = pred['belief_states'] - if 'request_labels' in pred: - request_labels = pred['request_labels'] - request_probs = pred['request_probs'] - active_domain_labels = pred['active_domain_labels'] - active_domain_probs = pred['active_domain_probs'] - general_act_labels = pred['general_act_labels'] - general_act_probs = pred['general_act_probs'] - else: - request_probs = None - del pred - else: - # Get training batch loaders and ontology embeddings - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size != args.test_batch_size: - test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size) - else: - tokenizer = Tokenizer(config.candidate_embedding_model_name) - test_dataloader = unified_format.get_dataloader(args.dataset, 'test', - args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db')) - else: - encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) - test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, - 'test', args, tokenizer, encoder) - - # Initialise Model - model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) - model = model.to(device) - - training.set_ontology_embeddings(model, test_slots) - - belief_states = evaluation_utils.get_predictions(args, model, device, test_dataloader) - state_labels = belief_states[1] - request_probs = belief_states[2] - request_labels = belief_states[3] - active_domain_probs = belief_states[4] - active_domain_labels = belief_states[5] - general_act_probs = belief_states[6] - general_act_labels = belief_states[7] - belief_states = belief_states[0] - out = {'belief_states': belief_states, 'state_labels': state_labels, 'request_probs': request_probs, - 'request_labels': request_labels, 'active_domain_probs': active_domain_probs, - 'active_domain_labels': active_domain_labels, 'general_act_probs': general_act_probs, - 'general_act_labels': general_act_labels} - torch.save(out, os.path.join(OUTPUT_DIR, 'predictions', 'test.predictions')) - - # Calculate calibration metrics - jg = jg_ece(belief_states, state_labels, 10) - logger.info('Joint Goal ECE: %f' % jg) - - jg_acc = 0.0 - padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0 - padding = (padding == len(state_labels)) - padding = padding.reshape(-1) - for slot in belief_states: - p_ = belief_states[slot] - gold = state_labels[slot] - - pred = p_.reshape(-1, p_.size(-1)).argmax(dim=-1).unsqueeze(-1) - acc = [lab in s for lab, s, pad in zip(gold.reshape(-1), pred, padding) if not pad] - acc = torch.tensor(acc).float() - - jg_acc += acc - - n_turns = jg_acc.size(0) - jg_acc = sum((jg_acc / len(belief_states)).int()).float() - - jg_acc /= n_turns - - logger.info(f'Joint Goal Accuracy: {jg_acc}') - - l2 = l2_acc(belief_states, state_labels, remove_belief=False) - logger.info(f'Model L2 Norm Goal Accuracy: {l2}') - l2 = l2_acc(belief_states, state_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Goal Accuracy: {l2}') - - padding = torch.cat([item.unsqueeze(-1) for _, item in state_labels.items()], -1).sum(-1) * -1.0 - padding = (padding == len(state_labels)) - padding = padding.reshape(-1) - - tp, fp, fn, tn, n = 0.0, 0.0, 0.0, 0.0, 0.0 - for slot in belief_states: - p_ = belief_states[slot] - gold = state_labels[slot].reshape(-1) - p_ = p_.reshape(-1, p_.size(-1)) - - p_ = p_[~padding].argmax(-1) - gold = gold[~padding] - - tp += (p_ == gold)[gold != 0].int().sum().item() - fp += (p_ != 0)[gold == 0].int().sum().item() - fp += (p_ != gold)[gold != 0].int().sum().item() - fp -= (p_ == 0)[gold != 0].int().sum().item() - fn += (p_ == 0)[gold != 0].int().sum().item() - tn += (p_ == 0)[gold == 0].int().sum().item() - n += p_.size(0) - - acc = (tp + tn) / n - prec = tp / (tp + fp) - rec = tp / (tp + fn) - f1 = 2 * (prec * rec) / (prec + rec) - - logger.info(f"Slot Accuracy: {acc}, Slot F1: {f1}, Slot Precision: {prec}, Slot Recall: {rec}") - - if request_probs is not None: - tp, fp, fn = 0.0, 0.0, 0.0 - for slot in request_probs: - p = request_probs[slot] - l = request_labels[slot] - - tp += (p.round().int() * (l == 1)).reshape(-1).float() - fp += (p.round().int() * (l == 0)).reshape(-1).float() - fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() - tp /= len(request_probs) - fp /= len(request_probs) - fn /= len(request_probs) - f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) - logger.info('Request F1 Score: %f' % f1.item()) - - for slot in request_probs: - p = request_probs[slot] - p = p.unsqueeze(-1) - p = torch.cat((1 - p, p), -1) - request_probs[slot] = p - jg = jg_ece(request_probs, request_labels, 10) - logger.info('Request Joint Goal ECE: %f' % jg) - - tp, fp, fn = 0.0, 0.0, 0.0 - for dom in active_domain_probs: - p = active_domain_probs[dom] - l = active_domain_labels[dom] - - tp += (p.round().int() * (l == 1)).reshape(-1).float() - fp += (p.round().int() * (l == 0)).reshape(-1).float() - fn += ((1 - p.round().int()) * (l == 1)).reshape(-1).float() - tp /= len(active_domain_probs) - fp /= len(active_domain_probs) - fn /= len(active_domain_probs) - f1 = tp.sum() / (tp.sum() + 0.5 * (fp.sum() + fn.sum())) - logger.info('Domain F1 Score: %f' % f1.item()) - - for dom in active_domain_probs: - p = active_domain_probs[dom] - p = p.unsqueeze(-1) - p = torch.cat((1 - p, p), -1) - active_domain_probs[dom] = p - jg = jg_ece(active_domain_probs, active_domain_labels, 10) - logger.info('Domain Joint Goal ECE: %f' % jg) - - tp = ((general_act_probs.argmax(-1) > 0) * - (general_act_labels > 0)).reshape(-1).float().sum() - fp = ((general_act_probs.argmax(-1) > 0) * - (general_act_labels == 0)).reshape(-1).float().sum() - fn = ((general_act_probs.argmax(-1) == 0) * - (general_act_labels > 0)).reshape(-1).float().sum() - f1 = tp / (tp + 0.5 * (fp + fn)) - logger.info('General Act F1 Score: %f' % f1.item()) - - err = ece(general_act_probs.reshape(-1, general_act_probs.size(-1)), - general_act_labels.reshape(-1), 10) - logger.info('General Act ECE: %f' % err) - - for slot in request_probs: - p = request_probs[slot].unsqueeze(-1) - request_probs[slot] = torch.cat((1 - p, p), -1) - - l2 = l2_acc(request_probs, request_labels, remove_belief=False) - logger.info(f'Model L2 Norm Request Accuracy: {l2}') - l2 = l2_acc(request_probs, request_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Request Accuracy: {l2}') - - for slot in active_domain_probs: - p = active_domain_probs[slot].unsqueeze(-1) - active_domain_probs[slot] = torch.cat((1 - p, p), -1) - - l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=False) - logger.info(f'Model L2 Norm Domain Accuracy: {l2}') - l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True) - logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}') - - general_act_labels = {'general': general_act_labels} - general_act_probs = {'general': general_act_probs} - - l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False) - logger.info(f'Model L2 Norm General Act Accuracy: {l2}') - l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False) - logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}') - - -if __name__ == "__main__": - main() diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index e3b4f7fff907fc9bf4880bcc16250f075bee16ca..821dca598c814240f39e359cedec7ef795a341b5 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,23 +16,33 @@ """Run SetSUMBT training/eval""" import logging +import random import os from shutil import copy2 as copy -import json import torch +from torch.nn import DataParallel from transformers import (BertModel, BertConfig, BertTokenizer, - RobertaModel, RobertaConfig, RobertaTokenizer) + RobertaModel, RobertaConfig, RobertaTokenizer, + AdamW, get_linear_schedule_with_warmup) +from tqdm import tqdm, trange +import numpy as np from tensorboardX import SummaryWriter -from convlab.dst.setsumbt.modeling import BertSetSUMBT, RobertaSetSUMBT -from convlab.dst.setsumbt.dataset import unified_format +from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT +from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT +from convlab.dst.setsumbt.multiwoz import multiwoz21 from convlab.dst.setsumbt.modeling import training -from convlab.dst.setsumbt.dataset import ontology as embeddings +from convlab.dst.setsumbt.multiwoz import ontology as embeddings from convlab.dst.setsumbt.utils import get_args, update_args +from convlab.dst.setsumbt.modeling import ensemble_utils -# Available model +# Datasets +DATASETS = { + 'multiwoz21': multiwoz21 +} + MODELS = { 'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer), 'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer) @@ -44,6 +54,12 @@ def main(args=None, config=None): if args is None: args, config = get_args(MODELS) + # Select Dataset object + if args.dataset in DATASETS: + Dataset = DATASETS[args.dataset] + else: + raise NameError('NotImplemented') + if args.model_type in MODELS: SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type] else: @@ -58,19 +74,53 @@ def main(args=None, config=None): args.output_dir = OUTPUT_DIR # Set pretrained model path to the trained checkpoint - paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] - if 'pytorch_model.bin' in paths and 'config.json' in paths: - args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) - else: - paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] + if args.do_train: + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + paths = [os.path.join(args.output_dir, p) + for p in paths if 'checkpoint-' in p] if paths: paths = paths[0] args.model_name_or_path = paths config = ConfigClass.from_pretrained(args.model_name_or_path) + else: + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + if 'pytorch_model.bin' in paths and 'config.json' in paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + else: + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + if 'pytorch_model.bin' in paths and 'config.json' in paths: + args.model_name_or_path = args.output_dir + config = ConfigClass.from_pretrained(args.model_name_or_path) + else: + paths = os.listdir(args.output_dir) if os.path.exists( + args.output_dir) else [] + paths = [os.path.join(args.output_dir, p) + for p in paths if 'checkpoint-' in p] + if paths: + paths = paths[0] + args.model_name_or_path = paths + config = ConfigClass.from_pretrained(args.model_name_or_path) args = update_args(args, config) + # Set up data directory + DATA_DIR = args.data_dir + Dataset.set_datadir(DATA_DIR) + embeddings.set_datadir(DATA_DIR) + + # If use shrinked domains, remove bus and hospital domains from the training data and model ontology + if args.shrink_active_domains and args.dataset == 'multiwoz21': + Dataset.set_active_domains( + ['attraction', 'hotel', 'restaurant', 'taxi', 'train']) + + # Download and preprocess + Dataset.create_examples( + args.max_turn_len, args.predict_actions, args.force_processing) + # Create TensorboardX writer tb_writer = SummaryWriter(logdir=args.tensorboard_path) @@ -79,12 +129,19 @@ def main(args=None, config=None): logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(message)s', '%H:%M %m-%d-%y') + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - fh = logging.FileHandler(args.logging_path) - fh.setLevel(logging.INFO) - fh.setFormatter(formatter) - logger.addHandler(fh) + if 'stream' not in args.logging_path: + fh = logging.FileHandler(args.logging_path) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) + else: + ch = logging.StreamHandler() + ch.setLevel(level=logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) # Get device if torch.cuda.is_available() and args.n_gpu > 0: @@ -97,11 +154,14 @@ def main(args=None, config=None): args.fp16 = False # Initialise Model - model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) + model = SetSumbtModel.from_pretrained( + args.model_name_or_path, config=config) model = model.to(device) # Create Tokenizer and embedding model for Data Loaders and ontology - encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) + encoder = model.roberta if args.model_type == 'roberta' else None + encoder = model.bert if args.model_type == 'bert' else encoder + tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config) # Set up model training/evaluation @@ -110,62 +170,39 @@ def main(args=None, config=None): embeddings.set_seed(args) if args.ensemble_size > 1: + ensemble_utils.set_logger(logger, tb_writer) + ensemble.set_seed(args) logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size, args.data_sampling_size)) - dataloaders = [unified_format.get_dataloader(args.dataset, - 'train', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - resampled_size=args.data_sampling_size, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - for _ in range(args.ensemble_size)] + dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset) logger.info('Dataloaders built.') - for i, loader in enumerate(dataloaders): - path = os.path.join(OUTPUT_DIR, 'ens-%i' % i) + path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i) if not os.path.exists(path): os.mkdir(path) path = os.path.join(path, 'train.dataloader') torch.save(loader, path) logger.info('Dataloaders saved.') - train_dataloader = unified_format.get_dataloader(args.dataset, - 'train', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - resampled_size=args.data_sampling_size, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - dev_dataloader = unified_format.get_dataloader(args.dataset, - 'validation', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - resampled_size=args.data_sampling_size, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - test_dataloader = unified_format.get_dataloader(args.dataset, - 'test', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - args.max_turn_len, - resampled_size=args.data_sampling_size, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - - embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, 'train', args, tokenizer, encoder) - embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder) - embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder) + train_slots = embeddings.get_slot_candidate_embeddings( + 'train', args, tokenizer, encoder) + dev_slots = embeddings.get_slot_candidate_embeddings( + 'dev', args, tokenizer, encoder) + test_slots = embeddings.get_slot_candidate_embeddings( + 'test', args, tokenizer, encoder) + + train_dataloader = Dataset.get_dataloader( + 'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len) + torch.save(dev_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + dev_dataloader = Dataset.get_dataloader( + 'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len) + torch.save(dev_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + test_dataloader = Dataset.get_dataloader( + 'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len) + torch.save(test_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) # Do not perform standard training after ensemble setup is created return 0 @@ -173,51 +210,47 @@ def main(args=None, config=None): # Perform tasks # TRAINING if args.do_train: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')): - train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - if train_dataloader.batch_size != args.train_batch_size: - train_dataloader = unified_format.change_batch_size(train_dataloader, args.train_batch_size) - else: - if args.data_sampling_size <= 0: - args.data_sampling_size = None - train_dataloader = unified_format.get_dataloader(args.dataset, - 'train', - args.train_batch_size, - tokenizer, - args.max_dialogue_len, - config.max_turn_len, - resampled_size=args.data_sampling_size, - train_ratio=args.dataset_train_ratio, - seed=args.seed) - torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')) - # Get training batch loaders and ontology embeddings if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')): - train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db')) + train_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'train.db')) else: - train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology, - 'train', args, tokenizer, encoder) + train_slots = embeddings.get_slot_candidate_embeddings( + 'train', args, tokenizer, encoder) + if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): + dev_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'dev.db')) + else: + dev_slots = embeddings.get_slot_candidate_embeddings( + 'dev', args, tokenizer, encoder) + + exists = False + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')): + train_dataloader = torch.load(os.path.join( + OUTPUT_DIR, 'dataloaders', 'train.dataloader')) + if train_dataloader.batch_size == args.train_batch_size: + exists = True + if not exists: + if args.data_sampling_size <= 0: + args.data_sampling_size = None + train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len, resampled_size=args.data_sampling_size) + torch.save(train_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'train.dataloader')) # Get development set batch loaders= and ontology embeddings if args.do_eval: + exists = False if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): - dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - if dev_dataloader.batch_size != args.dev_batch_size: - dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size) - else: - dev_dataloader = unified_format.get_dataloader(args.dataset, - 'validation', - args.dev_batch_size, - tokenizer, - args.max_dialogue_len, - config.max_turn_len) - torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - - if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): - dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db')) - else: - dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, - 'dev', args, tokenizer, encoder) + dev_dataloader = torch.load(os.path.join( + OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if dev_dataloader.batch_size == args.dev_batch_size: + exists = True + if not exists: + dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(dev_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) else: dev_dataloader = None dev_slots = None @@ -226,80 +259,94 @@ def main(args=None, config=None): training.set_ontology_embeddings(model, train_slots) # TRAINING !!!!!!!!!!!!!!!!!! - training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots) + training.train(args, model, device, train_dataloader, dev_dataloader, train_slots, dev_slots, + embeddings=embeddings, tokenizer=tokenizer) # Copy final best model to the output dir checkpoints = os.listdir(OUTPUT_DIR) checkpoints = [p for p in checkpoints if 'checkpoint' in p] checkpoints = sorted([int(p.split('-')[-1]) for p in checkpoints]) - best_checkpoint = os.path.join(OUTPUT_DIR, f'checkpoint-{checkpoints[-1]}') - copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), os.path.join(OUTPUT_DIR, 'pytorch_model.bin')) - copy(os.path.join(best_checkpoint, 'config.json'), os.path.join(OUTPUT_DIR, 'config.json')) + best_checkpoint = checkpoints[-1] + best_checkpoint = os.path.join( + OUTPUT_DIR, f'checkpoint-{best_checkpoint}') + copy(os.path.join(best_checkpoint, 'pytorch_model.bin'), + os.path.join(OUTPUT_DIR, 'pytorch_model.bin')) + copy(os.path.join(best_checkpoint, 'config.json'), + os.path.join(OUTPUT_DIR, 'config.json')) # Load best model for evaluation - model = SetSumbtModel.from_pretrained(OUTPUT_DIR) + model = SumbtModel.from_pretrained(OUTPUT_DIR) model = model.to(device) # Evaluation on the development set if args.do_eval: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): - dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - if dev_dataloader.batch_size != args.dev_batch_size: - dev_dataloader = unified_format.change_batch_size(dev_dataloader, args.dev_batch_size) - else: - dev_dataloader = unified_format.get_dataloader(args.dataset, - 'validation', - args.dev_batch_size, - tokenizer, - args.max_dialogue_len, - config.max_turn_len) - torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) - + # Get development set batch loaders= and ontology embeddings if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')): - dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db')) + dev_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'dev.db')) else: - dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, - 'dev', args, tokenizer, encoder) + dev_slots = embeddings.get_slot_candidate_embeddings( + 'dev', args, tokenizer, encoder) + + exists = False + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')): + dev_dataloader = torch.load(os.path.join( + OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) + if dev_dataloader.batch_size == args.dev_batch_size: + exists = True + if not exists: + dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(dev_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'dev.dataloader')) # Load model ontology training.set_ontology_embeddings(model, dev_slots) # EVALUATION - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss = training.evaluate(args, model, device, dev_dataloader) - training.log_info('dev', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) + jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate( + args, model, device, dev_dataloader) + if req_f1: + logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f' + % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) + else: + logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f' + % (loss, jg_acc, sl_acc)) # Evaluation on the test set if args.do_test: - if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): - test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - if test_dataloader.batch_size != args.test_batch_size: - test_dataloader = unified_format.change_batch_size(test_dataloader, args.test_batch_size) - else: - test_dataloader = unified_format.get_dataloader(args.dataset, 'test', - args.test_batch_size, tokenizer, args.max_dialogue_len, - config.max_turn_len) - torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')) - + # Get test set batch loaders= and ontology embeddings if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')): - test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db')) + test_slots = torch.load(os.path.join( + OUTPUT_DIR, 'database', 'test.db')) else: - test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, - 'test', args, tokenizer, encoder) + test_slots = embeddings.get_slot_candidate_embeddings( + 'test', args, tokenizer, encoder) + + exists = False + if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')): + test_dataloader = torch.load(os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) + if test_dataloader.batch_size == args.test_batch_size: + exists = True + if not exists: + test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len, + config.max_turn_len) + torch.save(test_dataloader, os.path.join( + OUTPUT_DIR, 'dataloaders', 'test.dataloader')) # Load model ontology training.set_ontology_embeddings(model, test_slots) # TESTING - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, output = training.evaluate(args, model, device, test_dataloader, - return_eval_output=True) - - if not os.path.exists(os.path.join(OUTPUT_DIR, 'predictions')): - os.mkdir(os.path.join(OUTPUT_DIR, 'predictions')) - writer = open(os.path.join(OUTPUT_DIR, 'predictions', 'test.json'), 'w') - json.dump(output, writer) - writer.close() - - training.log_info('test', loss, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) + jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate( + args, model, device, test_dataloader) + if req_f1: + logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f' + % (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) + else: + logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f' + % (loss, jg_acc, sl_acc)) tb_writer.close() diff --git a/convlab/dst/setsumbt/loss/__init__.py b/convlab/dst/setsumbt/loss/__init__.py deleted file mode 100644 index 475f7646126ea03b630efcbbc688f86c5a8ec16e..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/loss/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from convlab.dst.setsumbt.loss.bayesian_matching import BayesianMatchingLoss, BinaryBayesianMatchingLoss -from convlab.dst.setsumbt.loss.kl_distillation import KLDistillationLoss, BinaryKLDistillationLoss -from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss -from convlab.dst.setsumbt.loss.endd_loss import RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss diff --git a/convlab/dst/setsumbt/loss/bayesian.py b/convlab/dst/setsumbt/loss/bayesian.py new file mode 100644 index 0000000000000000000000000000000000000000..e52d8d07733383c7f95b6825b4ab5e5e1c7a0977 --- /dev/null +++ b/convlab/dst/setsumbt/loss/bayesian.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bayesian Matching Activation and Loss Functions""" + +import torch +from torch import digamma, lgamma +from torch.nn import Module + + +# Inverse Linear activation function +def invlinear(x): + z = (1.0 / (1.0 - x)) * (x < 0) + z += (1.0 + x) * (x >= 0) + return z + +# Exponential activation function +def exponential(x): + return torch.exp(x) + + +# Dirichlet activation function for the model +def dirichlet(a): + p = exponential(a) + repeat_dim = (1,)*(len(p.shape)-1) + (p.size(-1),) + p = p / p.sum(-1).unsqueeze(-1).repeat(repeat_dim) + return p + + +# Pytorch BayesianMatchingLoss nn.Module +class BayesianMatchingLoss(Module): + + def __init__(self, lamb=0.01, ignore_index=-1): + super(BayesianMatchingLoss, self).__init__() + + self.lamb = lamb + self.ignore_index = ignore_index + + def forward(self, alpha, labels, prior=None): + # Assert input sizes + assert alpha.dim() == 2 # Observations, predictive distribution + assert labels.dim() == 1 # Label for each observation + assert labels.size(0) == alpha.size(0) # Equal number of observation + + # Confirm predictive distribution dimension + if labels.max() <= alpha.size(-1): + dimension = alpha.size(-1) + else: + raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1))) + + # Remove observations with no labels + if prior is not None: + prior = prior[labels != self.ignore_index] + alpha = exponential(alpha[labels != self.ignore_index]) + labels = labels[labels != self.ignore_index] + + # Initialise and reshape prior parameters + if prior is None: + prior = torch.ones(dimension) + prior = prior.to(alpha.device) + + # KL divergence term + lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1) + e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1))) + e = ((alpha - prior) * e).sum(-1) + kl = lb + e + kl *= self.lamb + del lb, e, prior + + # Expected log likelihood + expected_likelihood = digamma(alpha[range(labels.size(0)), labels]) - digamma(alpha.sum(1)) + del alpha, labels + + # Apply ELBO loss and mean reduction + loss = (kl - expected_likelihood).mean() + del kl, expected_likelihood + + return loss + + +# Pytorch BayesianMatchingLoss nn.Module +class BinaryBayesianMatchingLoss(Module): + + def __init__(self, lamb=0.01, ignore_index=-1): + super(BinaryBayesianMatchingLoss, self).__init__() + + self.lamb = lamb + self.ignore_index = ignore_index + + def forward(self, alpha, labels, prior=None): + # Assert input sizes + assert alpha.dim() == 1 # Observations, predictive distribution + assert labels.dim() == 1 # Label for each observation + assert labels.size(0) == alpha.size(0) # Equal number of observation + + # Confirm predictive distribution dimension + if labels.max() <= 2: + dimension = 2 + else: + raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.max(), alpha.size(-1))) + + # Remove observations with no labels + if prior is not None: + prior = prior[labels != self.ignore_index] + alpha = alpha[labels != self.ignore_index] + alpha_sum = 1 + (1 / self.lamb) + alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1) + alpha = torch.cat((alpha_sum - alpha, alpha), 1) + labels = labels[labels != self.ignore_index] + + # Initialise and reshape prior parameters + if prior is None: + prior = torch.ones(dimension) + prior = prior.to(alpha.device) + + # KL divergence term + lb = lgamma(alpha.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(alpha)).sum(-1) + e = digamma(alpha) - digamma(alpha.sum(-1)).unsqueeze(-1).repeat((1, alpha.size(-1))) + e = ((alpha - prior) * e).sum(-1) + kl = lb + e + kl *= self.lamb + del lb, e, prior + + # Expected log likelihood + expected_likelihood = digamma(alpha[range(labels.size(0)), labels.long()]) - digamma(alpha.sum(1)) + del alpha, labels + + # Apply ELBO loss and mean reduction + loss = (kl - expected_likelihood).mean() + del kl, expected_likelihood + + return loss diff --git a/convlab/dst/setsumbt/loss/bayesian_matching.py b/convlab/dst/setsumbt/loss/bayesian_matching.py deleted file mode 100644 index 3e91444d60afeeb6e2ca54192dd2283810fc5135..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/loss/bayesian_matching.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Bayesian Matching Activation and Loss Functions (see https://arxiv.org/pdf/2002.07965.pdf for details)""" - -import torch -from torch import digamma, lgamma -from torch.nn import Module - - -class BayesianMatchingLoss(Module): - """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" - - def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: - """ - Args: - lamb (float): Weighting factor for the KL Divergence component - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - """ - super(BayesianMatchingLoss, self).__init__() - - self.lamb = lamb - self.ignore_index = ignore_index - - def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: - """ - Args: - inputs (Tensor): Predictive distribution - labels (Tensor): Label indices - prior (Tensor): Prior distribution over label classes - - Returns: - loss (Tensor): Loss value - """ - # Assert input sizes - assert inputs.dim() == 2 # Observations, predictive distribution - assert labels.dim() == 1 # Label for each observation - assert labels.size(0) == inputs.size(0) # Equal number of observation - - # Confirm predictive distribution dimension - if labels.max() <= inputs.size(-1): - dimension = inputs.size(-1) - else: - raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.') - - # Remove observations to be ignored in loss calculation - if prior is not None: - prior = prior[labels != self.ignore_index] - inputs = torch.exp(inputs[labels != self.ignore_index]) - labels = labels[labels != self.ignore_index] - - # Initialise and reshape prior parameters - if prior is None: - prior = torch.ones(dimension).to(inputs.device) - prior = prior.to(inputs.device) - - # KL divergence term (divergence of predictive distribution from prior over label classes - regularisation term) - log_gamma_term = lgamma(inputs.sum(-1)) - lgamma(prior.sum(-1)) + (lgamma(prior) - lgamma(inputs)).sum(-1) - div_term = digamma(inputs) - digamma(inputs.sum(-1)).unsqueeze(-1).repeat((1, inputs.size(-1))) - div_term = ((inputs - prior) * div_term).sum(-1) - kl_term = log_gamma_term + div_term - kl_term *= self.lamb - del log_gamma_term, div_term, prior - - # Expected log likelihood - expected_likelihood = digamma(inputs[range(labels.size(0)), labels]) - digamma(inputs.sum(-1)) - del inputs, labels - - # Apply ELBO loss and mean reduction - loss = (kl_term - expected_likelihood).mean() - del kl_term, expected_likelihood - - return loss - - -class BinaryBayesianMatchingLoss(BayesianMatchingLoss): - """Bayesian matching loss (https://arxiv.org/pdf/2002.07965.pdf) implementation""" - - def __init__(self, lamb: float = 0.001, ignore_index: int = -1) -> Module: - """ - Args: - lamb (float): Weighting factor for the KL Divergence component - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - """ - super(BinaryBayesianMatchingLoss, self).__init__(lamb, ignore_index) - - def forward(self, inputs: torch.Tensor, labels: torch.Tensor, prior: torch.Tensor = None) -> torch.Tensor: - """ - Args: - inputs (Tensor): Predictive distribution - labels (Tensor): Label indices - prior (Tensor): Prior distribution over label classes - - Returns: - loss (Tensor): Loss value - """ - - # Create 2D input dirichlet distribution - input_sum = 1 + (1 / self.lamb) - inputs = (torch.sigmoid(inputs) * input_sum).reshape(-1, 1) - inputs = torch.cat((input_sum - inputs, inputs), 1) - - return super().forward(inputs, labels, prior=prior) diff --git a/convlab/dst/setsumbt/loss/distillation.py b/convlab/dst/setsumbt/loss/distillation.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf13f10635376467f3b137adaa83367f26603ef --- /dev/null +++ b/convlab/dst/setsumbt/loss/distillation.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bayesian Matching Activation and Loss Functions""" + +import torch +from torch import lgamma, log +from torch.nn import Module +from torch.nn.functional import kl_div + +from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss + + +# Pytorch BayesianMatchingLoss nn.Module +class DistillationKL(Module): + + def __init__(self, lamb=1e-4, ignore_index=-1): + super(DistillationKL, self).__init__() + + self.lamb = lamb + self.ignore_index = ignore_index + + def forward(self, alpha, labels, temp=1.0): + # Assert input sizes + assert alpha.dim() == 2 # Observations, predictive distribution + assert labels.dim() == 2 # Label for each observation + assert labels.size(0) == alpha.size(0) # Equal number of observation + + # Confirm predictive distribution dimension + if labels.size(-1) == alpha.size(-1): + dimension = alpha.size(-1) + else: + raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) + + alpha = torch.log(torch.softmax(alpha / temp, -1)) + ids = torch.where(labels[:, 0] != self.ignore_index)[0] + alpha = alpha[ids] + labels = labels[ids] + + labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))) + + kl = kl_div(alpha, labels, reduction='none').sum(-1).mean() + return kl + + +# Pytorch BayesianMatchingLoss nn.Module +class BinaryDistillationKL(Module): + + def __init__(self, lamb=1e-4, ignore_index=-1): + super(BinaryDistillationKL, self).__init__() + + self.lamb = lamb + self.ignore_index = ignore_index + + def forward(self, alpha, labels, temp=0.0): + # Assert input sizes + assert alpha.dim() == 1 # Observations, predictive distribution + assert labels.dim() == 1 # Label for each observation + assert labels.size(0) == alpha.size(0) # Equal number of observation + + # Confirm predictive distribution dimension + # if labels.size(-1) == alpha.size(-1): + # dimension = alpha.size(-1) + # else: + # raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) + + alpha = torch.sigmoid(alpha / temp).unsqueeze(-1) + ids = torch.where(labels != self.ignore_index)[0] + alpha = alpha[ids] + labels = labels[ids] + + alpha = torch.log(torch.cat((1 - alpha, alpha), 1)) + + labels = labels.unsqueeze(-1) + labels = torch.cat((1 - labels, labels), -1) + labels = ((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1))) + + kl = kl_div(alpha, labels, reduction='none').sum(-1).mean() + return kl + + +# def smart_sort(x, permutation): +# assert x.dim() == permutation.dim() +# if x.dim() == 3: +# d1, d2, d3 = x.size() +# ret = x[torch.arange(d1).unsqueeze(-1).unsqueeze(-1).repeat((1, d2, d3)).flatten(), +# torch.arange(d2).unsqueeze(0).unsqueeze(-1).repeat((d1, 1, d3)).flatten(), +# permutation.flatten()].view(d1, d2, d3) +# return ret +# elif x.dim() == 2: +# d1, d2 = x.size() +# ret = x[torch.arange(d1).unsqueeze(-1).repeat((1, d2)).flatten(), +# permutation.flatten()].view(d1, d2) +# return ret + + +# # Pytorch BayesianMatchingLoss nn.Module +# class DistillationNLL(Module): + +# def __init__(self, lamb=1e-4, ignore_index=-1): +# super(DistillationNLL, self).__init__() + +# self.lamb = lamb +# self.ignore_index = ignore_index +# self.loss_add = BayesianMatchingLoss(lamb=0.001) + +# def forward(self, alpha, labels, temp=1.0): +# # Assert input sizes +# assert alpha.dim() == 2 # Observations, predictive distribution +# assert labels.dim() == 3 # Label for each observation +# assert labels.size(0) == alpha.size(0) # Equal number of observation + +# # Confirm predictive distribution dimension +# if labels.size(-1) == alpha.size(-1): +# dimension = alpha.size(-1) +# else: +# raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) + +# alpha = torch.exp(alpha / temp) +# ids = torch.where(labels[:, 0, 0] != self.ignore_index)[0] +# alpha = alpha[ids] +# labels = labels[ids] + +# best_labels = labels.mean(-2).argmax(-1) +# loss2 = self.loss_add(alpha, best_labels) + +# topn = labels.mean(-2).argsort(-1, descending=True) +# n = 10 +# alpha = smart_sort(alpha, topn)[:, :n] +# labels = smart_sort(labels, topn.unsqueeze(-2).repeat((1, labels.size(-2), 1))) +# labels = labels[:, :, :n] +# labels = labels / labels.sum(-1).unsqueeze(-1).repeat((1, 1, labels.size(-1))) + +# labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))) + +# loss = (alpha - 1) * labels.mean(-2) +# # loss = (alpha - 1) * labels +# loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) +# loss = -1.0 * loss.mean() +# # loss = -1.0 * loss.mean() / alpha.size(-1) + +# return loss + + +# # Pytorch BayesianMatchingLoss nn.Module +# class BinaryDistillationNLL(Module): + +# def __init__(self, lamb=1e-4, ignore_index=-1): +# super(BinaryDistillationNLL, self).__init__() + +# self.lamb = lamb +# self.ignore_index = ignore_index + +# def forward(self, alpha, labels, temp=0.0): +# # Assert input sizes +# assert alpha.dim() == 1 # Observations, predictive distribution +# assert labels.dim() == 2 # Label for each observation +# assert labels.size(0) == alpha.size(0) # Equal number of observation + +# # Confirm predictive distribution dimension +# # if labels.size(-1) == alpha.size(-1): +# # dimension = alpha.size(-1) +# # else: +# # raise NameError('Label dimension %i is larger than prediction dimension %i.' % (labels.size(-1), alpha.size(-1))) + +# # Remove observations with no labels +# ids = torch.where(labels[:, 0] != self.ignore_index)[0] +# # alpha_sum = 1 + (1 / self.lamb) +# alpha_sum = 10.0 +# alpha = (torch.sigmoid(alpha) * alpha_sum).reshape(-1, 1) +# alpha = alpha[ids] +# labels = labels[ids] + +# if temp != 1.0: +# alpha = torch.log(alpha + 1e-4) +# alpha = torch.exp(alpha / temp) + +# alpha = torch.cat((alpha_sum - alpha, alpha), 1) + +# labels = labels.unsqueeze(-1) +# labels = torch.cat((1 - labels, labels), -1) +# # labels[labels[:, 0, 0] == self.ignore_index] = 1 +# labels = log(((1 - self.lamb) * labels) + (self.lamb * (1 / labels.size(-1)))) + +# loss = (alpha - 1) * labels.mean(-2) +# loss = lgamma(alpha.sum(-1)) - lgamma(alpha).sum(-1) + loss.sum(-1) +# loss = -1.0 * loss.mean() + +# return loss diff --git a/convlab/dst/setsumbt/loss/uncertainty_measures.py b/convlab/dst/setsumbt/loss/ece.py similarity index 50% rename from convlab/dst/setsumbt/loss/uncertainty_measures.py rename to convlab/dst/setsumbt/loss/ece.py index 87c89dd31c724cc7d599230c6d4a15faee9b680e..034b9aa0bf5882aea49b08a64d7f93164208b5d9 100644 --- a/convlab/dst/setsumbt/loss/uncertainty_measures.py +++ b/convlab/dst/setsumbt/loss/ece.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,24 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Uncertainty evaluation metrics for dialogue belief tracking""" +"""Expected calibration error""" import torch -def fill_bins(n_bins: int, probs: torch.Tensor) -> list: - """ - Function to split observations into bins based on predictive probabilities - - Args: - n_bins (int): Number of bins - probs (Tensor): Predictive probabilities for the observations - - Returns: - bins (list): List of observation ids for each bin - """ - assert probs.dim() == 2 - probs = probs.max(-1)[0] +def fill_bins(n_bins, logits): + assert logits.dim() == 2 + logits = logits.max(-1)[0] step = 1.0 / n_bins bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) @@ -38,49 +28,29 @@ def fill_bins(n_bins: int, probs: torch.Tensor) -> list: for b in range(n_bins): lower, upper = bin_ranges[b], bin_ranges[b + 1] if b == 0: - ids = torch.where((probs >= lower) * (probs <= upper))[0] + ids = torch.where((logits >= lower) * (logits <= upper))[0] else: - ids = torch.where((probs > lower) * (probs <= upper))[0] + ids = torch.where((logits > lower) * (logits <= upper))[0] bins.append(ids) return bins -def bin_confidence(bins: list, probs: torch.Tensor) -> torch.Tensor: - """ - Compute the confidence score within each bin - - Args: - bins (list): List of observation ids for each bin - probs (Tensor): Predictive probabilities for the observations - - Returns: - scores (Tensor): Average confidence score within each bin - """ - probs = probs.max(-1)[0] +def bin_confidence(bins, logits): + logits = logits.max(-1)[0] scores = [] for b in bins: if b is not None: - scores.append(probs[b].mean()) + l = logits[b] + scores.append(l.mean()) else: scores.append(-1) scores = torch.tensor(scores) return scores -def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - """ - Compute the accuracy score for observations in each bin - - Args: - bins (list): List of observation ids for each bin - probs (Tensor): Predictive probabilities for the observations - y_true (Tensor): Labels for the observations - - Returns: - acc (Tensor): Accuracies for the observations in each bin - """ - y_pred = probs.argmax(-1) +def bin_accuracy(bins, logits, y_true): + y_pred = logits.argmax(-1) acc = [] for b in bins: @@ -98,24 +68,13 @@ def bin_accuracy(bins: list, probs: torch.Tensor, y_true: torch.Tensor) -> torch return acc -def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float: - """ - Expected calibration error calculation +def ece(logits, y_true, n_bins): + bins = fill_bins(n_bins, logits) - Args: - probs (Tensor): Predictive probabilities for the observations - y_true (Tensor): Labels for the observations - n_bins (int): Number of bins + scores = bin_confidence(bins, logits) + acc = bin_accuracy(bins, logits, y_true) - Returns: - ece (float): Expected calibration error - """ - bins = fill_bins(n_bins, probs) - - scores = bin_confidence(bins, probs) - acc = bin_accuracy(bins, probs, y_true) - - n = probs.size(0) + n = logits.size(0) bk = torch.tensor([b.size(0) for b in bins]) ece = torch.abs(scores - acc) * bk / n @@ -125,30 +84,34 @@ def ece(probs: torch.Tensor, y_true: torch.Tensor, n_bins: int) -> float: return ece -def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float: - """ - Joint goal expected calibration error calculation - - Args: - belief_state (dict): Belief state probabilities for the dialogue turns - y_true (dict): Labels for the state in dialogue turns - n_bins (int): Number of bins - - Returns: - ece (float): Joint goal expected calibration error - """ - y_pred = {slot: bs.reshape(-1, bs.size(-1)).argmax(-1) for slot, bs in belief_state.items()} +def jg_ece(logits, y_true, n_bins): + y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits} goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred} goal_acc = sum([goal_acc[slot] for slot in goal_acc]) goal_acc = (goal_acc == len(y_true)).int() - # Confidence score is minimum across slots as a single bad predictions leads to incorrect prediction in state - scores = [bs.reshape(-1, bs.size(-1)).max(-1)[0].unsqueeze(0) for slot, bs in belief_state.items()] + scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits] scores = torch.cat(scores, 0).min(0)[0] - bins = fill_bins(n_bins, scores.unsqueeze(-1)) + step = 1.0 / n_bins + bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step) + bins = [] + for b in range(n_bins): + lower, upper = bin_ranges[b], bin_ranges[b + 1] + if b == 0: + ids = torch.where((scores >= lower) * (scores <= upper))[0] + else: + ids = torch.where((scores > lower) * (scores <= upper))[0] + bins.append(ids) - conf = bin_confidence(bins, scores.unsqueeze(-1)) + conf = [] + for b in bins: + if b is not None: + l = scores[b] + conf.append(l.mean()) + else: + conf.append(-1) + conf = torch.tensor(conf) slot = [s for s in y_true][0] acc = [] @@ -164,7 +127,7 @@ def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float: acc.append(-1) acc = torch.tensor(acc) - n = belief_state[slot].reshape(-1, belief_state[slot].size(-1)).size(0) + n = logits[slot].reshape(-1, logits[slot].size(-1)).size(0) bk = torch.tensor([b.size(0) for b in bins]) ece = torch.abs(conf - acc) * bk / n @@ -174,22 +137,12 @@ def jg_ece(belief_state: dict, y_true: dict, n_bins: int) -> float: return ece -def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> float: - """ - Compute L2 Error of belief state prediction - - Args: - belief_state (dict): Belief state probabilities for the dialogue turns - labels (dict): Labels for the state in dialogue turns - remove_belief (bool): Convert belief state to dialogue state - - Returns: - err (float): L2 Error of belief state prediction - """ +def l2_acc(belief_state, labels, remove_belief=False): # Get ids used for removing padding turns. padding = labels[list(labels.keys())[0]].reshape(-1) padding = torch.where(padding != -1)[0] + # l2 = [] state = [] labs = [] for slot, bs in belief_state.items(): @@ -210,8 +163,13 @@ def l2_acc(belief_state: dict, labels: dict, remove_belief: bool = False) -> flo y = torch.zeros(bs.shape).cuda() y[range(y.size(0)), lab] = 1.0 + # err = torch.sqrt(((y - bs) ** 2).sum(-1)) + # l2.append(err.unsqueeze(-1)) + state.append(bs) labs.append(y) + + # err = torch.cat(l2, -1).max(-1)[0] # Concatenate all slots into a single belief state state = torch.cat(state, -1) diff --git a/convlab/dst/setsumbt/loss/endd_loss.py b/convlab/dst/setsumbt/loss/endd_loss.py index c3b2497373fda86a9e4e2b847daf2081555a0952..d84c3f720d4970c520d97aa9e65a12647468d3ae 100644 --- a/convlab/dst/setsumbt/loss/endd_loss.py +++ b/convlab/dst/setsumbt/loss/endd_loss.py @@ -1,46 +1,30 @@ import torch -from torch.nn import Module -from torch.nn.functional import kl_div EPS = torch.finfo(torch.float32).eps - @torch.no_grad() -def compute_mkl(ensemble_mean_probs: torch.Tensor, ensemble_logprobs: torch.Tensor) -> torch.Tensor: - """ - Computing MKL in ensemble. - - Args: - ensemble_mean_probs (Tensor): Marginal predictive distribution of the ensemble - ensemble_logprobs (Tensor): Log predictive distributions of individual ensemble members - - Returns: - mkl (Tensor): MKL - """ - mkl = kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_logprobs),reduction='none') - return mkl.sum(-1).mean(1) - +def compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs): + mkl = torch.nn.functional.kl_div(ensemble_logprobs, ensemble_mean_probs.unsqueeze(1).expand_as(ensemble_probs), + reduction='none').sum(-1).mean(1) + return mkl @torch.no_grad() -def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict: - """ - Compute a range of ensemble uncertainty measures - - Args: - ensemble_probs (Tensor): Predictive distributions of the ensemble members - - Returns: - stats (dict): Dictionary of ensemble uncertainty measures - """ +def compute_ensemble_stats(ensemble_logits): + # ensemble_probs = torch.softmax(ensemble_logits, dim=-1) + # ensemble_mean_probs = ensemble_probs.mean(dim=1) + # ensemble_logprobs = torch.log_softmax(ensemble_logits, dim=-1) + ensemble_probs = ensemble_logits ensemble_mean_probs = ensemble_probs.mean(dim=1) - num_classes = ensemble_probs.size(-1) - ensemble_logprobs = torch.log(ensemble_probs + (1e-4 / num_classes)) + num_classes = ensemble_logits.size(-1) + ensemble_logprobs = torch.log(ensemble_logits + (1e-4 / num_classes)) entropy_of_expected = torch.distributions.Categorical(probs=ensemble_mean_probs).entropy() expected_entropy = torch.distributions.Categorical(probs=ensemble_probs).entropy().mean(dim=1) mutual_info = entropy_of_expected - expected_entropy - mkl = compute_mkl(ensemble_mean_probs, ensemble_logprobs) + mkl = compute_mkl(ensemble_probs, ensemble_mean_probs, ensemble_logprobs) + + # num_classes = ensemble_logits.size(-1) ensemble_precision = (num_classes - 1) / (2 * mkl.unsqueeze(1) + EPS) @@ -55,226 +39,108 @@ def compute_ensemble_stats(ensemble_probs: torch.Tensor) -> dict: } return stats - -def entropy(probs: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Compute entropy in a predictive distribution - - Args: - probs (Tensor): Predictive distributions - dim (int): Dimension representing the predictive probabilities for a single prediction - - Returns: - entropy (Tensor): Entropy - """ +def entropy(probs, dim: int = -1): return -(probs * (probs + EPS).log()).sum(dim=dim) -def compute_dirichlet_uncertainties(dirichlet_params: torch.Tensor, - precisions: torch.Tensor, - expected_dirichlet: torch.Tensor) -> tuple: +def compute_dirichlet_uncertainties(dirichlet_params, precisions, expected_dirichlet): """ Function which computes measures of uncertainty for Dirichlet model. - - Args: - dirichlet_params (Tensor): Dirichlet concentration parameters. - precisions (Tensor): Dirichlet Precisions - expected_dirichlet (Tensor): Probabities of expected categorical under Dirichlet. - - Returns: - stats (tuple): Token level uncertainties + :param dirichlet_params: Tensor of size [batch_size, n_classes] of Dirichlet concentration parameters. + :param precisions: Tensor of size [batch_size, 1] of Dirichlet Precisions + :param expected_dirichlet: Tensor of size [batch_size, n_classes] of probablities of expected categorical under Dirichlet. + :return: Tensors of token level uncertainties of size [batch_size] """ batch_size, n_classes = dirichlet_params.size() entropy_of_expected = entropy(expected_dirichlet) - expected_entropy = -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1)) - expected_entropy = expected_entropy.sum(dim=-1) + expected_entropy = ( + -expected_dirichlet * (torch.digamma(dirichlet_params + 1) - torch.digamma(precisions + 1))).sum(dim=-1) - mutual_information = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) - mutual_information += torch.digamma(precisions + 1 + EPS) - mutual_information *= -(expected_dirichlet + EPS) - mutual_information = mutual_information.sum(dim=-1) + mutual_information = -((expected_dirichlet + EPS) * ( + torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + 1 + EPS) + torch.digamma( + precisions + 1 + EPS))).sum(dim=-1) + # assert torch.allclose(mutual_information, entropy_of_expected - expected_entropy, atol=1e-4, rtol=0) epkl = (n_classes - 1) / precisions.squeeze(-1) - mkl = torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) - mkl += torch.digamma(precisions + EPS) - mkl *= expected_dirichlet - mkl = mkl.sum(dim=-1) - - stats = (entropy_of_expected.clamp(min=0), expected_entropy.clamp(min=0), mutual_information.clamp(min=0)) - stats += (epkl.clamp(min=0), mkl.clamp(min=0)) - - return stats - - -def get_dirichlet_parameters(logits: torch.Tensor, - parametrization, - add_to_alphas: float = 0, - dtype=torch.double) -> tuple: - """ - Get dirichlet parameters from model logits + mkl = (expected_dirichlet * ( + torch.log(expected_dirichlet + EPS) - torch.digamma(dirichlet_params + EPS) + torch.digamma( + precisions + EPS))).sum(dim=-1) - Args: - logits (Tensor): Model logits - parametrization (function): Mapping from logits to concentration parameters - add_to_alphas (float): Addition constant for stability - dtype (data type): Data type of the parameters + return entropy_of_expected.clamp(min=0), \ + expected_entropy.clamp(min=0), \ + mutual_information.clamp(min=0), \ + epkl.clamp(min=0), \ + mkl.clamp(min=0) - Return: - params (tuple): Concentration and precision parameters of the model Dirichlet - """ +def get_dirichlet_parameters(logits, parametrization, add_to_alphas=0, dtype=torch.double): max_val = torch.finfo(dtype).max / logits.size(-1) - 1 alphas = torch.clip(parametrization(logits.to(dtype=dtype)) + add_to_alphas, max=max_val) precision = torch.sum(alphas, dim=-1, dtype=dtype) return alphas, precision -def logits_to_mutual_info(logits: torch.Tensor) -> torch.Tensor: - """ - Map modfel logits to mutual information of model Dirichlet - - Args: - logits (Tensor): Model logits - - Returns: - mutual_information (Tensor): Mutual information of the model Dirichlet - """ +def logits_to_mutual_info(logits): alphas, precision = get_dirichlet_parameters(logits, torch.exp, 1.0) - normalized_probs = alphas / precision.unsqueeze(1) + unsqueezed_precision = precision.unsqueeze(1) + normalized_probs = alphas / unsqueezed_precision - _, _, mutual_information, _, _ = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs) + entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas, + unsqueezed_precision, + normalized_probs) + + # Max entropy is log(K) for K classes. Hence relative MI is calculated as MI/log(K) + # mutual_information /= torch.log(torch.tensor(logits.size(-1))) return mutual_information -class RKLDirichletMediatorLoss(Module): - """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)""" - - def __init__(self, - model_offset: float = 1.0, - target_offset: float = 1, - ignore_index: int = -1, - parameterization=torch.exp): - """ - Args: - model_offset (float): Offset of model Dirichlet for stability - target_offset (float): Offset of target Dirichlet for stability - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - parameterization (function): Mapping from logits to concentration parameters - """ - super(RKLDirichletMediatorLoss, self).__init__() - - self.model_offset = model_offset - self.target_offset = target_offset - self.ignore_index = ignore_index - self.parameterization = parameterization - - def logits_to_mutual_info(self, logits: torch.Tensor) -> torch.Tensor: - """ - Map modfel logits to mutual information of model Dirichlet - - Args: - logits (Tensor): Model logits - - Returns: - mutual_information (Tensor): Mutual information of the model Dirichlet - """ - return logits_to_mutual_info(logits) - - def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - """ - Args: - logits (Tensor): Model logits - targets (Tensor): Ensemble predictive distributions - - Returns: - loss (Tensor): RKL dirichlet mediator loss value - """ - - # Remove padding - turns = torch.where(targets[:, 0, 0] != self.ignore_index)[0] - logits = logits[turns] - targets = targets[turns] - - ensemble_stats = compute_ensemble_stats(targets) - - alphas, precision = get_dirichlet_parameters(logits, self.parametrization, self.model_offset) - - normalized_probs = alphas / precision.unsqueeze(1) - - stats = compute_dirichlet_uncertainties(alphas, precision.unsqueeze(1), normalized_probs) - entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = stats - - stats = { - 'alpha_min': alphas.min(), - 'alpha_mean': alphas.mean(), - 'precision': precision, - 'entropy_of_expected': entropy_of_expected, - 'mutual_info': mutual_information, - 'mkl': mkl, - } - - num_classes = alphas.size(-1) - - ensemble_precision = ensemble_stats['precision'] - - ensemble_precision += self.target_offset * num_classes - ensemble_probs = ensemble_stats['mean_probs'] - - expected_kl_term = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS) - expected_kl_term = -1.0 * torch.sum(ensemble_probs * expected_kl_term, dim=-1) - assert torch.isfinite(expected_kl_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype) - - differential_negentropy_term_ = torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS) - differential_negentropy_term_ *= alphas - 1.0 - differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) - differential_negentropy_term -= torch.sum(differential_negentropy_term_, dim=-1) - assert torch.isfinite(differential_negentropy_term).all() - - loss = expected_kl_term - differential_negentropy_term / ensemble_precision.squeeze(-1) - assert torch.isfinite(loss).all() - - return torch.mean(loss), stats, ensemble_stats - - -class BinaryRKLDirichletMediatorLoss(RKLDirichletMediatorLoss): - """Reverse KL Dirichlet Mediator Loss (https://arxiv.org/abs/2105.06987)""" - - def __init__(self, - model_offset: float = 1.0, - target_offset: float = 1, - ignore_index: int = -1, - parameterization=torch.exp): - """ - Args: - model_offset (float): Offset of model Dirichlet for stability - target_offset (float): Offset of target Dirichlet for stability - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - parameterization (function): Mapping from logits to concentration parameters - """ - super(BinaryRKLDirichletMediatorLoss, self).__init__(model_offset, target_offset, - ignore_index, parameterization) - - def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - """ - Args: - logits (Tensor): Model logits - targets (Tensor): Ensemble predictive distributions - - Returns: - loss (Tensor): RKL dirichlet mediator loss value - """ - # Convert single target probability p to distribution [1-p, p] - targets = targets.reshape(-1, targets.size(-1), 1) - targets = torch.cat([1 - targets, targets], -1) - targets[targets[:, 1] == self.ignore_index] = self.ignore_index - - # Convert input logits into predictive distribution [1-z, z] - logits = torch.sigmoid(logits).unsqueeze(1) - logits = torch.cat((1 - logits, logits), 1) - logits = -1.0 * torch.log((1 / (logits + 1e-8)) - 1) # Inverse sigmoid - - return super().forward(logits, targets) +def rkl_dirichlet_mediator_loss(logits, ensemble_stats, model_offset, target_offset, parametrization=torch.exp): + turns = torch.where(ensemble_stats[:, 0, 0] != -1)[0] + logits = logits[turns] + ensemble_stats = ensemble_stats[turns] + + ensemble_stats = compute_ensemble_stats(ensemble_stats) + + alphas, precision = get_dirichlet_parameters(logits, parametrization, model_offset) + + unsqueezed_precision = precision.unsqueeze(1) + normalized_probs = alphas / unsqueezed_precision + + entropy_of_expected, expected_entropy, mutual_information, epkl, mkl = compute_dirichlet_uncertainties(alphas, + unsqueezed_precision, + normalized_probs) + + stats = { + 'alpha_min': alphas.min(), + 'alpha_mean': alphas.mean(), + 'precision': precision, + 'entropy_of_expected': entropy_of_expected, + 'mutual_info': mutual_information, + 'mkl': mkl, + } + + num_classes = alphas.size(-1) + + ensemble_precision = ensemble_stats['precision'] + + ensemble_precision += target_offset * num_classes + ensemble_probs = ensemble_stats['mean_probs'] + + expected_KL_term = -1.0 * torch.sum(ensemble_probs * (torch.digamma(alphas + EPS) + - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1) + assert torch.isfinite(expected_KL_term).all(), (torch.max(alphas), torch.max(precision), alphas.dtype) + + differential_negentropy_term = torch.sum(torch.lgamma(alphas + EPS), dim=-1) - torch.lgamma(precision + EPS) \ + - torch.sum( + (alphas - 1) * (torch.digamma(alphas + EPS) - torch.digamma(precision.unsqueeze(-1) + EPS)), dim=-1) + assert torch.isfinite(differential_negentropy_term).all() + + cost = expected_KL_term - differential_negentropy_term / ensemble_precision.squeeze(-1) + + assert torch.isfinite(cost).all() + return torch.mean(cost), stats, ensemble_stats + diff --git a/convlab/dst/setsumbt/loss/kl_distillation.py b/convlab/dst/setsumbt/loss/kl_distillation.py deleted file mode 100644 index 9aee234ab68054f2b4a83d6feb5e453384d89e94..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/loss/kl_distillation.py +++ /dev/null @@ -1,104 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""KL Divergence Ensemble Distillation loss""" - -import torch -from torch.nn import Module -from torch.nn.functional import kl_div - - -class KLDistillationLoss(Module): - """Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" - - def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: - """ - Args: - lamb (float): Target smoothing parameter - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - """ - super(KLDistillationLoss, self).__init__() - - self.lamb = lamb - self.ignore_index = ignore_index - - def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: - """ - Args: - inputs (Tensor): Predictive distribution - targets (Tensor): Target distribution (ensemble marginal) - temp (float): Temperature scaling coefficient for predictive distribution - - Returns: - loss (Tensor): Loss value - """ - # Assert input sizes - assert inputs.dim() == 2 # Observations, predictive distribution - assert targets.dim() == 2 # Label for each observation - assert targets.size(0) == inputs.size(0) # Equal number of observation - - # Confirm predictive distribution dimension - if targets.size(-1) != inputs.size(-1): - name_error = f'Target dimension {targets.size(-1)} is not the same as the prediction dimension ' - name_error += f'{inputs.size(-1)}.' - raise NameError(name_error) - - # Remove observations to be ignored in loss calculation - inputs = torch.log(torch.softmax(inputs / temp, -1)) - ids = torch.where(targets[:, 0] != self.ignore_index)[0] - inputs = inputs[ids] - targets = targets[ids] - - # Target smoothing - targets = ((1 - self.lamb) * targets) + (self.lamb / targets.size(-1)) - - return kl_div(inputs, targets, reduction='none').sum(-1).mean() - - -# Pytorch BayesianMatchingLoss nn.Module -class BinaryKLDistillationLoss(KLDistillationLoss): - """Binary Ensemble Distillation loss using KL Divergence (https://arxiv.org/pdf/1503.02531.pdf) implementation""" - - def __init__(self, lamb: float = 1e-4, ignore_index: int = -1) -> Module: - """ - Args: - lamb (float): Target smoothing parameter - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - """ - super(BinaryKLDistillationLoss, self).__init__(lamb, ignore_index) - - def forward(self, inputs: torch.Tensor, targets: torch.Tensor, temp: float = 1.0) -> torch.Tensor: - """ - Args: - inputs (Tensor): Predictive distribution - targets (Tensor): Target distribution (ensemble marginal) - temp (float): Temperature scaling coefficient for predictive distribution - - Returns: - loss (Tensor): Loss value - """ - # Assert input sizes - assert inputs.dim() == 1 # Observations, predictive distribution - assert targets.dim() == 1 # Label for each observation - assert targets.size(0) == inputs.size(0) # Equal number of observation - - # Convert input and target to 2D binary distribution for KL divergence computation - inputs = torch.sigmoid(inputs / temp).unsqueeze(-1) - inputs = torch.log(torch.cat((1 - inputs, inputs), 1)) - - targets = targets.unsqueeze(-1) - targets = torch.cat((1 - targets, targets), -1) - - return super().forward(input, targets, temp) diff --git a/convlab/dst/setsumbt/loss/labelsmoothing.py b/convlab/dst/setsumbt/loss/labelsmoothing.py index 70cef73170c0f539c3eeab888c50ae571b36c882..8fcc60afd50603cb5c2c84fd698fd11ce7fb7415 100644 --- a/convlab/dst/setsumbt/loss/labelsmoothing.py +++ b/convlab/dst/setsumbt/loss/labelsmoothing.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Label smoothing loss function""" +"""Inhibited Softmax Activation and Loss Functions""" import torch @@ -23,87 +23,66 @@ from torch.nn.functional import kl_div class LabelSmoothingLoss(Module): """ - Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w) - and p_{prob. computed by model}(w). + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. """ - - def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module: - """ - Args: - label_smoothing (float): Label smoothing constant - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - """ + def __init__(self, label_smoothing=0.05, ignore_index=-1): super(LabelSmoothingLoss, self).__init__() assert 0.0 < label_smoothing <= 1.0 self.ignore_index = ignore_index self.label_smoothing = float(label_smoothing) - def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def forward(self, logits, targets): """ - Args: - input (Tensor): Predictive distribution - labels (Tensor): Label indices - - Returns: - loss (Tensor): Loss value + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size """ - # Assert input sizes - assert inputs.dim() == 2 - assert labels.dim() == 1 - assert self.label_smoothing <= ((inputs.size(-1) - 1) / inputs.size(-1)) - - # Confirm predictive distribution dimension - if labels.max() <= inputs.size(-1): - dimension = inputs.size(-1) - else: - raise NameError(f'Label dimension {labels.max()} is larger than prediction dimension {inputs.size(-1)}.') + assert logits.dim() == 2 + assert targets.dim() == 1 + assert self.label_smoothing <= ((logits.size(-1) - 1) / logits.size(-1)) - # Remove observations to be ignored in loss calculation - inputs = inputs[labels != self.ignore_index] - labels = labels[labels != self.ignore_index] + logits = logits[targets != self.ignore_index] + targets = targets[targets != self.ignore_index] - # Create target distribution - inputs = torch.log(torch.softmax(inputs, -1)) - targets = torch.ones(inputs.size()).float().to(inputs.device) - targets *= self.label_smoothing / (dimension - 1) - targets[range(labels.size(0)), labels] = 1.0 - self.label_smoothing + logits = torch.log(torch.softmax(logits, -1)) + labels = torch.ones(logits.size()).float().to(logits.device) + labels *= self.label_smoothing / (logits.size(-1) - 1) + labels[range(labels.size(0)), targets] = 1.0 - self.label_smoothing - return kl_div(inputs, targets, reduction='none').sum(-1).mean() + kl = kl_div(logits, labels, reduction='none').sum(-1).mean() + del logits, targets, labels + return kl -class BinaryLabelSmoothingLoss(LabelSmoothingLoss): +class BinaryLabelSmoothingLoss(Module): """ - Label smoothing loss minimises the KL-divergence between q_{smoothed ground truth prob}(w) - and p_{prob. computed by model}(w). + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. """ + def __init__(self, label_smoothing=0.05): + super(BinaryLabelSmoothingLoss, self).__init__() - def __init__(self, label_smoothing: float = 0.05, ignore_index: int = -1) -> Module: - """ - Args: - label_smoothing (float): Label smoothing constant - ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. - """ - super(BinaryLabelSmoothingLoss, self).__init__(label_smoothing, ignore_index) + assert 0.0 < label_smoothing <= 1.0 + self.label_smoothing = float(label_smoothing) - def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def forward(self, logits, targets): """ - Args: - input (Tensor): Predictive distribution - labels (Tensor): Label indices - - Returns: - loss (Tensor): Loss value + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size """ - # Assert input sizes - assert inputs.dim() == 1 - assert labels.dim() == 1 + assert logits.dim() == 1 + assert targets.dim() == 1 assert self.label_smoothing <= 0.5 - inputs = torch.sigmoid(inputs).reshape(-1, 1) - inputs = torch.log(torch.cat((1 - inputs, inputs), 1)) - targets = torch.ones(inputs.size()).float().to(inputs.device) - targets *= self.label_smoothing - targets[range(labels.size(0)), labels.long()] = 1.0 - self.label_smoothing + logits = torch.sigmoid(logits).reshape(-1, 1) + logits = torch.log(torch.cat((1 - logits, logits), 1)) + labels = torch.ones(logits.size()).float().to(logits.device) + labels *= self.label_smoothing + labels[range(labels.size(0)), targets.long()] = 1.0 - self.label_smoothing - return kl_div(inputs, targets, reduction='none').sum(-1).mean() + kl = kl_div(logits, labels, reduction='none').sum(-1).mean() + del logits, targets + return kl diff --git a/convlab/dst/setsumbt/modeling/__init__.py b/convlab/dst/setsumbt/modeling/__init__.py index 59f1439948421ac365e4602b7800c94d3b8b32dd..011a1a774e2d1a22e46e242d1812549895f2246b 100644 --- a/convlab/dst/setsumbt/modeling/__init__.py +++ b/convlab/dst/setsumbt/modeling/__init__.py @@ -1,5 +1,3 @@ from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT -from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT - -from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler +from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT diff --git a/convlab/dst/setsumbt/modeling/bert_nbt.py b/convlab/dst/setsumbt/modeling/bert_nbt.py index 6762fb3891b4720c3889d8c0809b8791f3bf7633..8b402b6be09684b27bb73acf17e578bc0e3b4bbd 100644 --- a/convlab/dst/setsumbt/modeling/bert_nbt.py +++ b/convlab/dst/setsumbt/modeling/bert_nbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,10 +16,11 @@ """BERT SetSUMBT""" import torch +import transformers from torch.autograd import Variable from transformers import BertModel, BertPreTrainedModel -from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead +from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward class BertSetSUMBT(BertPreTrainedModel): @@ -34,37 +35,59 @@ class BertSetSUMBT(BertPreTrainedModel): for p in self.bert.parameters(): p.requires_grad = False - self.setsumbt = SetSUMBTHead(config) - self.add_slot_candidates = self.setsumbt.add_slot_candidates - self.add_value_candidates = self.setsumbt.add_value_candidates - - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - hidden_state: torch.Tensor = None, - state_labels: torch.Tensor = None, - request_labels: torch.Tensor = None, - active_domain_labels: torch.Tensor = None, - general_act_labels: torch.Tensor = None, - get_turn_pooled_representation: bool = False, - calculate_state_mutual_info: bool = False): - """ - Args: - input_ids: Input token ids - attention_mask: Input padding mask - token_type_ids: Token type indicator - hidden_state: Latent internal dialogue belief state - state_labels: Dialogue state labels - request_labels: User request action labels - active_domain_labels: Current active domain labels - general_act_labels: General user action labels - get_turn_pooled_representation: Return pooled representation of the current dialogue turn - calculate_state_mutual_info: Return mutual information in the dialogue state - - Returns: - out: Tuple containing loss, predictive distributions, model statistics and state mutual information - """ + _initialise(self, config) + + # Add new slot candidates to the model + def add_slot_candidates(self, slot_candidates): + """slot_candidates is a list of tuples for each slot. + - The tuples contains the slot embedding, informable value embeddings and a request indicator. + - If the informable value embeddings is None the slot is not informable + - If the request indicator is false the slot is not requestable""" + if self.slot_embeddings.size(0) != 0: + embeddings = self.slot_embeddings.detach() + else: + embeddings = torch.zeros(0) + + for slot in slot_candidates: + if slot in self.slot_ids: + index = self.slot_ids[slot] + embeddings[index, :] = slot_candidates[slot][0] + else: + index = embeddings.size(0) + emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) + embeddings = torch.cat((embeddings, emb), 0) + self.slot_ids[slot] = index + setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) + # Add slot to relevant requestable and informable slot lists + if slot_candidates[slot][2]: + self.requestable_slot_ids[slot] = index + if slot_candidates[slot][1] is not None: + self.informable_slot_ids[slot] = index + + domain = slot.split('-', 1)[0] + if domain not in self.domain_ids: + self.domain_ids[domain] = [] + self.domain_ids[domain].append(index) + self.domain_ids[domain] = list(set(self.domain_ids[domain])) + + self.slot_embeddings = Variable(embeddings, requires_grad=False) + + + # Add new value candidates to the model + def add_value_candidates(self, slot, value_candidates, replace=False): + embeddings = getattr(self, slot + '_value_embeddings') + + if embeddings.size(0) == 0 or replace: + embeddings = value_candidates + else: + embeddings = torch.cat((embeddings, value_candidates), 0) + + setattr(self, slot + '_value_embeddings', embeddings) + + + def forward(self, input_ids, token_type_ids, attention_mask, hidden_state=None, inform_labels=None, + request_labels=None, domain_labels=None, goodbye_labels=None, + get_turn_pooled_representation=False, calculate_inform_mutual_info=False): # Encode Dialogues batch_size, dialogue_size, turn_size = input_ids.size() @@ -80,10 +103,9 @@ class BertSetSUMBT(BertPreTrainedModel): turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) if get_turn_pooled_representation: - return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, - batch_size, dialogue_size, hidden_state, state_labels, - request_labels, active_domain_labels, general_act_labels, - calculate_state_mutual_info) + (bert_output.pooler_output,) - return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, - dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels, - general_act_labels, calculate_state_mutual_info) + return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, + dialogue_size, turn_size, hidden_state, inform_labels, request_labels, domain_labels, + goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,) + return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size, + turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, + calculate_inform_mutual_info) diff --git a/convlab/dst/setsumbt/modeling/calibration_utils.py b/convlab/dst/setsumbt/modeling/calibration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8514ac8d259162c5bcc55607bb8356de6d4b47c7 --- /dev/null +++ b/convlab/dst/setsumbt/modeling/calibration_utils.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Discriminative models calibration""" + +import random + +import torch +import numpy as np +from tqdm import tqdm + + +# Load logger and tensorboard summary writer +def set_logger(logger_, tb_writer_): + global logger, tb_writer + logger = logger_ + tb_writer = tb_writer_ + + +# Set seeds +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + logger.info('Seed set to %d.' % args.seed) + + +def get_predictions(args, model, device, dataloader): + logger.info(" Num Batches = %d", len(dataloader)) + + model.eval() + if args.dropout_iterations > 1: + model.train() + + belief_states = {slot: [] for slot in model.informable_slot_ids} + request_belief = {slot: [] for slot in model.requestable_slot_ids} + domain_belief = {dom: [] for dom in model.domain_ids} + greeting_belief = [] + labels = {slot: [] for slot in model.informable_slot_ids} + request_labels = {slot: [] for slot in model.requestable_slot_ids} + domain_labels = {dom: [] for dom in model.domain_ids} + greeting_labels = [] + epoch_iterator = tqdm(dataloader, desc="Iteration") + for step, batch in enumerate(epoch_iterator): + with torch.no_grad(): + input_ids = batch['input_ids'].to(device) + token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None + attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None + + if args.dropout_iterations > 1: + p = {slot: [] for slot in model.informable_slot_ids} + for _ in range(args.dropout_iterations): + p_, p_req_, p_dom_, p_bye_, _ = model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + for slot in model.informable_slot_ids: + p[slot].append(p_[slot].unsqueeze(0)) + + mu = {slot: torch.cat(p[slot], 0).mean(0) for slot in model.informable_slot_ids} + sig = {slot: torch.cat(p[slot], 0).var(0) for slot in model.informable_slot_ids} + p = {slot: mu[slot] / torch.sqrt(1 + sig[slot]) for slot in model.informable_slot_ids} + p = {slot: normalise(p[slot]) for slot in model.informable_slot_ids} + else: + p, p_req, p_dom, p_bye, _ = model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + + for slot in model.informable_slot_ids: + p_ = p[slot] + labs = batch['labels-' + slot].to(device) + + belief_states[slot].append(p_) + labels[slot].append(labs) + + if p_req is not None: + for slot in model.requestable_slot_ids: + p_ = p_req[slot] + labs = batch['request-' + slot].to(device) + + request_belief[slot].append(p_) + request_labels[slot].append(labs) + + for domain in model.domain_ids: + p_ = p_dom[domain] + labs = batch['active-' + domain].to(device) + + domain_belief[domain].append(p_) + domain_labels[domain].append(labs) + + greeting_belief.append(p_bye) + greeting_labels.append(batch['goodbye'].to(device)) + + for slot in belief_states: + belief_states[slot] = torch.cat(belief_states[slot], 0) + labels[slot] = torch.cat(labels[slot], 0) + if p_req is not None: + for slot in request_belief: + request_belief[slot] = torch.cat(request_belief[slot], 0) + request_labels[slot] = torch.cat(request_labels[slot], 0) + for domain in domain_belief: + domain_belief[domain] = torch.cat(domain_belief[domain], 0) + domain_labels[domain] = torch.cat(domain_labels[domain], 0) + greeting_belief = torch.cat(greeting_belief, 0) + greeting_labels = torch.cat(greeting_labels, 0) + else: + request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels = [None]*6 + + return belief_states, labels, request_belief, request_labels, domain_belief, domain_labels, greeting_belief, greeting_labels + + +def normalise(p): + p_shape = p.size() + + p = p.reshape(-1, p_shape[-1]) + 1e-10 + p_sum = p.sum(-1).unsqueeze(1).repeat((1, p_shape[-1])) + p /= p_sum + + p = p.reshape(p_shape) + + return p diff --git a/convlab/dst/setsumbt/modeling/ensemble_nbt.py b/convlab/dst/setsumbt/modeling/ensemble_nbt.py index 0ae03a60b1a754c7305028158adece48d71fee03..9f101d128c6b8c9093c4959834cf5aac35b322a2 100644 --- a/convlab/dst/setsumbt/modeling/ensemble_nbt.py +++ b/convlab/dst/setsumbt/modeling/ensemble_nbt.py @@ -18,6 +18,7 @@ import os import torch +import transformers from torch.nn import Module from transformers import RobertaConfig, BertConfig @@ -28,13 +29,8 @@ MODELS = {'bert': BertSetSUMBT, 'roberta': RobertaSetSUMBT} class EnsembleSetSUMBT(Module): - """Ensemble SetSUMBT Model for joint ensemble prediction""" def __init__(self, config): - """ - Args: - config (configuration): Model configuration class - """ super(EnsembleSetSUMBT, self).__init__() self.config = config @@ -42,98 +38,75 @@ class EnsembleSetSUMBT(Module): model_cls = MODELS[self.config.model_type] for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: setattr(self, attr, model_cls(config)) + - def _load(self, path: str): - """ - Load parameters - Args: - path: Location of model parameters - """ + # Load all ensemble memeber parameters + def load(self, path, config=None): + if config is None: + config = self.config + for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: idx = attr.split('_', 1)[-1] state_dict = torch.load(os.path.join(path, f'pytorch_model_{idx}.bin')) getattr(self, attr).load_state_dict(state_dict) + - def add_slot_candidates(self, slot_candidates: tuple): - """ - Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings - and a request indicator, if the informable value embeddings is None the slot is not informable and if - the request indicator is false the slot is not requestable. - - Args: - slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator - """ + # Add new slot candidates to the ensemble members + def add_slot_candidates(self, slot_candidates): for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: getattr(self, attr).add_slot_candidates(slot_candidates) - self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids - self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids - self.domain_ids = self.setsumbt.model_0.domain_ids - - def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): - """ - Add value candidates for a slot - - Args: - slot: Slot name - value_candidates: Value candidate embeddings - replace: If true existing value candidates are replaced - """ + self.requestable_slot_ids = self.model_0.requestable_slot_ids + self.informable_slot_ids = self.model_0.informable_slot_ids + self.domain_ids = self.model_0.domain_ids + + + # Add new value candidates to the ensemble members + def add_value_candidates(self, slot, value_candidates, replace=False): for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: getattr(self, attr).add_value_candidates(slot, value_candidates, replace) + - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - reduction: str = 'mean') -> tuple: - """ - Args: - input_ids: Input token ids - attention_mask: Input padding mask - token_type_ids: Token type indicator - reduction: Reduction of ensemble member predictive distributions (mean, none) - - Returns: - - """ - belief_state_probs = {slot: [] for slot in self.model_0.informable_slot_ids} - request_probs = {slot: [] for slot in self.model_0.requestable_slot_ids} - active_domain_probs = {dom: [] for dom in self.model_0.domain_ids} - general_act_probs = [] + # Forward pass of full ensemble + def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'): + logits, request_logits, domain_logits, goodbye_scores = [], [], [], [] + logits = {slot: [] for slot in self.model_0.informable_slot_ids} + request_logits = {slot: [] for slot in self.model_0.requestable_slot_ids} + domain_logits = {dom: [] for dom in self.model_0.domain_ids} + goodbye_scores = [] for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]: # Prediction from each ensemble member - b, r, d, g, _ = getattr(self, attr)(input_ids=input_ids, + l, r, d, g, _ = getattr(self, attr)(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - for slot in belief_state_probs: - belief_state_probs[slot].append(b[slot].unsqueeze(-2)) + for slot in logits: + logits[slot].append(l[slot].unsqueeze(-2)) if self.config.predict_intents: - for slot in request_probs: - request_probs[slot].append(r[slot].unsqueeze(-1)) - for dom in active_domain_probs: - active_domain_probs[dom].append(d[dom].unsqueeze(-1)) - general_act_probs.append(g.unsqueeze(-2)) + for slot in request_logits: + request_logits[slot].append(r[slot].unsqueeze(-1)) + for dom in domain_logits: + domain_logits[dom].append(d[dom].unsqueeze(-1)) + goodbye_scores.append(g.unsqueeze(-2)) - belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()} + logits = {slot: torch.cat(l, -2) for slot, l in logits.items()} if self.config.predict_intents: - request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()} - active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()} - general_act_probs = torch.cat(general_act_probs, -2) + request_logits = {slot: torch.cat(l, -1) for slot, l in request_logits.items()} + domain_logits = {dom: torch.cat(l, -1) for dom, l in domain_logits.items()} + goodbye_scores = torch.cat(goodbye_scores, -2) else: - request_probs = {} - active_domain_probs = {} - general_act_probs = torch.tensor(0.0) + request_logits = {} + domain_logits = {} + goodbye_scores = torch.tensor(0.0) # Apply reduction of ensemble to single posterior if reduction == 'mean': - belief_state_probs = {slot: l.mean(-2) for slot, l in belief_state_probs.items()} - request_probs = {slot: l.mean(-1) for slot, l in request_probs.items()} - active_domain_probs = {dom: l.mean(-1) for dom, l in active_domain_probs.items()} - general_act_probs = general_act_probs.mean(-2) + logits = {slot: l.mean(-2) for slot, l in logits.items()} + request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()} + domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()} + goodbye_scores = goodbye_scores.mean(-2) elif reduction != 'none': raise(NameError('Not Implemented!')) - return belief_state_probs, request_probs, active_domain_probs, general_act_probs, _ + return logits, request_logits, domain_logits, goodbye_scores, _ @classmethod @@ -152,3 +125,88 @@ class EnsembleSetSUMBT(Module): model.load(path) return model + + +class DropoutEnsembleSetSUMBT(Module): + + def __init__(self, config): + super(DropoutEnsembleBeliefTracker, self).__init__() + self.config = config + + model_cls = MODELS[self.config.model_type] + self.model = model_cls(config) + self.model.train() + + + def load(self, path, config=None): + if config is None: + config = self.config + state_dict = torch.load(os.path.join(path, f'pytorch_model.bin')) + self.model.load_state_dict(state_dict) + + + # Add new slot candidates to the model + def add_slot_candidates(self, slot_candidates): + self.model.add_slot_candidates(slot_candidates) + self.requestable_slot_ids = self.model.requestable_slot_ids + self.informable_slot_ids = self.model.informable_slot_ids + self.domain_ids = self.model.domain_ids + + + # Add new value candidates to the model + def add_value_candidates(self, slot, value_candidates, replace=False): + self.model.add_value_candidates(slot, value_candidates, replace) + + + def forward(self, input_ids, attention_mask, token_type_ids=None, reduction='mean'): + + input_ids = input_ids.unsqueeze(0).repeat((self.config.ensemble_size, 1, 1, 1)) + input_ids = input_ids.reshape(-1, input_ids.size(-2), input_ids.size(-1)) + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0).repeat((10, 1, 1, 1)) + attention_mask = attention_mask.reshape(-1, attention_mask.size(-2), attention_mask.size(-1)) + if token_type_ids is not None: + token_type_ids = token_type_ids.unsqueeze(0).repeat((10, 1, 1, 1)) + token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-2), token_type_ids.size(-1)) + + self.model.train() + logits, request_logits, domain_logits, goodbye_scores, _ = self.model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids) + + logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-2), l.size(-1)).transpose(0, 1).transpose(1, 2) + for s, l in logits.items()} + request_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2) + for s, l in request_logits.items()} + domain_logits = {s: l.reshape(self.config.ensemble_size, -1, l.size(-1)).transpose(0, 1).transpose(1, 2) + for s, l in domain_logits.items()} + goodbye_scores = goodbye_scores.reshape(self.config.ensemble_size, -1, goodbye_scores.size(-2), goodbye_scores.size(-1)) + goodbye_scores = goodbye_scores.transpose(0, 1).transpose(1, 2) + + if reduction == 'mean': + logits = {slot: l.mean(-2) for slot, l in logits.items()} + request_logits = {slot: l.mean(-1) for slot, l in request_logits.items()} + domain_logits = {dom: l.mean(-1) for dom, l in domain_logits.items()} + goodbye_scores = goodbye_scores.mean(-2) + elif reduction != 'none': + raise(NameError('Not Implemented!')) + + return logits, request_logits, domain_logits, goodbye_scores, _ + + + @classmethod + def from_pretrained(cls, path): + if not os.path.exists(os.path.join(path, 'config.json')): + raise(NameError('Could not find config.json in model path.')) + if not os.path.exists(os.path.join(path, 'pytorch_model.bin')): + raise(NameError('Could not find a model binary in the model path.')) + + try: + config = RobertaConfig.from_pretrained(path) + except: + config = BertConfig.from_pretrained(path) + + model = cls(config) + model.load(path) + + return model diff --git a/convlab/dst/setsumbt/modeling/ensemble_utils.py b/convlab/dst/setsumbt/modeling/ensemble_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19f6abf81a4070b9498310adfab93d50f5a692f5 --- /dev/null +++ b/convlab/dst/setsumbt/modeling/ensemble_utils.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Discriminative models calibration""" + +import random +import os + +import torch +import numpy as np +from torch.distributions import Categorical +from torch.nn.functional import kl_div +from torch.nn import Module +from tqdm import tqdm + + +# Load logger and tensorboard summary writer +def set_logger(logger_, tb_writer_): + global logger, tb_writer + logger = logger_ + tb_writer = tb_writer_ + + +# Set seeds +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + logger.info('Seed set to %d.' % args.seed) + + +def build_train_loaders(args, tokenizer, dataset): + dataloaders = [dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len, + args.max_turn_len, resampled_size=args.data_sampling_size) + for _ in range(args.ensemble_size)] + return dataloaders diff --git a/convlab/dst/setsumbt/modeling/evaluation_utils.py b/convlab/dst/setsumbt/modeling/evaluation_utils.py deleted file mode 100644 index c73d4b6d32a485a2cf2b5948dbd6a9a4d7f346cb..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/evaluation_utils.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Evaluation Utilities""" - -import random - -import torch -import numpy as np -from tqdm import tqdm - - -def set_seed(args): - """ - Set random seeds - - Args: - args (Arguments class): Arguments class containing seed and number of gpus to use - """ - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - - -def get_predictions(args, model, device: torch.device, dataloader: torch.utils.data.DataLoader) -> tuple: - """ - Get model predictions - - Args: - args: Runtime arguments - model: SetSUMBT Model - device: Torch device - dataloader: Dataloader containing eval data - """ - model.eval() - - belief_states = {slot: [] for slot in model.setsumbt.informable_slot_ids} - request_probs = {slot: [] for slot in model.setsumbt.requestable_slot_ids} - active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids} - general_act_probs = [] - state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids} - request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids} - active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids} - general_act_labels = [] - epoch_iterator = tqdm(dataloader, desc="Iteration") - for step, batch in enumerate(epoch_iterator): - with torch.no_grad(): - input_ids = batch['input_ids'].to(device) - token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None - - p, p_req, p_dom, p_gen, _ = model(input_ids=input_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask) - - for slot in belief_states: - p_ = p[slot] - labs = batch['state_labels-' + slot].to(device) - - belief_states[slot].append(p_) - state_labels[slot].append(labs) - - if p_req is not None: - for slot in request_probs: - p_ = p_req[slot] - labs = batch['request_labels-' + slot].to(device) - - request_probs[slot].append(p_) - request_labels[slot].append(labs) - - for domain in active_domain_probs: - p_ = p_dom[domain] - labs = batch['active_domain_labels-' + domain].to(device) - - active_domain_probs[domain].append(p_) - active_domain_labels[domain].append(labs) - - general_act_probs.append(p_gen) - general_act_labels.append(batch['general_act_labels'].to(device)) - - for slot in belief_states: - belief_states[slot] = torch.cat(belief_states[slot], 0) - state_labels[slot] = torch.cat(state_labels[slot], 0) - if p_req is not None: - for slot in request_probs: - request_probs[slot] = torch.cat(request_probs[slot], 0) - request_labels[slot] = torch.cat(request_labels[slot], 0) - for domain in active_domain_probs: - active_domain_probs[domain] = torch.cat(active_domain_probs[domain], 0) - active_domain_labels[domain] = torch.cat(active_domain_labels[domain], 0) - general_act_probs = torch.cat(general_act_probs, 0) - general_act_labels = torch.cat(general_act_labels, 0) - else: - request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4 - general_act_probs, general_act_labels = [None] * 2 - - out = (belief_states, state_labels, request_probs, request_labels) - out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels) - return out diff --git a/convlab/dst/setsumbt/modeling/functional.py b/convlab/dst/setsumbt/modeling/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd083d0da080ca089e81d9ae01e5f0954243f61 --- /dev/null +++ b/convlab/dst/setsumbt/modeling/functional.py @@ -0,0 +1,456 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SetSUMBT functionals""" + +import torch +import transformers +from torch.autograd import Variable +from torch.nn import (MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout, + CosineSimilarity, CrossEntropyLoss, PairwiseDistance, + Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss) +from torch.nn.init import (xavier_normal_, constant_) + +from convlab.dst.setsumbt.loss.bayesian import BayesianMatchingLoss, BinaryBayesianMatchingLoss, dirichlet +from convlab.dst.setsumbt.loss.labelsmoothing import LabelSmoothingLoss, BinaryLabelSmoothingLoss +from convlab.dst.setsumbt.loss.distillation import DistillationKL, BinaryDistillationKL +from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info + + +# Default belief tracker model intialisation function +def _initialise(self, config): + # Slot Utterance matching attention + self.slot_attention = MultiheadAttention( + config.hidden_size, config.slot_attention_heads) + + # Latent context tracker + # Initial state prediction + if not config.rnn_zero_init and config.nbt_type in ['gru', 'lstm']: + self.belief_init = Sequential(Linear(config.hidden_size, config.nbt_hidden_size), + ReLU(), Dropout(config.dropout_rate)) + + # Recurrent context tracker setup + if config.nbt_type == 'gru': + self.nbt = GRU(input_size=config.hidden_size, + hidden_size=config.nbt_hidden_size, + num_layers=config.nbt_layers, + dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate, + batch_first=True) + # Initialise Parameters + xavier_normal_(self.nbt.weight_ih_l0) + xavier_normal_(self.nbt.weight_hh_l0) + constant_(self.nbt.bias_ih_l0, 0.0) + constant_(self.nbt.bias_hh_l0, 0.0) + elif config.nbt_type == 'lstm': + self.nbt = LSTM(input_size=config.hidden_size, + hidden_size=config.nbt_hidden_size, + num_layers=config.nbt_layers, + dropout=0.0 if config.nbt_layers == 1 else config.dropout_rate, + batch_first=True) + # Initialise Parameters + xavier_normal_(self.nbt.weight_ih_l0) + xavier_normal_(self.nbt.weight_hh_l0) + constant_(self.nbt.bias_ih_l0, 0.0) + constant_(self.nbt.bias_hh_l0, 0.0) + else: + raise NameError('Not Implemented') + + # Feature decoder and layer norm + self.intermediate = Linear(config.nbt_hidden_size, config.hidden_size) + self.layer_norm = LayerNorm(config.hidden_size) + + # Dropout + self.dropout = Dropout(config.dropout_rate) + + # Set pooler for set similarity model + if self.config.set_similarity: + # 1D convolutional set pooler + if self.config.set_pooling == 'cnn': + self.conv_pooler = Conv1d( + self.config.hidden_size, self.config.hidden_size, 3) + # Deep averaging network set pooler + elif self.config.set_pooling == 'dan': + self.avg_net = Sequential(Linear(self.config.hidden_size, 2 * self.config.hidden_size), GELU(), + Linear(2 * self.config.hidden_size, self.config.hidden_size)) + + # Model ontology placeholders + self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False) + self.slot_ids = dict() + self.requestable_slot_ids = dict() + self.informable_slot_ids = dict() + self.domain_ids = dict() + + # Matching network similarity measure + if config.distance_measure == 'cosine': + self.distance = CosineSimilarity(dim=-1, eps=1e-8) + elif config.distance_measure == 'euclidean': + self.distance = PairwiseDistance(p=2.0, eps=1e-06, keepdim=False) + else: + raise NameError('NotImplemented') + + # Belief state loss function + if config.loss_function == 'crossentropy': + self.loss = CrossEntropyLoss(ignore_index=-1) + elif config.loss_function == 'bayesianmatching': + self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + elif config.loss_function == 'labelsmoothing': + self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) + elif config.loss_function == 'distillation': + self.loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) + self.temp = 1.0 + elif config.loss_function == 'distribution_distillation': + self.loss = rkl_dirichlet_mediator_loss + self.temp = 1.0 + else: + raise NameError('NotImplemented') + + # Intent and domain prediction heads + if config.predict_actions: + self.request_gate = Linear(config.hidden_size, 1) + self.goodbye_gate = Linear(config.hidden_size, 3) + self.domain_gate = Linear(config.hidden_size, 1) + + # Intent and domain loss function + self.request_weight = float(self.config.user_request_loss_weight) + self.goodbye_weight = float(self.config.user_general_act_loss_weight) + self.domain_weight = float(self.config.active_domain_loss_weight) + if config.loss_function == 'crossentropy': + self.request_loss = BCEWithLogitsLoss() + self.goodbye_loss = CrossEntropyLoss(ignore_index=-1) + self.domain_loss = BCEWithLogitsLoss() + elif config.loss_function == 'labelsmoothing': + self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) + self.goodbye_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) + self.domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) + elif config.loss_function == 'bayesianmatching': + self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + self.goodbye_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + self.domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) + elif config.loss_function == 'distillation': + self.request_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) + self.goodbye_loss = DistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) + self.domain_loss = BinaryDistillationKL(ignore_index=-1, lamb=config.ensemble_smoothing) + + +# Default belief tracker forward pass. +def _nbt_forward(self, turn_embeddings, + turn_pooled_representation, + attention_mask, + batch_size, + dialogue_size, + turn_size, + hidden_state, + inform_labels, + request_labels, + domain_labels, + goodbye_labels, + calculate_inform_mutual_info): + hidden_size = turn_embeddings.size(-1) + # Initialise loss + loss = 0.0 + + # Goodbye predictions + goodbye_probs = None + if self.config.predict_actions: + # General action prediction + goodbye_scores = self.goodbye_gate( + turn_pooled_representation.reshape(batch_size * dialogue_size, hidden_size)) + + # Compute loss for general action predictions (weighted loss) + if goodbye_labels is not None: + if self.config.loss_function == 'distillation': + goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-1)) + loss += self.goodbye_loss(goodbye_scores, goodbye_labels, self.temp) * self.goodbye_weight + elif self.config.loss_function == 'distribution_distillation': + goodbye_labels = goodbye_labels.reshape(-1, goodbye_labels.size(-2), goodbye_labels.size(-1)) + loss += self.loss(goodbye_scores, goodbye_labels, 1.0, 1.0)[0] * self.goodbye_weight + else: + goodbye_labels = goodbye_labels.reshape(-1) + loss += self.goodbye_loss(goodbye_scores, goodbye_labels) * self.request_weight + + # Compute general action probabilities + if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'distillation', 'distribution_distillation']: + goodbye_probs = torch.softmax(goodbye_scores, -1).reshape(batch_size, dialogue_size, -1) + elif self.config.loss_function in ['bayesianmatching']: + goodbye_probs = dirichlet(goodbye_scores.reshape(batch_size, dialogue_size, -1)) + + # Slot utterance matching + num_slots = self.slot_embeddings.size(0) + slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size) + slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)).to(turn_embeddings.device) + + if self.config.set_similarity: + # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768] + slot_mask = (slot_embeddings != 0.0).float() + + # Turn embeddings shape [turn_size, batch_size * dialogue_size, 768] + turn_embeddings = turn_embeddings.transpose(0, 1) + # Compute key padding mask + key_padding_mask = (attention_mask[:, :, 0] == 0.0) + key_padding_mask[key_padding_mask[:, 0] == True, :] = False + # Multi head attention of slot over tokens + hidden, _ = self.slot_attention(query=slot_embeddings, + key=turn_embeddings, + value=turn_embeddings, + key_padding_mask=key_padding_mask) # [num_slots, batch_size * dialogue_size, 768] + + # Set embeddings for all masked tokens to 0 + attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1)) + hidden = hidden * attention_mask + if self.config.set_similarity: + hidden = hidden * slot_mask + # Hidden layer shape [num_dials, num_slots, num_turns, 768] + hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2) + + # Latent context tracking + # [batch_size * num_slots, dialogue_size, 768] + hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1) + + if self.config.nbt_type == 'gru': + self.nbt.flatten_parameters() + if hidden_state is None: + if self.config.rnn_zero_init: + context = torch.zeros(self.config.nbt_layers, batch_size * slot_embeddings.size(0), + self.config.nbt_hidden_size) + context = context.to(turn_embeddings.device) + else: + context = self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)) + else: + context = hidden_state.to(hidden.device) + + # [batch_size, dialogue_size, nbt_hidden_size] + belief_embedding, context = self.nbt(hidden, context) + elif self.config.nbt_type == 'lstm': + self.nbt.flatten_parameters() + if self.config.rnn_zero_init: + context = (torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size), + torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size)) + context = (context[0].to(turn_embeddings.device), + context[1].to(turn_embeddings.device)) + else: + context = (self.belief_init(hidden[:, 0, :]).unsqueeze(0).repeat((self.config.nbt_layers, 1, 1)), + torch.zeros(self.config.nbt_layers, batch_size * num_slots, self.config.nbt_hidden_size)) + context = (context[0], context[1].to(turn_embeddings.device)) + + # [batch_size, dialogue_size, nbt_hidden_size] + belief_embedding, context = self.nbt(hidden, context) + + # Decode features + belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), dialogue_size, -1).transpose(1, 2) + if self.config.set_similarity: + belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1, + self.config.nbt_hidden_size) + # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768] + # Normalisation and regularisation + belief_embedding = self.layer_norm(self.intermediate(belief_embedding)) + belief_embedding = self.dropout(belief_embedding) + + # Pooling of the set of latent context representation + if self.config.set_similarity: + slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size) + belief_embedding = belief_embedding * slot_mask + + # Apply pooler to latent context sequence + if self.config.set_pooling == 'mean': + belief_embedding = belief_embedding.sum(-2) / slot_mask.sum(-2) + belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1) + elif self.config.set_pooling == 'cnn': + belief_embedding = belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size).transpose(1, 2) + belief_embedding = self.conv_pooler(belief_embedding) + # Mean pooling after CNN + belief_embedding = belief_embedding.mean(-1).reshape(batch_size, dialogue_size, num_slots, -1) + elif self.config.set_pooling == 'dan': + # sqrt N reduction + belief_embedding = belief_embedding.sum(-2) / torch.sqrt(torch.tensor(slot_mask.sum(-2))) + # Deep averaging feature extractor + belief_embedding = self.avg_net(belief_embedding) + belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1) + + # Perform classification + if self.config.predict_actions: + # User request prediction + request_probs = dict() + for slot, slot_id in self.requestable_slot_ids.items(): + request_scores = self.request_gate(belief_embedding[:, :, slot_id, :]) + + # Store output probabilities + request_scores = request_scores.reshape(batch_size, dialogue_size) + mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size) + batches, dialogues = torch.where(mask == 0.0) + # Set request scores to 0.0 for padded turns + request_scores[batches, dialogues] = 0.0 + if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', + 'distillation', 'distribution_distillation']: + request_probs[slot] = torch.sigmoid(request_scores) + + if request_labels is not None: + # Compute request gate loss + request_scores = request_scores.reshape(-1) + if self.config.loss_function == 'distillation': + loss += self.request_loss(request_scores, request_labels[slot].reshape(-1), + self.temp) * self.request_weight + elif self.config.loss_function == 'distribution_distillation': + scores, labs = convert_probs_to_logits(request_scores, request_labels[slot]) + loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight + else: + labs = request_labels[slot].reshape(-1) + request_scores = request_scores[labs != -1] + labs = labs[labs != -1].float() + loss += self.request_loss(request_scores, labs) * self.request_weight + + # Active domain prediction + domain_probs = dict() + for domain, slot_ids in self.domain_ids.items(): + belief = belief_embedding[:, :, slot_ids, :] + if len(slot_ids) > 1: + # SqrtN reduction across all slots within a domain + belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5) + domain_scores = self.domain_gate(belief) + + # Store output probabilities + domain_scores = domain_scores.reshape(batch_size, dialogue_size) + mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size) + batches, dialogues = torch.where(mask == 0.0) + domain_scores[batches, dialogues] = 0.0 + if self.config.loss_function in ['crossentropy', 'labelsmoothing', 'bayesianmatching', 'distillation', + 'distribution_distillation']: + domain_probs[domain] = torch.sigmoid(domain_scores) + + if domain_labels is not None: + # Compute domain prediction loss + domain_scores = domain_scores.reshape(-1) + if self.config.loss_function == 'distillation': + loss += self.domain_loss(domain_scores, domain_labels[domain].reshape(-1), + self.temp) * self.domain_weight + elif self.config.loss_function == 'distribution_distillation': + scores, labs = convert_probs_to_logits(domain_scores, domain_labels[domain]) + loss += self.loss(scores, labs, 1.0, 1.0)[0] * self.request_weight + else: + labs = domain_labels[domain].reshape(-1) + domain_scores = domain_scores[labs != -1] + labs = labs[labs != -1].float() + loss += self.domain_loss(domain_scores, labs) * self.domain_weight + else: + request_probs, domain_probs = None, None + + # Informable slot predictions + inform_probs = dict() + out_dict = dict() + mutual_info = dict() + stats = dict() + for slot, slot_id in self.informable_slot_ids.items(): + # Get slot belief embedding and value candidates + candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device) + belief = belief_embedding[:, :, slot_id, :] + slot_size = candidate_embeddings.size(0) + + # Use similaroty matching to produce belief state + if self.config.distance_measure in ['cosine', 'euclidean']: + belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1)) + belief = belief.reshape(-1, self.config.hidden_size) + + # Pooling of set of value candidate description representation + if self.config.set_similarity and self.config.set_pooling == 'mean': + candidate_mask = (candidate_embeddings != 0.0).float() + candidate_embeddings = candidate_embeddings.sum(1) / candidate_mask.sum(1) + elif self.config.set_similarity and self.config.set_pooling == 'cnn': + candidate_embeddings = candidate_embeddings.transpose(1, 2) + candidate_embeddings = self.conv_pooler(candidate_embeddings).mean(-1) + elif self.config.set_similarity and self.config.set_pooling == 'dan': + candidate_mask = (candidate_embeddings != 0.0).float() + candidate_embeddings = candidate_embeddings.sum(1) / torch.sqrt(torch.tensor(candidate_mask.sum(1))) + candidate_embeddings = self.avg_net(candidate_embeddings) + + candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size, + dialogue_size, 1, 1)) + candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size) + + # Score value candidates + if self.config.distance_measure == 'cosine': + scores = self.distance(belief, candidate_embeddings) + # *27 here rescales the cosine similarity for better learning + scores = scores.reshape(batch_size * dialogue_size, -1) * 27.0 + elif self.config.distance_measure == 'euclidean': + scores = -1.0 * self.distance(belief, candidate_embeddings) + scores = scores.reshape(batch_size * dialogue_size, -1) + + # Calculate belief state + if self.config.loss_function in ['crossentropy', 'inhibitedce', + 'labelsmoothing', 'distillation', 'distribution_distillation']: + probs_ = torch.softmax(scores.reshape(batch_size, dialogue_size, -1), -1) + elif self.config.loss_function in ['bayesianmatching']: + probs_ = dirichlet(scores.reshape(batch_size, dialogue_size, -1)) + + # Compute knowledge uncertainty in the beleif states + if calculate_inform_mutual_info and self.config.loss_function == 'distribution_distillation': + mutual_info[slot] = logits_to_mutual_info(scores).reshape(batch_size, dialogue_size) + + # Set padded turn probabilities to zero + mask = attention_mask[self.slot_ids[slot],:, 0].reshape(batch_size, dialogue_size) + batches, dialogues = torch.where(mask == 0.0) + probs_[batches, dialogues, :] = 0.0 + inform_probs[slot] = probs_ + + # Calculate belief state loss + if inform_labels is not None and slot in inform_labels: + if self.config.loss_function == 'bayesianmatching': + prior = torch.ones(scores.size(-1)).float().to(scores.device) + prior = prior * self.config.prior_constant + prior = prior.unsqueeze(0).repeat((scores.size(0), 1)) + + loss += self.loss(scores, inform_labels[slot].reshape(-1), prior=prior) + elif self.config.loss_function == 'distillation': + labels = inform_labels[slot] + labels = labels.reshape(-1, labels.size(-1)) + loss += self.loss(scores, labels, self.temp) + elif self.config.loss_function == 'distribution_distillation': + labels = inform_labels[slot] + labels = labels.reshape(-1, labels.size(-2), labels.size(-1)) + loss_, model_stats, ensemble_stats = self.loss(scores, labels, 1.0, 1.0) + loss += loss_ + + # Calculate stats regarding model precisions + precision = model_stats['precision'] + ensemble_precision = ensemble_stats['precision'] + stats[slot] = {'model_precision_min': precision.min(), + 'model_precision_max': precision.max(), + 'model_precision_mean': precision.mean(), + 'ensemble_precision_min': ensemble_precision.min(), + 'ensemble_precision_max': ensemble_precision.max(), + 'ensemble_precision_mean': ensemble_precision.mean()} + else: + loss += self.loss(scores, inform_labels[slot].reshape(-1)) + + # Return model outputs + out = inform_probs, request_probs, domain_probs, goodbye_probs, context + if inform_labels is not None or request_labels is not None or domain_labels is not None or goodbye_labels is not None: + out = (loss,) + out + (stats,) + if calculate_inform_mutual_info: + out = out + (mutual_info,) + return out + + +# Convert binary scores and labels to 2 class classification problem for distribution distillation +def convert_probs_to_logits(scores, labels): + # Convert single target probability p to distribution [1-p, p] + labels = labels.reshape(-1, labels.size(-1), 1) + labels = torch.cat([1 - labels, labels], -1) + + # Convert input scores into predictive distribution [1-z, z] + scores = torch.sigmoid(scores).unsqueeze(1) + scores = torch.cat((1 - scores, scores), 1) + scores = -1.0 * torch.log((1 / (scores + 1e-8)) - 1) # Inverse sigmoid + + return scores, labels diff --git a/convlab/dst/setsumbt/modeling/roberta_nbt.py b/convlab/dst/setsumbt/modeling/roberta_nbt.py index f72d17fafa50553434b6d4dcd20b8e53d143892f..36920c5ca550a3295f31aaca53c2ebed8c22be37 100644 --- a/convlab/dst/setsumbt/modeling/roberta_nbt.py +++ b/convlab/dst/setsumbt/modeling/roberta_nbt.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,19 +16,16 @@ """RoBERTa SetSUMBT""" import torch +import transformers +from torch.autograd import Variable from transformers import RobertaModel, RobertaPreTrainedModel -from convlab.dst.setsumbt.modeling.setsumbt import SetSUMBTHead +from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward class RobertaSetSUMBT(RobertaPreTrainedModel): - """Roberta based SetSUMBT model""" def __init__(self, config): - """ - Args: - config (configuration): Model configuration class - """ super(RobertaSetSUMBT, self).__init__(config) self.config = config @@ -38,37 +35,60 @@ class RobertaSetSUMBT(RobertaPreTrainedModel): for p in self.roberta.parameters(): p.requires_grad = False - self.setsumbt = SetSUMBTHead(config) - self.add_slot_candidates = self.setsumbt.add_slot_candidates - self.add_value_candidates = self.setsumbt.add_value_candidates + _initialise(self, config) - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - hidden_state: torch.Tensor = None, - state_labels: torch.Tensor = None, - request_labels: torch.Tensor = None, - active_domain_labels: torch.Tensor = None, - general_act_labels: torch.Tensor = None, - get_turn_pooled_representation: bool = False, - calculate_state_mutual_info: bool = False): - """ - Args: - input_ids: Input token ids - attention_mask: Input padding mask - token_type_ids: Token type indicator - hidden_state: Latent internal dialogue belief state - state_labels: Dialogue state labels - request_labels: User request action labels - active_domain_labels: Current active domain labels - general_act_labels: General user action labels - get_turn_pooled_representation: Return pooled representation of the current dialogue turn - calculate_state_mutual_info: Return mutual information in the dialogue state - Returns: - out: Tuple containing loss, predictive distributions, model statistics and state mutual information - """ + # Add new slot candidates to the model + def add_slot_candidates(self, slot_candidates): + """slot_candidates is a list of tuples for each slot. + - The tuples contains the slot embedding, informable value embeddings and a request indicator. + - If the informable value embeddings is None the slot is not informable + - If the request indicator is false the slot is not requestable""" + if self.slot_embeddings.size(0) != 0: + embeddings = self.slot_embeddings.detach() + else: + embeddings = torch.zeros(0) + + for slot in slot_candidates: + if slot in self.slot_ids: + index = self.slot_ids[slot] + embeddings[index, :] = slot_candidates[slot][0] + else: + index = embeddings.size(0) + emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) + embeddings = torch.cat((embeddings, emb), 0) + self.slot_ids[slot] = index + setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) + # Add slot to relevant requestable and informable slot lists + if slot_candidates[slot][2]: + self.requestable_slot_ids[slot] = index + if slot_candidates[slot][1] is not None: + self.informable_slot_ids[slot] = index + + domain = slot.split('-', 1)[0] + if domain not in self.domain_ids: + self.domain_ids[domain] = [] + self.domain_ids[domain].append(index) + self.domain_ids[domain] = list(set(self.domain_ids[domain])) + + self.slot_embeddings = Variable(embeddings, requires_grad=False) + + + # Add new value candidates to the model + def add_value_candidates(self, slot, value_candidates, replace=False): + embeddings = getattr(self, slot + '_value_embeddings') + + if embeddings.size(0) == 0 or replace: + embeddings = value_candidates + else: + embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0) + + setattr(self, slot + '_value_embeddings', embeddings) + + + def forward(self, input_ids, attention_mask, token_type_ids=None, hidden_state=None, inform_labels=None, + request_labels=None, domain_labels=None, goodbye_labels=None, + get_turn_pooled_representation=False, calculate_inform_mutual_info=False): if token_type_ids is not None: token_type_ids = None @@ -86,10 +106,9 @@ class RobertaSetSUMBT(RobertaPreTrainedModel): turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) if get_turn_pooled_representation: - return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, - batch_size, dialogue_size, hidden_state, state_labels, - request_labels, active_domain_labels, general_act_labels, - calculate_state_mutual_info) + (roberta_output.pooler_output,) - return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, - dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels, - general_act_labels, calculate_state_mutual_info) + return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size, + turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, + calculate_inform_mutual_info) + (roberta_output.pooler_output,) + return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size, + turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, + calculate_inform_mutual_info) diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py deleted file mode 100644 index 0249649f0840d66b0cec8a65c91aded906f62f85..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/modeling/setsumbt.py +++ /dev/null @@ -1,564 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf -# Authors: Carel van Niekerk (niekerk@hhu.de) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SetSUMBT Prediction Head""" - -import torch -from torch.autograd import Variable -from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout, - CosineSimilarity, CrossEntropyLoss, PairwiseDistance, - Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss) -from torch.nn.init import (xavier_normal_, constant_) - -from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss, - KLDistillationLoss, BinaryKLDistillationLoss, - LabelSmoothingLoss, BinaryLabelSmoothingLoss, - RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss) - - -class SlotUtteranceMatching(Module): - """Slot Utterance matching attention based information extractor""" - - def __init__(self, hidden_size: int = 768, attention_heads: int = 12): - """ - Args: - hidden_size (int): Dimension of token embeddings - attention_heads (int): Number of attention heads to use in attention module - """ - super(SlotUtteranceMatching, self).__init__() - - self.attention = MultiheadAttention(hidden_size, attention_heads) - - def forward(self, - turn_embeddings: torch.Tensor, - attention_mask: torch.Tensor, - slot_embeddings: torch.Tensor) -> torch.Tensor: - """ - Args: - turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size] - attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size] - slot_embeddings: Embeddings for each token in the slot descriptions - - Returns: - hidden: Information extracted from turn related to slot descriptions - """ - turn_embeddings = turn_embeddings.transpose(0, 1) - - key_padding_mask = (attention_mask[:, :, 0] == 0.0) - key_padding_mask[key_padding_mask[:, 0], :] = False - - hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings, - key_padding_mask=key_padding_mask) - - attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1)) - hidden = hidden * attention_mask - - return hidden - - -class RecurrentNeuralBeliefTracker(Module): - """Recurrent latent neural belief tracking module""" - - def __init__(self, - nbt_type: str = 'gru', - rnn_zero_init: bool = False, - input_size: int = 768, - hidden_size: int = 300, - hidden_layers: int = 1, - dropout_rate: float = 0.3): - """ - Args: - nbt_type: Type of recurrent neural network (gru/lstm) - rnn_zero_init: Use zero initialised state for the RNN - input_size: Embedding size of the inputs - hidden_size: Hidden size of the RNN - hidden_layers: Number of RNN Layers - dropout_rate: Dropout rate - """ - super(RecurrentNeuralBeliefTracker, self).__init__() - - if rnn_zero_init: - self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate)) - else: - self.belief_init = None - - self.nbt_type = nbt_type - self.hidden_layers = hidden_layers - self.hidden_size = hidden_size - if nbt_type == 'gru': - self.nbt = GRU(input_size=input_size, - hidden_size=hidden_size, - num_layers=hidden_layers, - dropout=0.0 if hidden_layers == 1 else dropout_rate, - batch_first=True) - elif nbt_type == 'lstm': - self.nbt = LSTM(input_size=input_size, - hidden_size=hidden_size, - num_layers=hidden_layers, - dropout=0.0 if hidden_layers == 1 else dropout_rate, - batch_first=True) - else: - raise NameError('Not Implemented') - - # Initialise Parameters - xavier_normal_(self.nbt.weight_ih_l0) - xavier_normal_(self.nbt.weight_hh_l0) - constant_(self.nbt.bias_ih_l0, 0.0) - constant_(self.nbt.bias_hh_l0, 0.0) - - # Intermediate feature mapping and layer normalisation - self.intermediate = Linear(hidden_size, input_size) - self.layer_norm = LayerNorm(input_size) - self.dropout = Dropout(dropout_rate) - - def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor: - """ - Args: - inputs: Latent turn level information - hidden_state: Latent internal belief state - - Returns: - belief_embedding: Belief state embeddings - context: Latent internal belief state - """ - self.nbt.flatten_parameters() - if hidden_state is None: - if self.belief_init is None: - context = torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device) - else: - context = self.belief_init(inputs[:, 0, :]).unsqueeze(0).repeat((self.hidden_layers, 1, 1)) - if self.nbt_type == "lstm": - context = (context, torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device)) - else: - context = hidden_state.to(inputs.device) - - # [batch_size, dialogue_size, nbt_hidden_size] - belief_embedding, context = self.nbt(inputs, context) - - # Normalisation and regularisation - belief_embedding = self.layer_norm(self.intermediate(belief_embedding)) - belief_embedding = self.dropout(belief_embedding) - - return belief_embedding, context - - -class SetPooler(Module): - """Token set pooler""" - - def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768): - """ - Args: - pooling_strategy: Type of set pooler (cnn/dan/mean) - hidden_size: Token embedding size - """ - super(SetPooler, self).__init__() - - self.pooling_strategy = pooling_strategy - if pooling_strategy == 'cnn': - self.cnn_filter_size = 3 - self.pooler = Conv1d(hidden_size, hidden_size, self.cnn_filter_size) - elif pooling_strategy == 'dan': - self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size)) - - def forward(self, inputs, attention_mask): - """ - Args: - inputs: Token set embeddings - attention_mask: Padding mask for the set of tokens - - Returns: - - """ - if self.pooling_strategy == "mean": - hidden = inputs.sum(1) / attention_mask.sum(1) - elif self.pooling_strategy == "cnn": - hidden = self.pooler(inputs.transpose(1, 2)).mean(-1) - elif self.pooling_strategy == 'dan': - hidden = inputs.sum(1) / torch.sqrt(torch.tensor(attention_mask.sum(1))) - hidden = self.pooler(hidden) - - return hidden - - -class SetSUMBTHead(Module): - """SetSUMBT Prediction Head for Language Models""" - - def __init__(self, config): - """ - Args: - config (configuration): Model configuration class - """ - super(SetSUMBTHead, self).__init__() - self.config = config - # Slot Utterance matching attention - self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads) - - # Latent context tracker - self.nbt = RecurrentNeuralBeliefTracker(config.nbt_type, config.rnn_zero_init, config.hidden_size, - config.nbt_hidden_size, config.nbt_layers, config.dropout_rate) - - # Set pooler for set similarity model - if self.config.set_similarity: - self.set_pooler = SetPooler(config.set_pooling, config.hidden_size) - - # Model ontology placeholders - self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False) - self.slot_ids = dict() - self.requestable_slot_ids = dict() - self.informable_slot_ids = dict() - self.domain_ids = dict() - - # Matching network similarity measure - if config.distance_measure == 'cosine': - self.distance = CosineSimilarity(dim=-1, eps=1e-8) - elif config.distance_measure == 'euclidean': - self.distance = PairwiseDistance(p=2.0, eps=1e-6, keepdim=False) - else: - raise NameError('NotImplemented') - - # User goal prediction loss function - if config.loss_function == 'crossentropy': - self.loss = CrossEntropyLoss(ignore_index=-1) - elif config.loss_function == 'bayesianmatching': - self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - elif config.loss_function == 'labelsmoothing': - self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) - elif config.loss_function == 'distillation': - self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - self.temp = 1.0 - elif config.loss_function == 'distribution_distillation': - self.loss = RKLDirichletMediatorLoss(ignore_index=-1) - else: - raise NameError('NotImplemented') - - # Intent and domain prediction heads - if config.predict_actions: - self.request_gate = Linear(config.hidden_size, 1) - self.general_act_gate = Linear(config.hidden_size, 3) - self.active_domain_gate = Linear(config.hidden_size, 1) - - # Intent and domain loss function - self.request_weight = float(self.config.user_request_loss_weight) - self.general_act_weight = float(self.config.user_general_act_loss_weight) - self.active_domain_weight = float(self.config.active_domain_loss_weight) - if config.loss_function == 'crossentropy': - self.request_loss = BCEWithLogitsLoss() - self.general_act_loss = CrossEntropyLoss(ignore_index=-1) - self.active_domain_loss = BCEWithLogitsLoss() - elif config.loss_function == 'labelsmoothing': - self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) - self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing) - self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing) - elif config.loss_function == 'bayesianmatching': - self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor) - elif config.loss_function == 'distillation': - self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing) - elif config.loss_function == 'distribution_distillation': - self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1) - self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1) - self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1) - - def add_slot_candidates(self, slot_candidates: tuple): - """ - Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings - and a request indicator, if the informable value embeddings is None the slot is not informable and if - the request indicator is false the slot is not requestable. - - Args: - slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator - """ - if self.slot_embeddings.size(0) != 0: - embeddings = self.slot_embeddings.detach() - else: - embeddings = torch.zeros(0) - - for slot in slot_candidates: - if slot in self.slot_ids: - index = self.slot_ids[slot] - embeddings[index, :] = slot_candidates[slot][0] - else: - index = embeddings.size(0) - emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device) - embeddings = torch.cat((embeddings, emb), 0) - self.slot_ids[slot] = index - setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False)) - # Add slot to relevant requestable and informable slot lists - if slot_candidates[slot][2]: - self.requestable_slot_ids[slot] = index - if slot_candidates[slot][1] is not None: - self.informable_slot_ids[slot] = index - - domain = slot.split('-', 1)[0] - if domain not in self.domain_ids: - self.domain_ids[domain] = [] - self.domain_ids[domain].append(index) - self.domain_ids[domain] = list(set(self.domain_ids[domain])) - - self.slot_embeddings = Variable(embeddings, requires_grad=False) - - def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False): - """ - Add value candidates for a slot - - Args: - slot: Slot name - value_candidates: Value candidate embeddings - replace: If true existing value candidates are replaced - """ - embeddings = getattr(self, slot + '_value_embeddings') - - if embeddings.size(0) == 0 or replace: - embeddings = value_candidates - else: - embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0) - - setattr(self, slot + '_value_embeddings', embeddings) - - def forward(self, - turn_embeddings: torch.Tensor, - turn_pooled_representation: torch.Tensor, - attention_mask: torch.Tensor, - batch_size: int, - dialogue_size: int, - hidden_state: torch.Tensor = None, - state_labels: torch.Tensor = None, - request_labels: torch.Tensor = None, - active_domain_labels: torch.Tensor = None, - general_act_labels: torch.Tensor = None, - calculate_state_mutual_info: bool = False): - """ - Args: - turn_embeddings: Token embeddings in the current turn - turn_pooled_representation: Pooled representation of the current dialogue turn - attention_mask: Padding mask for the current dialogue turn - batch_size: Number of dialogues in the batch - dialogue_size: Number of turns in each dialogue - hidden_state: Latent internal dialogue belief state - state_labels: Dialogue state labels - request_labels: User request action labels - active_domain_labels: Current active domain labels - general_act_labels: General user action labels - calculate_state_mutual_info: Return mutual information in the dialogue state - - Returns: - out: Tuple containing loss, predictive distributions, model statistics and state mutual information - """ - hidden_size = turn_embeddings.size(-1) - # Initialise loss - loss = 0.0 - - # General Action predictions - general_act_probs = None - if self.config.predict_actions: - # General action prediction - general_act_logits = self.general_act_gate(turn_pooled_representation.reshape(batch_size * dialogue_size, - hidden_size)) - - # Compute loss for general action predictions (weighted loss) - if general_act_labels is not None: - if self.config.loss_function == 'distillation': - general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-1)) - loss += self.general_act_loss(general_act_logits, general_act_labels, - self.temp) * self.general_act_weight - elif self.config.loss_function == 'distribution_distillation': - general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-2), - general_act_labels.size(-1)) - loss += self.general_act_loss(general_act_logits, general_act_labels)[0] * self.general_act_weight - else: - general_act_labels = general_act_labels.reshape(-1) - loss += self.general_act_loss(general_act_logits, general_act_labels) * self.general_act_weight - - # Compute general action probabilities - general_act_probs = torch.softmax(general_act_logits, -1).reshape(batch_size, dialogue_size, -1) - - # Slot utterance matching - num_slots = self.slot_embeddings.size(0) - slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size) - slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1)) - slot_embeddings = slot_embeddings.to(turn_embeddings.device) - - if self.config.set_similarity: - # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768] - slot_mask = (slot_embeddings != 0.0).float() - - hidden = self.slot_utterance_matching(turn_embeddings, attention_mask, slot_embeddings) - - if self.config.set_similarity: - hidden = hidden * slot_mask - # Hidden layer shape [num_dials, num_slots, num_turns, 768] - hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2) - - # Latent context tracking - # [batch_size * num_slots, dialogue_size, 768] - hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1) - belief_embedding, hidden_state = self.nbt(hidden, hidden_state) - - belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0), - dialogue_size, -1).transpose(1, 2) - if self.config.set_similarity: - belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1, - self.config.hidden_size) - # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768] - - # Pooling of the set of latent context representation - if self.config.set_similarity: - slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size) - belief_embedding = belief_embedding * slot_mask - - belief_embedding = self.set_pooler(belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size), - slot_mask.reshape(-1, slot_mask.size(-2), hidden_size)) - belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1) - - # Perform classification - # Get padded batch, dialogue idx pairs - batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0) - - if self.config.predict_actions: - # User request prediction - request_probs = dict() - for slot, slot_id in self.requestable_slot_ids.items(): - request_logits = self.request_gate(belief_embedding[:, :, slot_id, :]) - - # Store output probabilities - request_logits = request_logits.reshape(batch_size, dialogue_size) - # Set request scores to 0.0 for padded turns - request_logits[batches, dialogues] = 0.0 - request_probs[slot] = torch.sigmoid(request_logits) - - if request_labels is not None: - # Compute request gate loss - request_logits = request_logits.reshape(-1) - if self.config.loss_function == 'distillation': - loss += self.request_loss(request_logits, request_labels[slot].reshape(-1), - self.temp) * self.request_weight - elif self.config.loss_function == 'distribution_distillation': - loss += self.request_loss(request_logits, request_labels[slot])[0] * self.request_weight - else: - labs = request_labels[slot].reshape(-1) - request_logits = request_logits[labs != -1] - labs = labs[labs != -1].float() - loss += self.request_loss(request_logits, labs) * self.request_weight - - # Active domain prediction - active_domain_probs = dict() - for domain, slot_ids in self.domain_ids.items(): - belief = belief_embedding[:, :, slot_ids, :] - if len(slot_ids) > 1: - # SqrtN reduction across all slots within a domain - belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5) - active_domain_logits = self.active_domain_gate(belief) - - # Store output probabilities - active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size) - active_domain_logits[batches, dialogues] = 0.0 - active_domain_probs[domain] = torch.sigmoid(active_domain_logits) - - if active_domain_labels is not None and domain in active_domain_labels: - # Compute domain prediction loss - active_domain_logits = active_domain_logits.reshape(-1) - if self.config.loss_function == 'distillation': - loss += self.active_domain_loss(active_domain_logits, active_domain_labels[domain].reshape(-1), - self.temp) * self.active_domain_weight - elif self.config.loss_function == 'distribution_distillation': - loss += self.active_domain_loss(active_domain_logits, - active_domain_labels[domain])[0] * self.active_domain_weight - else: - labs = active_domain_labels[domain].reshape(-1) - active_domain_logits = active_domain_logits[labs != -1] - labs = labs[labs != -1].float() - loss += self.active_domain_loss(active_domain_logits, labs) * self.active_domain_weight - else: - request_probs, active_domain_probs = None, None - - # Dialogue state predictions - belief_state_probs = dict() - belief_state_mutual_info = dict() - belief_state_stats = dict() - for slot, slot_id in self.informable_slot_ids.items(): - # Get slot belief embedding and value candidates - candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device) - belief = belief_embedding[:, :, slot_id, :] - slot_size = candidate_embeddings.size(0) - - belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1)) - belief = belief.reshape(-1, self.config.hidden_size) - - if self.config.set_similarity: - candidate_embeddings = self.set_pooler(candidate_embeddings, (candidate_embeddings != 0.0).float()) - candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size, - dialogue_size, 1, 1)) - candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size) - - # Score value candidates - if self.config.distance_measure == 'cosine': - logits = self.distance(belief, candidate_embeddings) - # *27 here rescales the cosine similarity for better learning - logits = logits.reshape(batch_size * dialogue_size, -1) * 27.0 - elif self.config.distance_measure == 'euclidean': - logits = -1.0 * self.distance(belief, candidate_embeddings) - logits = logits.reshape(batch_size * dialogue_size, -1) - - # Calculate belief state - probs_ = torch.softmax(logits.reshape(batch_size, dialogue_size, -1), -1) - - # Compute knowledge uncertainty in the beleif states - if calculate_state_mutual_info and self.config.loss_function == 'distribution_distillation': - belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size) - - # Set padded turn probabilities to zero - probs_[batches, dialogues, :] = 0.0 - belief_state_probs[slot] = probs_ - - # Calculate belief state loss - if state_labels is not None and slot in state_labels: - if self.config.loss_function == 'bayesianmatching': - prior = torch.ones(logits.size(-1)).float().to(logits.device) - prior = prior * self.config.prior_constant - prior = prior.unsqueeze(0).repeat((logits.size(0), 1)) - - loss += self.loss(logits, state_labels[slot].reshape(-1), prior=prior) - elif self.config.loss_function == 'distillation': - labels = state_labels[slot] - labels = labels.reshape(-1, labels.size(-1)) - loss += self.loss(logits, labels, self.temp) - elif self.config.loss_function == 'distribution_distillation': - labels = state_labels[slot] - labels = labels.reshape(-1, labels.size(-2), labels.size(-1)) - loss_, model_stats, ensemble_stats = self.loss(logits, labels) - loss += loss_ - - # Calculate stats regarding model precisions - precision = model_stats['precision'] - ensemble_precision = ensemble_stats['precision'] - belief_state_stats[slot] = {'model_precision_min': precision.min(), - 'model_precision_max': precision.max(), - 'model_precision_mean': precision.mean(), - 'ensemble_precision_min': ensemble_precision.min(), - 'ensemble_precision_max': ensemble_precision.max(), - 'ensemble_precision_mean': ensemble_precision.mean()} - else: - loss += self.loss(logits, state_labels[slot].reshape(-1)) - - # Return model outputs - out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state - if state_labels is not None or request_labels is not None: - out = (loss,) + out + (belief_state_stats,) - if calculate_state_mutual_info: - out = out + (belief_state_mutual_info,) - return out diff --git a/convlab/dst/setsumbt/modeling/temperature_scheduler.py b/convlab/dst/setsumbt/modeling/temperature_scheduler.py index 654e83c5d1ad9dc908213cca8967a84893395b04..fab205befe3350c9beb9d81566c813cf00b55cf2 100644 --- a/convlab/dst/setsumbt/modeling/temperature_scheduler.py +++ b/convlab/dst/setsumbt/modeling/temperature_scheduler.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,70 +13,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Linear Temperature Scheduler Class""" - +"""Temperature Scheduler Class""" +import torch # Temp scheduler class for ensemble distillation -class LinearTemperatureScheduler: - """ - Temperature scheduler object used for distribution temperature scheduling in distillation +class TemperatureScheduler: - Attributes: - state (dict): Internal state of scheduler - """ - def __init__(self, - total_steps: int, - base_temp: float = 2.5, - cycle_len: float = 0.1): - """ - Args: - total_steps (int): Total number of training steps - base_temp (float): Starting temperature - cycle_len (float): Fraction of total steps used for scheduling cycle - """ - self.state = dict() + def __init__(self, total_steps, base_temp=2.5, cycle_len=0.1): + self.state = {} self.state['total_steps'] = total_steps self.state['current_step'] = 0 self.state['base_temp'] = base_temp self.state['current_temp'] = base_temp self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)] - self.state['rate'] = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0]) def step(self): - """ - Update temperature based on the schedule - """ self.state['current_step'] += 1 assert self.state['current_step'] <= self.state['total_steps'] if self.state['current_step'] > self.state['cycles'][0]: if self.state['current_step'] < self.state['cycles'][1]: - self.state['current_temp'] -= self.state['rate'] + rate = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0]) + self.state['current_temp'] -= rate else: self.state['current_temp'] = 1.0 def temp(self): - """ - Get current temperature - - Returns: - temp (float): Current temperature for distribution scaling - """ return float(self.state['current_temp']) def state_dict(self): - """ - Return scheduler state - - Returns: - state (dict): Dictionary format state of the scheduler - """ return self.state - def load_state_dict(self, state_dict: dict): - """ - Load scheduler state from dictionary + def load_state_dict(self, sd): + self.state = sd + + +# if __name__ == "__main__": +# temp_scheduler = TemperatureScheduler(100) +# print(temp_scheduler.state_dict()) + +# temp = [] +# for i in range(100): +# temp.append(temp_scheduler.temp()) +# temp_scheduler.step() + +# temp_scheduler.load_state_dict(temp_scheduler.state_dict()) +# print(temp_scheduler.state_dict()) - Args: - state_dict (dict): Dictionary format state of the scheduler - """ - self.state = state_dict +# print(temp) diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py index a898a42a9f7321942d940e246d24b25f7b15eedc..259c6e1da061ad6800f92e9237324181678459cb 100644 --- a/convlab/dst/setsumbt/modeling/training.py +++ b/convlab/dst/setsumbt/modeling/training.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,19 +13,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Training and evaluation utils""" +"""Training utils""" import random import os import logging -from copy import deepcopy import torch from torch.nn import DataParallel from torch.distributions import Categorical import numpy as np -from transformers import get_linear_schedule_with_warmup -from torch.optim import AdamW +from transformers import AdamW, get_linear_schedule_with_warmup from tqdm import tqdm, trange try: from apex import amp @@ -33,7 +31,7 @@ except: print('Apex not used') from convlab.dst.setsumbt.utils import clear_checkpoints -from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler +from convlab.dst.setsumbt.modeling.temperature_scheduler import TemperatureScheduler # Load logger and tensorboard summary writer @@ -61,124 +59,18 @@ def set_ontology_embeddings(model, slots, load_slots=True): if load_slots: slots = {slot: embs for slot, embs in slots.items()} model.add_slot_candidates(slots) - for slot in model.setsumbt.informable_slot_ids: + for slot in model.informable_slot_ids: model.add_value_candidates(slot, values[slot], replace=True) -def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=None, gen_f1=None, stats=None): - """ - Log training statistics. - - Args: - global_step: Number of global training steps completed - loss: Training loss - jg_acc: Joint goal accuracy - sl_acc: Slot accuracy - req_f1: Request prediction F1 score - dom_f1: Active domain prediction F1 score - gen_f1: General action prediction F1 score - stats: Uncertainty measure statistics of model - """ - info = f"{global_step} steps complete, " if type(global_step) == int else "" - if global_step == 'training_complete': - info += f"Training Complete" - info += f"Loss since last update: {loss}. Validation set stats: " - if global_step == 'dev': - info += f"Validation set stats: Loss: {loss}, " - if global_step == 'test': - info += f"Test set stats: Loss: {loss}, " - info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, " - if req_f1 is not None: - info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, " - info += f"General Action F1 Score: {gen_f1}" - logger.info(info) - - if type(global_step) == int: - tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) - tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) - if req_f1 is not None: - tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step) - tb_writer.add_scalar('ActiveDomainF1Score/Dev', dom_f1, global_step) - tb_writer.add_scalar('GeneralActionF1Score/Dev', gen_f1, global_step) - tb_writer.add_scalar('Loss/Dev', loss, global_step) - - if stats: - for slot, stats_slot in stats.items(): - for key, item in stats_slot.items(): - tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step) - - -def get_input_dict(batch: dict, - predict_actions: bool, - model_informable_slot_ids: list, - model_requestable_slot_ids: list = None, - model_domain_ids: list = None, - device = 'cpu') -> dict: - """ - Produce model input arguments - - Args: - batch: Batch of data from the dataloader - predict_actions: Model should predict user actions if set true - model_informable_slot_ids: List of model dialogue state slots - model_requestable_slot_ids: List of model requestable slots - model_domain_ids: List of model domains - device: Current torch device in use - - Returns: - input_dict: Dictrionary containing model inputs for the batch - """ - input_dict = dict() - - input_dict['input_ids'] = batch['input_ids'].to(device) - input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None - input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None - - if 'belief_state' in batch: - input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device) for slot in model_informable_slot_ids - if ('belief_state-' + slot) in batch} - if predict_actions: - input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device) - for slot in model_requestable_slot_ids - if ('request_probs-' + slot) in batch} - input_dict['active_domain_labels'] = {domain: batch['active_domain_probs-' + domain].to(device) - for domain in model_domain_ids - if ('active_domain_probs-' + domain) in batch} - input_dict['general_act_labels'] = batch['general_act_probs'].to(device) - else: - input_dict['state_labels'] = {slot: batch['state_labels-' + slot].to(device) - for slot in model_informable_slot_ids if ('state_labels-' + slot) in batch} - if predict_actions: - input_dict['request_labels'] = {slot: batch['request_labels-' + slot].to(device) - for slot in model_requestable_slot_ids - if ('request_labels-' + slot) in batch} - input_dict['active_domain_labels'] = {domain: batch['active_domain_labels-' + domain].to(device) - for domain in model_domain_ids - if ('active_domain_labels-' + domain) in batch} - input_dict['general_act_labels'] = batch['general_act_labels'].to(device) - - return input_dict - - -def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, slots_dev: dict): - """ - Train the SetSUMBT model. - - Args: - args: Runtime arguments - model: SetSUMBT Model instance to train - device: Torch device to use during training - train_dataloader: Dataloader containing the training data - dev_dataloader: Dataloader containing the validation set data - slots: Model ontology used for training - slots_dev: Model ontology used for evaluating on the validation set - """ +def train(args, model, device, train_dataloader, dev_dataloader, slots, slots_dev, embeddings=None, tokenizer=None): + """Train model!""" # Calculate the total number of training steps to be performed if args.max_training_steps > 0: t_total = args.max_training_steps - args.num_train_epochs = (len(train_dataloader) // args.gradient_accumulation_steps) + 1 - args.num_train_epochs = args.max_training_steps // args.num_train_epochs + args.num_train_epochs = args.max_training_steps // ( + (len(train_dataloader) // args.gradient_accumulation_steps) + 1) else: t_total = (len(train_dataloader) // args.gradient_accumulation_steps) * args.num_train_epochs @@ -196,12 +88,12 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, - "lr": args.learning_rate + "lr":args.learning_rate }, ] # Initialise the optimizer - optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) # Initialise linear lr scheduler num_warmup_steps = int(t_total * args.warmup_proportion) @@ -217,7 +109,8 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl # Set up fp16 and multi gpu usage if args.fp16: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.fp16_opt_level) if args.n_gpu > 1: model = DataParallel(model) @@ -225,7 +118,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl best_model = {'joint goal accuracy': 0.0, 'request f1 score': 0.0, 'active domain f1 score': 0.0, - 'general act f1 score': 0.0, + 'goodbye act f1 score': 0.0, 'train loss': np.inf} if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')): logger.info("Optimizer loaded from previous run.") @@ -243,27 +136,27 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl model.eval() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, _, _ = evaluate(args, model, device, dev_dataloader, is_train=True) + jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader) # Set model back to training mode model.train() model.zero_grad() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False) else: - jg_acc, req_f1, dom_f1, gen_f1 = 0.0, 0.0, 0.0, 0.0 + jg_acc, req_f1, dom_f1, bye_f1 = 0.0, 0.0, 0.0, 0.0 best_model['joint goal accuracy'] = jg_acc best_model['request f1 score'] = req_f1 best_model['active domain f1 score'] = dom_f1 - best_model['general act f1 score'] = gen_f1 + best_model['goodbye act f1 score'] = bye_f1 # Log training set up - logger.info(f"Device: {device}, Number of GPUs: {args.n_gpu}, FP16 training: {args.fp16}") + logger.info("Device: %s, Number of GPUs: %s, FP16 training: %s" % (device, args.n_gpu, args.fp16)) logger.info("***** Running training *****") - logger.info(f" Num Batches = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {t_total}") + logger.info(" Num Batches = %d" % len(train_dataloader)) + logger.info(" Num Epochs = %d" % args.num_train_epochs) + logger.info(" Gradient Accumulation steps = %d" % args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d" % t_total) # Initialise training parameters global_step = 0 @@ -280,11 +173,11 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(f" Continuing training from epoch {epochs_trained}") - logger.info(f" Continuing training from global step {global_step}") - logger.info(f" Will skip the first {steps_trained_in_current_epoch} steps in the first epoch") + logger.info(" Continuing training from epoch %d" % epochs_trained) + logger.info(" Continuing training from global step %d" % global_step) + logger.info(" Will skip the first %d steps in the first epoch" % steps_trained_in_current_epoch) except ValueError: - logger.info(f" Starting fine-tuning.") + logger.info(" Starting fine-tuning.") # Prepare model for training tr_loss, logging_loss = 0.0, 0.0 @@ -303,15 +196,43 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl continue # Extract all label dictionaries from the batch - input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids, - model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device) + if 'goodbye_belief' in batch: + labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids + if ('belief-' + slot) in batch} + request_labels = {slot: batch['request_belief-' + slot].to(device) + for slot in model.requestable_slot_ids + if ('request_belief-' + slot) in batch} if args.predict_actions else None + domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids + if ('domain_belief-' + domain) in batch} if args.predict_actions else None + goodbye_labels = batch['goodbye_belief'].to( + device) if args.predict_actions else None + else: + labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids + if ('labels-' + slot) in batch} + request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids + if ('request-' + slot) in batch} if args.predict_actions else None + domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids + if ('active-' + domain) in batch} if args.predict_actions else None + goodbye_labels = batch['goodbye'].to( + device) if args.predict_actions else None + + # Extract all model inputs from batch + input_ids = batch['input_ids'].to(device) + token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None + attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None # Set up temperature scaling for the model if temp_scheduler is not None: - model.setsumbt.temp = temp_scheduler.temp() + model.temp = temp_scheduler.temp() # Forward pass to obtain loss - loss, _, _, _, _, _, stats = model(**input_dict) + loss, _, _, _, _, _, stats = model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + inform_labels=labels, + request_labels=request_labels, + domain_labels=domain_labels, + goodbye_labels=goodbye_labels) if args.n_gpu > 1: loss = loss.mean() @@ -337,6 +258,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl tb_writer.add_scalar('LearningRate', lr, global_step) if stats: + # print(stats.keys()) for slot, stats_slot in stats.items(): for key, item in stats_slot.items(): tb_writer.add_scalar(f'{key}_{slot}/Train', item, global_step) @@ -351,6 +273,7 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl tr_loss += loss.float().item() epoch_iterator.set_postfix(loss=loss.float().item()) + loss = 0.0 global_step += 1 # Save model checkpoint @@ -363,34 +286,52 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl model.eval() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader, - is_train=True) + jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader) # Log model eval information - log_info(global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, gen_f1, stats) + if req_f1 is not None: + logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f' + % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) + tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) + tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) + tb_writer.add_scalar('RequestF1Score/Dev', req_f1, global_step) + tb_writer.add_scalar('DomainF1Score/Dev', dom_f1, global_step) + tb_writer.add_scalar('GoodbyeF1Score/Dev', bye_f1, global_step) + else: + logger.info('%i steps complete, Loss since last update = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' + % (global_step, logging_loss / args.save_steps, jg_acc, sl_acc)) + tb_writer.add_scalar('JointGoalAccuracy/Dev', jg_acc, global_step) + tb_writer.add_scalar('SlotAccuracy/Dev', sl_acc, global_step) + tb_writer.add_scalar('Loss/Dev', loss, global_step) + if stats: + for slot, stats_slot in stats.items(): + for key, item in stats_slot.items(): + tb_writer.add_scalar(f'{key}_{slot}/Dev', item, global_step) # Set model back to training mode model.train() model.zero_grad() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots, load_slots=False) else: - log_info(global_step, logging_loss / args.save_steps) + jg_acc, req_f1 = 0.0, None + logger.info('%i steps complete, Loss since last update = %f' % (global_step, logging_loss / args.save_steps)) logging_loss = tr_loss # Compute the score of the best model try: - best_score = best_model['request f1 score'] * model.config.user_request_loss_weight - best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight - best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight + best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \ + (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \ + (best_model['goodbye act f1 score'] * + model.config.user_general_act_loss_weight) except AttributeError: best_score = 0.0 best_score += best_model['joint goal accuracy'] # Compute the score of the current model try: - current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0 - current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0 - current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0 + current_score = (req_f1 * model.config.user_request_loss_weight) + \ + (dom_f1 * model.config.active_domain_loss_weight) + \ + (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0 except AttributeError: current_score = 0.0 current_score += jg_acc @@ -412,10 +353,10 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl if req_f1: best_model['request f1 score'] = req_f1 best_model['active domain f1 score'] = dom_f1 - best_model['general act f1 score'] = gen_f1 + best_model['goodbye act f1 score'] = bye_f1 best_model['train loss'] = tr_loss / global_step - output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}") + output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) @@ -445,14 +386,14 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl epoch_iterator.close() break - logger.info(f'Epoch {e + 1} complete, average training loss = {tr_loss / global_step}') + logger.info('Epoch %i complete, average training loss = %f' % (e + 1, tr_loss / global_step)) if args.max_training_steps > 0 and global_step > args.max_training_steps: train_iterator.close() break if args.patience > 0 and steps_since_last_update >= args.patience: train_iterator.close() - logger.info(f'Model has not improved for at least {args.patience} steps. Training stopped!') + logger.info('Model has not improved for at least %i steps. Training stopped!' % args.patience) break # Evaluate final model @@ -460,25 +401,30 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl model.eval() set_ontology_embeddings(model.module if args.n_gpu > 1 else model, slots_dev, load_slots=False) - jg_acc, sl_acc, req_f1, dom_f1, gen_f1, loss, stats = evaluate(args, model, device, dev_dataloader, - is_train=True) - - log_info('training_complete', tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, gen_f1) + jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats = train_eval(args, model, device, dev_dataloader) + if req_f1 is not None: + logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f, Dev Request F1 Score = %f, Dev Domain F1 Score = %f, Dev Goodbye F1 Score = %f' + % (tr_loss / global_step, jg_acc, sl_acc, req_f1, dom_f1, bye_f1)) + else: + logger.info('Training complete, Training Loss = %f, Dev Joint goal acc = %f, Dev Slot acc = %f' + % (tr_loss / global_step, jg_acc, sl_acc)) else: + jg_acc = 0.0 logger.info('Training complete!') # Store final model try: - best_score = best_model['request f1 score'] * model.config.user_request_loss_weight - best_score += best_model['active domain f1 score'] * model.config.active_domain_loss_weight - best_score += best_model['general act f1 score'] * model.config.user_general_act_loss_weight + best_score = (best_model['request f1 score'] * model.config.user_request_loss_weight) + \ + (best_model['active domain f1 score'] * model.config.active_domain_loss_weight) + \ + (best_model['goodbye act f1 score'] * + model.config.user_general_act_loss_weight) except AttributeError: best_score = 0.0 best_score += best_model['joint goal accuracy'] try: - current_score = req_f1 * model.config.user_request_loss_weight if req_f1 is not None else 0.0 - current_score += dom_f1 * model.config.active_domain_loss_weight if dom_f1 is not None else 0.0 - current_score += gen_f1 * model.config.user_general_act_loss_weight if gen_f1 is not None else 0.0 + current_score = (req_f1 * model.config.user_request_loss_weight) + \ + (dom_f1 * model.config.active_domain_loss_weight) + \ + (bye_f1 * model.config.user_general_act_loss_weight) if req_f1 is not None else 0.0 except AttributeError: current_score = 0.0 current_score += jg_acc @@ -510,85 +456,225 @@ def train(args, model, device, train_dataloader, dev_dataloader, slots: dict, sl torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt")) clear_checkpoints(args.output_dir) else: - logger.info('Final model not saved, as it is not the best performing model.') + logger.info( + 'Final model not saved, since it is not the best performing model.') + + +# Function for validation +def train_eval(args, model, device, dev_dataloader): + """Evaluate Model during training!""" + accuracy_jg = [] + accuracy_sl = [] + accuracy_req = [] + truepos_req, falsepos_req, falseneg_req = [], [], [] + truepos_dom, falsepos_dom, falseneg_dom = [], [], [] + truepos_bye, falsepos_bye, falseneg_bye = [], [], [] + accuracy_dom = [] + accuracy_bye = [] + turns = [] + for batch in dev_dataloader: + # Perform with no gradients stored + with torch.no_grad(): + if 'goodbye_belief' in batch: + labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids + if ('belief-' + slot) in batch} + request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids + if ('request_belief-' + slot) in batch} if args.predict_actions else None + domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids + if ('domain_belief-' + domain) in batch} if args.predict_actions else None + goodbye_labels = batch['goodbye_belief'].to( + device) if args.predict_actions else None + else: + labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids + if ('labels-' + slot) in batch} + request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids + if ('request-' + slot) in batch} if args.predict_actions else None + domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids + if ('active-' + domain) in batch} if args.predict_actions else None + goodbye_labels = batch['goodbye'].to( + device) if args.predict_actions else None + + input_ids = batch['input_ids'].to(device) + token_type_ids = batch['token_type_ids'].to( + device) if 'token_type_ids' in batch else None + attention_mask = batch['attention_mask'].to( + device) if 'attention_mask' in batch else None + + loss, p, p_req, p_dom, p_bye, _, stats = model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + inform_labels=labels, + request_labels=request_labels, + domain_labels=domain_labels, + goodbye_labels=goodbye_labels) + + jg_acc = 0.0 + req_acc = 0.0 + req_tp, req_fp, req_fn = 0.0, 0.0, 0.0 + dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0 + dom_acc = 0.0 + for slot in model.informable_slot_ids: + labels = batch['labels-' + slot].to(device) + p_ = p[slot] + + acc = (p_.argmax(-1) == labels).reshape(-1).float() + jg_acc += acc + + if model.config.predict_actions: + for slot in model.requestable_slot_ids: + p_req_ = p_req[slot] + request_labels = batch['request-' + slot].to(device) + + acc = (p_req_.round().int() == request_labels).reshape(-1).float() + tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float() + fp = (p_req_.round().int() * (request_labels == 0)).reshape(-1).float() + fn = ((1 - p_req_.round().int()) * (request_labels == 1)).reshape(-1).float() + req_acc += acc + req_tp += tp + req_fp += fp + req_fn += fn + + for domain in model.domain_ids: + p_dom_ = p_dom[domain] + domain_labels = batch['active-' + domain].to(device) + acc = (p_dom_.round().int() == domain_labels).reshape(-1).float() + tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float() + fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float() + fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float() + dom_acc += acc + dom_tp += tp + dom_fp += fp + dom_fn += fn -def evaluate(args, model, device, dataloader, return_eval_output=False, is_train=False): - """ - Evaluate model + goodbye_labels = batch['goodbye'].to(device) + bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum() + bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum() + bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum() + bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum() + else: + req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0) + req_tp, req_fp, req_fn = None, None, None + dom_tp, dom_fp, dom_fn = None, None, None + bye_tp, bye_fp, bye_fn = torch.tensor( + 0.0), torch.tensor(0.0), torch.tensor(0.0) + + sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float() + jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float() + req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0) + req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0) + req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0) + req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0) + dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) + dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) + dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) + dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0) + n_turns = (labels >= 0).reshape(-1).sum().float().item() - Args: - args: Runtime arguments - model: SetSUMBT model instance - device: Torch device in use - dataloader: Dataloader of data to evaluate on - return_eval_output: If true return predicted and true states for all dialogues evaluated in semantic format - is_train: If true model is training and no logging is performed + accuracy_jg.append(jg_acc.item()) + accuracy_sl.append(sl_acc.item()) + accuracy_req.append(req_acc.item()) + truepos_req.append(req_tp.item()) + falsepos_req.append(req_fp.item()) + falseneg_req.append(req_fn.item()) + accuracy_dom.append(dom_acc.item()) + truepos_dom.append(dom_tp.item()) + falsepos_dom.append(dom_fp.item()) + falseneg_dom.append(dom_fn.item()) + accuracy_bye.append(bye_acc.item()) + truepos_bye.append(bye_tp.item()) + falsepos_bye.append(bye_fp.item()) + falseneg_bye.append(bye_fn.item()) + turns.append(n_turns) - Returns: - out: Evaluted model statistics - """ - return_eval_output = False if is_train else return_eval_output - if not is_train: - logger.info("***** Running evaluation *****") - logger.info(" Num Batches = %d", len(dataloader)) + # Global accuracy reduction across batches + turns = sum(turns) + jg_acc = sum(accuracy_jg) / turns + sl_acc = sum(accuracy_sl) / turns + if model.config.predict_actions: + req_acc = sum(accuracy_req) / turns + req_tp = sum(truepos_req) + req_fp = sum(falsepos_req) + req_fn = sum(falseneg_req) + req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn)) + dom_acc = sum(accuracy_dom) / turns + dom_tp = sum(truepos_dom) + dom_fp = sum(falsepos_dom) + dom_fn = sum(falseneg_dom) + dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn)) + bye_tp = sum(truepos_bye) + bye_fp = sum(falsepos_bye) + bye_fn = sum(falseneg_bye) + bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn)) + bye_acc = sum(accuracy_bye) / turns + else: + req_acc, dom_acc, bye_acc = None, None, None + req_f1, dom_f1, bye_f1 = None, None, None + + return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss, stats + + +def evaluate(args, model, device, dataloader): + """Evaluate Model!""" + # Evaluate! + logger.info("***** Running evaluation *****") + logger.info(" Num Batches = %d", len(dataloader)) tr_loss = 0.0 model.eval() - if return_eval_output: - ontology = dataloader.dataset.ontology + # logits = {slot: [] for slot in model.informable_slot_ids} accuracy_jg = [] accuracy_sl = [] + accuracy_req = [] truepos_req, falsepos_req, falseneg_req = [], [], [] truepos_dom, falsepos_dom, falseneg_dom = [], [], [] - truepos_gen, falsepos_gen, falseneg_gen = [], [], [] + truepos_bye, falsepos_bye, falseneg_bye = [], [], [] + accuracy_dom = [] + accuracy_bye = [] turns = [] - if return_eval_output: - evaluation_output = [] - epoch_iterator = tqdm(dataloader, desc="Iteration") if not is_train else dataloader + epoch_iterator = tqdm(dataloader, desc="Iteration") for batch in epoch_iterator: with torch.no_grad(): - input_dict = get_input_dict(batch, args.predict_actions, model.setsumbt.informable_slot_ids, - model.setsumbt.requestable_slot_ids, model.setsumbt.domain_ids, device) - - loss, p, p_req, p_dom, p_gen, _, stats = model(**input_dict) + if 'goodbye_belief' in batch: + labels = {slot: batch['belief-' + slot].to(device) for slot in model.informable_slot_ids + if ('belief-' + slot) in batch} + request_labels = {slot: batch['request_belief-' + slot].to(device) for slot in model.requestable_slot_ids + if ('request_belief-' + slot) in batch} if args.predict_actions else None + domain_labels = {domain: batch['domain_belief-' + domain].to(device) for domain in model.domain_ids + if ('domain_belief-' + domain) in batch} if args.predict_actions else None + goodbye_labels = batch['goodbye_belief'].to( + device) if args.predict_actions else None + else: + labels = {slot: batch['labels-' + slot].to(device) for slot in model.informable_slot_ids + if ('labels-' + slot) in batch} + request_labels = {slot: batch['request-' + slot].to(device) for slot in model.requestable_slot_ids + if ('request-' + slot) in batch} if args.predict_actions else None + domain_labels = {domain: batch['active-' + domain].to(device) for domain in model.domain_ids + if ('active-' + domain) in batch} if args.predict_actions else None + goodbye_labels = batch['goodbye'].to( + device) if args.predict_actions else None + + input_ids = batch['input_ids'].to(device) + token_type_ids = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None + attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None + + loss, p, p_req, p_dom, p_bye, _, _ = model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + inform_labels=labels, + request_labels=request_labels, + domain_labels=domain_labels, + goodbye_labels=goodbye_labels) jg_acc = 0.0 req_acc = 0.0 req_tp, req_fp, req_fn = 0.0, 0.0, 0.0 dom_tp, dom_fp, dom_fn = 0.0, 0.0, 0.0 dom_acc = 0.0 - - if return_eval_output: - eval_output_batch = [] - for dial_id, dial in enumerate(input_dict['input_ids']): - for turn_id, turn in enumerate(dial): - if turn.sum() != 0: - eval_output_batch.append({'dial_idx': dial_id, - 'utt_idx': turn_id, - 'state': {domain: {slot: '' for slot in substate} - for domain, substate in ontology.items()}, - 'predictions': {'state': {domain: {slot: '' for slot in substate} - for domain, substate in ontology.items()}} - }) - - for slot in model.setsumbt.informable_slot_ids: + for slot in model.informable_slot_ids: p_ = p[slot] - state_labels = batch['state_labels-' + slot].to(device) - - if return_eval_output: - prediction = p_.argmax(-1) - - for sample in eval_output_batch: - dom, slt = slot.split('-', 1) - pred = prediction[sample['dial_idx']][sample['utt_idx']].item() - pred = ontology[dom][slt]['possible_values'][pred] - lab = state_labels[sample['dial_idx']][sample['utt_idx']].item() - lab = ontology[dom][slt]['possible_values'][lab] - - sample['state'][dom][slt] = lab if lab != 'none' else '' - sample['predictions']['state'][dom][slt] = pred if pred != 'none' else '' + labels = batch['labels-' + slot].to(device) if args.temp_scaling > 0.0: p_ = torch.log(p_ + 1e-10) / args.temp_scaling @@ -597,18 +683,28 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train p_ = torch.log(p_ + 1e-10) / 1.0 p_ = torch.softmax(p_, -1) - acc = (p_.argmax(-1) == state_labels).reshape(-1).float() + # logits[slot].append(p_) + + if args.accuracy_samples > 0: + dist = Categorical(probs=p_.reshape(-1, p_.size(-1))) + lab_sample = dist.sample((args.accuracy_samples,)) + lab_sample = lab_sample.transpose(0, 1) + acc = [lab in s for lab, s in zip(labels.reshape(-1), lab_sample)] + acc = torch.tensor(acc).float() + elif args.accuracy_topn > 0: + labs = p_.reshape(-1, p_.size(-1)).argsort(dim=-1, descending=True) + labs = labs[:, :args.accuracy_topn] + acc = [lab in s for lab, s in zip(labels.reshape(-1), labs)] + acc = torch.tensor(acc).float() + else: + acc = (p_.argmax(-1) == labels).reshape(-1).float() jg_acc += acc - if return_eval_output: - evaluation_output += deepcopy(eval_output_batch) - eval_output_batch = [] - if model.config.predict_actions: - for slot in model.setsumbt.requestable_slot_ids: + for slot in model.requestable_slot_ids: p_req_ = p_req[slot] - request_labels = batch['request_labels-' + slot].to(device) + request_labels = batch['request-' + slot].to(device) acc = (p_req_.round().int() == request_labels).reshape(-1).float() tp = (p_req_.round().int() * (request_labels == 1)).reshape(-1).float() @@ -619,88 +715,85 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train req_fp += fp req_fn += fn - domains = [domain for domain in model.setsumbt.domain_ids if f'active_domain_labels-{domain}' in batch] - for domain in domains: + for domain in model.domain_ids: p_dom_ = p_dom[domain] - active_domain_labels = batch['active_domain_labels-' + domain].to(device) + domain_labels = batch['active-' + domain].to(device) - acc = (p_dom_.round().int() == active_domain_labels).reshape(-1).float() - tp = (p_dom_.round().int() * (active_domain_labels == 1)).reshape(-1).float() - fp = (p_dom_.round().int() * (active_domain_labels == 0)).reshape(-1).float() - fn = ((1 - p_dom_.round().int()) * (active_domain_labels == 1)).reshape(-1).float() + acc = (p_dom_.round().int() == domain_labels).reshape(-1).float() + tp = (p_dom_.round().int() * (domain_labels == 1)).reshape(-1).float() + fp = (p_dom_.round().int() * (domain_labels == 0)).reshape(-1).float() + fn = ((1 - p_dom_.round().int()) * (domain_labels == 1)).reshape(-1).float() dom_acc += acc dom_tp += tp dom_fp += fp dom_fn += fn - general_act_labels = batch['general_act_labels'].to(device) - gen_tp = ((p_gen.argmax(-1) > 0) * (general_act_labels > 0)).reshape(-1).float().sum() - gen_fp = ((p_gen.argmax(-1) > 0) * (general_act_labels == 0)).reshape(-1).float().sum() - gen_fn = ((p_gen.argmax(-1) == 0) * (general_act_labels > 0)).reshape(-1).float().sum() + goodbye_labels = batch['goodbye'].to(device) + bye_acc = (p_bye.argmax(-1) == goodbye_labels).reshape(-1).float().sum() + bye_tp = ((p_bye.argmax(-1) > 0) * (goodbye_labels > 0)).reshape(-1).float().sum() + bye_fp = ((p_bye.argmax(-1) > 0) * (goodbye_labels == 0)).reshape(-1).float().sum() + bye_fn = ((p_bye.argmax(-1) == 0) * (goodbye_labels > 0)).reshape(-1).float().sum() else: + req_acc, dom_acc, bye_acc = None, None, torch.tensor(0.0) req_tp, req_fp, req_fn = None, None, None dom_tp, dom_fp, dom_fn = None, None, None - gen_tp, gen_fp, gen_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) - - sl_acc = sum(jg_acc / len(model.setsumbt.informable_slot_ids)).float() - jg_acc = sum((jg_acc == len(model.setsumbt.informable_slot_ids)).int()).float() - req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0) - req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0) - req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0) - dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) - dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) - dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) - n_turns = (state_labels >= 0).reshape(-1).sum().float().item() + bye_tp, bye_fp, bye_fn = torch.tensor( + 0.0), torch.tensor(0.0), torch.tensor(0.0) + + sl_acc = sum(jg_acc / len(model.informable_slot_ids)).float() + jg_acc = sum((jg_acc / len(model.informable_slot_ids)).int()).float() + req_acc = sum(req_acc / len(model.requestable_slot_ids)).float() if req_acc is not None else torch.tensor(0.0) + req_tp = sum(req_tp / len(model.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0) + req_fp = sum(req_fp / len(model.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0) + req_fn = sum(req_fn / len(model.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0) + dom_tp = sum(dom_tp / len(model.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0) + dom_fp = sum(dom_fp / len(model.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0) + dom_fn = sum(dom_fn / len(model.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0) + dom_acc = sum(dom_acc / len(model.domain_ids)).float() if dom_acc is not None else torch.tensor(0.0) + n_turns = (labels >= 0).reshape(-1).sum().float().item() accuracy_jg.append(jg_acc.item()) accuracy_sl.append(sl_acc.item()) + accuracy_req.append(req_acc.item()) truepos_req.append(req_tp.item()) falsepos_req.append(req_fp.item()) falseneg_req.append(req_fn.item()) + accuracy_dom.append(dom_acc.item()) truepos_dom.append(dom_tp.item()) falsepos_dom.append(dom_fp.item()) falseneg_dom.append(dom_fn.item()) - truepos_gen.append(gen_tp.item()) - falsepos_gen.append(gen_fp.item()) - falseneg_gen.append(gen_fn.item()) + accuracy_bye.append(bye_acc.item()) + truepos_bye.append(bye_tp.item()) + falsepos_bye.append(bye_fp.item()) + falseneg_bye.append(bye_fn.item()) turns.append(n_turns) tr_loss += loss.item() + # for slot in logits: + # logits[slot] = torch.cat(logits[slot], 0) + # Global accuracy reduction across batches turns = sum(turns) jg_acc = sum(accuracy_jg) / turns sl_acc = sum(accuracy_sl) / turns if model.config.predict_actions: + req_acc = sum(accuracy_req) / turns req_tp = sum(truepos_req) req_fp = sum(falsepos_req) req_fn = sum(falseneg_req) - req_f1 = req_tp + 0.5 * (req_fp + req_fn) - req_f1 = req_tp / req_f1 if req_f1 != 0.0 else 0.0 + req_f1 = req_tp / (req_tp + 0.5 * (req_fp + req_fn)) + dom_acc = sum(accuracy_dom) / turns dom_tp = sum(truepos_dom) dom_fp = sum(falsepos_dom) dom_fn = sum(falseneg_dom) - dom_f1 = dom_tp + 0.5 * (dom_fp + dom_fn) - dom_f1 = dom_tp / dom_f1 if dom_f1 != 0.0 else 0.0 - gen_tp = sum(truepos_gen) - gen_fp = sum(falsepos_gen) - gen_fn = sum(falseneg_gen) - gen_f1 = gen_tp + 0.5 * (gen_fp + gen_fn) - gen_f1 = gen_tp / gen_f1 if gen_f1 != 0.0 else 0.0 + dom_f1 = dom_tp / (dom_tp + 0.5 * (dom_fp + dom_fn)) + bye_tp = sum(truepos_bye) + bye_fp = sum(falsepos_bye) + bye_fn = sum(falseneg_bye) + bye_f1 = bye_tp / (bye_tp + 0.5 * (bye_fp + bye_fn)) + bye_acc = sum(accuracy_bye) / turns else: - req_f1, dom_f1, gen_f1 = None, None, None - - if return_eval_output: - dial_idx = 0 - for sample in evaluation_output: - if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0: - dial_idx = 0 - elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0: - dial_idx += 1 - elif sample['utt_idx'] == 0: - dial_idx += 1 - sample['dial_idx'] = dial_idx - - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output - if is_train: - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader) + req_acc, dom_acc, bye_acc = None, None, None + req_f1, dom_f1, bye_f1 = None, None, None + + return jg_acc, sl_acc, req_f1, dom_f1, bye_f1, tr_loss / len(dataloader) diff --git a/convlab/dst/setsumbt/multiwoz/Tracker.py b/convlab/dst/setsumbt/multiwoz/Tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..fed1a1a62334e9a721c1f5b7de442b17bfd14b76 --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/Tracker.py @@ -0,0 +1,455 @@ +import os +import json +import copy +import logging + +import torch +import transformers +from transformers import (BertModel, BertConfig, BertTokenizer, + RobertaModel, RobertaConfig, RobertaTokenizer) +from convlab.dst.setsumbt.modeling import (RobertaSetSUMBT, + BertSetSUMBT) + +from convlab.dst.dst import DST +from convlab.util.multiwoz.state import default_state +from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA +from convlab.dst.rule.multiwoz import normalize_value +from convlab.util.custom_util import model_downloader + +USE_CUDA = torch.cuda.is_available() + +# Map from SetSUMBT slot names to Convlab slot names +SLOT_MAP = {'arrive by': 'arriveBy', + 'leave at': 'leaveAt', + 'price range': 'pricerange', + 'trainid': 'trainID', + 'reference': 'Ref', + 'taxi types': 'car type'} + + +class SetSUMBTTracker(DST): + + def __init__(self, model_path="", model_type="roberta", + get_turn_pooled_representation=False, + get_confidence_scores=False, + threshold='auto', + return_entropy=False, + return_mutual_info=False, + store_full_belief_state=False): + super(SetSUMBTTracker, self).__init__() + + self.model_type = model_type + self.model_path = model_path + self.get_turn_pooled_representation = get_turn_pooled_representation + self.get_confidence_scores = get_confidence_scores + self.threshold = threshold + self.return_entropy = return_entropy + self.return_mutual_info = return_mutual_info + self.store_full_belief_state = store_full_belief_state + if self.store_full_belief_state: + self.full_belief_state = {} + self.info_dict = {} + + # Download model if needed + if not os.path.exists(self.model_path): + # Get path /.../convlab/dst/setsumbt/multiwoz/models + download_path = os.path.dirname(os.path.abspath(__file__)) + download_path = os.path.join(download_path, 'models') + if not os.path.exists(download_path): + os.mkdir(download_path) + model_downloader(download_path, self.model_path) + # Downloadable model path format http://.../setsumbt_model_name.zip + self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') + self.model_path = os.path.join(download_path, self.model_path) + + # Select model type based on the encoder + if model_type == "roberta": + self.config = RobertaConfig.from_pretrained(self.model_path) + self.tokenizer = RobertaTokenizer + self.model = RobertaSetSUMBT + elif model_type == "bert": + self.config = BertConfig.from_pretrained(self.model_path) + self.tokenizer = BertTokenizer + self.model = BertSetSUMBT + else: + logging.debug("Name Error: Not Implemented") + + self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu') + + # Value dict for value normalisation + path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + path = os.path.join(path, 'data/multiwoz/value_dict.json') + self.value_dict = json.load(open(path)) + + self.load_weights() + + def load_weights(self): + # Load tokenizer and model checkpoints + logging.info('Loading SetSUMBT pretrained model.') + self.tokenizer = self.tokenizer.from_pretrained( + self.config.tokenizer_name) + logging.info( + f'Model tokenizer loaded from {self.config.tokenizer_name}.') + self.model = self.model.from_pretrained( + self.model_path, config=self.config) + logging.info(f'Model loaded from {self.model_path}.') + + # Transfer model to compute device and setup eval environment + self.model = self.model.to(self.device) + self.model.eval() + logging.info(f'Model transferred to device: {self.device}') + + logging.info('Loading model ontology') + f = open(os.path.join(self.model_path, 'ontology.json'), 'r') + self.ontology = json.load(f) + f.close() + + db = torch.load(os.path.join(self.model_path, 'ontology.db')) + # Get slot and value embeddings + slots = {slot: db[slot] for slot in db} + values = {slot: db[slot][1] for slot in db} + del db + + # Load model ontology + self.model.add_slot_candidates(slots) + for slot in values: + self.model.add_value_candidates(slot, values[slot], replace=True) + + if self.get_confidence_scores: + logging.info('Model will output action and state confidence scores.') + if self.get_confidence_scores: + self.get_thresholds(self.threshold) + logging.info('Uncertain Querying set up and thresholds set up at:') + logging.info(self.thresholds) + if self.return_entropy: + logging.info('Model will output state distribution entropy.') + if self.return_mutual_info: + logging.info('Model will output state distribution mutual information.') + logging.info('Ontology loaded successfully.') + + self.det_dic = {} + for domain, dic in REF_USR_DA.items(): + for key, value in dic.items(): + assert '-' not in key + self.det_dic[key.lower()] = key + '-' + domain + self.det_dic[value.lower()] = key + '-' + domain + + def get_thresholds(self, threshold='auto'): + self.thresholds = {} + for slot, value_candidates in self.ontology.items(): + domain, slot = slot.split('-', 1) + slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) + slot = slot.strip().split()[1] if 'book ' in slot else slot + slot = SLOT_MAP.get(slot, slot) + + # Auto thresholds are set based on the number of value candidates per slot + if domain not in self.thresholds: + self.thresholds[domain] = {} + if threshold == 'auto': + thres = 1.0 / (float(len(value_candidates)) - 2.1) + self.thresholds[domain][slot] = max(0.05, thres) + else: + self.thresholds[domain][slot] = max(0.05, threshold) + + return self.thresholds + + def init_session(self): + self.state = default_state() + self.active_domains = {} + self.hidden_states = None + self.info_dict = {} + + def update(self, user_act=''): + prev_state = self.state + + # Convert dialogs into transformer input features (token_ids, masks, etc) + features = self.get_features(user_act) + # Model forward pass + pred_states, active_domains, user_acts, turn_pooled_representation, belief_state, entropy_, mutual_info_ = self.predict( + features) + + if entropy_ is not None: + entropy = {} + for slot, e in entropy_.items(): + domain, slot = slot.split('-', 1) + if domain not in entropy: + entropy[domain] = {} + if 'book' in slot: + assert slot.startswith('book ') + slot = slot.strip().split()[1] + slot = SLOT_MAP.get(slot, slot) + entropy[domain][slot] = e + del entropy_ + else: + entropy = None + + if mutual_info_ is not None: + mutual_info = {} + for slot, mi in mutual_info_.items(): + domain, slot = slot.split('-', 1) + if domain not in mutual_info: + mutual_info[domain] = {} + if 'book' in slot: + assert slot.startswith('book ') + slot = slot.strip().split()[1] + slot = SLOT_MAP.get(slot, slot) + mutual_info[domain][slot] = mi[0, 0] + else: + mutual_info = None + + if belief_state is not None: + bs_probs = {} + belief_state, request_dist, domain_dist, greeting_dist = belief_state + for slot, p in belief_state.items(): + domain, slot = slot.split('-', 1) + if domain not in bs_probs: + bs_probs[domain] = {} + if 'book' in slot: + assert slot.startswith('book ') + slot = slot.strip().split()[1] + slot = SLOT_MAP.get(slot, slot) + if slot not in bs_probs[domain]: + bs_probs[domain][slot] = {} + bs_probs[domain][slot]['inform'] = p + + for slot, p in request_dist.items(): + domain, slot = slot.split('-', 1) + if domain not in bs_probs: + bs_probs[domain] = {} + slot = SLOT_MAP.get(slot, slot) + if slot not in bs_probs[domain]: + bs_probs[domain][slot] = {} + bs_probs[domain][slot]['request'] = p + + for domain, p in domain_dist.items(): + if domain not in bs_probs: + bs_probs[domain] = {} + bs_probs[domain]['none'] = {'inform': p} + + if 'general' not in bs_probs: + bs_probs['general'] = {} + bs_probs['general']['none'] = greeting_dist + + new_domains = [d for d, active in active_domains.items() if active] + new_domains = [ + d for d in new_domains if not self.active_domains.get(d, False)] + self.active_domains = active_domains + + for domain in new_domains: + user_acts.append(['Inform', domain.capitalize(), 'none', 'none']) + + new_belief_state = copy.deepcopy(prev_state['belief_state']) + # user_acts = [] + for state, value in pred_states.items(): + domain, slot = state.split('-', 1) + value = '' if value == 'none' else value + value = 'dontcare' if value == 'do not care' else value + value = 'guesthouse' if value == 'guest house' else value + if slot not in ['name', 'book']: + if domain not in new_belief_state: + if domain == 'bus': + continue + else: + logging.debug( + 'Error: domain <{}> not in belief state'.format(domain)) + slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) + assert 'semi' in new_belief_state[domain] + assert 'book' in new_belief_state[domain] + if 'book' in slot: + assert slot.startswith('book ') + slot = slot.strip().split()[1] + slot = SLOT_MAP.get(slot, slot) + + # Uncertainty clipping of state + if belief_state is not None: + if bs_probs[domain][slot].get('inform', 1.0) < self.thresholds[domain][slot]: + value = '' + + domain_dic = new_belief_state[domain] + value = normalize_value(self.value_dict, domain, slot, value) + if slot in domain_dic['semi']: + new_belief_state[domain]['semi'][slot] = value + if prev_state['belief_state'][domain]['semi'][slot] != value: + user_acts.append(['Inform', domain.capitalize( + ), REF_USR_DA[domain.capitalize()].get(slot, slot), value]) + elif slot in domain_dic['book']: + new_belief_state[domain]['book'][slot] = value + if prev_state['belief_state'][domain]['book'][slot] != value: + user_acts.append(['Inform', domain.capitalize( + ), REF_USR_DA[domain.capitalize()].get(slot, slot), value]) + elif slot.lower() in domain_dic['book']: + new_belief_state[domain]['book'][slot.lower()] = value + if prev_state['belief_state'][domain]['book'][slot.lower()] != value: + user_acts.append(['Inform', domain.capitalize( + ), REF_USR_DA[domain.capitalize()].get(slot.lower(), slot.lower()), value]) + else: + logging.debug( + 'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format( + slot, value, domain, state) + ) + + new_state = copy.deepcopy(dict(prev_state)) + new_state['belief_state'] = new_belief_state + new_state['active_domains'] = self.active_domains + if belief_state is not None: + new_state['belief_state_probs'] = bs_probs + if entropy is not None: + new_state['entropy'] = entropy + if mutual_info is not None: + new_state['mutual_information'] = mutual_info + + new_state['user_action'] = user_acts + + user_requests = [[a, d, s, v] + for a, d, s, v in user_acts if a == 'Request'] + for act, domain, slot, value in user_requests: + k = REF_SYS_DA[domain].get(slot, slot) + domain = domain.lower() + if domain not in new_state['request_state']: + new_state['request_state'][domain] = {} + if k not in new_state['request_state'][domain]: + new_state['request_state'][domain][k] = 0 + + if turn_pooled_representation is not None: + new_state['turn_pooled_representation'] = turn_pooled_representation + + self.state = new_state + self.info_dict = copy.deepcopy(dict(new_state)) + + return self.state + + # Model prediction function + + def predict(self, features): + # Forward Pass + mutual_info = None + with torch.no_grad(): + turn_pooled_representation = None + if self.get_turn_pooled_representation: + belief_state, request, domain, goodbye, self.hidden_states, turn_pooled_representation = self.model(input_ids=features['input_ids'], + token_type_ids=features[ + 'token_type_ids'], + attention_mask=features[ + 'attention_mask'], + hidden_state=self.hidden_states, + get_turn_pooled_representation=True) + elif self.return_mutual_info: + belief_state, request, domain, goodbye, self.hidden_states, mutual_info = self.model(input_ids=features['input_ids'], + token_type_ids=features[ + 'token_type_ids'], + attention_mask=features[ + 'attention_mask'], + hidden_state=self.hidden_states, + get_turn_pooled_representation=False, + calculate_inform_mutual_info=True) + else: + belief_state, request, domain, goodbye, self.hidden_states = self.model(input_ids=features['input_ids'], + token_type_ids=features['token_type_ids'], + attention_mask=features['attention_mask'], + hidden_state=self.hidden_states, + get_turn_pooled_representation=False) + + # Convert belief state into dialog state + predictions = {slot: state[0, 0, :].argmax().item() + for slot, state in belief_state.items()} + predictions = {slot: self.ontology[slot][idx] + for slot, idx in predictions.items()} + predictions = {s: v for s, v in predictions.items() if v != 'none'} + + if self.store_full_belief_state: + self.full_belief_state = belief_state + + # Obtain model output probabilities + if self.get_confidence_scores: + entropy = None + if self.return_entropy: + entropy = {slot: state[0, 0, :] + for slot, state in belief_state.items()} + entropy = {slot: self.relative_entropy( + p).item() for slot, p in entropy.items()} + + # Confidence score is the max probability across all not "none" values candidates. + belief_state = {slot: state[0, 0, 1:].max().item() + for slot, state in belief_state.items()} + request_dist = {SLOT_MAP.get( + slot, slot): p[0, 0].item() for slot, p in request.items()} + domain_dist = {domain: p[0, 0].item() + for domain, p in domain.items()} + greeting_dist = {'bye': goodbye[0, 0, 1].item( + ), 'thank': goodbye[0, 0, 2].item()} + belief_state = (belief_state, request_dist, + domain_dist, greeting_dist) + else: + belief_state = None + entropy = None + + # Construct request action prediction + request = [slot for slot, p in request.items() if p[0, 0].item() > 0.5] + request = [slot.split('-', 1) for slot in request] + request = [[domain, SLOT_MAP.get(slot, slot)] + for domain, slot in request] + request = [['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get( + slot, slot), '?'] for domain, slot in request] + + # Construct active domain set + domain = {domain: p[0, 0].item() > 0.5 for domain, p in domain.items()} + + # Construct general domain action + goodbye = goodbye[0, 0, :].argmax(-1).item() + goodbye = [[], ['bye'], ['thank']][goodbye] + goodbye = [[act, 'general', 'none', 'none'] for act in goodbye] + + user_acts = request + goodbye + + return predictions, domain, user_acts, turn_pooled_representation, belief_state, entropy, mutual_info + + def relative_entropy(self, probs): + entropy = probs * torch.log(probs + 1e-8) + entropy = -entropy.sum() + # Maximum entropy of a K dimentional distribution is ln(K) + entropy /= torch.log(torch.tensor(probs.size(-1)).float()) + + return entropy + + # Convert dialog turns into model features + def get_features(self, user_act): + # Extract system utterance from dialog history + context = self.state['history'] + if context: + if context[-1][0] != 'sys': + system_act = '' + else: + system_act = context[-1][-1] + else: + system_act = '' + + # Tokenize dialog + features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, max_length=self.config.max_turn_len, + padding='max_length', truncation='longest_first') + + input_ids = torch.tensor(features['input_ids']).reshape( + 1, 1, -1).to(self.device) if 'input_ids' in features else None + token_type_ids = torch.tensor(features['token_type_ids']).reshape( + 1, 1, -1).to(self.device) if 'token_type_ids' in features else None + attention_mask = torch.tensor(features['attention_mask']).reshape( + 1, 1, -1).to(self.device) if 'attention_mask' in features else None + features = {'input_ids': input_ids, + 'token_type_ids': token_type_ids, 'attention_mask': attention_mask} + + return features + + +# if __name__ == "__main__": +# tracker = SetSUMBTTracker(model_type='roberta', model_path='/gpfs/project/niekerk/results/nbt/convlab_setsumbt_acts') +# # nlu_path='/gpfs/project/niekerk/data/bert_multiwoz_all_context.zip') +# tracker.init_session() +# state = tracker.update('hey. I need a cheap restaurant.') +# # tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) +# # tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) +# # state = tracker.update('If you have something Asian that would be great.') +# # tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) +# # tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) +# # state = tracker.update('Great. Where are they located?') +# # tracker.state['history'].append(['usr', 'Great. Where are they located?']) +# print(tracker.state) diff --git a/convlab/dst/setsumbt/multiwoz/__init__.py b/convlab/dst/setsumbt/multiwoz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f1fb894a0430545c62b78f9a6f4786c4e328a8 --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/__init__.py @@ -0,0 +1,2 @@ +from convlab.dst.setsumbt.multiwoz.dataset import multiwoz21, ontology +from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair new file mode 100644 index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2 --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/mapping.pair @@ -0,0 +1,83 @@ +it's it is +don't do not +doesn't does not +didn't did not +you'd you would +you're you are +you'll you will +i'm i am +they're they are +that's that is +what's what is +couldn't could not +i've i have +we've we have +can't cannot +i'd i would +i'd i would +aren't are not +isn't is not +wasn't was not +weren't were not +won't will not +there's there is +there're there are +. . . +restaurants restaurant -s +hotels hotel -s +laptops laptop -s +cheaper cheap -er +dinners dinner -s +lunches lunch -s +breakfasts breakfast -s +expensively expensive -ly +moderately moderate -ly +cheaply cheap -ly +prices price -s +places place -s +venues venue -s +ranges range -s +meals meal -s +locations location -s +areas area -s +policies policy -s +children child -s +kids kid -s +kidfriendly kid friendly +cards card -s +upmarket expensive +inpricey cheap +inches inch -s +uses use -s +dimensions dimension -s +driverange drive range +includes include -s +computers computer -s +machines machine -s +families family -s +ratings rating -s +constraints constraint -s +pricerange price range +batteryrating battery rating +requirements requirement -s +drives drive -s +specifications specification -s +weightrange weight range +harddrive hard drive +batterylife battery life +businesses business -s +hours hour -s +one 1 +two 2 +three 3 +four 4 +five 5 +six 6 +seven 7 +eight 8 +nine 9 +ten 10 +eleven 11 +twelve 12 +anywhere any where +good bye goodbye diff --git a/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8e98f35429ce5194d82d07a1cf0de8fee54515 --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/multiwoz21.py @@ -0,0 +1,502 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MultiWOZ 2.1/2.3 Dialogue Dataset""" + +import os +import json +import requests +import zipfile +import io +from shutil import copy2 as copy + +import torch +from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler +from tqdm import tqdm + +from convlab.dst.setsumbt.multiwoz.dataset.utils import (clean_text, ACTIVE_DOMAINS, get_domains, set_util_domains, + fix_delexicalisation, extract_dialogue, PRICERANGE, + BOOLEAN, DAYS, QUANTITIES, TIME, VALUE_MAP, map_values) + + +# Set up global data_directory +def set_datadir(dir): + global DATA_DIR + DATA_DIR = dir + + +def set_active_domains(domains): + global ACTIVE_DOMAINS + ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS] + set_util_domains(ACTIVE_DOMAINS) + + +# MultiWOZ2.1 download link +URL = 'https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip' +def set_url(url): + global URL + URL = url + + +# Create Dialogue examples from the dataset +def create_examples(max_utt_len, get_requestable_slots=False, force_processing=False): + + # Load or download Raw Data + if not os.path.exists(DATA_DIR): + os.mkdir(DATA_DIR) + if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')): + # Download data archive and extract + archive = _download() + data = _extract(archive) + + writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w') + json.dump(data, writer, indent = 2) + del archive, writer + else: + reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r') + data = json.load(reader) + + if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')): + # Preprocess all dialogues + data_processed = _process(data['data'], data['system_acts']) + # Format data and split train, test and devlopment sets + train, dev, test = _split_data(data_processed, data['testListFile'], + data['valListFile'], max_utt_len) + + # Write data + writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w') + json.dump(train, writer, indent = 2) + writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w') + json.dump(test, writer, indent = 2) + writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w') + json.dump(dev, writer, indent = 2) + writer.flush() + writer.close() + del writer + + # Extract slots and slot value candidates from the dataset + for set_type in ['train', 'dev', 'test']: + _get_ontology(set_type, get_requestable_slots) + + script_path = os.path.abspath(__file__).replace('/multiwoz21.py', '') + file_name = 'mwoz21_ont_request.json' if get_requestable_slots else 'mwoz21_ont.json' + copy(os.path.join(script_path, file_name), os.path.join(DATA_DIR, 'ontology_test.json')) + copy(os.path.join(script_path, 'mwoz21_slot_descriptions.json'), os.path.join(DATA_DIR, 'slot_descriptions.json')) + + +# Extract slots and slot value candidates from the dataset +def _get_ontology(set_type, get_requestable_slots=False): + + datasets = ['train'] + if set_type in ['test', 'dev']: + datasets.append('dev') + datasets.append('test') + + # Load examples + data = [] + for dataset in datasets: + reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r') + data += json.load(reader) + + ontology = dict() + for dial in data: + for turn in dial['dialogue']: + for state in turn['dialogue_state']: + slot, value = state + value = map_values(value) + if slot not in ontology: + ontology[slot] = [value] + else: + ontology[slot].append(value) + + requestable_slots = [] + if get_requestable_slots: + for dial in data: + for turn in dial['dialogue']: + for act, dom, slot, val in turn['user_acts']: + if act == 'request': + requestable_slots.append(f'{dom}-{slot}') + requestable_slots = list(set(requestable_slots)) + + for slot in ontology: + if 'price' in slot: + ontology[slot] = PRICERANGE + if 'parking' in slot or 'internet' in slot: + ontology[slot] = BOOLEAN + if 'day' in slot: + ontology[slot] = DAYS + if 'people' in slot or 'duration' in slot or 'stay' in slot: + ontology[slot] = QUANTITIES + if 'time' in slot or 'leave' in slot or 'arrive' in slot: + ontology[slot] = TIME + if 'stars' in slot: + ontology[slot] += [str(i) for i in range(5)] + + # Sort slot values and add none and dontcare values + for slot in ontology: + ontology[slot] = list(set(ontology[slot])) + ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']]) + for slot in requestable_slots: + if slot in ontology: + ontology[slot].append('request') + else: + ontology[slot] = ['request'] + + writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w') + json.dump(ontology, writer, indent=2) + writer.close() + + +# Convert dialogue examples to model input features and labels +def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64): + + features = dict() + + # Load examples + reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r') + data = json.load(reader) + + # Get encoder input for system, user utterance pairs + input_feats = [] + for dial in data: + dial_feats = [] + for turn in dial['dialogue']: + if len(turn['system_transcript']) == 0: + usr = turn['transcript'] + dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens = True, + max_length = max_seq_len, padding='max_length', + truncation = 'longest_first')) + else: + usr = turn['transcript'] + sys = turn['system_transcript'] + dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens = True, + max_length = max_seq_len, padding='max_length', + truncation = 'longest_first')) + if len(dial_feats) >= max_turns: + break + input_feats.append(dial_feats) + del dial_feats + + # Perform turn level padding + input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] + if 'token_type_ids' in input_feats[0][0]: + token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] + else: + token_type_ids = None + if 'attention_mask' in input_feats[0][0]: + attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] + else: + attention_mask = None + del input_feats + + # Create torch data tensors + features['input_ids'] = torch.tensor(input_ids) + features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None + features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None + del input_ids, token_type_ids, attention_mask + + # Load ontology + reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r') + ontology = json.load(reader) + reader.close() + + informable_slots = [slot for slot, values in ontology.items() if values != ['request']] + requestable_slots = [slot for slot, values in ontology.items() if 'request' in values] + for slot in requestable_slots: + ontology[slot].remove('request') + + domains = list(set(informable_slots + requestable_slots)) + domains = list(set([slot.split('-', 1)[0] for slot in domains])) + + # Create slot labels + for slot in informable_slots: + labels = [] + for dial in data: + labs = [] + for turn in dial['dialogue']: + slots_active = [s for s, v in turn['dialogue_state']] + if slot in slots_active: + value = [v for s, v in turn['dialogue_state'] if s == slot][0] + else: + value = 'none' + if value in ontology[slot]: + value = ontology[slot].index(value) + else: + value = map_values(value) + if value in ontology[slot]: + value = ontology[slot].index(value) + else: + value = -1 # If value is not in ontology then we do not penalise the model + labs.append(value) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['labels-' + slot] = labels + + for slot in requestable_slots: + labels = [] + for dial in data: + labs = [] + for turn in dial['dialogue']: + slots_active = [[d, s] for i, d, s, v in turn['user_acts']] + if slot.split('-', 1) in slots_active: + act_ = [i for i, d, s, v in turn['user_acts'] if f"{d}-{s}" == slot][0] + if act_ == 'request': + labs.append(1) + else: + labs.append(0) + else: + labs.append(0) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['request-' + slot] = labels + + # Greeting act labels (0-no greeting, 1-goodbye, 2-thank you) + labels = [] + for dial in data: + labs = [] + for turn in dial['dialogue']: + greeting_active = [i for i, d, s, v in turn['user_acts'] if i in ['bye', 'thank']] + if greeting_active: + if 'bye' in greeting_active: + labs.append(1) + else : + labs.append(2) + else: + labs.append(0) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['goodbye'] = labels + + for domain in domains: + labels = [] + for dial in data: + labs = [] + for turn in dial['dialogue']: + if domain == turn['domain']: + labs.append(1) + else: + labs.append(0) + if len(labs) >= max_turns: + break + labs = labs + [-1] * (max_turns - len(labs)) + labels.append(labs) + + labels = torch.tensor(labels) + features['active-' + domain] = labels + + del labels + + return features + + +# MultiWOZ2.1 Dataset object +class MultiWoz21(Dataset): + + def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64): + self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len) + + def __getitem__(self, index): + return {label: self.features[label][index] for label in self.features + if self.features[label] is not None} + + def __len__(self): + return self.features['input_ids'].size(0) + + def resample(self, size=None): + n_dialogues = self.__len__() + if not size: + size = n_dialogues + + dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) + self.features = {label: self.features[label][dialogues] for label in self.features + if self.features[label] is not None} + + return self + + def to(self, device): + self.device = device + self.features = {label: self.features[label].to(device) for label in self.features + if self.features[label] is not None} + + +# MultiWOZ2.1 Dataset object +class EnsembleMultiWoz21(Dataset): + def __init__(self, data): + self.features = data + + def __getitem__(self, index): + return {label: self.features[label][index] for label in self.features + if self.features[label] is not None} + + def __len__(self): + return self.features['input_ids'].size(0) + + def resample(self, size=None): + n_dialogues = self.__len__() + if not size: + size = n_dialogues + + dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) + self.features = {label: self.features[label][dialogues] for label in self.features + if self.features[label] is not None} + + def to(self, device): + self.device = device + self.features = {label: self.features[label].to(device) for label in self.features + if self.features[label] is not None} + + +# Module to create torch dataloaders +def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None): + data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len) + data.to('cpu') + + if resampled_size: + data.resample(resampled_size) + + if set_type in ['test', 'dev']: + sampler = SequentialSampler(data) + else: + sampler = RandomSampler(data) + loader = DataLoader(data, sampler=sampler, batch_size=batch_size) + + return loader + + +def _download(chunk_size=1048576): + """Download data archive. + + Parameters: + chunk_size (int): Download chunk size. (default=1048576) + Returns: + archive: ZipFile archive object. + """ + # Download the archive byte string + req = requests.get(URL, stream=True) + archive = b'' + for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'): + if chunk: + archive += chunk + + # Convert the bytestring into a zipfile object + archive = io.BytesIO(archive) + archive = zipfile.ZipFile(archive) + + return archive + + +def _extract(archive): + """Extract the json dictionaries from the archive. + + Parameters: + archive: ZipFile archive object. + Returns: + data: Data dictionary. + """ + files = [file for file in archive.filelist if ('.json' in file.filename or '.txt' in file.filename) + and 'MACOSX' not in file.filename] + objects = [] + for file in tqdm(files, desc='File'): + data = archive.open(file).read() + # Get data objects from the files + try: + data = json.loads(data) + except json.decoder.JSONDecodeError: + data = data.decode().split('\n') + objects.append(data) + + files = [file.filename.split('/')[-1].split('.')[0] for file in files] + + data = {file: data for file, data in zip(files, objects)} + return data + + +# Process files +def _process(dialogue_data, acts_data): + print('Processing Dialogues') + out = {} + for dial_name in tqdm(dialogue_data): + dialogue = dialogue_data[dial_name] + + prev_dom = '' + for turn_id, turn in enumerate(dialogue['log']): + dialogue['log'][turn_id]['text'] = clean_text(turn['text']) + if len(turn['metadata']) != 0: + crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom) + prev_dom = crnt_dom + dialogue['log'][turn_id - 1]['domain'] = crnt_dom + + dialogue['log'][turn_id] = fix_delexicalisation(turn) + + out[dial_name] = dialogue + + return out + + +# Split data (train, dev, test) +def _split_data(dial_data, test, dev, max_utt_len): + train_dials, test_dials, dev_dials = [], [], [] + print('Formatting and Splitting Data') + for name in tqdm(dial_data): + dialogue = dial_data[name] + domains = [] + + dial = extract_dialogue(dialogue, max_utt_len) + if dial: + dialogue = dict() + dialogue['dialogue_idx'] = name + dialogue['domains'] = [] + dialogue['dialogue'] = [] + + for turn_id, turn in enumerate(dial): + turn_dialog = dict() + turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else '' + turn_dialog['turn_idx'] = turn_id + turn_dialog['dialogue_state'] = turn['ds'] + turn_dialog['transcript'] = turn['usr'] + # turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else [] + turn_dialog['user_acts'] = turn['usr_a'] + turn_dialog['domain'] = turn['domain'] + dialogue['domains'].append(turn['domain']) + dialogue['dialogue'].append(turn_dialog) + + dialogue['domains'] = [d for d in list(set(dialogue['domains'])) if d != ''] + if True in [dom not in ACTIVE_DOMAINS for dom in dialogue['domains']]: + dialogue['domains'] = [] + dialogue['domains'] = [dom for dom in dialogue['domains'] if dom in ACTIVE_DOMAINS] + + if dialogue['domains']: + if name in test: + test_dials.append(dialogue) + elif name in dev: + dev_dials.append(dialogue) + else: + train_dials.append(dialogue) + + print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials))) + + return train_dials, dev_dials, test_dials diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json new file mode 100644 index 0000000000000000000000000000000000000000..b703793dc747535132b748d8b69838b4c151d8d5 --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont.json @@ -0,0 +1,2990 @@ +{ + "hotel-price range": [ + "none", + "do not care", + "cheap", + "expensive", + "moderate" + ], + "hotel-type": [ + "none", + "do not care", + "bed and breakfast", + "guest house", + "hotel" + ], + "hotel-parking": [ + "none", + "do not care", + "no", + "yes" + ], + "hotel-book day": [ + "none", + "do not care", + "friday", + "monday", + "saterday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "hotel-book people": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "hotel-book stay": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "train-destination": [ + "none", + "do not care", + "bishops stortford", + "kings lynn", + "london liverpool street", + "centre", + "bishop stortford", + "liverpool", + "leicester", + "broxbourne", + "gourmet burger kitchen", + "copper kettle", + "bournemouth", + "stevenage", + "liverpool street", + "norwich", + "huntingdon marriott hotel", + "city centre north", + "taj tandoori", + "the copper kettle", + "peterborough", + "ely", + "lecester", + "london", + "willi", + "stansted airport", + "huntington marriott", + "cambridge", + "gonv", + "glastonbury", + "hol", + "north", + "birmingham new street", + "norway", + "petersborough", + "london kings cross", + "curry prince", + "bishops storford" + ], + "train-arrive by": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55" + ], + "train-departure": [ + "none", + "do not care", + "bishops stortford", + "kings lynn", + "brookshite", + "london liverpool street", + "cam", + "liverpool", + "bro", + "leicester", + "broxbourne", + "norwhich", + "saint johns", + "stevenage", + "stansted", + "london liverpool", + "cambrid", + "city hall", + "rosas bed and breakfast", + "alpha-milton", + "wandlebury country park", + "norwich", + "liecester", + "stratford", + "peterborough", + "duxford", + "ely", + "london", + "stansted airport", + "lon", + "cambridge", + "panahar", + "cineworld", + "leicaster", + "birmingham", + "cafe uno", + "camboats", + "huntingdon", + "birmingham new street", + "arbu", + "alpha milton", + "east london", + "london kings cross", + "hamilton lodge", + "aylesbray lodge guest", + "el shaddai" + ], + "train-day": [ + "none", + "do not care", + "friday", + "monday", + "saterday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "train-book people": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "hotel-stars": [ + "none", + "do not care", + "0", + "1", + "2", + "3", + "4", + "5" + ], + "hotel-internet": [ + "none", + "do not care", + "no", + "yes" + ], + "hotel-name": [ + "a and b guest house", + "city roomz", + "carolina bed and breakfast", + "limehouse", + "anatolia", + "hamilton lodge", + "the lensfield hotel", + "rosa's bed and breakfast", + "gall", + "aylesbray lodge", + "kirkwood", + "cambridge belfry", + "warkworth house", + "gonville", + "belfy hotel", + "nus", + "alexander", + "super 5", + "aylesbray lodge guest house", + "the gonvile hotel", + "allenbell", + "nothamilton lodge", + "ashley hotel", + "autumn house", + "hobsons house", + "hotel", + "ashely hotel", + "caridge belfrey", + "el shaddia guest house", + "avalon", + "cote", + "city centre north bed and breakfast", + "the cambridge belfry", + "home from home", + "wandlebury coutn", + "wankworth house", + "city stop rest", + "the worth house", + "cityroomz", + "huntingdon marriottt hotel", + "none", + "lensfield", + "rosas bed and breakfast", + "leverton house", + "gonville hotel", + "holiday inn cambridge", + "do not care", + "archway house", + "lan hon", + "levert", + "acorn guest house", + "cambridge", + "the ashley hotel", + "el shaddai", + "sleeperz", + "alpha milton guest house", + "doubletree by hilton cambridge", + "tandoori palace", + "express by", + "express by holiday inn cambridge", + "north bed and breakfast", + "holiday inn", + "arbury lodge guest house", + "alexander bed and breakfast", + "huntingdon marriott hotel", + "royal spice", + "sou", + "finches bed and breakfast", + "the alpha milton", + "bridge guest house", + "the acorn guest house", + "kirkwood house", + "eraina", + "la margherit", + "lensfield hotel", + "marriott hotel", + "nusha", + "city centre bed and breakfast", + "the allenbell", + "university arms hotel", + "clare", + "cherr", + "wartworth", + "acorn place", + "lovell lodge", + "whale" + ], + "train-leave at": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55" + ], + "restaurant-price range": [ + "none", + "do not care", + "cheap", + "expensive", + "moderate" + ], + "restaurant-food": [ + "british food", + "steakhouse", + "turkish", + "sushi", + "north american", + "scottish", + "french", + "austrian", + "korean", + "eastern european", + "swedish", + "gastro pub", + "modern eclectic", + "afternoon tea", + "welsh", + "christmas", + "tuscan", + "gastropub", + "sri lankan", + "molecular gastronomy", + "traditional american", + "italian", + "pizza", + "thai", + "south african", + "creative", + "english", + "asian", + "lebanese", + "hungarian", + "halal", + "portugese", + "modern english", + "african", + "light bites", + "malaysian", + "venetian", + "traditional", + "chinese", + "vegetarian", + "persian", + "thai and chinese", + "scandinavian", + "catalan", + "polynesian", + "crossover", + "canapes", + "cantonese", + "north african", + "seafood", + "brazilian", + "south indian", + "australasian", + "belgian", + "barbeque", + "the americas", + "indonesian", + "singaporean", + "irish", + "middle eastern", + "dojo noodle bar", + "caribbean", + "vietnamese", + "modern european", + "russian", + "none", + "german", + "world", + "japanese", + "moroccan", + "modern global", + "do not care", + "indian", + "british", + "american", + "danish", + "panasian", + "swiss", + "basque", + "north indian", + "modern american", + "australian", + "european", + "corsica", + "greek", + "northern european", + "mediterranean", + "portuguese", + "romanian", + "jamaican", + "polish", + "international", + "unusual", + "latin american", + "asian oriental", + "mexican", + "bistro", + "cuban", + "fusion", + "new zealand", + "spanish", + "eritrean", + "afghan", + "kosher" + ], + "attraction-name": [ + "downing college", + "fitzwilliam", + "clare college", + "ruskin gallery", + "sidney sussex college", + "great saint mary's church", + "cherry hinton water play park", + "wandlebury country park", + "cafe uno", + "place", + "broughton", + "cineworld cinema", + "jesus college", + "vue cinema", + "history of science museum", + "mumford theatre", + "whale of time", + "fitzbillies", + "christs church", + "churchill college", + "museum of classical archaeology", + "gonville and caius college", + "pizza", + "kirkwood", + "saint catharines college", + "kings college", + "parkside", + "by", + "st catharines college", + "saint john's college", + "cherry hinton water park", + "st christs college", + "christ's college", + "bangkok city", + "scudamores punti co", + "free", + "great saint marys church", + "milton country park", + "the fez club", + "soultree", + "autu", + "whipple museum of the history of science", + "aylesbray lodge guest house", + "broughton house gallery", + "peoples portraits exhibition", + "primavera", + "kettles yard", + "all saint's church", + "cinema cinema", + "regency gallery", + "corpus christi", + "corn cambridge exchange", + "da vinci pizzeria", + "school", + "hobsons house", + "cambride and country folk museum", + "north", + "da v", + "cambridge corn exchange", + "soul tree nightclub", + "cambridge arts theater", + "saint catharine's college", + "byard art", + "cambridge punter", + "cambridge university botanic gardens", + "castle galleries", + "museum of archaelogy and anthropogy", + "no specific location", + "cherry hinton hall", + "gallery at 12 a high street", + "parkside pools", + "queen's college", + "little saint mary's church", + "gallery", + "home from home", + "tenpin", + "the wandlebury", + "county folk museum", + "swimming pool", + "christs college", + "cafe jello museum", + "scott polar", + "christ college", + "cambridge museum of technology", + "abbey pool and astroturf pitch", + "king hedges learner pool", + "the cambridge arts theatre", + "the castle galleries", + "cambridge and country folk museum", + "kohinoor", + "scudamores punting co", + "sidney sussex", + "the man on the moon", + "little saint marys church", + "queens", + "the place", + "old school", + "churchill", + "churchills college", + "hughes hall", + "churchhill college", + "riverboat georgina", + "none", + "belf", + "cambridge temporary art", + "abc theatre", + "cambridge contemporary art museum", + "man on the moon", + "the junction", + "cherry hinton water play", + "adc theatre", + "gonville hotel", + "magdalene college", + "peoples portraits exhibition at girton college", + "boat", + "centre", + "sheep's green and lammas land park fen causeway", + "do not care", + "the mumford theatre", + "archway house", + "queens' college", + "williams art and antiques", + "funky fun house", + "cherry hinton village centre", + "camboats", + "cambridge", + "old schools", + "kettle's yard", + "whale of a time", + "the churchill college", + "cafe jello gallery", + "aut", + "salsa", + "city", + "clare hall", + "boating", + "pembroke college", + "kings hedges learner pool", + "caffe uno", + "lammas land park", + "museum", + "the fitzwilliam museum", + "the cherry hinton village centre", + "the cambridge corn exchange", + "fitzwilliam museum", + "museum of archaelogy and anthropology", + "fez club", + "the cambridge punter", + "saint johns college", + "emmanuel college", + "cambridge belf", + "scudamore", + "lynne strover gallery", + "king's college", + "whippple museum", + "trinity college", + "college in the north", + "sheep's green", + "kambar", + "museum of archaelogy", + "adc", + "garde", + "club salsa", + "people's portraits exhibition at girton college", + "botanic gardens", + "carol", + "college", + "gallery at twelve a high street", + "abbey pool and astroturf", + "cambridge book and print gallery", + "jesus green outdoor pool", + "scott polar museum", + "saint barnabas press gallery", + "cambridge artworks", + "older churches", + "cambridge contemporary art", + "cherry hinton hall and grounds", + "univ", + "jesus green", + "ballare", + "abbey pool", + "cambridge botanic gardens", + "nusha", + "worth house", + "thanh", + "university arms hotel", + "cambridge arts theatre", + "cafe jello", + "cambridge and county folk museum", + "the cambridge artworks", + "all saints church", + "holy trinity church", + "contemporary art museum", + "architectural churches", + "queens college", + "trinity street college" + ], + "restaurant-name": [ + "none", + "do not care", + "hotel du vin and bistro", + "ask", + "gourmet formal kitchen", + "the meze bar", + "lan hong house", + "cow pizza", + "one seven", + "prezzo", + "maharajah tandoori restaurant", + "alex", + "shanghai", + "golden wok", + "restaurant", + "fitzbillies", + "nil", + "copper kettle", + "meghna", + "hk fusion", + "bangkok city", + "hobsons house", + "tang chinese", + "anatolia", + "ugly duckling", + "anatolia and efes restaurant", + "sitar tandoori", + "city stop", + "ashley", + "pizza express fen ditton", + "molecular gastronomy", + "autumn house", + "el shaddia guesthouse", + "the grafton hotel", + "limehouse", + "gardenia", + "not metioned", + "hakka", + "michaelhouse cafe", + "pipasha", + "meze bar", + "archway", + "molecular gastonomy", + "yipee noodle bar", + "the peking", + "curry prince", + "midsummer house restaurant", + "pizza hut cherry hinton", + "the lucky star", + "stazione restaurant and coffee bar", + "shanghi family restaurant", + "good luck", + "j restaurant", + "bedouin", + "cott", + "little seoul", + "south", + "thanh binh", + "el", + "efes restaurant", + "kohinoor", + "clowns", + "india", + "the slug and lettuce", + "shiraz", + "barbakan", + "zizzi cambridge", + "restaurant one seven", + "slug and lettuce", + "travellers rest", + "binh", + "worth house", + "broughton house gallery", + "chiquito", + "the river bar steakhouse and grill", + "ros", + "golden house", + "india west", + "cam", + "panahar", + "restaurant 22", + "adden", + "indian", + "hu", + "jinling noodle bar", + "darrys cookhouse and wine shop", + "hobson house", + "cambridge be", + "el shaddai", + "ac", + "nandos", + "cambridge lodge", + "the cow pizza kitchen and bar", + "charlie", + "rajmahal", + "kymmoy", + "cambri", + "backstreet bistro", + "galleria", + "restaurant 2 two", + "chiquito restaurant bar", + "royal standard", + "lucky star", + "curry king", + "grafton hotel restaurant", + "mahal of cambridge", + "the bedouin", + "nus", + "the kohinoor", + "pizza hut fenditton", + "camboats", + "the gardenia", + "de luca cucina and bar", + "nusha", + "european", + "taj tandoori", + "tandoori palace", + "golden curry", + "efes", + "loch fyne", + "the maharajah tandoor", + "lovel", + "restaurant 17", + "clowns cafe", + "cambridge punter", + "bloomsbury restaurant", + "la mimosa", + "the cambridge chop house", + "funky", + "cotto", + "oak bistro", + "restaurant two two", + "pipasha restaurant", + "river bar steakhouse and grill", + "royal spice", + "the copper kettle", + "graffiti", + "nandos city centre", + "saffron brasserie", + "cambridge chop house", + "sitar", + "kitchen and bar", + "the good luck chinese food takeaway", + "clu", + "la tasca", + "cafe uno", + "cote", + "the varsity restaurant", + "bri", + "eraina", + "bridge", + "fin", + "cambridge lodge restaurant", + "grafton", + "hotpot", + "sala thong", + "margherita", + "wise buddha", + "the missing sock", + "seasame restaurant and bar", + "the dojo noodle bar", + "restaurant alimentum", + "gastropub", + "saigon city", + "la margherita", + "pizza hut", + "curry garden", + "ashley hotel", + "eraina and michaelhouse cafe", + "the golden curry", + "curry queen", + "cow pizza kitchen and bar", + "the peking restaurant:", + "hamilton lodge", + "alimentum", + "yippee noodle bar", + "2 two and cote", + "shanghai family restaurant", + "grafton hotel", + "yes", + "ali baba", + "dif", + "fitzbillies restaurant", + "peking restaurant", + "lev", + "nirala", + "the alex", + "tandoori", + "city stop restaurant", + "rice house", + "cityr", + "yu garden", + "meze bar restaurant", + "the", + "don pasquale pizzeria", + "rice boat", + "the hotpot", + "old school", + "the oak bistro", + "sesame restaurant and bar", + "pizza express", + "the gandhi", + "pizza hut fen ditton", + "charlie chan", + "da vinci pizzeria", + "dojo noodle bar", + "gourmet burger kitchen", + "the golden house", + "india house", + "hobso", + "missing sock", + "pizza hut city centre", + "parkside pools", + "riverside brasserie", + "caffe uno", + "primavera", + "the nirala", + "wagamama", + "au", + "ian hong house", + "frankie and bennys", + "4 kings parade city centre", + "shiraz restaurant", + "scudamores punt", + "mahal", + "saint johns chop house", + "de luca cucina and bar riverside brasserie", + "cocum", + "la raza" + ], + "attraction-type": [ + "none", + "do not care", + "architecture", + "boat", + "boating", + "camboats", + "church", + "churchills college", + "cinema", + "college", + "concert", + "concerthall", + "entertainment", + "gallery", + "gastropub", + "hiking", + "hotel", + "multiple sports", + "museum", + "museum kettles yard", + "night club", + "outdoor", + "park", + "pool", + "special", + "sports", + "swimming pool", + "theater", + "theatre", + "concert hall", + "local site", + "nightclub", + "hotspot" + ], + "taxi-leave at": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55" + ], + "taxi-destination": [ + "none", + "do not care", + "a and b guest house", + "abbey pool and astroturf pitch", + "acorn guest house", + "adc theatre", + "addenbrookes hospital", + "alexander bed and breakfast", + "ali baba", + "all saints church", + "allenbell", + "alpha milton guest house", + "anatolia", + "arbury lodge guesthouse", + "archway house", + "ashley hotel", + "ask", + "attraction", + "autumn house", + "avalon", + "aylesbray lodge guest house", + "backstreet bistro", + "ballare", + "bangkok city", + "bedouin", + "birmingham new street train station", + "bishops stortford train station", + "bloomsbury restaurant", + "bridge guest house", + "broughton house gallery", + "broxbourne train station", + "byard art", + "cafe jello gallery", + "cafe uno", + "camboats", + "cambridge", + "cambridge and county folk museum", + "cambridge arts theatre", + "cambridge artworks", + "cambridge belfry", + "cambridge book and print gallery", + "cambridge chop house", + "cambridge contemporary art", + "cambridge county fair next to the city tourist museum", + "cambridge lodge restaurant", + "cambridge museum of technology", + "cambridge punter", + "cambridge road church of christ", + "cambridge train station", + "cambridge university botanic gardens", + "carolina bed and breakfast", + "castle galleries", + "charlie chan", + "cherry hinton hall and grounds", + "cherry hinton village centre", + "cherry hinton water park", + "cherry hinton water play", + "chiquito restaurant bar", + "christ college", + "churchills college", + "cineworld cinema", + "city centre north bed and breakfast", + "city stop restaurant", + "cityroomz", + "clare college", + "clare hall", + "clowns cafe", + "club salsa", + "cocum", + "copper kettle", + "corpus christi", + "cote", + "cotto", + "cow pizza kitchen and bar", + "curry garden", + "curry king", + "curry prince", + "da vinci pizzeria", + "darrys cookhouse and wine shop", + "de luca cucina and bar", + "dojo noodle bar", + "don pasquale pizzeria", + "downing college", + "efes restaurant", + "el shaddia guesthouse", + "ely train station", + "emmanuel college", + "eraina", + "express by holiday inn cambridge", + "finches bed and breakfast", + "finders corner newmarket road", + "fitzbillies restaurant", + "fitzwilliam museum", + "frankie and bennys", + "funky fun house", + "galleria", + "gallery at 12 a high street", + "gastropub", + "golden curry", + "golden house", + "golden wok", + "gonville and caius college", + "gonville hotel", + "good luck", + "gourmet burger kitchen", + "graffiti", + "grafton hotel restaurant", + "great saint marys church", + "hakka", + "hamilton lodge", + "hk fusion", + "hobsons house", + "holy trinity church", + "home from home", + "hotel du vin and bistro", + "hughes hall", + "huntingdon marriott hotel", + "ian hong", + "india house", + "j restaurant", + "jesus college", + "jesus green outdoor pool", + "jinling noodle bar", + "kambar", + "kettles yard", + "kings college", + "kings hedges learner pool", + "kirkwood house", + "kohinoor", + "kymmoy", + "la margherita", + "la mimosa", + "la raza", + "la tasca", + "lan hong house", + "leicester train station", + "lensfield hotel", + "limehouse", + "little saint marys church", + "little seoul", + "loch fyne", + "london kings cross train station", + "london liverpool street train station", + "lovell lodge", + "lynne strover gallery", + "magdalene college", + "mahal of cambridge", + "maharajah tandoori restaurant", + "meghna", + "meze bar", + "michaelhouse cafe", + "midsummer house restaurant", + "milton country park", + "mumford theatre", + "museum of archaelogy and anthropology", + "museum of classical archaeology", + "nandos", + "nandos city centre", + "nil", + "nirala", + "norwich train station", + "nusha", + "old schools", + "panahar", + "parkside police station", + "parkside pools", + "peking restaurant", + "pembroke college", + "peoples portraits exhibition at girton college", + "peterborough train station", + "pipasha restaurant", + "pizza express", + "pizza hut cherry hinton", + "pizza hut city centre", + "pizza hut fenditton", + "prezzo", + "primavera", + "queens college", + "rajmahal", + "regency gallery", + "restaurant 17", + "restaurant 2 two", + "restaurant alimentum", + "rice boat", + "rice house", + "riverboat georgina", + "riverside brasserie", + "rosas bed and breakfast", + "royal spice", + "royal standard", + "ruskin gallery", + "saffron brasserie", + "saigon city", + "saint barnabas", + "saint barnabas press gallery", + "saint catharines college", + "saint johns chop house", + "saint johns college", + "sala thong", + "scott polar museum", + "scudamores punting co", + "sesame restaurant and bar", + "shanghai family restaurant", + "sheeps green and lammas land park fen causeway", + "shiraz", + "sidney sussex college", + "sitar tandoori", + "sleeperz hotel", + "soul tree nightclub", + "st johns chop house", + "stansted airport train station", + "station road", + "stazione restaurant and coffee bar", + "stevenage train station", + "taj tandoori", + "tall monument", + "tandoori palace", + "tang chinese", + "tenpin", + "thanh binh", + "the anatolia", + "the cambridge corn exchange", + "the cambridge shop", + "the fez club", + "the gandhi", + "the gardenia", + "the hotpot", + "the junction", + "the lucky star", + "the man on the moon", + "the missing sock", + "the oak bistro", + "the place", + "the regent street city center", + "the river bar steakhouse and grill", + "the slug and lettuce", + "the varsity restaurant", + "travellers rest", + "trinity college", + "ugly duckling", + "university arms hotel", + "vue cinema", + "wagamama", + "wandlebury country park", + "wankworth hotel", + "warkworth house", + "whale of a time", + "whipple museum of the history of science", + "williams art and antiques", + "worth house", + "yippee noodle bar", + "yu garden", + "zizzi cambridge", + "leverton house", + "the cambridge chop house", + "saint john's college", + "churchill college", + "the nirala", + "the cow pizza kitchen and bar", + "christ's college", + "el shaddai", + "saint catharine's college", + "camb", + "the golden curry", + "little saint mary's church", + "country folk museum", + "meze bar restaurant", + "the cambridge belfry", + "the fitzwilliam museum", + "the lensfield hotel", + "pizza express fen ditton", + "the cambridge punter", + "king's college", + "the cherry hinton village centre", + "shiraz restaurant", + "sheep's green and lammas land park fen causeway", + "caffe uno", + "the ghandi", + "the copper kettle", + "man on the moon concert hall", + "alpha-milton guest house", + "queen's college", + "restaurant one seven", + "restaurant two two", + "city centre north b and b", + "rosa's bed and breakfast", + "the good luck chinese food takeaway", + "not museum of archaeology and anthropologymentioned", + "tandori in cambridge", + "kettle's yard", + "megna", + "grou", + "gallery at twelve a high street", + "maharajah tandoori restaurant", + "pizza hut fen ditton", + "gandhi", + "tranh binh", + "kambur", + "people's portraits exhibition at girton college", + "hotel", + "restaurant", + "the galleria", + "queens' college", + "great saint mary's church", + "theathre", + "cambridge artworks", + "acorn house", + "shiraz", + "riverboat georginawd", + "mic", + "the gallery at twelve", + "the soul tree", + "finches" + ], + "taxi-departure": [ + "none", + "do not care", + "172 chestertown road", + "4455 woodbridge road", + "a and b guest house", + "abbey pool and astroturf pitch", + "acorn guest house", + "adc theatre", + "addenbrookes hospital", + "alexander bed and breakfast", + "ali baba", + "all saints church", + "allenbell", + "alpha milton guest house", + "alyesbray lodge hotel", + "ambridge", + "anatolia", + "arbury lodge guesthouse", + "archway house", + "ashley hotel", + "ask", + "autumn house", + "avalon", + "aylesbray lodge guest house", + "backstreet bistro", + "ballare", + "bangkok city", + "bedouin", + "birmingham new street train station", + "bishops stortford train station", + "bloomsbury restaurant", + "bridge guest house", + "broughton house gallery", + "broxbourne train station", + "byard art", + "cafe jello gallery", + "cafe uno", + "caffee uno", + "camboats", + "cambridge", + "cambridge and county folk museum", + "cambridge arts theatre", + "cambridge artworks", + "cambridge belfry", + "cambridge book and print gallery", + "cambridge chop house", + "cambridge contemporary art", + "cambridge lodge restaurant", + "cambridge museum of technology", + "cambridge punter", + "cambridge towninfo centre", + "cambridge train station", + "cambridge university botanic gardens", + "carolina bed and breakfast", + "castle galleries", + "centre of town at my hotel", + "charlie chan", + "cherry hinton hall and grounds", + "cherry hinton village center", + "cherry hinton village centre", + "cherry hinton water play", + "chiquito restaurant bar", + "christ college", + "churchills college", + "cineworld cinema", + "citiroomz", + "city centre north bed and breakfast", + "city stop restaurant", + "cityroomz", + "clair hall", + "clare college", + "clare hall", + "clowns cafe", + "club salsa", + "cocum", + "copper kettle", + "corpus christi", + "cote", + "cotto", + "cow pizza kitchen and bar", + "curry garden", + "curry king", + "curry prince", + "curry queen", + "da vinci pizzeria", + "darrys cookhouse and wine shop", + "de luca cucina and bar", + "dojo noodle bar", + "don pasquale pizzeria", + "downing college", + "downing street", + "el shaddia guesthouse", + "ely", + "ely train station", + "emmanuel college", + "eraina", + "express by holiday inn cambridge", + "finches bed and breakfast", + "fitzbillies restaurant", + "fitzwilliam museum", + "frankie and bennys", + "funky fun house", + "galleria", + "gallery at 12 a high street", + "girton college", + "golden curry", + "golden house", + "golden wok", + "gonville and caius college", + "gonville hotel", + "good luck", + "gourmet burger kitchen", + "graffiti", + "grafton hotel restaurant", + "great saint marys church", + "hakka", + "hamilton lodge", + "hobsons house", + "holy trinity church", + "home", + "home from home", + "hotel", + "hotel du vin and bistro", + "hughes hall", + "huntingdon marriott hotel", + "india house", + "j restaurant", + "jesus college", + "jesus green outdoor pool", + "jinling noodle bar", + "junction theatre", + "kambar", + "kettles yard", + "kings college", + "kings hedges learner pool", + "kings lynn train station", + "kirkwood house", + "kohinoor", + "kymmoy", + "la margherita", + "la mimosa", + "la raza", + "la tasca", + "lan hong house", + "lensfield hotel", + "leverton house", + "limehouse", + "little saint marys church", + "little seoul", + "loch fyne", + "london kings cross train station", + "london liverpool street", + "london liverpool street train station", + "lovell lodge", + "lynne strover gallery", + "magdalene college", + "mahal of cambridge", + "maharajah tandoori restaurant", + "meghna", + "meze bar", + "michaelhouse cafe", + "milton country park", + "mumford theatre", + "museum", + "museum of archaelogy and anthropology", + "museum of classical archaeology", + "nandos", + "nandos city centre", + "new england", + "nirala", + "norwich train station", + "nstaot mentioned", + "nusha", + "old schools", + "panahar", + "parkside police station", + "parkside pools", + "peking restaurant", + "pembroke college", + "peoples portraits exhibition at girton college", + "peterborough train station", + "pizza express", + "pizza hut cherry hinton", + "pizza hut city centre", + "pizza hut fenditton", + "prezzo", + "primavera", + "queens college", + "rajmahal", + "regency gallery", + "restaurant 17", + "restaurant 2 two", + "restaurant alimentum", + "rice boat", + "rice house", + "riverboat georgina", + "riverside brasserie", + "rosas bed and breakfast", + "royal spice", + "royal standard", + "ruskin gallery", + "saffron brasserie", + "saigon city", + "saint barnabas press gallery", + "saint catharines college", + "saint johns chop house", + "saint johns college", + "sala thong", + "scott polar museum", + "scudamores punting co", + "sesame restaurant and bar", + "sheeps green and lammas land park", + "sheeps green and lammas land park fen causeway", + "shiraz", + "sidney sussex college", + "sitar tandoori", + "soul tree nightclub", + "st johns college", + "stazione restaurant and coffee bar", + "stevenage train station", + "taj tandoori", + "tandoori palace", + "tang chinese", + "tenpin", + "thanh binh", + "the cambridge corn exchange", + "the fez club", + "the gallery at 12", + "the gandhi", + "the gardenia", + "the hotpot", + "the junction", + "the lucky star", + "the man on the moon", + "the missing sock", + "the oak bistro", + "the place", + "the river bar steakhouse and grill", + "the slug and lettuce", + "the varsity restaurant", + "travellers rest", + "trinity college", + "ugly duckling", + "university arms hotel", + "vue cinema", + "wagamama", + "wandlebury country park", + "warkworth house", + "whale of a time", + "whipple museum of the history of science", + "williams art and antiques", + "worth house", + "yippee noodle bar", + "yu garden", + "zizzi cambridge", + "christ's college", + "city centre north b and b", + "the lensfield hotel", + "alpha-milton guest house", + "el shaddai", + "churchill college", + "the cambridge belfry", + "king's college", + "great saint mary's church", + "restaurant two two", + "queens' college", + "little saint mary's church", + "chinese city centre", + "kettle's yard", + "pizza hut", + "the golden curry", + "rosa's bed and breakfast", + "the cambridge punter", + "the byard art museum", + "saint catharine's college", + "meze bar restaurant", + "the good luck chinese food takeaway", + "restaurant one seven", + "pizza hut fen ditton", + "the nirala", + "the fitzwilliam museum", + "st. john's college", + "gallery at twelve a high street", + "sheep's green and lammas land park fen causeway", + "the cherry hinton village centre", + "pizza express fen ditton", + "corpus cristi", + "cas", + "acorn house", + "lens", + "the cambridge chop house", + "the copper kettle", + "the avalon", + "saint john's college", + "aylesbray lodge", + "the alexander bed and breakfast", + "cambridge belfy", + "people's portraits exhibition at girton college", + "gonville", + "caffe uno", + "the cow pizza kitchen and bar", + "lovell ldoge", + "cinema", + "shiraz restaurant", + "park", + "the allenbell" + ], + "restaurant-book day": [ + "none", + "do not care", + "friday", + "monday", + "saterday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "restaurant-book people": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "restaurant-book time": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55" + ], + "taxi-arrive by": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55" + ], + "restaurant-area": [ + "none", + "do not care", + "centre", + "east", + "north", + "south", + "west" + ], + "hotel-area": [ + "none", + "do not care", + "centre", + "east", + "north", + "south", + "west" + ], + "attraction-area": [ + "none", + "do not care", + "centre", + "east", + "north", + "south", + "west" + ] +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json new file mode 100644 index 0000000000000000000000000000000000000000..b0dd00fdf6dc2b824f1f50a44e776c63ce72f14b --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_ont_request.json @@ -0,0 +1,3128 @@ +{ + "hotel-price range": [ + "none", + "do not care", + "cheap", + "expensive", + "moderate", + "request" + ], + "hotel-type": [ + "none", + "do not care", + "bed and breakfast", + "guest house", + "hotel", + "request" + ], + "hotel-parking": [ + "none", + "do not care", + "no", + "yes", + "request" + ], + "hotel-book day": [ + "none", + "do not care", + "friday", + "monday", + "saterday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "hotel-book people": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "hotel-book stay": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "train-destination": [ + "none", + "do not care", + "bishops stortford", + "kings lynn", + "london liverpool street", + "centre", + "bishop stortford", + "liverpool", + "leicester", + "broxbourne", + "gourmet burger kitchen", + "copper kettle", + "bournemouth", + "stevenage", + "liverpool street", + "norwich", + "huntingdon marriott hotel", + "city centre north", + "taj tandoori", + "the copper kettle", + "peterborough", + "ely", + "lecester", + "london", + "willi", + "stansted airport", + "huntington marriott", + "cambridge", + "gonv", + "glastonbury", + "hol", + "north", + "birmingham new street", + "norway", + "petersborough", + "london kings cross", + "curry prince", + "bishops storford" + ], + "train-arrive by": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55", + "request" + ], + "train-departure": [ + "none", + "do not care", + "bishops stortford", + "kings lynn", + "brookshite", + "london liverpool street", + "cam", + "liverpool", + "bro", + "leicester", + "broxbourne", + "norwhich", + "saint johns", + "stevenage", + "stansted", + "london liverpool", + "cambrid", + "city hall", + "rosas bed and breakfast", + "alpha-milton", + "wandlebury country park", + "norwich", + "liecester", + "stratford", + "peterborough", + "duxford", + "ely", + "london", + "stansted airport", + "lon", + "cambridge", + "panahar", + "cineworld", + "leicaster", + "birmingham", + "cafe uno", + "camboats", + "huntingdon", + "birmingham new street", + "arbu", + "alpha milton", + "east london", + "london kings cross", + "hamilton lodge", + "aylesbray lodge guest", + "el shaddai" + ], + "train-day": [ + "none", + "do not care", + "friday", + "monday", + "saterday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "train-book people": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "hotel-stars": [ + "none", + "do not care", + "0", + "1", + "2", + "3", + "4", + "5", + "request" + ], + "hotel-internet": [ + "none", + "do not care", + "no", + "yes", + "request" + ], + "hotel-name": [ + "none", + "do not care", + "a and b guest house", + "city roomz", + "carolina bed and breakfast", + "limehouse", + "anatolia", + "hamilton lodge", + "the lensfield hotel", + "rosa's bed and breakfast", + "gall", + "aylesbray lodge", + "kirkwood", + "cambridge belfry", + "warkworth house", + "gonville", + "belfy hotel", + "nus", + "alexander", + "super 5", + "aylesbray lodge guest house", + "the gonvile hotel", + "allenbell", + "nothamilton lodge", + "ashley hotel", + "autumn house", + "hobsons house", + "hotel", + "ashely hotel", + "caridge belfrey", + "el shaddia guest house", + "avalon", + "cote", + "city centre north bed and breakfast", + "the cambridge belfry", + "home from home", + "wandlebury coutn", + "wankworth house", + "city stop rest", + "the worth house", + "cityroomz", + "huntingdon marriottt hotel", + "lensfield", + "rosas bed and breakfast", + "leverton house", + "gonville hotel", + "holiday inn cambridge", + "archway house", + "lan hon", + "levert", + "acorn guest house", + "cambridge", + "the ashley hotel", + "el shaddai", + "sleeperz", + "alpha milton guest house", + "doubletree by hilton cambridge", + "tandoori palace", + "express by", + "express by holiday inn cambridge", + "north bed and breakfast", + "holiday inn", + "arbury lodge guest house", + "alexander bed and breakfast", + "huntingdon marriott hotel", + "royal spice", + "sou", + "finches bed and breakfast", + "the alpha milton", + "bridge guest house", + "the acorn guest house", + "kirkwood house", + "eraina", + "la margherit", + "lensfield hotel", + "marriott hotel", + "nusha", + "city centre bed and breakfast", + "the allenbell", + "university arms hotel", + "clare", + "cherr", + "wartworth", + "acorn place", + "lovell lodge", + "whale" + ], + "train-leave at": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55", + "request" + ], + "restaurant-price range": [ + "none", + "do not care", + "cheap", + "expensive", + "moderate", + "request" + ], + "restaurant-food": [ + "none", + "do not care", + "british food", + "steakhouse", + "turkish", + "sushi", + "north american", + "scottish", + "french", + "austrian", + "korean", + "eastern european", + "swedish", + "gastro pub", + "modern eclectic", + "afternoon tea", + "welsh", + "christmas", + "tuscan", + "gastropub", + "sri lankan", + "molecular gastronomy", + "traditional american", + "italian", + "pizza", + "thai", + "south african", + "creative", + "english", + "asian", + "lebanese", + "hungarian", + "halal", + "portugese", + "modern english", + "african", + "light bites", + "malaysian", + "venetian", + "traditional", + "chinese", + "vegetarian", + "persian", + "thai and chinese", + "scandinavian", + "catalan", + "polynesian", + "crossover", + "canapes", + "cantonese", + "north african", + "seafood", + "brazilian", + "south indian", + "australasian", + "belgian", + "barbeque", + "the americas", + "indonesian", + "singaporean", + "irish", + "middle eastern", + "dojo noodle bar", + "caribbean", + "vietnamese", + "modern european", + "russian", + "german", + "world", + "japanese", + "moroccan", + "modern global", + "indian", + "british", + "american", + "danish", + "panasian", + "swiss", + "basque", + "north indian", + "modern american", + "australian", + "european", + "corsica", + "greek", + "northern european", + "mediterranean", + "portuguese", + "romanian", + "jamaican", + "polish", + "international", + "unusual", + "latin american", + "asian oriental", + "mexican", + "bistro", + "cuban", + "fusion", + "new zealand", + "spanish", + "eritrean", + "afghan", + "kosher", + "request" + ], + "attraction-name": [ + "none", + "do not care", + "downing college", + "fitzwilliam", + "clare college", + "ruskin gallery", + "sidney sussex college", + "great saint mary's church", + "cherry hinton water play park", + "wandlebury country park", + "cafe uno", + "place", + "broughton", + "cineworld cinema", + "jesus college", + "vue cinema", + "history of science museum", + "mumford theatre", + "whale of time", + "fitzbillies", + "christs church", + "churchill college", + "museum of classical archaeology", + "gonville and caius college", + "pizza", + "kirkwood", + "saint catharines college", + "kings college", + "parkside", + "by", + "st catharines college", + "saint john's college", + "cherry hinton water park", + "st christs college", + "christ's college", + "bangkok city", + "scudamores punti co", + "free", + "great saint marys church", + "milton country park", + "the fez club", + "soultree", + "autu", + "whipple museum of the history of science", + "aylesbray lodge guest house", + "broughton house gallery", + "peoples portraits exhibition", + "primavera", + "kettles yard", + "all saint's church", + "cinema cinema", + "regency gallery", + "corpus christi", + "corn cambridge exchange", + "da vinci pizzeria", + "school", + "hobsons house", + "cambride and country folk museum", + "north", + "da v", + "cambridge corn exchange", + "soul tree nightclub", + "cambridge arts theater", + "saint catharine's college", + "byard art", + "cambridge punter", + "cambridge university botanic gardens", + "castle galleries", + "museum of archaelogy and anthropogy", + "no specific location", + "cherry hinton hall", + "gallery at 12 a high street", + "parkside pools", + "queen's college", + "little saint mary's church", + "gallery", + "home from home", + "tenpin", + "the wandlebury", + "county folk museum", + "swimming pool", + "christs college", + "cafe jello museum", + "scott polar", + "christ college", + "cambridge museum of technology", + "abbey pool and astroturf pitch", + "king hedges learner pool", + "the cambridge arts theatre", + "the castle galleries", + "cambridge and country folk museum", + "kohinoor", + "scudamores punting co", + "sidney sussex", + "the man on the moon", + "little saint marys church", + "queens", + "the place", + "old school", + "churchill", + "churchills college", + "hughes hall", + "churchhill college", + "riverboat georgina", + "belf", + "cambridge temporary art", + "abc theatre", + "cambridge contemporary art museum", + "man on the moon", + "the junction", + "cherry hinton water play", + "adc theatre", + "gonville hotel", + "magdalene college", + "peoples portraits exhibition at girton college", + "boat", + "centre", + "sheep's green and lammas land park fen causeway", + "the mumford theatre", + "archway house", + "queens' college", + "williams art and antiques", + "funky fun house", + "cherry hinton village centre", + "camboats", + "cambridge", + "old schools", + "kettle's yard", + "whale of a time", + "the churchill college", + "cafe jello gallery", + "aut", + "salsa", + "city", + "clare hall", + "boating", + "pembroke college", + "kings hedges learner pool", + "caffe uno", + "lammas land park", + "museum", + "the fitzwilliam museum", + "the cherry hinton village centre", + "the cambridge corn exchange", + "fitzwilliam museum", + "museum of archaelogy and anthropology", + "fez club", + "the cambridge punter", + "saint johns college", + "emmanuel college", + "cambridge belf", + "scudamore", + "lynne strover gallery", + "king's college", + "whippple museum", + "trinity college", + "college in the north", + "sheep's green", + "kambar", + "museum of archaelogy", + "adc", + "garde", + "club salsa", + "people's portraits exhibition at girton college", + "botanic gardens", + "carol", + "college", + "gallery at twelve a high street", + "abbey pool and astroturf", + "cambridge book and print gallery", + "jesus green outdoor pool", + "scott polar museum", + "saint barnabas press gallery", + "cambridge artworks", + "older churches", + "cambridge contemporary art", + "cherry hinton hall and grounds", + "univ", + "jesus green", + "ballare", + "abbey pool", + "cambridge botanic gardens", + "nusha", + "worth house", + "thanh", + "university arms hotel", + "cambridge arts theatre", + "cafe jello", + "cambridge and county folk museum", + "the cambridge artworks", + "all saints church", + "holy trinity church", + "contemporary art museum", + "architectural churches", + "queens college", + "trinity street college" + ], + "restaurant-name": [ + "none", + "do not care", + "hotel du vin and bistro", + "ask", + "gourmet formal kitchen", + "the meze bar", + "lan hong house", + "cow pizza", + "one seven", + "prezzo", + "maharajah tandoori restaurant", + "alex", + "shanghai", + "golden wok", + "restaurant", + "fitzbillies", + "nil", + "copper kettle", + "meghna", + "hk fusion", + "bangkok city", + "hobsons house", + "tang chinese", + "anatolia", + "ugly duckling", + "anatolia and efes restaurant", + "sitar tandoori", + "city stop", + "ashley", + "pizza express fen ditton", + "molecular gastronomy", + "autumn house", + "el shaddia guesthouse", + "the grafton hotel", + "limehouse", + "gardenia", + "not metioned", + "hakka", + "michaelhouse cafe", + "pipasha", + "meze bar", + "archway", + "molecular gastonomy", + "yipee noodle bar", + "the peking", + "curry prince", + "midsummer house restaurant", + "pizza hut cherry hinton", + "the lucky star", + "stazione restaurant and coffee bar", + "shanghi family restaurant", + "good luck", + "j restaurant", + "bedouin", + "cott", + "little seoul", + "south", + "thanh binh", + "el", + "efes restaurant", + "kohinoor", + "clowns", + "india", + "the slug and lettuce", + "shiraz", + "barbakan", + "zizzi cambridge", + "restaurant one seven", + "slug and lettuce", + "travellers rest", + "binh", + "worth house", + "broughton house gallery", + "chiquito", + "the river bar steakhouse and grill", + "ros", + "golden house", + "india west", + "cam", + "panahar", + "restaurant 22", + "adden", + "indian", + "hu", + "jinling noodle bar", + "darrys cookhouse and wine shop", + "hobson house", + "cambridge be", + "el shaddai", + "ac", + "nandos", + "cambridge lodge", + "the cow pizza kitchen and bar", + "charlie", + "rajmahal", + "kymmoy", + "cambri", + "backstreet bistro", + "galleria", + "restaurant 2 two", + "chiquito restaurant bar", + "royal standard", + "lucky star", + "curry king", + "grafton hotel restaurant", + "mahal of cambridge", + "the bedouin", + "nus", + "the kohinoor", + "pizza hut fenditton", + "camboats", + "the gardenia", + "de luca cucina and bar", + "nusha", + "european", + "taj tandoori", + "tandoori palace", + "golden curry", + "efes", + "loch fyne", + "the maharajah tandoor", + "lovel", + "restaurant 17", + "clowns cafe", + "cambridge punter", + "bloomsbury restaurant", + "la mimosa", + "the cambridge chop house", + "funky", + "cotto", + "oak bistro", + "restaurant two two", + "pipasha restaurant", + "river bar steakhouse and grill", + "royal spice", + "the copper kettle", + "graffiti", + "nandos city centre", + "saffron brasserie", + "cambridge chop house", + "sitar", + "kitchen and bar", + "the good luck chinese food takeaway", + "clu", + "la tasca", + "cafe uno", + "cote", + "the varsity restaurant", + "bri", + "eraina", + "bridge", + "fin", + "cambridge lodge restaurant", + "grafton", + "hotpot", + "sala thong", + "margherita", + "wise buddha", + "the missing sock", + "seasame restaurant and bar", + "the dojo noodle bar", + "restaurant alimentum", + "gastropub", + "saigon city", + "la margherita", + "pizza hut", + "curry garden", + "ashley hotel", + "eraina and michaelhouse cafe", + "the golden curry", + "curry queen", + "cow pizza kitchen and bar", + "the peking restaurant:", + "hamilton lodge", + "alimentum", + "yippee noodle bar", + "2 two and cote", + "shanghai family restaurant", + "grafton hotel", + "yes", + "ali baba", + "dif", + "fitzbillies restaurant", + "peking restaurant", + "lev", + "nirala", + "the alex", + "tandoori", + "city stop restaurant", + "rice house", + "cityr", + "yu garden", + "meze bar restaurant", + "the", + "don pasquale pizzeria", + "rice boat", + "the hotpot", + "old school", + "the oak bistro", + "sesame restaurant and bar", + "pizza express", + "the gandhi", + "pizza hut fen ditton", + "charlie chan", + "da vinci pizzeria", + "dojo noodle bar", + "gourmet burger kitchen", + "the golden house", + "india house", + "hobso", + "missing sock", + "pizza hut city centre", + "parkside pools", + "riverside brasserie", + "caffe uno", + "primavera", + "the nirala", + "wagamama", + "au", + "ian hong house", + "frankie and bennys", + "4 kings parade city centre", + "shiraz restaurant", + "scudamores punt", + "mahal", + "saint johns chop house", + "de luca cucina and bar riverside brasserie", + "cocum", + "la raza" + ], + "attraction-type": [ + "none", + "do not care", + "architecture", + "boat", + "boating", + "camboats", + "church", + "churchills college", + "cinema", + "college", + "concert", + "concerthall", + "entertainment", + "gallery", + "gastropub", + "hiking", + "hotel", + "multiple sports", + "museum", + "museum kettles yard", + "night club", + "outdoor", + "park", + "pool", + "special", + "sports", + "swimming pool", + "theater", + "theatre", + "concert hall", + "local site", + "nightclub", + "hotspot", + "request" + ], + "taxi-leave at": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55", + "request" + ], + "taxi-destination": [ + "none", + "do not care", + "a and b guest house", + "abbey pool and astroturf pitch", + "acorn guest house", + "adc theatre", + "addenbrookes hospital", + "alexander bed and breakfast", + "ali baba", + "all saints church", + "allenbell", + "alpha milton guest house", + "anatolia", + "arbury lodge guesthouse", + "archway house", + "ashley hotel", + "ask", + "attraction", + "autumn house", + "avalon", + "aylesbray lodge guest house", + "backstreet bistro", + "ballare", + "bangkok city", + "bedouin", + "birmingham new street train station", + "bishops stortford train station", + "bloomsbury restaurant", + "bridge guest house", + "broughton house gallery", + "broxbourne train station", + "byard art", + "cafe jello gallery", + "cafe uno", + "camboats", + "cambridge", + "cambridge and county folk museum", + "cambridge arts theatre", + "cambridge artworks", + "cambridge belfry", + "cambridge book and print gallery", + "cambridge chop house", + "cambridge contemporary art", + "cambridge county fair next to the city tourist museum", + "cambridge lodge restaurant", + "cambridge museum of technology", + "cambridge punter", + "cambridge road church of christ", + "cambridge train station", + "cambridge university botanic gardens", + "carolina bed and breakfast", + "castle galleries", + "charlie chan", + "cherry hinton hall and grounds", + "cherry hinton village centre", + "cherry hinton water park", + "cherry hinton water play", + "chiquito restaurant bar", + "christ college", + "churchills college", + "cineworld cinema", + "city centre north bed and breakfast", + "city stop restaurant", + "cityroomz", + "clare college", + "clare hall", + "clowns cafe", + "club salsa", + "cocum", + "copper kettle", + "corpus christi", + "cote", + "cotto", + "cow pizza kitchen and bar", + "curry garden", + "curry king", + "curry prince", + "da vinci pizzeria", + "darrys cookhouse and wine shop", + "de luca cucina and bar", + "dojo noodle bar", + "don pasquale pizzeria", + "downing college", + "efes restaurant", + "el shaddia guesthouse", + "ely train station", + "emmanuel college", + "eraina", + "express by holiday inn cambridge", + "finches bed and breakfast", + "finders corner newmarket road", + "fitzbillies restaurant", + "fitzwilliam museum", + "frankie and bennys", + "funky fun house", + "galleria", + "gallery at 12 a high street", + "gastropub", + "golden curry", + "golden house", + "golden wok", + "gonville and caius college", + "gonville hotel", + "good luck", + "gourmet burger kitchen", + "graffiti", + "grafton hotel restaurant", + "great saint marys church", + "hakka", + "hamilton lodge", + "hk fusion", + "hobsons house", + "holy trinity church", + "home from home", + "hotel du vin and bistro", + "hughes hall", + "huntingdon marriott hotel", + "ian hong", + "india house", + "j restaurant", + "jesus college", + "jesus green outdoor pool", + "jinling noodle bar", + "kambar", + "kettles yard", + "kings college", + "kings hedges learner pool", + "kirkwood house", + "kohinoor", + "kymmoy", + "la margherita", + "la mimosa", + "la raza", + "la tasca", + "lan hong house", + "leicester train station", + "lensfield hotel", + "limehouse", + "little saint marys church", + "little seoul", + "loch fyne", + "london kings cross train station", + "london liverpool street train station", + "lovell lodge", + "lynne strover gallery", + "magdalene college", + "mahal of cambridge", + "maharajah tandoori restaurant", + "meghna", + "meze bar", + "michaelhouse cafe", + "midsummer house restaurant", + "milton country park", + "mumford theatre", + "museum of archaelogy and anthropology", + "museum of classical archaeology", + "nandos", + "nandos city centre", + "nil", + "nirala", + "norwich train station", + "nusha", + "old schools", + "panahar", + "parkside police station", + "parkside pools", + "peking restaurant", + "pembroke college", + "peoples portraits exhibition at girton college", + "peterborough train station", + "pipasha restaurant", + "pizza express", + "pizza hut cherry hinton", + "pizza hut city centre", + "pizza hut fenditton", + "prezzo", + "primavera", + "queens college", + "rajmahal", + "regency gallery", + "restaurant 17", + "restaurant 2 two", + "restaurant alimentum", + "rice boat", + "rice house", + "riverboat georgina", + "riverside brasserie", + "rosas bed and breakfast", + "royal spice", + "royal standard", + "ruskin gallery", + "saffron brasserie", + "saigon city", + "saint barnabas", + "saint barnabas press gallery", + "saint catharines college", + "saint johns chop house", + "saint johns college", + "sala thong", + "scott polar museum", + "scudamores punting co", + "sesame restaurant and bar", + "shanghai family restaurant", + "sheeps green and lammas land park fen causeway", + "shiraz", + "sidney sussex college", + "sitar tandoori", + "sleeperz hotel", + "soul tree nightclub", + "st johns chop house", + "stansted airport train station", + "station road", + "stazione restaurant and coffee bar", + "stevenage train station", + "taj tandoori", + "tall monument", + "tandoori palace", + "tang chinese", + "tenpin", + "thanh binh", + "the anatolia", + "the cambridge corn exchange", + "the cambridge shop", + "the fez club", + "the gandhi", + "the gardenia", + "the hotpot", + "the junction", + "the lucky star", + "the man on the moon", + "the missing sock", + "the oak bistro", + "the place", + "the regent street city center", + "the river bar steakhouse and grill", + "the slug and lettuce", + "the varsity restaurant", + "travellers rest", + "trinity college", + "ugly duckling", + "university arms hotel", + "vue cinema", + "wagamama", + "wandlebury country park", + "wankworth hotel", + "warkworth house", + "whale of a time", + "whipple museum of the history of science", + "williams art and antiques", + "worth house", + "yippee noodle bar", + "yu garden", + "zizzi cambridge", + "leverton house", + "the cambridge chop house", + "saint john's college", + "churchill college", + "the nirala", + "the cow pizza kitchen and bar", + "christ's college", + "el shaddai", + "saint catharine's college", + "camb", + "the golden curry", + "little saint mary's church", + "country folk museum", + "meze bar restaurant", + "the cambridge belfry", + "the fitzwilliam museum", + "the lensfield hotel", + "pizza express fen ditton", + "the cambridge punter", + "king's college", + "the cherry hinton village centre", + "shiraz restaurant", + "sheep's green and lammas land park fen causeway", + "caffe uno", + "the ghandi", + "the copper kettle", + "man on the moon concert hall", + "alpha-milton guest house", + "queen's college", + "restaurant one seven", + "restaurant two two", + "city centre north b and b", + "rosa's bed and breakfast", + "the good luck chinese food takeaway", + "not museum of archaeology and anthropologymentioned", + "tandori in cambridge", + "kettle's yard", + "megna", + "grou", + "gallery at twelve a high street", + "maharajah tandoori restaurant", + "pizza hut fen ditton", + "gandhi", + "tranh binh", + "kambur", + "people's portraits exhibition at girton college", + "hotel", + "restaurant", + "the galleria", + "queens' college", + "great saint mary's church", + "theathre", + "cambridge artworks", + "acorn house", + "shiraz", + "riverboat georginawd", + "mic", + "the gallery at twelve", + "the soul tree", + "finches" + ], + "taxi-departure": [ + "none", + "do not care", + "172 chestertown road", + "4455 woodbridge road", + "a and b guest house", + "abbey pool and astroturf pitch", + "acorn guest house", + "adc theatre", + "addenbrookes hospital", + "alexander bed and breakfast", + "ali baba", + "all saints church", + "allenbell", + "alpha milton guest house", + "alyesbray lodge hotel", + "ambridge", + "anatolia", + "arbury lodge guesthouse", + "archway house", + "ashley hotel", + "ask", + "autumn house", + "avalon", + "aylesbray lodge guest house", + "backstreet bistro", + "ballare", + "bangkok city", + "bedouin", + "birmingham new street train station", + "bishops stortford train station", + "bloomsbury restaurant", + "bridge guest house", + "broughton house gallery", + "broxbourne train station", + "byard art", + "cafe jello gallery", + "cafe uno", + "caffee uno", + "camboats", + "cambridge", + "cambridge and county folk museum", + "cambridge arts theatre", + "cambridge artworks", + "cambridge belfry", + "cambridge book and print gallery", + "cambridge chop house", + "cambridge contemporary art", + "cambridge lodge restaurant", + "cambridge museum of technology", + "cambridge punter", + "cambridge towninfo centre", + "cambridge train station", + "cambridge university botanic gardens", + "carolina bed and breakfast", + "castle galleries", + "centre of town at my hotel", + "charlie chan", + "cherry hinton hall and grounds", + "cherry hinton village center", + "cherry hinton village centre", + "cherry hinton water play", + "chiquito restaurant bar", + "christ college", + "churchills college", + "cineworld cinema", + "citiroomz", + "city centre north bed and breakfast", + "city stop restaurant", + "cityroomz", + "clair hall", + "clare college", + "clare hall", + "clowns cafe", + "club salsa", + "cocum", + "copper kettle", + "corpus christi", + "cote", + "cotto", + "cow pizza kitchen and bar", + "curry garden", + "curry king", + "curry prince", + "curry queen", + "da vinci pizzeria", + "darrys cookhouse and wine shop", + "de luca cucina and bar", + "dojo noodle bar", + "don pasquale pizzeria", + "downing college", + "downing street", + "el shaddia guesthouse", + "ely", + "ely train station", + "emmanuel college", + "eraina", + "express by holiday inn cambridge", + "finches bed and breakfast", + "fitzbillies restaurant", + "fitzwilliam museum", + "frankie and bennys", + "funky fun house", + "galleria", + "gallery at 12 a high street", + "girton college", + "golden curry", + "golden house", + "golden wok", + "gonville and caius college", + "gonville hotel", + "good luck", + "gourmet burger kitchen", + "graffiti", + "grafton hotel restaurant", + "great saint marys church", + "hakka", + "hamilton lodge", + "hobsons house", + "holy trinity church", + "home", + "home from home", + "hotel", + "hotel du vin and bistro", + "hughes hall", + "huntingdon marriott hotel", + "india house", + "j restaurant", + "jesus college", + "jesus green outdoor pool", + "jinling noodle bar", + "junction theatre", + "kambar", + "kettles yard", + "kings college", + "kings hedges learner pool", + "kings lynn train station", + "kirkwood house", + "kohinoor", + "kymmoy", + "la margherita", + "la mimosa", + "la raza", + "la tasca", + "lan hong house", + "lensfield hotel", + "leverton house", + "limehouse", + "little saint marys church", + "little seoul", + "loch fyne", + "london kings cross train station", + "london liverpool street", + "london liverpool street train station", + "lovell lodge", + "lynne strover gallery", + "magdalene college", + "mahal of cambridge", + "maharajah tandoori restaurant", + "meghna", + "meze bar", + "michaelhouse cafe", + "milton country park", + "mumford theatre", + "museum", + "museum of archaelogy and anthropology", + "museum of classical archaeology", + "nandos", + "nandos city centre", + "new england", + "nirala", + "norwich train station", + "nstaot mentioned", + "nusha", + "old schools", + "panahar", + "parkside police station", + "parkside pools", + "peking restaurant", + "pembroke college", + "peoples portraits exhibition at girton college", + "peterborough train station", + "pizza express", + "pizza hut cherry hinton", + "pizza hut city centre", + "pizza hut fenditton", + "prezzo", + "primavera", + "queens college", + "rajmahal", + "regency gallery", + "restaurant 17", + "restaurant 2 two", + "restaurant alimentum", + "rice boat", + "rice house", + "riverboat georgina", + "riverside brasserie", + "rosas bed and breakfast", + "royal spice", + "royal standard", + "ruskin gallery", + "saffron brasserie", + "saigon city", + "saint barnabas press gallery", + "saint catharines college", + "saint johns chop house", + "saint johns college", + "sala thong", + "scott polar museum", + "scudamores punting co", + "sesame restaurant and bar", + "sheeps green and lammas land park", + "sheeps green and lammas land park fen causeway", + "shiraz", + "sidney sussex college", + "sitar tandoori", + "soul tree nightclub", + "st johns college", + "stazione restaurant and coffee bar", + "stevenage train station", + "taj tandoori", + "tandoori palace", + "tang chinese", + "tenpin", + "thanh binh", + "the cambridge corn exchange", + "the fez club", + "the gallery at 12", + "the gandhi", + "the gardenia", + "the hotpot", + "the junction", + "the lucky star", + "the man on the moon", + "the missing sock", + "the oak bistro", + "the place", + "the river bar steakhouse and grill", + "the slug and lettuce", + "the varsity restaurant", + "travellers rest", + "trinity college", + "ugly duckling", + "university arms hotel", + "vue cinema", + "wagamama", + "wandlebury country park", + "warkworth house", + "whale of a time", + "whipple museum of the history of science", + "williams art and antiques", + "worth house", + "yippee noodle bar", + "yu garden", + "zizzi cambridge", + "christ's college", + "city centre north b and b", + "the lensfield hotel", + "alpha-milton guest house", + "el shaddai", + "churchill college", + "the cambridge belfry", + "king's college", + "great saint mary's church", + "restaurant two two", + "queens' college", + "little saint mary's church", + "chinese city centre", + "kettle's yard", + "pizza hut", + "the golden curry", + "rosa's bed and breakfast", + "the cambridge punter", + "the byard art museum", + "saint catharine's college", + "meze bar restaurant", + "the good luck chinese food takeaway", + "restaurant one seven", + "pizza hut fen ditton", + "the nirala", + "the fitzwilliam museum", + "st. john's college", + "gallery at twelve a high street", + "sheep's green and lammas land park fen causeway", + "the cherry hinton village centre", + "pizza express fen ditton", + "corpus cristi", + "cas", + "acorn house", + "lens", + "the cambridge chop house", + "the copper kettle", + "the avalon", + "saint john's college", + "aylesbray lodge", + "the alexander bed and breakfast", + "cambridge belfy", + "people's portraits exhibition at girton college", + "gonville", + "caffe uno", + "the cow pizza kitchen and bar", + "lovell ldoge", + "cinema", + "shiraz restaurant", + "park", + "the allenbell" + ], + "restaurant-book day": [ + "none", + "do not care", + "friday", + "monday", + "saterday", + "sunday", + "thursday", + "tuesday", + "wednesday" + ], + "restaurant-book people": [ + "none", + "do not care", + "1", + "10 or more", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9" + ], + "restaurant-book time": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55" + ], + "taxi-arrive by": [ + "none", + "do not care", + "00:00", + "00:05", + "00:10", + "00:15", + "00:20", + "00:25", + "00:30", + "00:35", + "00:40", + "00:45", + "00:50", + "00:55", + "01:00", + "01:05", + "01:10", + "01:15", + "01:20", + "01:25", + "01:30", + "01:35", + "01:40", + "01:45", + "01:50", + "01:55", + "02:00", + "02:05", + "02:10", + "02:15", + "02:20", + "02:25", + "02:30", + "02:35", + "02:40", + "02:45", + "02:50", + "02:55", + "03:00", + "03:05", + "03:10", + "03:15", + "03:20", + "03:25", + "03:30", + "03:35", + "03:40", + "03:45", + "03:50", + "03:55", + "04:00", + "04:05", + "04:10", + "04:15", + "04:20", + "04:25", + "04:30", + "04:35", + "04:40", + "04:45", + "04:50", + "04:55", + "05:00", + "05:05", + "05:10", + "05:15", + "05:20", + "05:25", + "05:30", + "05:35", + "05:40", + "05:45", + "05:50", + "05:55", + "06:00", + "06:05", + "06:10", + "06:15", + "06:20", + "06:25", + "06:30", + "06:35", + "06:40", + "06:45", + "06:50", + "06:55", + "07:00", + "07:05", + "07:10", + "07:15", + "07:20", + "07:25", + "07:30", + "07:35", + "07:40", + "07:45", + "07:50", + "07:55", + "08:00", + "08:05", + "08:10", + "08:15", + "08:20", + "08:25", + "08:30", + "08:35", + "08:40", + "08:45", + "08:50", + "08:55", + "09:00", + "09:05", + "09:10", + "09:15", + "09:20", + "09:25", + "09:30", + "09:35", + "09:40", + "09:45", + "09:50", + "09:55", + "10:00", + "10:05", + "10:10", + "10:15", + "10:20", + "10:25", + "10:30", + "10:35", + "10:40", + "10:45", + "10:50", + "10:55", + "11:00", + "11:05", + "11:10", + "11:15", + "11:20", + "11:25", + "11:30", + "11:35", + "11:40", + "11:45", + "11:50", + "11:55", + "12:00", + "12:05", + "12:10", + "12:15", + "12:20", + "12:25", + "12:30", + "12:35", + "12:40", + "12:45", + "12:50", + "12:55", + "13:00", + "13:05", + "13:10", + "13:15", + "13:20", + "13:25", + "13:30", + "13:35", + "13:40", + "13:45", + "13:50", + "13:55", + "14:00", + "14:05", + "14:10", + "14:15", + "14:20", + "14:25", + "14:30", + "14:35", + "14:40", + "14:45", + "14:50", + "14:55", + "15:00", + "15:05", + "15:10", + "15:15", + "15:20", + "15:25", + "15:30", + "15:35", + "15:40", + "15:45", + "15:50", + "15:55", + "16:00", + "16:05", + "16:10", + "16:15", + "16:20", + "16:25", + "16:30", + "16:35", + "16:40", + "16:45", + "16:50", + "16:55", + "17:00", + "17:05", + "17:10", + "17:15", + "17:20", + "17:25", + "17:30", + "17:35", + "17:40", + "17:45", + "17:50", + "17:55", + "18:00", + "18:05", + "18:10", + "18:15", + "18:20", + "18:25", + "18:30", + "18:35", + "18:40", + "18:45", + "18:50", + "18:55", + "19:00", + "19:05", + "19:10", + "19:15", + "19:20", + "19:25", + "19:30", + "19:35", + "19:40", + "19:45", + "19:50", + "19:55", + "20:00", + "20:05", + "20:10", + "20:15", + "20:20", + "20:25", + "20:30", + "20:35", + "20:40", + "20:45", + "20:50", + "20:55", + "21:00", + "21:05", + "21:10", + "21:15", + "21:20", + "21:25", + "21:30", + "21:35", + "21:40", + "21:45", + "21:50", + "21:55", + "22:00", + "22:05", + "22:10", + "22:15", + "22:20", + "22:25", + "22:30", + "22:35", + "22:40", + "22:45", + "22:50", + "22:55", + "23:00", + "23:05", + "23:10", + "23:15", + "23:20", + "23:25", + "23:30", + "23:35", + "23:40", + "23:45", + "23:50", + "23:55", + "request" + ], + "restaurant-area": [ + "none", + "do not care", + "centre", + "east", + "north", + "south", + "west", + "request" + ], + "hotel-area": [ + "none", + "do not care", + "centre", + "east", + "north", + "south", + "west", + "request" + ], + "attraction-area": [ + "none", + "do not care", + "centre", + "east", + "north", + "south", + "west", + "request" + ], + "hospital-department": [ + "none", + "do not care", + "acute medical assessment unit", + "acute medicine for the elderly", + "antenatal", + "cambridge eye unit", + "cardiology", + "cardiology and coronary care unit", + "childrens oncology and haematology", + "childrens surgical and medicine", + "clinical decisions unit", + "clinical research facility", + "coronary care unit", + "diabetes and endocrinology", + "emergency department", + "gastroenterology", + "gynaecology", + "haematology", + "haematology and haematological oncology", + "haematology day unit", + "hepatobillary and gastrointestinal surgery regional referral centre", + "hepatology", + "infectious diseases", + "infusion services", + "inpatient occupational therapy", + "intermediate dependancy area", + "john farman intensive care unit", + "medical decisions unit", + "medicine for the elderly", + "neonatal unit", + "neurology", + "neurology neurosurgery", + "neurosciences", + "neurosciences critical care unit", + "oncology", + "oral and maxillofacial surgery and ent", + "paediatric clinic", + "paediatric day unit", + "paediatric intensive care unit", + "plastic and vascular surgery plastics", + "psychiatry", + "respiratory medicine", + "surgery", + "teenage cancer trust unit", + "transitional care", + "transplant high dependency unit", + "trauma and orthopaedics", + "trauma high dependency unit", + "urology" + ], + "police-postcode": [ + "request" + ], + "restaurant-postcode": [ + "request" + ], + "train-duration": [ + "request" + ], + "train-trainid": [ + "request" + ], + "hospital-address": [ + "request" + ], + "restaurant-phone": [ + "request" + ], + "hotel-phone": [ + "request" + ], + "restaurant-address": [ + "request" + ], + "hotel-postcode": [ + "request" + ], + "attraction-phone": [ + "request" + ], + "attraction-entrance fee": [ + "request" + ], + "hotel-reference": [ + "request" + ], + "taxi-taxi types": [ + "request" + ], + "attraction-address": [ + "request" + ], + "hospital-phone": [ + "request" + ], + "attraction-postcode": [ + "request" + ], + "police-address": [ + "request" + ], + "taxi-taxi phone": [ + "request" + ], + "train-price": [ + "request" + ], + "hospital-postcode": [ + "request" + ], + "police-phone": [ + "request" + ], + "hotel-address": [ + "request" + ], + "restaurant-reference": [ + "request" + ], + "train-reference": [ + "request" + ] +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json new file mode 100644 index 0000000000000000000000000000000000000000..87e315363ad3dfd8cadd5bd10cfd7e5047450160 --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/mwoz21_slot_descriptions.json @@ -0,0 +1,57 @@ +{ + "hotel-price range": "preferred cost or price of the hotel", + "hotel-type": "what is the type of the hotel", + "hotel-parking": "does the hotel have parking", + "hotel-book stay": "number of nights for the hotel reservation", + "hotel-book day": "starting day of the hotel booking", + "hotel-book people": "number of people for the hotel booking", + "hotel-area": "area or place of the hotel", + "hotel-stars": "star rating of the hotel", + "hotel-internet": "does the hotel have internet or wifi", + "hotel-name": "name of the hotel", + "hotel-phone": "phone number of the hotel", + "hotel-postcode": "postcode of the hotel", + "hotel-reference": "booking reference of the hotel booking", + "hotel-address": "street address of the hotel", + "train-destination": "train station you want to travel to", + "train-day": "day of the train booking", + "train-departure": "train station you want to leave from", + "train-arrive by": "arrival time of the train", + "train-book people": "number of people for the train booking", + "train-leave at": "departure time for the train", + "train-duration": "duration of the train journey", + "train-trainid": "train identifier or number", + "train-price": "how much does the train trip cost", + "train-reference": "booking reference of the train booking", + "attraction-type": "type of attraction or point of interest", + "attraction-area": "area or place of the attraction", + "attraction-name": "name of the attraction", + "attraction-phone": "phone number of the attraction", + "attraction-entrance fee": "entrace fee at the attraction", + "attraction-address": "street address of the attraction", + "attraction-postcode": "postcode of the attraction", + "restaurant-book people": "number of people for the restaurant booking", + "restaurant-book day": "weekday for the restaurant booking", + "restaurant-book time": "time of the restaurant booking", + "restaurant-food": "type of food served at the restaurant", + "restaurant-price range": "preferred cost or price of the restaurant", + "restaurant-name": "name of the restaurant", + "restaurant-area": "area or place of the restaurant", + "restaurant-postcode": "postcode of the restaurant", + "restaurant-phone": "phone number of the restaurant", + "restaurant-address": "street address of the restaurant", + "restaurant-reference": "booking reference of the hotel booking", + "taxi-leave at": "what time you want the taxi to leave by", + "taxi-destination": "where you want the taxi to drop you off", + "taxi-departure": "where you want the taxi to pick you up", + "taxi-arrive by": "what time you to arrive at your destination", + "taxi-taxi types": "vehicle type of the taxi", + "taxi-taxi phone": "phone number of the taxi", + "hospital-department": "name of hospital department", + "hospital-address": "street address of the hospital", + "hospital-phone": "phone number of the hospital", + "hospital-postcode": "postcode of the hospital", + "police-postcode": "postcode of the police station", + "police-address": "street address of the police station", + "police-phone": "phone number of the police station" +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/multiwoz/dataset/ontology.py b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b9c3365764eb8f1c1cab5d68674dd85b39ce2a --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/ontology.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Authors: Carel van Niekerk (niekerk@hhu.de) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Create Ontology Embeddings""" + +import json +import os +import random + +import torch +import numpy as np + + +# Slot mapping table for description extractions +# SLOT_NAME_MAPPINGS = { +# 'arrive at': 'arriveAt', +# 'arrive by': 'arriveBy', +# 'leave at': 'leaveAt', +# 'leave by': 'leaveBy', +# 'arriveby': 'arriveBy', +# 'arriveat': 'arriveAt', +# 'leaveat': 'leaveAt', +# 'leaveby': 'leaveBy', +# 'price range': 'pricerange' +# } + +# Set up global data directory +def set_datadir(dir): + global DATA_DIR + DATA_DIR = dir + + +# Set seeds +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +# Get embeddings for slots and candidates +def get_slot_candidate_embeddings(set_type, args, tokenizer, embedding_model, save_to_file=True): + # Get set alots and candidates + reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r') + ontology = json.load(reader) + reader.close() + + reader = open(os.path.join(DATA_DIR, 'slot_descriptions.json'), 'r') + slot_descriptions = json.load(reader) + reader.close() + + embedding_model.eval() + + slots = dict() + for slot in ontology: + if args.use_descriptions: + # d, s = slot.split('-', 1) + # s = SLOT_NAME_MAPPINGS[s] if s in SLOT_NAME_MAPPINGS else s + # s = d + '-' + s + # if slot in slot_descriptions: + desc = slot_descriptions[slot] + # elif slot.lower() in slot_descriptions: + # desc = slot_descriptions[s.lower()] + # else: + # desc = slot.replace('-', ' ') + else: + desc = slot + + # Tokenize slot and get embeddings + feats = tokenizer.encode_plus(desc, add_special_tokens = True, + max_length = args.max_slot_len, padding='max_length', + truncation = 'longest_first') + + with torch.no_grad(): + input_ids = torch.tensor([feats['input_ids']]).to(embedding_model.device) # [1, max_slot_len] + if 'token_type_ids' in feats: + token_type_ids = torch.tensor([feats['token_type_ids']]).to(embedding_model.device) # [1, max_slot_len] + if 'attention_mask' in feats: + attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) # [1, max_slot_len] + feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask) + attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) + feats = feats * attention_mask # [1, max_slot_len, hidden_dim] + else: + feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) + else: + if 'attention_mask' in feats: + attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) + feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask) + attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) + feats = feats * attention_mask # [1, max_slot_len, hidden_dim] + else: + feats, pooled_feats = embedding_model(input_ids=input_ids) # [1, max_slot_len, hidden_dim] + + if args.set_similarity: + slot_emb = feats[0, :, :].detach().cpu() # [seq_len, hidden_dim] + else: + if args.candidate_pooling == 'cls' and pooled_feats is not None: + slot_emb = pooled_feats[0, :].detach().cpu() # [hidden_dim] + elif args.candidate_pooling == 'mean': + feats = feats.sum(1) + feats = torch.nn.functional.layer_norm(feats, feats.size()) + slot_emb = feats[0, :].detach().cpu() # [hidden_dim] + + # Tokenize value candidates and get embeddings + values = ontology[slot] + is_requestable = False + if 'request' in values: + is_requestable = True + values.remove('request') + if values: + feats = [tokenizer.encode_plus(val, add_special_tokens = True, + max_length = args.max_candidate_len, padding='max_length', + truncation = 'longest_first') + for val in values] + with torch.no_grad(): + input_ids = torch.tensor([f['input_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len] + if 'token_type_ids' in feats[0]: + token_type_ids = torch.tensor([f['token_type_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len] + if 'attention_mask' in feats[0]: + attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len] + feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask) + attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) + feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim] + else: + feats, pooled_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) # [num_candidates, max_candidate_len, hidden_dim] + else: + if 'attention_mask' in feats[0]: + attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) + feats, pooled_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask) + attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, feats.size(-1))) + feats = feats * attention_mask # [num_candidates, max_candidate_len, hidden_dim] + else: + feats, pooled_feats = embedding_model(input_ids=input_ids) # [num_candidates, max_candidate_len, hidden_dim] + + if args.set_similarity: + feats = feats.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim] + else: + if args.candidate_pooling == 'cls' and pooled_feats is not None: + feats = pooled_feats.detach().cpu() + elif args.candidate_pooling == "mean": + feats = feats.sum(1) + feats = torch.nn.functional.layer_norm(feats, feats.size()) + feats = feats.detach().cpu() + else: + feats = None + slots[slot] = (slot_emb, feats, is_requestable) + + # Dump tensors for use in training + if save_to_file: + writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type) + torch.save(slots, writer) + + return slots diff --git a/convlab/dst/setsumbt/multiwoz/dataset/utils.py b/convlab/dst/setsumbt/multiwoz/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..485dee643148ad1d37069064ebc4e2e4553b3dac --- /dev/null +++ b/convlab/dst/setsumbt/multiwoz/dataset/utils.py @@ -0,0 +1,446 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf +# Code adapted from the TRADE preprocessing code (https://github.com/jasonwu0731/trade-dst) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MultiWOZ2.1/3 data processing utilities""" + +import re +import os + +from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA +from convlab.dst.rule.multiwoz import normalize_value + +# ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train'] +ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train', 'hospital', 'police'] +def set_util_domains(domains): + global ACTIVE_DOMAINS + ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS] + +MAPPING_PATH = os.path.abspath(__file__).replace('utils.py', 'mapping.pair') +# Read replacement pairs from the mapping.pair file +REPLACEMENTS = [] +for line in open(MAPPING_PATH).readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + REPLACEMENTS.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + +# Extract belief state from mturk annotations +def build_dialoguestate(metadata, get_domains=False): + domains_list = [dom for dom in ACTIVE_DOMAINS if dom in metadata] + dialogue_state, domains = [], [] + for domain in domains_list: + active = False + # Extract booking information + booking = [] + for slot in sorted(metadata[domain]['book'].keys()): + if slot != 'booked': + if metadata[domain]['book'][slot] == 'not mentioned': + continue + if metadata[domain]['book'][slot] != '': + val = ['%s-book %s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['book'][slot])] + dialogue_state.append(val) + active = True + + for slot in metadata[domain]['semi']: + if metadata[domain]['semi'][slot] == 'not mentioned': + continue + elif metadata[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", 'don not care', + 'do not care', 'does not care']: + dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), 'do not care']) + active = True + elif metadata[domain]['semi'][slot]: + dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['semi'][slot])]) + active = True + + if active: + domains.append(domain) + + if get_domains: + return domains + return clean_dialoguestate(dialogue_state) + + +PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive'] +BOOLEAN = ['do not care', 'yes', 'no'] +DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday', + 'friday', 'saterday', 'sunday'] +QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more'] +TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)] +TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l] + +VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast', + 'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott', + 'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge', + 'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel', + 'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european', + 'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"} + +def map_values(value): + for old, new in VALUE_MAP.items(): + value = value.replace(old, new) + return value + +def clean_dialoguestate(states, is_acts=False): + # path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) + # path = os.path.join(path, 'data/multiwoz/value_dict.json') + # value_dict = json.load(open(path)) + clean_state = [] + for slot, value in states: + if 'pricerange' in slot: + d, s = slot.split('-', 1) + s = 'price range' + slot = f'{d}-{s}' + if value in PRICERANGE: + clean_state.append([slot, value]) + elif True in [v in value for v in PRICERANGE]: + value = [v for v in PRICERANGE if v in value][0] + clean_state.append([slot, value]) + elif value == '?' and is_acts: + clean_state.append([slot, value]) + else: + continue + elif 'parking' in slot or 'internet' in slot: + if value in BOOLEAN: + clean_state.append([slot, value]) + if value == 'free': + value = 'yes' + clean_state.append([slot, value]) + elif True in [v in value for v in BOOLEAN]: + value = [v for v in BOOLEAN if v in value][0] + clean_state.append([slot, value]) + elif value == '?' and is_acts: + clean_state.append([slot, value]) + else: + continue + elif 'day' in slot: + if value in DAYS: + clean_state.append([slot, value]) + elif True in [v in value for v in DAYS]: + value = [v for v in DAYS if v in value][0] + clean_state.append([slot, value]) + else: + continue + elif 'people' in slot or 'duration' in slot or 'stay' in slot: + if value in QUANTITIES: + clean_state.append([slot, value]) + elif True in [v in value for v in QUANTITIES]: + value = [v for v in QUANTITIES if v in value][0] + clean_state.append([slot, value]) + elif value == '?' and is_acts: + clean_state.append([slot, value]) + else: + try: + value = int(value) + if value >= 10: + value = '10 or more' + clean_state.append([slot, value]) + else: + continue + except: + continue + elif 'time' in slot or 'leaveat' in slot or 'arriveby' in slot: + if 'leaveat' in slot: + d, s = slot.split('-', 1) + s = 'leave at' + slot = f'{d}-{s}' + if 'arriveby' in slot: + d, s = slot.split('-', 1) + s = 'arrive by' + slot = f'{d}-{s}' + if value in TIME: + if value == 'do not care': + clean_state.append([slot, value]) + else: + h, m = value.split(':') + if int(m) % 5 == 0: + clean_state.append([slot, value]) + else: + m = round(int(m) / 5) * 5 + h = int(h) + if m == 60: + m = 0 + h += 1 + if h >= 24: + h -= 24 + value = '%02i:%02i' % (h, m) + clean_state.append([slot, value]) + elif True in [v in value for v in TIME]: + value = [v for v in TIME if v in value][0] + h, m = value.split(':') + if int(m) % 5 == 0: + clean_state.append([slot, value]) + else: + m = round(int(m) / 5) * 5 + h = int(h) + if m == 60: + m = 0 + h += 1 + if h >= 24: + h -= 24 + value = '%02i:%02i' % (h, m) + clean_state.append([slot, value]) + elif value == '?' and is_acts: + clean_state.append([slot, value]) + else: + continue + elif 'stars' in slot: + if len(value) == 1 or value == 'do not care': + clean_state.append([slot, value]) + elif value == '?' and is_acts: + clean_state.append([slot, value]) + elif len(value) > 1: + try: + value = int(value[0]) + value = str(value) + clean_state.append([slot, value]) + except: + continue + elif 'area' in slot: + if '|' in value: + value = value.split('|', 1)[0] + clean_state.append([slot, value]) + else: + if '|' in value: + value = value.split('|', 1)[0] + value = map_values(value) + # d, s = slot.split('-', 1) + # value = normalize_value(value_dict, d, s, value) + clean_state.append([slot, value]) + + return clean_state + + +# Module to process a dialogue and check its validity +def process_dialogue(dialogue, max_utt_len=128): + if len(dialogue['log']) % 2 != 0: + return None + + # Extract user and system utterances + usr_utts, sys_utts = [], [] + avg_len = sum(len(utt['text'].split(' ')) for utt in dialogue['log']) + avg_len = avg_len / len(dialogue['log']) + if avg_len > max_utt_len: + return None + + # If the first term is a system turn then ignore dialogue + if dialogue['log'][0]['metadata']: + return None + + usr, sys = None, None + for turn in dialogue['log']: + if not is_ascii(turn['text']): + return None + + if not usr or not sys: + if len(turn['metadata']) == 0: + usr = turn + else: + sys = turn + + if usr and sys: + states = build_dialoguestate(sys['metadata'], get_domains = False) + sys['dialogue_states'] = states + + usr_utts.append(usr) + sys_utts.append(sys) + usr, sys = None, None + + dial_clean = dict() + dial_clean['usr_log'] = usr_utts + dial_clean['sys_log'] = sys_utts + return dial_clean + + +# Get new domains +def get_act_domains(prev, crnt): + diff = {} + if not prev or not crnt: + return diff + + for ((prev_dom, prev_val), (crnt_dom, crnt_val)) in zip(prev.items(), crnt.items()): + assert prev_dom == crnt_dom + if prev_val != crnt_val: + diff[crnt_dom] = crnt_val + return diff + + +# Get current domains +def get_domains(dial_log, turn_id, prev_domain): + if turn_id == 1: + active = build_dialoguestate(dial_log[turn_id]['metadata'], get_domains=True) + acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else [] + acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']] + active += acts + crnt = active[0] if active else '' + else: + active = get_act_domains(dial_log[turn_id - 2]['metadata'], dial_log[turn_id]['metadata']) + active = list(active.keys()) + acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else [] + acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']] + active += acts + crnt = [prev_domain] if not active else active + crnt = crnt[0] + + return crnt + + +# Function to extract dialogue info from data +def extract_dialogue(dialogue, max_utt_len=50): + dialogue = process_dialogue(dialogue, max_utt_len) + if not dialogue: + return None + + usr_utts = [turn['text'] for turn in dialogue['usr_log']] + sys_utts = [turn['text'] for turn in dialogue['sys_log']] + # sys_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['sys_log']] + usr_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['usr_log']] + dialogue_states = [turn['dialogue_states'] for turn in dialogue['sys_log']] + domains = [turn['domain'] for turn in dialogue['usr_log']] + + # dial = [{'usr': u,'sys': s, 'usr_a': ua, 'sys_a': a, 'domain': d, 'ds': v} + # for u, s, ua, a, d, v in zip(usr_utts, sys_utts, usr_acts, sys_acts, domains, dialogue_states)] + dial = [{'usr': u,'sys': s, 'usr_a': ua, 'domain': d, 'ds': v} + for u, s, ua, d, v in zip(usr_utts, sys_utts, usr_acts, domains, dialogue_states)] + return dial + + +def format_acts(acts): + new_acts = [] + for key, item in acts.items(): + domain, intent = key.split('-', 1) + if domain.lower() in ACTIVE_DOMAINS + ['general']: + state = [] + for slot, value in item: + slot = str(REF_SYS_DA[domain].get(slot, slot)).lower() if domain in REF_SYS_DA else slot + value = clean_text(value) + slot = slot.replace('_', ' ').replace('ref', 'reference') + state.append([f'{domain.lower()}-{slot}', value]) + state = clean_dialoguestate(state, is_acts=True) + if domain == 'general': + if intent in ['thank', 'bye']: + state = [['general-none', 'none']] + else: + state = [] + for slot, value in state: + if slot not in ['train-people']: + slot = slot.split('-', 1)[-1] + new_acts.append([intent.lower(), domain.lower(), slot, value]) + + return new_acts + + +# Fix act labels +def fix_delexicalisation(turn): + if 'dialog_act' in turn: + for dom, act in turn['dialog_act'].items(): + if 'Attraction' in dom: + if 'restaurant_' in turn['text']: + turn['text'] = turn['text'].replace("restaurant", "attraction") + if 'hotel_' in turn['text']: + turn['text'] = turn['text'].replace("hotel", "attraction") + if 'Hotel' in dom: + if 'attraction_' in turn['text']: + turn['text'] = turn['text'].replace("attraction", "hotel") + if 'restaurant_' in turn['text']: + turn['text'] = turn['text'].replace("restaurant", "hotel") + if 'Restaurant' in dom: + if 'attraction_' in turn['text']: + turn['text'] = turn['text'].replace("attraction", "restaurant") + if 'hotel_' in turn['text']: + turn['text'] = turn['text'].replace("hotel", "restaurant") + + return turn + + +# Check if a character is an ascii character +def is_ascii(s): + return all(ord(c) < 128 for c in s) + + +# Insert white space +def separate_token(token, text): + sidx = 0 + while True: + # Find next instance of token + sidx = text.find(token, sidx) + if sidx == -1: + break + # If the token is already seperated continue to next + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + # Create white space separation around token + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + + +def clean_text(text): + # Replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text.strip().lower()) + + # Replace b&v or 'b and b' with 'bed and breakfast' + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # Fix apostrophies + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # Correct punctuation + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # Replace special characters + text = text.replace('-', ' ') + text = re.sub('[\"\<>@\(\)]', '', text) + + # Insert white space around special tokens: + for token in ['?', '.', ',', '!']: + text = separate_token(token, text) + + # insert white space for 's + text = separate_token('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + + # Perform pair replacements listed in the mapping.pair file + for fromx, tox in REPLACEMENTS: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # Remove multiple spaces + text = re.sub(' +', ' ', text) + + # Concatenate numbers eg '1 3' -> '13' + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text diff --git a/convlab/dst/setsumbt/process_mwoz_data.py b/convlab/dst/setsumbt/process_mwoz_data.py new file mode 100755 index 0000000000000000000000000000000000000000..701a523613961d83a5188fa9ab0cf786b19a5a7e --- /dev/null +++ b/convlab/dst/setsumbt/process_mwoz_data.py @@ -0,0 +1,99 @@ +import os +import json +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import torch +from tqdm import tqdm + +from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker +from convlab.util.multiwoz.lexicalize import deflat_da, flat_da + + +def load_data(path): + with open(path, 'r') as reader: + data = json.load(reader) + reader.close() + + return data + + +def load_tracker(model_checkpoint): + model = SetSUMBTTracker(model_path=model_checkpoint) + model.init_session() + + return model + + +def process_dialogue(dial, model, get_full_belief_state): + model.store_full_belief_state = get_full_belief_state + model.init_session() + + model.state['history'].append(['sys', '']) + processed_dial = [] + belief_state = {} + for turn in dial: + if not turn['metadata']: + state = model.update(turn['text']) + model.state['history'].append(['usr', turn['text']]) + + acts = model.state['user_action'] + acts = [[val.replace('-', ' ') for val in act] for act in acts] + acts = flat_da(acts) + acts = deflat_da(acts) + turn['dialog_act'] = acts + else: + model.state['history'].append(['sys', turn['text']]) + turn['metadata'] = model.state['belief_state'] + + if get_full_belief_state: + for slot, probs in model.full_belief_state.items(): + if slot not in belief_state: + belief_state[slot] = [probs[0]] + else: + belief_state[slot].append(probs[0]) + + processed_dial.append(turn) + + if get_full_belief_state: + belief_state = {slot: torch.cat(probs, 0).cpu() for slot, probs in belief_state.items()} + + return processed_dial, belief_state + + +def process_dialogues(data, model, get_full_belief_state=False): + processed_data = {} + belief_states = {} + for dial_id, dial in tqdm(data.items()): + dial['log'], bs = process_dialogue(dial['log'], model, get_full_belief_state) + processed_data[dial_id] = dial + if get_full_belief_state: + belief_states[dial_id] = bs + + return processed_data, belief_states + + +def get_arguments(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + + parser.add_argument('--model_path') + parser.add_argument('--data_path') + parser.add_argument('--get_full_belief_state', action='store_true') + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_arguments() + + print('Loading data and model...') + data = load_data(os.path.join(args.data_path, 'data.json')) + model = load_tracker(args.model_path) + + print('Processing data...\n') + data, belief_states = process_dialogues(data, model, get_full_belief_state=args.get_full_belief_state) + + print('Saving results...\n') + torch.save(belief_states, os.path.join(args.data_path, 'setsumbt_belief_states.bin')) + with open(os.path.join(args.data_path, 'setsumbt_data.json'), 'w') as writer: + json.dump(data, writer, indent=2) + writer.close() diff --git a/convlab/dst/setsumbt/run.py b/convlab/dst/setsumbt/run.py index e45bf129f0c9f2c5c1fba01d4b5eb80e29a5a1f0..b9c9a75b86d47cd5db733b4755d6af11f08b827d 100644 --- a/convlab/dst/setsumbt/run.py +++ b/convlab/dst/setsumbt/run.py @@ -33,8 +33,8 @@ def main(): if args.run_nbt: from convlab.dst.setsumbt.do.nbt import main main(args, config) - if args.run_evaluation: - from convlab.dst.setsumbt.do.evaluate import main + if args.run_calibration: + from convlab.dst.setsumbt.do.calibration import main main(args, config) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py deleted file mode 100644 index b58fc5bd17fdf486e544fe06c4f8c2bc0b83f8b3..0000000000000000000000000000000000000000 --- a/convlab/dst/setsumbt/tracker.py +++ /dev/null @@ -1,442 +0,0 @@ -import os -import json -import copy -import logging - -import torch -import transformers -from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer - -from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT -from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings -from convlab.dst.dst import DST -from convlab.util.custom_util import model_downloader - -USE_CUDA = torch.cuda.is_available() - - -class SetSUMBTTracker(DST): - """SetSUMBT Tracker object for Convlab dialogue system""" - - def __init__(self, - model_path: str = "", - model_type: str = "roberta", - return_turn_pooled_representation: bool = False, - return_confidence_scores: bool = False, - confidence_threshold='auto', - return_belief_state_entropy: bool = False, - return_belief_state_mutual_info: bool = False, - store_full_belief_state: bool = False): - """ - Args: - model_path: Model path or download URL - model_type: Transformer type (roberta/bert) - return_turn_pooled_representation: If true a turn level pooled representation is returned - return_confidence_scores: If true act confidence scores are included in the state - confidence_threshold: Confidence threshold value for constraints or option auto - return_belief_state_entropy: If true belief state distribution entropies are included in the state - return_belief_state_mutual_info: If true belief state distribution mutual infos are included in the state - store_full_belief_state: If true full belief state is stored within tracker object - """ - super(SetSUMBTTracker, self).__init__() - - self.model_type = model_type - self.model_path = model_path - self.return_turn_pooled_representation = return_turn_pooled_representation - self.return_confidence_scores = return_confidence_scores - self.confidence_threshold = confidence_threshold - self.return_belief_state_entropy = return_belief_state_entropy - self.return_belief_state_mutual_info = return_belief_state_mutual_info - self.store_full_belief_state = store_full_belief_state - if self.store_full_belief_state: - self.full_belief_state = {} - self.info_dict = {} - - # Download model if needed - if not os.path.exists(self.model_path): - # Get path /.../convlab/dst/setsumbt/multiwoz/models - download_path = os.path.dirname(os.path.abspath(__file__)) - download_path = os.path.join(download_path, 'models') - if not os.path.exists(download_path): - os.mkdir(download_path) - model_downloader(download_path, self.model_path) - # Downloadable model path format http://.../setsumbt_model_name.zip - self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') - self.model_path = os.path.join(download_path, self.model_path) - - # Select model type based on the encoder - if model_type == "roberta": - self.config = RobertaConfig.from_pretrained(self.model_path) - self.tokenizer = RobertaTokenizer - self.model = RobertaSetSUMBT - elif model_type == "bert": - self.config = BertConfig.from_pretrained(self.model_path) - self.tokenizer = BertTokenizer - self.model = BertSetSUMBT - else: - logging.debug("Name Error: Not Implemented") - - self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu') - - self.load_weights() - - def load_weights(self): - """Load model weights and model ontology""" - logging.info('Loading SetSUMBT pretrained model.') - self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name) - logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.') - self.model = self.model.from_pretrained(self.model_path, config=self.config) - logging.info(f'Model loaded from {self.model_path}.') - - # Transfer model to compute device and setup eval environment - self.model = self.model.to(self.device) - self.model.eval() - logging.info(f'Model transferred to device: {self.device}') - - logging.info('Loading model ontology') - f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r') - self.ontology = json.load(f) - f.close() - - db = torch.load(os.path.join(self.model_path, 'database', 'test.db')) - set_ontology_embeddings(self.model, db) - - if self.return_confidence_scores: - logging.info('Model returns user action and belief state confidence scores.') - self.get_thresholds(self.confidence_threshold) - logging.info('Uncertain Querying set up and thresholds set up at:') - logging.info(self.confidence_thresholds) - if self.return_belief_state_entropy: - logging.info('Model returns belief state distribution entropy scores (Total uncertainty).') - if self.return_belief_state_mutual_info: - logging.info('Model returns belief state distribution mutual information scores (Knowledge uncertainty).') - logging.info('Ontology loaded successfully.') - - def get_thresholds(self, threshold='auto') -> dict: - """ - Setup dictionary of domain specific confidence thresholds - - Args: - threshold: Threshold value or option auto - - Returns: - confidence_thresholds: Domain specific confidence thresholds - """ - self.confidence_thresholds = dict() - for domain, substate in self.ontology.items(): - for slot, slot_info in substate.items(): - # Auto thresholds are set based on the number of value candidates per slot - if domain not in self.confidence_thresholds: - self.confidence_thresholds[domain] = dict() - if threshold == 'auto': - thres = 1.0 / (float(len(slot_info['possible_values'])) - 2.1) - self.confidence_thresholds[domain][slot] = max(0.05, thres) - else: - self.confidence_thresholds[domain][slot] = max(0.05, threshold) - - return self.confidence_thresholds - - def init_session(self): - self.state = dict() - self.state['belief_state'] = dict() - for domain, substate in self.ontology.items(): - self.state['belief_state'][domain] = dict() - for slot, slot_info in substate.items(): - if slot_info['possible_values'] and slot_info['possible_values'] != ['?']: - self.state['belief_state'][domain][slot] = '' - self.state['history'] = [] - self.state['system_action'] = [] - self.state['user_action'] = [] - self.active_domains = {} - self.hidden_states = None - self.info_dict = {} - - def update(self, user_act: str = '') -> dict: - """ - Update user actions and dialogue and belief states. - - Args: - user_act: - - Returns: - - """ - prev_state = self.state - _output = self.predict(self.get_features(user_act)) - - # Format state entropy - if _output[5] is not None: - state_entropy = dict() - for slot, e in _output[5].items(): - domain, slot = slot.split('-', 1) - if domain not in state_entropy: - state_entropy[domain] = dict() - state_entropy[domain][slot] = e - else: - state_entropy = None - - # Format state mutual information - if _output[6] is not None: - state_mutual_info = dict() - for slot, mi in _output[5].items(): - domain, slot = slot.split('-', 1) - if domain not in state_mutual_info: - state_mutual_info[domain] = dict() - state_mutual_info[domain][slot] = mi[0, 0] - else: - state_mutual_info = None - - # Format all confidence scores - belief_state_confidence = None - if _output[4] is not None: - belief_state_confidence = dict() - belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4] - for slot, p in belief_state_conf.items(): - domain, slot = slot.split('-', 1) - if domain not in belief_state_confidence: - belief_state_confidence[domain] = dict() - if slot not in belief_state_confidence[domain]: - belief_state_confidence[domain][slot] = dict() - belief_state_confidence[domain][slot]['inform'] = p - - for slot, p in request_probs.items(): - domain, slot = slot.split('-', 1) - if domain not in belief_state_confidence: - belief_state_confidence[domain] = dict() - if slot not in belief_state_confidence[domain]: - belief_state_confidence[domain][slot] = dict() - belief_state_confidence[domain][slot]['request'] = p - - for domain, p in active_domain_probs.items(): - if domain not in belief_state_confidence: - belief_state_confidence[domain] = dict() - belief_state_confidence[domain]['none'] = {'inform': p} - - if 'general' not in belief_state_confidence: - belief_state_confidence['general'] = dict() - belief_state_confidence['general']['none'] = general_act_probs - - # Get new domain activation actions - new_domains = [d for d, active in _output[1].items() if active] - new_domains = [d for d in new_domains if not self.active_domains.get(d, False)] - self.active_domains = _output[1] - - user_acts = _output[2] - for domain in new_domains: - user_acts.append(['inform', domain, 'none', 'none']) - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - for domain, substate in _output[0].items(): - for slot, value in substate.items(): - value = '' if value == 'none' else value - value = 'dontcare' if value == 'do not care' else value - value = 'guesthouse' if value == 'guest house' else value - - if domain not in new_belief_state: - if domain == 'bus': - continue - else: - logging.debug('Error: domain <{}> not in belief state'.format(domain)) - - # Uncertainty clipping of state - if belief_state_confidence is not None: - threshold = self.confidence_thresholds[domain][slot] - if belief_state_confidence[domain][slot].get('inform', 1.0) < threshold: - value = '' - - new_belief_state[domain][slot] = value - if prev_state['belief_state'][domain][slot] != value: - user_acts.append(['inform', domain, slot, value]) - else: - bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>' - logging.debug(bug) - - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state - new_state['active_domains'] = self.active_domains - if belief_state_confidence is not None: - new_state['belief_state_probs'] = belief_state_confidence - if state_entropy is not None: - new_state['entropy'] = state_entropy - if state_mutual_info is not None: - new_state['mutual_information'] = state_mutual_info - - user_acts = [act for act in user_acts if act not in new_state['system_action']] - new_state['user_action'] = user_acts - - if _output[3] is not None: - new_state['turn_pooled_representation'] = _output[3] - - self.state = new_state - self.info_dict = copy.deepcopy(dict(new_state)) - - return self.state - - def predict(self, features: dict) -> tuple: - """ - Model forward pass and prediction post processing. - - Args: - features: Dictionary of model input features - - Returns: - out: Model predictions and uncertainty features - """ - state_mutual_info = None - with torch.no_grad(): - turn_pooled_representation = None - if self.return_turn_pooled_representation: - _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], hidden_state=self.hidden_states, - get_turn_pooled_representation=True) - belief_state = _outputs[0] - request_probs = _outputs[1] - active_domain_probs = _outputs[2] - general_act_probs = _outputs[3] - self.hidden_states = _outputs[4] - turn_pooled_representation = _outputs[5] - elif self.return_belief_state_mutual_info: - _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], hidden_state=self.hidden_states, - get_turn_pooled_representation=True, calculate_inform_mutual_info=True) - belief_state = _outputs[0] - request_probs = _outputs[1] - active_domain_probs = _outputs[2] - general_act_probs = _outputs[3] - self.hidden_states = _outputs[4] - state_mutual_info = _outputs[5] - else: - _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'], - attention_mask=features['attention_mask'], hidden_state=self.hidden_states, - get_turn_pooled_representation=False) - belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs - - # Convert belief state into dialog state - dialogue_state = dict() - for slot, probs in belief_state.items(): - dom, slot = slot.split('-', 1) - if dom not in dialogue_state: - dialogue_state[dom] = dict() - val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()] - if val != 'none': - dialogue_state[dom][slot] = val - - if self.store_full_belief_state: - self.full_belief_state = belief_state - - # Obtain model output probabilities - if self.return_confidence_scores: - state_entropy = None - if self.return_belief_state_entropy: - state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()} - state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()} - - # Confidence score is the max probability across all not "none" values candidates. - belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()} - _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()} - _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()} - _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()} - confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs) - else: - confidence_scores = None - state_entropy = None - - # Construct request action prediction - request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] - request_acts = [slot.split('-', 1) for slot in request_acts] - request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] - - # Construct active domain set - active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} - - # Construct general domain action - general_acts = general_act_probs[0, 0, :].argmax(-1).item() - general_acts = [[], ['bye'], ['thank']][general_acts] - general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] - - user_acts = request_acts + general_acts - - out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores) - out += (state_entropy, state_mutual_info) - return out - - def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor: - """ - Compute relative entrop for a probability distribution - - Args: - probs: Probability distributions - - Returns: - entropy: Relative entropy - """ - entropy = probs * torch.log(probs + 1e-8) - entropy = -entropy.sum() - # Maximum entropy of a K dimentional distribution is ln(K) - entropy /= torch.log(torch.tensor(probs.size(-1)).float()) - - return entropy - - def get_features(self, user_act: str) -> dict: - """ - Tokenize utterances and construct model input features - - Args: - user_act: User action string - - Returns: - features: Model input features - """ - # Extract system utterance from dialog history - context = self.state['history'] - if context: - if context[-1][0] != 'sys': - system_act = '' - else: - system_act = context[-1][-1] - else: - system_act = '' - - # Tokenize dialog - features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True, - max_length=self.config.max_turn_len, padding='max_length', - truncation='longest_first') - - input_ids = torch.tensor(features['input_ids']).reshape( - 1, 1, -1).to(self.device) if 'input_ids' in features else None - token_type_ids = torch.tensor(features['token_type_ids']).reshape( - 1, 1, -1).to(self.device) if 'token_type_ids' in features else None - attention_mask = torch.tensor(features['attention_mask']).reshape( - 1, 1, -1).to(self.device) if 'attention_mask' in features else None - features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask} - - return features - - -# if __name__ == "__main__": -# from convlab.policy.vector.vector_uncertainty import VectorUncertainty -# # from convlab.policy.vector.vector_binary import VectorBinary -# tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42', -# return_confidence_scores=True, confidence_threshold='auto', -# return_belief_state_entropy=True) -# vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds, -# use_masking=True) -# # vector = VectorBinary() -# tracker.init_session() -# -# state = tracker.update('hey. I need a cheap restaurant.') -# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# state = tracker.update('If you have something Asian that would be great.') -# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'], -# ['inform', 'restaurant', 'name', 'the golden wok']] -# state = tracker.update('Great. Where are they located?') -# tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# state = tracker.state -# state['terminated'] = False -# state['booked'] = {} -# -# print(state) -# print(vector.state_vectorize(state)) diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils.py index 23cec10ddd06872118af77bfac9ea856f3947677..75a6a1febd7510f1a6d152676b82d709abf177f0 100644 --- a/convlab/dst/setsumbt/utils.py +++ b/convlab/dst/setsumbt/utils.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf +# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Authors: Carel van Niekerk (niekerk@hhu.de) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,41 +15,57 @@ # limitations under the License. """SetSUMBT utils""" +import re import os import shutil from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from glob import glob from datetime import datetime +from google.cloud import storage -def get_args(base_models: dict): + +def get_args(MODELS): # Get arguments parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) # Optional - parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='') + parser.add_argument('--tensorboard_path', + help='Path to tensorboard', default='') parser.add_argument('--logging_path', help='Path for log file', default='') - parser.add_argument('--seed', help='Seed value for reproducibility', default=0, type=int) + parser.add_argument( + '--seed', help='Seed value for reproducability', default=0, type=int) # DATASET (Optional) - parser.add_argument('--dataset', help='Dataset Name (See Convlab 3 unified format for possible datasets', - default='multiwoz21') - parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0, - type=float) - parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int) - parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int) - parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int) - parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12, - type=int) - parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.') - parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int) - parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings', + parser.add_argument( + '--dataset', help='Dataset Name: multiwoz21/simr', default='multiwoz21') + parser.add_argument('--shrink_active_domains', help='Shrink active domains to only well represented test set domains', + action='store_true') + parser.add_argument( + '--data_dir', help='Data storage directory', default=None) + parser.add_argument( + '--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int) + parser.add_argument( + '--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int) + parser.add_argument( + '--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int) + parser.add_argument('--max_candidate_len', + help='Maximum number of tokens per value candidate', default=12, type=int) + parser.add_argument('--force_processing', action='store_true', + help='Force preprocessing of data.') + parser.add_argument('--data_sampling_size', + help='Resampled dataset size', default=-1, type=int) + parser.add_argument('--use_descriptions', help='Use slot descriptions rather than slot names for embeddings', action='store_true') # MODEL # Environment - parser.add_argument('--output_dir', help='Output storage directory', default=None) - parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', default='roberta') - parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None) + parser.add_argument( + '--output_dir', help='Output storage directory', default=None) + parser.add_argument('--model_type', help='Encoder Model Type: bert/roberta', + default='roberta') + parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', + default=None) parser.add_argument('--candidate_embedding_model_name', default=None, help='Name of the pretrained candidate embedding model.') @@ -58,73 +74,92 @@ def get_args(base_models: dict): action='store_true') parser.add_argument('--slot_attention_heads', help='Number of attention heads for slot conditioning', default=12, type=int) - parser.add_argument('--dropout_rate', help='Dropout Rate', default=0.3, type=float) - parser.add_argument('--nbt_type', help='Belief Tracker type: gru/lstm', default='gru') + parser.add_argument('--dropout_rate', help='Dropout Rate', + default=0.3, type=float) + parser.add_argument( + '--nbt_type', help='Belief Tracker type: gru/lstm', default='gru') parser.add_argument('--nbt_hidden_size', help='Hidden embedding size for the Neural Belief Tracker', default=300, type=int) - parser.add_argument('--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int) - parser.add_argument('--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true') + parser.add_argument( + '--nbt_layers', help='Number of RNN layers in the NBT', default=1, type=int) + parser.add_argument( + '--rnn_zero_init', help='Zero Initialise RNN hidden states', action='store_true') parser.add_argument('--distance_measure', default='cosine', help='Similarity measure for candidate scoring: cosine/euclidean') - parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int) - parser.add_argument('--no_set_similarity', action='store_true', help='Set True to not use set similarity') - parser.add_argument('--set_pooling', - help='Set pooling method for set similarity model using single embedding distances', + parser.add_argument( + '--ensemble_size', help='Number of models in ensemble', default=-1, type=int) + parser.add_argument('--set_similarity', action='store_true', + help='Set True to not use set similarity (Model tracks latent belief state as sequence and performs semantic similarity of sets)') + parser.add_argument('--set_pooling', help='Set pooling method for set similarity model using single embedding distances', default='cnn') - parser.add_argument('--candidate_pooling', - help='Pooling approach for non set based candidate representations: cls/mean', + parser.add_argument('--candidate_pooling', help='Pooling approach for non set based candidate representations: cls/mean', default='mean') - parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain', + parser.add_argument('--predict_actions', help='Model predicts user actions and active domain', action='store_true') # Loss - parser.add_argument('--loss_function', - help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/...', + parser.add_argument('--loss_function', help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/distillation/distribution_distillation', default='labelsmoothing') parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss', type=float) parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss', type=float) - parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float) - parser.add_argument('--annealing_base_temp', help='Ensemble Distribution distillation temp annealing base temp', + parser.add_argument('--ensemble_smoothing', + help='Ensemble distribution smoothing constant', type=float) + parser.add_argument('--annealing_base_temp', help='Ensemble Distribution destillation temp annealing base temp', type=float) - parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution distillation temp annealing cycle length', + parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution destillation temp annealing cycle length', type=float) - parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float) - parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', - type=float) - parser.add_argument('--user_request_loss_weight', - help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float) - parser.add_argument('--user_general_act_loss_weight', - help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float) - parser.add_argument('--active_domain_loss_weight', - help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument('--inhibiting_factor', + help='Inhibiting factor for Inhibited Softmax CE', type=float) + parser.add_argument('--label_smoothing', + help='Label smoothing coefficient.', type=float) + parser.add_argument( + '--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument( + '--user_request_loss_weight', help='Weight of the user request prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument( + '--user_general_act_loss_weight', help='Weight of the user general act prediction loss. 0.0<weight<=1.0', type=float) + parser.add_argument( + '--active_domain_loss_weight', help='Weight of the active domain prediction loss. 0.0<weight<=1.0', type=float) # TRAINING - parser.add_argument('--train_batch_size', help='Training Set Batch Size', default=8, type=int) - parser.add_argument('--max_training_steps', help='Maximum number of training update steps', default=-1, type=int) + parser.add_argument('--train_batch_size', + help='Training Set Batch Size', default=4, type=int) + parser.add_argument('--max_training_steps', help='Maximum number of training update steps', + default=-1, type=int) parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='Number of batches accumulated for one update step') - parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int) + parser.add_argument('--num_train_epochs', + help='Number of training epochs', default=50, type=int) parser.add_argument('--patience', help='Number of training steps without improving model before stopping.', - default=20, type=int) - parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float) - parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float) - parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float) - parser.add_argument('--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float) - parser.add_argument('--save_steps', help='Number of update steps between saving model', default=-1, type=int) - parser.add_argument('--keep_models', help='How many model checkpoints should be kept during training', - default=1, type=int) + default=25, type=int) + parser.add_argument( + '--weight_decay', help='Weight decay rate', default=0.01, type=float) + parser.add_argument('--learning_rate', + help='Initial Learning Rate', default=5e-5, type=float) + parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', + default=0.2, type=float) + parser.add_argument( + '--max_grad_norm', help='Maximum norm of the loss gradients', default=1.0, type=float) + parser.add_argument( + '--save_steps', help='Number of update steps between saving model', default=-1, type=int) + parser.add_argument( + '--keep_models', help='How many model checkpoints should be kept during training', default=1, type=int) # CALIBRATION - parser.add_argument('--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float) + parser.add_argument( + '--temp_scaling', help='Temperature scaling coefficient', default=1.0, type=float) # EVALUATION - parser.add_argument('--dev_batch_size', help='Dev Set Batch Size', default=16, type=int) - parser.add_argument('--test_batch_size', help='Test Set Batch Size', default=16, type=int) + parser.add_argument('--dev_batch_size', + help='Dev Set Batch Size', default=16, type=int) + parser.add_argument('--test_batch_size', + help='Test Set Batch Size', default=16, type=int) # COMPUTING - parser.add_argument('--n_gpu', help='Number of GPUs to use', default=1, type=int) + parser.add_argument( + '--n_gpu', help='Number of GPUs to use', default=1, type=int) parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O1', @@ -132,29 +167,32 @@ def get_args(base_models: dict): "See details at https://nvidia.github.io/apex/amp.html") # ACTIONS - parser.add_argument('--run_nbt', help='Run NBT script',action='store_true') - parser.add_argument('--run_evaluation', help='Run evaluation script', action='store_true') + parser.add_argument('--run_nbt', help='Run NBT script', + action='store_true') + parser.add_argument('--run_calibration', + help='Run calibration', action='store_true') # RUN_NBT ACTIONS - parser.add_argument('--do_train', help='Perform training', action='store_true') - parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true') - parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true') + parser.add_argument( + '--do_train', help='Perform training', action='store_true') + parser.add_argument( + '--do_eval', help='Perform model evaluation during training', action='store_true') + parser.add_argument( + '--do_test', help='Evaulate model on test data', action='store_true') args = parser.parse_args() - # Simplify args - args.set_similarity = not args.no_set_similarity - args.use_descriptions = not args.no_descriptions - args.predict_actions = not args.no_action_prediction - # Setup default directories + if not args.data_dir: + args.data_dir = os.path.dirname(os.path.abspath(__file__)) + args.data_dir = os.path.join(args.data_dir, 'data') + os.makedirs(args.data_dir, exist_ok=True) + if not args.output_dir: args.output_dir = os.path.dirname(os.path.abspath(__file__)) args.output_dir = os.path.join(args.output_dir, 'models') - name = 'SetSUMBT' if args.set_similarity else 'SUMBT' - name += '+ActPrediction' if args.predict_actions else '' - name += '-' + args.dataset - name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else '' + name = 'SetSUMBT' + name += '-Acts' if args.predict_actions else '' name += '-' + args.model_type name += '-' + args.nbt_type name += '-' + args.distance_measure @@ -170,6 +208,9 @@ def get_args(base_models: dict): args.kl_scaling_factor = 0.001 if not args.prior_constant: args.prior_constant = 1.0 + if args.loss_function == 'inhibitedce': + if not args.inhibiting_factor: + args.inhibiting_factor = 1.0 if args.loss_function == 'labelsmoothing': if not args.label_smoothing: args.label_smoothing = 0.05 @@ -192,8 +233,10 @@ def get_args(base_models: dict): if not args.active_domain_loss_weight: args.active_domain_loss_weight = 0.2 - args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(args.output_dir, 'tb_logs') - args.logging_path = args.logging_path if args.logging_path else os.path.join(args.output_dir, 'run.log') + args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join( + args.output_dir, 'tb_logs') + args.logging_path = args.logging_path if args.logging_path else os.path.join( + args.output_dir, 'run.log') # Default model_name's if not args.model_name_or_path: @@ -207,22 +250,28 @@ def get_args(base_models: dict): if not args.candidate_embedding_model_name: args.candidate_embedding_model_name = args.model_name_or_path - if args.model_type in base_models: - config_class = base_models[args.model_type][-2] + if args.model_type in MODELS: + configClass = MODELS[args.model_type][-2] else: raise NameError('NotImplemented') - config = build_config(config_class, args) + config = build_config(configClass, args) return args, config -def build_config(config_class, args): - config = config_class.from_pretrained(args.model_name_or_path) - if not os.path.exists(args.model_name_or_path): +def build_config(configClass, args): + if args.model_type == 'fasttext': + config = configClass.from_pretrained('bert-base-uncased') + config.model_type == 'fasttext' + config.fasttext_path = args.model_name_or_path + config.vocab_size = None + elif not os.path.exists(args.model_name_or_path): + config = configClass.from_pretrained(args.model_name_or_path) config.tokenizer_name = args.model_name_or_path - try: - config.tokenizer_name = config.tokenizer_name - except AttributeError: + elif 'tod-bert' in args.model_name_or_path.lower(): + config = configClass.from_pretrained(args.model_name_or_path) config.tokenizer_name = args.model_name_or_path + else: + config = configClass.from_pretrained(args.model_name_or_path) if args.candidate_embedding_model_name: config.candidate_embedding_model_name = args.candidate_embedding_model_name config.max_dialogue_len = args.max_dialogue_len diff --git a/convlab/dst/trippy/__init__.py b/convlab/dst/trippy/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/convlab/dst/trippy/multiwoz/__init__.py b/convlab/dst/trippy/multiwoz/__init__.py deleted file mode 100644 index d432596abfaf8c0410b2b520cd59725b92de8932..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from convlab.dst.trippy.multiwoz.trippy import TRIPPY diff --git a/convlab/dst/trippy/multiwoz/__init__.py~ b/convlab/dst/trippy/multiwoz/__init__.py~ deleted file mode 100644 index fed1a0ba2130022910438c88074a0791f5c795fe..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/__init__.py~ +++ /dev/null @@ -1 +0,0 @@ -from convlab2.dst.trippy.multiwoz.trippy import TRIPPY diff --git a/convlab/dst/trippy/multiwoz/modeling_bert_dst.prev.py b/convlab/dst/trippy/multiwoz/modeling_bert_dst.prev.py deleted file mode 100644 index 1b39f7010516ddf72e326f07864434502cea389d..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/modeling_bert_dst.prev.py +++ /dev/null @@ -1,213 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -#from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable) -#from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) -from transformers import (BertModel, BertConfig, BertPreTrainedModel) - -#@add_start_docstrings( -# """BERT Model with a classification heads for the DST task. """, -# BERT_START_DOCSTRING, -#) -class BertForDST(BertPreTrainedModel): - def __init__(self, config): - super(BertForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - for slot in self.slot_list: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - - # Head for aux task - if hasattr(config, "aux_task_def"): - self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class']))) - - self.init_weights() - - #@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None, - aux_task_def=None): - outputs = self.bert( - input_ids, - attention_mask=input_mask, - token_type_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - if aux_task_def is not None: - if aux_task_def['task_type'] == "classification": - aux_logits = getattr(self, 'aux_out_projection')(pooled_output) - aux_logits = self.dropout_heads(aux_logits) - aux_loss_fct = CrossEntropyLoss() - aux_loss = aux_loss_fct(aux_logits, class_label_id) - # add hidden states and attention if they are here - return (aux_loss,) + outputs[2:] - elif aux_task_def['task_type'] == "span": - aux_logits = getattr(self, 'aux_out_projection')(sequence_output) - aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1) - aux_start_logits = self.dropout_heads(aux_start_logits) - aux_end_logits = self.dropout_heads(aux_end_logits) - aux_start_logits = aux_start_logits.squeeze(-1) - aux_end_logits = aux_end_logits.squeeze(-1) - - # If we are on multi-GPU, split add a dimension - if len(start_pos.size()) > 1: - start_pos = start_pos.squeeze(-1) - if len(end_pos.size()) > 1: - end_pos = end_pos.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = aux_start_logits.size(1) # This is a single index - start_pos.clamp_(0, ignored_index) - end_pos.clamp_(0, ignored_index) - - aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos) - aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos) - aux_loss = (aux_start_loss + aux_end_loss) / 2.0 - return (aux_loss,) + outputs[2:] - else: - raise Exception("Unknown task_type") - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - total_loss = 0 - per_slot_per_example_loss = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_refer_logits = {} - for slot in self.slot_list: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - pooled_output_aux = pooled_output - class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) - - token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) - start_logits, end_logits = token_logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = start_logits - per_slot_end_logits[slot] = end_logits - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - if len(end_pos[slot].size()) > 1: - end_pos[slot] = end_pos[slot].squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) # This is a single index - start_pos[slot].clamp_(0, ignored_index) - end_pos[slot].clamp_(0, ignored_index) - - class_loss_fct = CrossEntropyLoss(reduction='none') - token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) - refer_loss_fct = CrossEntropyLoss(reduction='none') - - start_loss = token_loss_fct(start_logits, start_pos[slot]) - end_loss = token_loss_fct(end_logits, end_pos[slot]) - token_loss = (start_loss + end_loss) / 2.0 - - token_is_pointable = (start_pos[slot] > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - class_loss = class_loss_fct(class_logits, class_label_id[slot]) - - if self.refer_index > -1: - per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss - else: - per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - - total_loss += per_example_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - - # add hidden states and attention if they are here - outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:] - - return outputs diff --git a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py b/convlab/dst/trippy/multiwoz/modeling_bert_dst.py deleted file mode 100644 index 8dc899344dc7e17883096997782ea2f7bf85d0f6..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers import (BertModel, BertPreTrainedModel) - - -class BertForDST(BertPreTrainedModel): - def __init__(self, config): - super(BertForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.stack_token_logits = config.dst_stack_token_logits - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - for slot in self.slot_list: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - - self.init_weights() - - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None): - outputs = self.bert( - input_ids, - attention_mask=input_mask, - token_type_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - total_loss = 0 - per_slot_per_example_loss = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_refer_logits = {} - for slot in self.slot_list: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - pooled_output_aux = pooled_output - class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) - - token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) - start_logits, end_logits = token_logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = start_logits - per_slot_end_logits[slot] = end_logits - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - if len(end_pos[slot].size()) > 1: - end_pos[slot] = end_pos[slot].squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) # This is a single index - start_pos[slot].clamp_(0, ignored_index) - end_pos[slot].clamp_(0, ignored_index) - - class_loss_fct = CrossEntropyLoss(reduction='none') - token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) - refer_loss_fct = CrossEntropyLoss(reduction='none') - - if not self.stack_token_logits: - start_loss = token_loss_fct(start_logits, start_pos[slot]) - end_loss = token_loss_fct(end_logits, end_pos[slot]) - else: - start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot]) - end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot]) - - token_loss = (start_loss + end_loss) / 2.0 - - token_is_pointable = (start_pos[slot] > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - class_loss = class_loss_fct(class_logits, class_label_id[slot]) - - if self.refer_index > -1: - per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss - else: - per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - - total_loss += per_example_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - - # add hidden states and attention if they are here - outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:] - - return outputs diff --git a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py~ b/convlab/dst/trippy/multiwoz/modeling_bert_dst.py~ deleted file mode 100644 index 809c26c5b449d2c800d8407e6dc2c561ff253d33..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/modeling_bert_dst.py~ +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers import (BertModel, BertPreTrainedModel) - - -class BertForDST(BertPreTrainedModel): - def __init__(self, config): - super(BertForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.stack_token_logits = config.dst_stack_token_logits - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - for slot in self.slot_list: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - - self.init_weights() - - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None): - outputs = self.bert( - input_ids, - attention_mask=input_mask, - token_type_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - total_loss = 0 - per_slot_per_example_loss = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_refer_logits = {} - for slot in self.slot_list: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - pooled_output_aux = pooled_output - class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) - - token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) - start_logits, end_logits = token_logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = start_logits - per_slot_end_logits[slot] = end_logits - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - if len(end_pos[slot].size()) > 1: - end_pos[slot] = end_pos[slot].squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) # This is a single index - start_pos[slot].clamp_(0, ignored_index) - end_pos[slot].clamp_(0, ignored_index) - - class_loss_fct = CrossEntropyLoss(reduction='none') - token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) - refer_loss_fct = CrossEntropyLoss(reduction='none') - - if not self.stack_token_logits: - start_loss = token_loss_fct(start_logits, start_pos[slot]) - end_loss = token_loss_fct(end_logits, end_pos[slot]) - else: - start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot]) - end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot]) - - token_loss = (start_loss + end_loss) / 2.0 - - token_is_pointable = (start_pos[slot] > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - class_loss = class_loss_fct(class_logits, class_label_id[slot]) - - if self.refer_index > -1: - per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss - else: - per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - - total_loss += per_example_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - - # add hidden states and attention if they are here - outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:] - - return outputs diff --git a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.prev.py b/convlab/dst/trippy/multiwoz/modeling_roberta_dst.prev.py deleted file mode 100644 index 5faf311318714007421242467d2579dfed99bad9..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.prev.py +++ /dev/null @@ -1,237 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -#from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable) -#from transformers.modeling_utils import (PreTrainedModel) -#from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, -# ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm) -from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel) - - -#class RobertaPreTrainedModel(PreTrainedModel): -# """ An abstract class to handle weights initialization and -# a simple interface for dowloading and loading pretrained models. -# """ -# config_class = RobertaConfig -# pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP -# base_model_prefix = "roberta" -# -# def _init_weights(self, module): -# """ Initialize the weights """ -# if isinstance(module, (nn.Linear, nn.Embedding)): -# # Slightly different from the TF version which uses truncated_normal for initialization -# # cf https://github.com/pytorch/pytorch/pull/5617 -# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) -# elif isinstance(module, BertLayerNorm): -# module.bias.data.zero_() -# module.weight.data.fill_(1.0) -# if isinstance(module, nn.Linear) and module.bias is not None: -# module.bias.data.zero_() - - -#@add_start_docstrings( -# """RoBERTa Model with classification heads for the DST task. """, -# ROBERTA_START_DOCSTRING, -#) -class RobertaForDST(RobertaPreTrainedModel): - def __init__(self, config): - super(RobertaForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - self.roberta = RobertaModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - for slot in self.slot_list: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - - # Head for aux task - if hasattr(config, "aux_task_def"): - self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class']))) - - self.init_weights() - - #@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None, - aux_task_def=None): - outputs = self.roberta( - input_ids, - attention_mask=input_mask, - token_type_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - if aux_task_def is not None: - if aux_task_def['task_type'] == "classification": - aux_logits = getattr(self, 'aux_out_projection')(pooled_output) - aux_logits = self.dropout_heads(aux_logits) - aux_loss_fct = CrossEntropyLoss() - aux_loss = aux_loss_fct(aux_logits, class_label_id) - # add hidden states and attention if they are here - return (aux_loss,) + outputs[2:] - elif aux_task_def['task_type'] == "span": - aux_logits = getattr(self, 'aux_out_projection')(sequence_output) - aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1) - aux_start_logits = self.dropout_heads(aux_start_logits) - aux_end_logits = self.dropout_heads(aux_end_logits) - aux_start_logits = aux_start_logits.squeeze(-1) - aux_end_logits = aux_end_logits.squeeze(-1) - - # If we are on multi-GPU, split add a dimension - if len(start_pos.size()) > 1: - start_pos = start_pos.squeeze(-1) - if len(end_pos.size()) > 1: - end_pos = end_pos.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = aux_start_logits.size(1) # This is a single index - start_pos.clamp_(0, ignored_index) - end_pos.clamp_(0, ignored_index) - - aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos) - aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos) - aux_loss = (aux_start_loss + aux_end_loss) / 2.0 - return (aux_loss,) + outputs[2:] - else: - raise Exception("Unknown task_type") - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - total_loss = 0 - per_slot_per_example_loss = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_refer_logits = {} - for slot in self.slot_list: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - pooled_output_aux = pooled_output - class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) - - token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) - start_logits, end_logits = token_logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = start_logits - per_slot_end_logits[slot] = end_logits - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - if len(end_pos[slot].size()) > 1: - end_pos[slot] = end_pos[slot].squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) # This is a single index - start_pos[slot].clamp_(0, ignored_index) - end_pos[slot].clamp_(0, ignored_index) - - class_loss_fct = CrossEntropyLoss(reduction='none') - token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) - refer_loss_fct = CrossEntropyLoss(reduction='none') - - start_loss = token_loss_fct(start_logits, start_pos[slot]) - end_loss = token_loss_fct(end_logits, end_pos[slot]) - token_loss = (start_loss + end_loss) / 2.0 - - token_is_pointable = (start_pos[slot] > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - class_loss = class_loss_fct(class_logits, class_label_id[slot]) - - if self.refer_index > -1: - per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss - else: - per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - - total_loss += per_example_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - - # add hidden states and attention if they are here - outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:] - - return outputs diff --git a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py b/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py deleted file mode 100644 index f4c2d773f80551965d0181c97c5fba64bbfb06b1..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel) - - -class RobertaForDST(RobertaPreTrainedModel): - def __init__(self, config): - super(RobertaForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.stack_token_logits = config.dst_stack_token_logits - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - self.roberta = RobertaModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - for slot in self.slot_list: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - - self.init_weights() - - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None): - outputs = self.roberta( - input_ids, - attention_mask=input_mask, - token_type_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - total_loss = 0 - per_slot_per_example_loss = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_refer_logits = {} - for slot in self.slot_list: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - pooled_output_aux = pooled_output - class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) - - token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) - start_logits, end_logits = token_logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = start_logits - per_slot_end_logits[slot] = end_logits - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - if len(end_pos[slot].size()) > 1: - end_pos[slot] = end_pos[slot].squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) # This is a single index - start_pos[slot].clamp_(0, ignored_index) - end_pos[slot].clamp_(0, ignored_index) - - class_loss_fct = CrossEntropyLoss(reduction='none') - token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) - refer_loss_fct = CrossEntropyLoss(reduction='none') - - if not self.stack_token_logits: - start_loss = token_loss_fct(start_logits, start_pos[slot]) - end_loss = token_loss_fct(end_logits, end_pos[slot]) - else: - start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot]) - end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot]) - - token_loss = (start_loss + end_loss) / 2.0 - - token_is_pointable = (start_pos[slot] > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - class_loss = class_loss_fct(class_logits, class_label_id[slot]) - - if self.refer_index > -1: - per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss - else: - per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - - total_loss += per_example_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - - # add hidden states and attention if they are here - outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + (pooled_output,) + outputs[2:] - - return outputs diff --git a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py~ b/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py~ deleted file mode 100644 index cdc2996f727c30bfa7b66e1b8bea2af28514a5a5..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/modeling_roberta_dst.py~ +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel) - - -class RobertaForDST(RobertaPreTrainedModel): - def __init__(self, config): - super(RobertaForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.stack_token_logits = config.dst_stack_token_logits - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - self.roberta = RobertaModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - for slot in self.slot_list: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - - self.init_weights() - - def forward(self, - input_ids, - input_mask=None, - segment_ids=None, - position_ids=None, - head_mask=None, - start_pos=None, - end_pos=None, - inform_slot_id=None, - refer_id=None, - class_label_id=None, - diag_state=None): - outputs = self.roberta( - input_ids, - attention_mask=input_mask, - token_type_ids=segment_ids, - position_ids=position_ids, - head_mask=head_mask - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - total_loss = 0 - per_slot_per_example_loss = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_refer_logits = {} - for slot in self.slot_list: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - pooled_output_aux = pooled_output - class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) - - token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) - start_logits, end_logits = token_logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = start_logits - per_slot_end_logits[slot] = end_logits - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - if len(end_pos[slot].size()) > 1: - end_pos[slot] = end_pos[slot].squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) # This is a single index - start_pos[slot].clamp_(0, ignored_index) - end_pos[slot].clamp_(0, ignored_index) - - class_loss_fct = CrossEntropyLoss(reduction='none') - token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) - refer_loss_fct = CrossEntropyLoss(reduction='none') - - if not self.stack_token_logits: - start_loss = token_loss_fct(start_logits, start_pos[slot]) - end_loss = token_loss_fct(end_logits, end_pos[slot]) - else: - start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot]) - end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot]) - - token_loss = (start_loss + end_loss) / 2.0 - - token_is_pointable = (start_pos[slot] > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - class_loss = class_loss_fct(class_logits, class_label_id[slot]) - - if self.refer_index > -1: - per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss - else: - per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - - total_loss += per_example_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - - # add hidden states and attention if they are here - outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:] - - return outputs diff --git a/convlab/dst/trippy/multiwoz/trippy.bck.py b/convlab/dst/trippy/multiwoz/trippy.bck.py deleted file mode 100644 index 0a41812a3718115faee90099aafdfcc25378bd50..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/trippy.bck.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright 2021 Heinrich Heine University Duesseldorf -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import copy - -import torch -from transformers import (BertConfig, BertTokenizer, - RobertaConfig, RobertaTokenizer) - -from convlab2.dst.trippy.multiwoz.modeling_bert_dst import (BertForDST) -from convlab2.dst.trippy.multiwoz.modeling_roberta_dst import (RobertaForDST) - -from convlab2.dst.dst import DST -from convlab2.util.multiwoz.state import default_state -from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from convlab2.nlu.jointBERT.multiwoz import BERTNLU - -import pdb - - -MODEL_CLASSES = { - 'bert': (BertConfig, BertForDST, BertTokenizer), - 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), -} - - -TEMPLATE_STATE = { - "attraction": { - "type": "", - "name": "", - "area": "" - }, - "hotel": { - "name": "", - "area": "", - "parking": "", - "price range": "", - "stars": "", - "internet": "", - "type": "", - "book stay": "", - "book day": "", - "book people": "" - }, - "restaurant": { - "food": "", - "price range": "", - "name": "", - "area": "", - "book time": "", - "book day": "", - "book people": "" - }, - "taxi": { - "leave at": "", - "destination": "", - "departure": "", - "arrive by": "" - }, - "train": { - "leave at": "", - "destination": "", - "day": "", - "arrive by": "", - "departure": "", - "book people": "" - }, - "hospital": { - "department": "" - } -} - - -SLOT_MAP_TRIPPY_TO_UDF = { - 'hotel': { - 'pricerange': 'price range', - 'book_stay': 'book stay', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode' - }, - 'restaurant': { - 'pricerange': 'price range', - 'book_time': 'book time', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode' - }, - 'taxi': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'car': 'type', - 'car type': 'type', - 'depart': 'departure', - 'dest': 'destination' - }, - 'train': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'book_people': 'book people', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'depart': 'departure', - 'dest': 'destination', - 'id': 'train id', - 'people': 'book people', - 'time': 'duration', - 'ticket': 'price', - 'trainid': 'train id' - }, - 'attraction': { - 'post': 'postcode', - 'addr': 'address', - 'fee': 'entrance fee', - 'price': 'price range' - }, - 'general': {}, - 'hospital': { - 'post': 'postcode', - 'addr': 'address' - }, - 'police': { - 'post': 'postcode', - 'addr': 'address' - } -} - - -class TRIPPY(DST): - def print_header(self): - print(" _________ ________ ___ ________ ________ ___ ___ ") - print("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /|") - print("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /") - print(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / / ") - print(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / / ") - print(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / ") - print(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / ") - print(" (c) 2022 Heinrich Heine University \|___|/ ") - print() - - def __init__(self, model_type="roberta", model_name="roberta-base", model_path="", nlu_path=""): - super(TRIPPY, self).__init__() - - self.print_header() - - self.model_type = model_type.lower() - self.model_name = model_name.lower() - self.model_path = model_path - self.nlu_path = nlu_path - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] - self.config = self.config_class.from_pretrained(self.model_path) - # TODO: update config (parameters) - - self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - - self.load_weights() - - def load_weights(self): - self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name) # TODO: do_lower_case=args.do_lower_case ? - self.model = self.model_class.from_pretrained(self.model_path, config=self.config) - self.model.to(self.device) - self.model.eval() - self.nlu = BERTNLU(model_file=self.nlu_path) # TODO: remove, once TripPy takes over its task - - def init_session(self): - self.state = default_state() # Initialise as empty state - self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE) - # TODO: define internal variables here as well that are tracked but not forwarded - - def update(self, user_act=''): - prev_state = self.state - - # TODO: add asserts to check format. if wrong, print suggested config - #print("--") - - # --- Get inform memory and auxiliary features --- - - # If system_action is plain text, get acts using NLU - if isinstance(prev_state['system_action'], str): - acts, _ = self.get_acts(prev_state['system_action']) - elif isinstance(prev_state['system_action'], list): - acts = prev_state['system_action'] - else: - raise Exception('Unknown format for system action:', prev_state['system_action']) - inform_aux, inform_mem = self.get_inform_aux(acts) - printed_inform_mem = False - for s in inform_mem: - if inform_mem[s] != 'none': - if not printed_inform_mem: - print("DST: inform_mem:") - print(s, ':', inform_mem[s]) - printed_inform_mem = True - - # --- Tokenize dialogue context and feed DST model --- - - features = self.get_features(self.state['history'], ds_aux=self.ds_aux, inform_aux=inform_aux) - pred_states, cls_representation = self.predict(features, inform_mem) - - # --- Update ConvLab-style dialogue state --- - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - user_acts = [] - for state, value in pred_states.items(): - if value == 'none': - continue - domain, slot = state.split('-', 1) - # TODO: value normalizations? - if domain == 'hotel' and slot == 'type': - value = "hotel" if value == "yes" else "guesthouse" - # TODO: needed? - if domain not in new_belief_state: - if domain == 'bus': - continue - else: - raise Exception('Domain <{}> not in belief state'.format(domain)) - slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot) - if slot in new_belief_state[domain]: - new_belief_state[domain][slot] = value # TODO: value normalization? - user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value]) # TODO: value normalization? - else: - raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - # TODO: only for debugging for now! - if re.search("not.*book.*", prev_state['user_action']) is not None: - user_acts.append(['inform', 'train', 'notbook', 'none']) - - # Update request_state - #new_request_state = copy.deepcopy(prev_state['request_state']) - - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state - #new_state['request_state'] = new_request_state - - # Get requestable slots from NLU, until implemented in DST. - nlu_user_acts, nlu_system_acts = self.get_acts(user_act) - for e in nlu_user_acts: - nlu_a, nlu_d, nlu_s, nlu_v = e - nlu_a = nlu_a.lower() - nlu_d = nlu_d.lower() - nlu_s = nlu_s.lower() - nlu_v = nlu_v.lower() - if nlu_a != 'inform': - user_acts.append([nlu_a, nlu_d, SLOT_MAP_TRIPPY_TO_UDF[nlu_d].get(nlu_s, nlu_s), nlu_v]) - new_state['system_action'] = nlu_system_acts # Empty when DST for user -> needed? - new_state['user_action'] = user_acts - - #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu? - - self.state = new_state - - # (Re)set internal states - if self.state['terminated']: - #print("!!! RESET DS_AUX") - self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - else: - self.ds_aux = self.update_ds_aux(self.state['belief_state']) - #print("ds:", [self.ds_aux[s][0].item() for s in self.ds_aux]) - - #print("--") - return self.state - - def predict(self, features, inform_mem): - with torch.no_grad(): - outputs = self.model(input_ids=features['input_ids'], - input_mask=features['attention_mask'], - inform_slot_id=features['inform_slot_id'], - diag_state=features['diag_state']) - - input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! - - #total_loss = outputs[0] - #per_slot_per_example_loss = outputs[1] - per_slot_class_logits = outputs[2] - per_slot_start_logits = outputs[3] - per_slot_end_logits = outputs[4] - per_slot_refer_logits = outputs[5] - - cls_representation = outputs[6] - - # TODO: maybe add assert to check that batch=1 - - predictions = {slot: 'none' for slot in self.config.dst_slot_list} - - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - start_logits = per_slot_start_logits[slot][0] - end_logits = per_slot_end_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - start_prediction = int(start_logits.argmax()) - end_prediction = int(end_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if class_prediction == self.config.dst_class_types.index('dontcare'): - predictions[slot] = 'dontcare' - elif class_prediction == self.config.dst_class_types.index('copy_value'): - predictions[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) - predictions[slot] = re.sub("(^| )##", "", predictions[slot]) - if "\u0120" in predictions[slot]: - predictions[slot] = re.sub(" ", "", predictions[slot]) - predictions[slot] = re.sub("\u0120", " ", predictions[slot]) - predictions[slot] = predictions[slot].strip() - elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): - predictions[slot] = "yes" # 'true' - elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): - predictions[slot] = "no" # 'false' - elif class_prediction == self.config.dst_class_types.index('inform'): - #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot]) - predictions[slot] = inform_mem[slot] - # Referral case is handled below - - # Referral case. All other slot values need to be seen first in order - # to be able to do this correctly. - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): - # Only slots that have been mentioned before can be referred to. - # One can think of a situation where one slot is referred to in the same utterance. - # This phenomenon is however currently not properly covered in the training data - # label generation process. - predictions[slot] = predictions[self.config.dst_slot_list[refer_prediction - 1]] - - if class_prediction > 0: - print(" ", slot, "->", class_prediction, ",", predictions[slot]) - - return predictions, cls_representation - - def get_features(self, context, ds_aux=None, inform_aux=None): - assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models - input_tokens = ['<s>'] - for e_itr, e in enumerate(reversed(context)): - #input_tokens.append(e[1].lower() if e[1] != 'null' else ' ') # TODO: normalise text - input_tokens.append(e[1] if e[1] != 'null' else ' ') # TODO: normalise text - if e_itr < 2: - input_tokens.append('</s> </s>') - if e_itr == 0: - input_tokens.append('</s> </s>') - input_tokens.append('</s>') - input_tokens = ' '.join(input_tokens) - - # TODO: delex sys utt somehow, or refrain from using delex for sys utts? - features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length) - - input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device) - attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device) - features = {'input_ids': input_ids, - 'attention_mask': attention_mask}, - 'inform_slot_id': inform_aux, - 'diag_state': ds_aux} - - return features - - def update_ds_aux(self, state, terminated=False): - ds_aux = copy.deepcopy(self.ds_aux) # TODO: deepcopy necessary? just update class variable? - for slot in self.config.dst_slot_list: - d, s = slot.split('-') - ds_aux[slot][0] = int(state[d][SLOT_MAP_TRIPPY_TO_UDF[d].get(s, s)] != '') - return ds_aux - - # TODO: consider "booked" values? - def get_inform_aux(self, state): - inform_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - inform_mem = {slot: 'none' for slot in self.config.dst_slot_list} - for e in state: - #print(e) - #pdb.set_trace() - a, d, s, v = e - if a in ['inform', 'recommend', 'select', 'book', 'offerbook']: - #ds_d = d.lower() - #if s in REF_SYS_DA[d]: - # ds_s = REF_SYS_DA[d][s] - #elif s in REF_SYS_DA['Booking']: - # ds_s = "book_" + REF_SYS_DA['Booking'][s] - #else: - # ds_s = s.lower() - # #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d)) - slot = "%s-%s" % (d, s) - if slot in inform_aux: - inform_aux[slot][0] = 1 - inform_mem[slot] = v - return inform_aux, inform_mem - - # TODO: fix, still a mess... - def get_acts(self, user_act): - context = self.state['history'] - if context: - if context[-1][0] != 'sys': - system_act = '' - context = [t for s,t in context] - else: - system_act = context[-1][-1] - context = [t for s,t in context[:-1]] - else: - system_act = '' - context = [''] - - #print(" SYS:", system_act, context) - system_acts = self.nlu.predict(system_act, context=context) - - context.append(system_act) - #print(" USR:", user_act, context) - user_acts = self.nlu.predict(user_act, context=context) - - return user_acts, system_acts - - -# if __name__ == "__main__": -# tracker = TRIPPY(model_type='roberta', model_path='/path/to/model', -# nlu_path='/path/to/nlu') -# tracker.init_session() -# state = tracker.update('hey. I need a cheap restaurant.') -# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# state = tracker.update('If you have something Asian that would be great.') -# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# state = tracker.update('Great. Where are they located?') -# tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# print(tracker.state) diff --git a/convlab/dst/trippy/multiwoz/trippy.py b/convlab/dst/trippy/multiwoz/trippy.py deleted file mode 100644 index 6bebf35341e4284ca6c911a7f797149b6c265069..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/trippy.py +++ /dev/null @@ -1,653 +0,0 @@ -# Copyright 2021 Heinrich Heine University Duesseldorf -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re -import json -import copy - -import torch -from transformers import (BertConfig, BertTokenizer, - RobertaConfig, RobertaTokenizer) - -from convlab.dst.trippy.multiwoz.modeling_bert_dst import (BertForDST) -from convlab.dst.trippy.multiwoz.modeling_roberta_dst import (RobertaForDST) - -from convlab.dst.dst import DST -from convlab.util.multiwoz.state import default_state -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from convlab.nlu.jointBERT.multiwoz import BERTNLU -from convlab.nlg.template.multiwoz import TemplateNLG -from convlab.dst.rule.multiwoz.dst_util import normalize_value - -from convlab.util import relative_import_module_from_unified_datasets -ONTOLOGY = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'ontology') -TEMPLATE_STATE = ONTOLOGY['state'] - -import pdb -import time - - -MODEL_CLASSES = { - 'bert': (BertConfig, BertForDST, BertTokenizer), - 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), -} - - -SLOT_MAP_TRIPPY_TO_UDF = { - 'hotel': { - 'pricerange': 'price range', - 'book_stay': 'book stay', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode', - 'price': 'price range', - 'people': 'book people' - }, - 'restaurant': { - 'pricerange': 'price range', - 'book_time': 'book time', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode', - 'price': 'price range', - 'people': 'book people' - }, - 'taxi': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'car': 'type', - 'car type': 'type', - 'depart': 'departure', - 'dest': 'destination' - }, - 'train': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'book_people': 'book people', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'depart': 'departure', - 'dest': 'destination', - 'id': 'train id', - 'people': 'book people', - 'time': 'duration', - 'ticket': 'price', - 'trainid': 'train id' - }, - 'attraction': { - 'post': 'postcode', - 'addr': 'address', - 'fee': 'entrance fee', - 'price': 'entrance fee' - }, - 'general': {}, - 'hospital': { - 'post': 'postcode', - 'addr': 'address' - }, - 'police': { - 'post': 'postcode', - 'addr': 'address' - } -} - - -class TRIPPY(DST): - def print_header(self): - print(" _________ ________ ___ ________ ________ ___ ___ ") - print("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /|") - print("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /") - print(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / / ") - print(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / / ") - print(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / ") - print(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / ") - print(" (c) 2022 Heinrich Heine University \|___|/ ") - print() - - def print_dialog(self, hst): - #print("Dialogue %s, turn %s:" % (self.global_diag_cnt, int(len(hst) / 2) - 1)) - print("Dialogue %s, turn %s:" % (self.global_diag_cnt, self.global_turn_cnt)) - for utt in hst[:-2]: - print(" \033[92m%s\033[0m" % (utt)) - if len(hst) > 1: - print(" ", hst[-2]) - print(" ", hst[-1]) - - def print_inform_memory(self, inform_mem): - print("Inform memory:") - is_all_none = True - for s in inform_mem: - if inform_mem[s] != 'none': - print(" %s = %s" % (s, inform_mem[s])) - is_all_none = False - if is_all_none: - print(" -") - - def eval_user_acts(self, user_act, user_acts): - print("User acts:") - for ua in user_acts: - if ua not in user_act: - print(" \033[33m%s\033[0m" % (ua)) - else: - print(" \033[92m%s\033[0m" % (ua)) - for ua in user_act: - if ua not in user_acts: - print(" \033[91m%s\033[0m" % (ua)) - - def eval_dialog_state(self, state_updates, new_belief_state): - print("Dialogue state:") - for d in self.gt_belief_state: - print(" %s:" % (d)) - for s in new_belief_state[d]: - is_printed = False - is_updated = False - if state_updates[d][s] > 0: - is_updated = True - if is_updated: - print("\033[3m", end='') - if new_belief_state[d][s] != self.gt_belief_state[d][s]: - self.global_eval_stats[d][s]['FP'] += 1 - if self.gt_belief_state[d][s] == '': - print(" \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='') - else: - print(" \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]), end='') - self.global_eval_stats[d][s]['FN'] += 1 - is_printed = True - elif new_belief_state[d][s] != '': - print(" \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='') - self.global_eval_stats[d][s]['TP'] += 1 - is_printed = True - if is_updated: - print(" (%s)" % (self.config.dst_class_types[state_updates[d][s]])) - elif is_printed: - print() - - def eval_print_stats(self): - print("Statistics:") - for d in self.global_eval_stats: - for s in self.global_eval_stats[d]: - TP = self.global_eval_stats[d][s]['TP'] - FP = self.global_eval_stats[d][s]['FP'] - FN = self.global_eval_stats[d][s]['FN'] - prec = TP / ( TP + FP + 1e-8) - rec = TP / ( TP + FN + 1e-8) - f1 = 2 * ((prec * rec) / (prec + rec + 1e-8)) - print(" %s %s Recall: %.2f, Precision: %.2f, F1: %.2f" % (d, s, rec, prec, f1)) - - def __init__(self, model_type="roberta", - model_name="roberta-base", - model_path="", - nlu_path="", - no_eval=False, - no_history=False, - no_normalize_value=False, - gt_user_acts=False, - gt_ds=False, - gt_request_acts=False): - super(TRIPPY, self).__init__() - - self.print_header() - - self.model_type = model_type.lower() - self.model_name = model_name.lower() - self.model_path = model_path - self.nlu_path = nlu_path - self.no_eval = no_eval - self.no_history = no_history - self.no_normalize_value = no_normalize_value - self.gt_user_acts = gt_user_acts - self.gt_ds = gt_ds - self.gt_request_acts = gt_request_acts - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] - self.config = self.config_class.from_pretrained(self.model_path, local_files_only=True) # TODO: parameterize - # TODO: update config (parameters) - - # For debugging only - path = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - path = os.path.join(path, 'data/multiwoz/value_dict.json') - self.value_dict = json.load(open(path)) - - self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - - self.global_eval_stats = copy.deepcopy(TEMPLATE_STATE) - for d in self.global_eval_stats: - for s in self.global_eval_stats[d]: - self.global_eval_stats[d][s] = {'TP': 0, 'FP': 0, 'FN': 0} - self.global_diag_cnt = -3 - self.global_turn_cnt = -1 - - self.load_weights() - - def load_weights(self): - self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name, local_files_only=True) # TODO: do_lower_case=args.do_lower_case ? # TODO: parameterize - self.model = self.model_class.from_pretrained(self.model_path, config=self.config, local_files_only=True) # TODO: parameterize - self.model.to(self.device) - self.model.eval() - self.nlu = BERTNLU(model_file=self.nlu_path) # This is used for internal evaluation - self.nlg_usr = TemplateNLG(is_user=True) - self.nlg_sys = TemplateNLG(is_user=False) - - def init_session(self): - self.state = default_state() # Initialise as empty state - self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE) - self.nlg_history = [] - self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - self.gt_belief_state = copy.deepcopy(TEMPLATE_STATE) - self.global_diag_cnt += 1 - self.global_turn_cnt = -1 - - def update_gt_belief_state(self, user_act): - for intent, domain, slot, value in user_act: - if domain == 'police': - continue - if intent == 'inform': - if slot == 'none' or slot == '': - continue - domain_dic = self.gt_belief_state[domain] - if slot in domain_dic: - #nvalue = normalize_value(self.value_dict, domain, slot, value) - self.gt_belief_state[domain][slot] = value # nvalue - #elif slot != 'none' or slot != '': - # raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - # TODO: receive semantic, convert semantic -> text -> semantic for sanity check - # For TripPy: receive semantic, convert semantic -> text (with context) as input to DST - # - allows for accuracy estimates - # - allows isolating inform prediction from request prediction (as can be taken from input for sanity check) - def update(self, user_act=''): - def normalize_values(text): - text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"} - #text = re.sub("^(\d{2}) : (\d{2})$", r"\1:\2", text) # Times - #text = re.sub(" ?' ?s", "s", text) # Genitive - text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters - text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe - if text in text_to_num: - text = text_to_num[text] - return text - - prev_state = self.state - - if not self.no_eval: - print("-" * 40) - - #nlg_history = [] - ##for h in prev_state['history'][-2:]: # TODO: make this an option? - #for h in prev_state['history']: - # nlg_history.append([h[0], self.get_text(h[1], is_user=(h[0]=='user'))]) - ## Special case: at the beginning of the dialog the history might be empty (depending on policy) - #if len(nlg_history) == 0: - # nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False)]) - # nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True)]) - if self.no_history: - self.nlg_history = [] - self.nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False, normalize=True)]) - self.nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True, normalize=True)]) - self.global_turn_cnt += 1 - if not self.no_eval: - self.print_dialog(self.nlg_history) - - # --- Get inform memory and auxiliary features --- - - # If system_action is plain text, get acts using NLU - if isinstance(prev_state['user_action'], str): - u_acts, s_acts = self.get_acts() - elif isinstance(prev_state['user_action'], list): - u_acts = user_act # same as prev_state['user_action'] - s_acts = prev_state['system_action'] - else: - raise Exception('Unknown format for user action:', prev_state['user_action']) - inform_aux, inform_mem = self.get_inform_aux(s_acts) - if not self.no_eval: - self.print_inform_memory(inform_mem) - - # --- Tokenize dialogue context and feed DST model --- - - ##features = self.get_features(self.state['history'], ds_aux=self.ds_aux, inform_aux=inform_aux) - used_ds_aux = None if not self.config.dst_class_aux_feats_ds else self.ds_aux - used_inform_aux = None if not self.config.dst_class_aux_feats_inform else inform_aux - features = self.get_features(self.nlg_history, ds_aux=used_ds_aux, inform_aux=used_inform_aux) - pred_states, pred_classes, cls_representation = self.predict(features, inform_mem) - - # --- Update ConvLab-style dialogue state --- - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - user_acts = [] - for state, value in pred_states.items(): - value = normalize_values(value) - if value == 'none': - continue - domain, slot = state.split('-', 1) - # Value normalization # TODO: according to trippy rules? - if domain == 'hotel' and slot == 'type': - value = "hotel" if value == "yes" else "guesthouse" - if not self.no_normalize_value: - value = normalize_value(self.value_dict, domain, slot, value) - slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot) - if slot in new_belief_state[domain]: - new_belief_state[domain][slot] = value - user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value]) - else: - raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - self.update_gt_belief_state(u_acts) # For evaluation - - # BELIEF STATE UPDATE - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state # TripPy - if self.gt_ds: - new_state['belief_state'] = self.gt_belief_state # Rule - - state_updates = {} - for cl in pred_classes: - cl_d, cl_s = cl.split('-') - # Some reformatting for the evaluation further down - if cl_d not in state_updates: - state_updates[cl_d] = {} - state_updates[cl_d][SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s)] = pred_classes[cl] - # We care only about the requestable slots here - if self.config.dst_class_types[pred_classes[cl]] != 'request': - continue - if cl_d != 'general' and cl_s == 'none': - user_acts.append(['inform', cl_d, '', '']) - elif cl_d == 'general': - user_acts.append([SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), 'general', '', '']) - #user_acts.append(['bye', 'general', '', '']) # Map "thank" to "bye"? Mind "hello" as well! - elif not self.gt_request_acts: - user_acts.append(['request', cl_d, SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), '']) - - # TODO: For debugging -> doesn't make a difference - #for e in user_act: - # nlu_a, nlu_d, nlu_s, nlu_v = e - # nlu_a = nlu_a.lower() - # nlu_d = nlu_d.lower() - # nlu_s = nlu_s.lower() - # nlu_v = nlu_v.lower() - # # Mostly requestables - # if nlu_a == 'inform' and nlu_d == 'train' and nlu_s == 'notbook': - # user_acts.append([nlu_a, nlu_d, 'NotBook', 'none']) - - # TODO: fix # TODO: still needed? - if 0: - domain = '' - is_inform = False - is_request = False - is_notbook = False - for act in user_act: - _, _, slot, _ = act - if slot == "NotBook": - is_notbook = True - for act in user_acts: - intent, domain, slot, value = act - if intent == 'inform': - is_inform = True - if intent == 'request': - is_request = True - if is_inform and not is_request and not is_notbook and domain != '' and domain != "general": - user_acts = [['inform', domain, '', '']] + user_acts - - # USER ACTS UPDATE - new_state['user_action'] = user_acts # TripPy - # ONLY FOR DEBUGGING - if self.gt_user_acts: - new_state['user_action'] = u_acts # Rule - elif self.gt_request_acts: - for e in u_acts: - ea, _, _, _ = e - if ea == 'request': - user_acts.append(e) - - if not self.no_eval: - self.eval_user_acts(u_acts, user_acts) - self.eval_dialog_state(state_updates, new_belief_state) - - #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu? - - self.state = new_state - - # Print eval statistics - if self.state['terminated'] and not self.no_eval: - print("Booked:", self.state['booked']) - self.eval_print_stats() - print("=" * 10, "End of the dialogue", "=" * 10) - #self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - #else: - self.ds_aux = self.update_ds_aux(self.state['belief_state'], pred_states) - #print("ds:", [self.ds_aux[s][0].item() for s in self.ds_aux]) - - return self.state - - def predict(self, features, inform_mem): - #aaa_time = time.time() - with torch.no_grad(): - outputs = self.model(input_ids=features['input_ids'], - input_mask=features['attention_mask'], - inform_slot_id=features['inform_slot_id'], - diag_state=features['diag_state']) - #bbb_time = time.time() - - input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! - - #total_loss = outputs[0] - #per_slot_per_example_loss = outputs[1] - per_slot_class_logits = outputs[2] - per_slot_start_logits = outputs[3] - per_slot_end_logits = outputs[4] - per_slot_refer_logits = outputs[5] - - cls_representation = outputs[6] - - # TODO: maybe add assert to check that batch=1 - - predictions = {slot: 'none' for slot in self.config.dst_slot_list} - class_predictions = {slot: 0 for slot in self.config.dst_slot_list} - - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - start_logits = per_slot_start_logits[slot][0] - end_logits = per_slot_end_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - start_prediction = int(start_logits.argmax()) - end_prediction = int(end_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if class_prediction == self.config.dst_class_types.index('dontcare'): - predictions[slot] = 'dontcare' - elif class_prediction == self.config.dst_class_types.index('copy_value'): - predictions[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) - predictions[slot] = re.sub("(^| )##", "", predictions[slot]) - if "\u0120" in predictions[slot]: - predictions[slot] = re.sub(" ", "", predictions[slot]) - predictions[slot] = re.sub("\u0120", " ", predictions[slot]) - predictions[slot] = predictions[slot].strip() - elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): - predictions[slot] = "yes" # 'true' - elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): - predictions[slot] = "no" # 'false' - elif class_prediction == self.config.dst_class_types.index('inform'): - #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot]) - predictions[slot] = inform_mem[slot] - # Referral case is handled below - - # Referral case. All other slot values need to be seen first in order - # to be able to do this correctly. - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): - # Only slots that have been mentioned before can be referred to. - # First try to resolve a reference within the same turn. (One can think of a situation - # where one slot is referred to in the same utterance. This phenomenon is however - # currently not properly covered in the training data label generation process) - # Then try to resolve a reference given the current dialogue state. - predictions[slot] = predictions[self.config.dst_slot_list[refer_prediction - 1]] - if predictions[slot] == 'none': - referred_slot = self.config.dst_slot_list[refer_prediction - 1] - referred_slot_d, referred_slot_s = referred_slot.split('-') - referred_slot_s = SLOT_MAP_TRIPPY_TO_UDF[referred_slot_d].get(referred_slot_s, referred_slot_s) - if self.state['belief_state'][referred_slot_d][referred_slot_s] != '': - predictions[slot] = self.state['belief_state'][referred_slot_d][referred_slot_s] - if predictions[slot] == 'none': - ref_slot = self.config.dst_slot_list[refer_prediction - 1] - if ref_slot == 'hotel-name': - predictions[slot] = 'the hotel' - elif ref_slot == 'restaurant-name': - predictions[slot] = 'the restaurant' - elif ref_slot == 'attraction-name': - predictions[slot] = 'the attraction' - elif ref_slot == 'hotel-area': - predictions[slot] = 'same area as the hotel' - elif ref_slot == 'restaurant-area': - predictions[slot] = 'same area as the restaurant' - elif ref_slot == 'attraction-area': - predictions[slot] = 'same area as the attraction' - elif ref_slot == 'hotel-pricerange': - predictions[slot] = 'in the same price range as the hotel' - elif ref_slot == 'restaurant-pricerange': - predictions[slot] = 'in the same price range as the restaurant' - - class_predictions[slot] = class_prediction - #if class_prediction > 0: - # print(" ", slot, "->", class_prediction, ",", predictions[slot]) - #ccc_time = time.time() - #print("TIME:", bbb_time - aaa_time, ccc_time - bbb_time) - - return predictions, class_predictions, cls_representation - - def get_features(self, context, ds_aux=None, inform_aux=None): - assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models - input_tokens = ['<s>'] - e_itr = 0 - for e_itr, e in enumerate(reversed(context)): - #input_tokens.append(e[1].lower() if e[1] != 'null' else ' ') # TODO: normalise text - input_tokens.append(e[1] if e[1] != 'null' else ' ') # TODO: normalise text - if e_itr < 2: - input_tokens.append('</s> </s>') - if e_itr == 0: - input_tokens.append('</s> </s>') - input_tokens.append('</s>') - input_tokens = ' '.join(input_tokens) - - # TODO: delex sys utt somehow, or refrain from using delex for sys utts? - features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length) - - input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device) - attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device) - features = {'input_ids': input_ids, - 'attention_mask': attention_mask, - 'inform_slot_id': inform_aux, - 'diag_state': ds_aux} - - return features - - def update_ds_aux(self, state, pred_states, terminated=False): - ds_aux = copy.deepcopy(self.ds_aux) # TODO: deepcopy necessary? just update class variable? - for slot in self.config.dst_slot_list: - d, s = slot.split('-') - if d in state and s in state[d]: - ds_aux[slot][0] = int(state[d][SLOT_MAP_TRIPPY_TO_UDF[d].get(s, s)] != '') - else: - # Requestable slots are not found in the DS - ds_aux[slot][0] = int(pred_states[slot] != 'none') - return ds_aux - - # TODO: consider "booked" values? - def get_inform_aux(self, state): - inform_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - inform_mem = {slot: 'none' for slot in self.config.dst_slot_list} - for e in state: - #print(e) - #pdb.set_trace() - a, d, s, v = e - if a in ['inform', 'recommend', 'select', 'book', 'offerbook']: - #ds_d = d.lower() - #if s in REF_SYS_DA[d]: - # ds_s = REF_SYS_DA[d][s] - #elif s in REF_SYS_DA['Booking']: - # ds_s = "book_" + REF_SYS_DA['Booking'][s] - #else: - # ds_s = s.lower() - # #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d)) - slot = "%s-%s" % (d, s) - if slot in inform_aux: - inform_aux[slot][0] = 1 - inform_mem[slot] = v - return inform_aux, inform_mem - - def get_acts(self): - context = self.state['history'] - if context[-1][0] != 'user': - raise Exception("Wrong order of utterances, check your input.") - system_act = context[-2][-1] - user_act = context[-1][-1] - system_context = [t for s,t in context[:-2]] - user_context = [t for s,t in context[:-1]] - - #print(" SYS:", system_act, system_context) - system_acts = self.nlu.predict(system_act, context=system_context) - - #print(" USR:", user_act, user_context) - user_acts = self.nlu.predict(user_act, context=user_context) - - return user_acts, system_acts - - def get_text(self, act, is_user=False, normalize=False): - if act == 'null': - return 'null' - if not isinstance(act, list): - result = act - elif is_user: - result = self.nlg_usr.generate(act) - else: - result = self.nlg_sys.generate(act) - if normalize: - return self.normalize_text(result) - else: - return result - - def normalize_text(self, text): - norm_text = text.lower() - #norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference - #norm_text = re.sub("ca not", "cannot", norm_text) - norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0]) - return norm_text - - -# if __name__ == "__main__": -# tracker = TRIPPY(model_type='roberta', model_path='/path/to/model', -# nlu_path='/path/to/nlu') -# tracker.init_session() -# state = tracker.update('hey. I need a cheap restaurant.') -# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# state = tracker.update('If you have something Asian that would be great.') -# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# state = tracker.update('Great. Where are they located?') -# tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# print(tracker.state) diff --git a/convlab/dst/trippy/multiwoz/trippy.py~ b/convlab/dst/trippy/multiwoz/trippy.py~ deleted file mode 100644 index fd1317036efafbf9c0b9ed21ba1aa7ea15955934..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/trippy.py~ +++ /dev/null @@ -1,547 +0,0 @@ -# Copyright 2021 Heinrich Heine University Duesseldorf -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import copy - -import torch -from transformers import (BertConfig, BertTokenizer, - RobertaConfig, RobertaTokenizer) - -from convlab.dst.trippy.multiwoz.modeling_bert_dst import (BertForDST) -from convlab.dst.trippy.multiwoz.modeling_roberta_dst import (RobertaForDST) - -from convlab.dst.dst import DST -from convlab.util.multiwoz.state import default_state -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from convlab.nlu.jointBERT.multiwoz import BERTNLU -from convlab.nlg.template.multiwoz import TemplateNLG - -from convlab.util import relative_import_module_from_unified_datasets -ONTOLOGY = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'ontology') -TEMPLATE_STATE = ONTOLOGY['state'] - -# For debugging only -from convlab.dst.rule.multiwoz.dst_util import normalize_value -import os -import json - -import pdb - - -MODEL_CLASSES = { - 'bert': (BertConfig, BertForDST, BertTokenizer), - 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), -} - - -SLOT_MAP_TRIPPY_TO_UDF = { - 'hotel': { - 'pricerange': 'price range', - 'book_stay': 'book stay', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode', - 'price': 'price range', - 'people': 'book people' - }, - 'restaurant': { - 'pricerange': 'price range', - 'book_time': 'book time', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode', - 'price': 'price range', - 'people': 'book people' - }, - 'taxi': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'car': 'type', - 'car type': 'type', - 'depart': 'departure', - 'dest': 'destination' - }, - 'train': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'book_people': 'book people', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'depart': 'departure', - 'dest': 'destination', - 'id': 'train id', - 'people': 'book people', - 'time': 'duration', - 'ticket': 'price', - 'trainid': 'train id' - }, - 'attraction': { - 'post': 'postcode', - 'addr': 'address', - 'fee': 'entrance fee', - 'price': 'entrance fee' - }, - 'general': {}, - 'hospital': { - 'post': 'postcode', - 'addr': 'address' - }, - 'police': { - 'post': 'postcode', - 'addr': 'address' - } -} - - -class TRIPPY(DST): - def print_header(self): - print(" _________ ________ ___ ________ ________ ___ ___ ") - print("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /|") - print("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /") - print(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / / ") - print(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / / ") - print(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / ") - print(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / ") - print(" (c) 2022 Heinrich Heine University \|___|/ ") - print() - - def print_dialog(self, hst): - print("Dialogue turn %s:" % (int(len(hst) / 2) - 1)) - for utt in hst[:-2]: - print(" \033[92m%s\033[0m" % (utt)) - if len(hst) > 1: - print(" ", hst[-2]) - print(" ", hst[-1]) - - def print_inform_memory(self, inform_mem): - print("Inform memory:") - is_all_none = True - for s in inform_mem: - if inform_mem[s] != 'none': - print(" %s = %s" % (s, inform_mem[s])) - is_all_none = False - if is_all_none: - print(" -") - - def eval_user_acts(self, user_act, user_acts): - print("User acts:") - for ua in user_acts: - if ua not in user_act: - print(" \033[33m%s\033[0m" % (ua)) - else: - print(" \033[92m%s\033[0m" % (ua)) - for ua in user_act: - if ua not in user_acts: - print(" \033[91m%s\033[0m" % (ua)) - - def eval_dialog_state(self, state_updates, new_belief_state): - print("Dialogue state:") - for d in self.gt_belief_state: - print(" %s:" % (d)) - for s in new_belief_state[d]: - is_printed = False - is_updated = False - if state_updates[d][s] > 0: - is_updated = True - if is_updated: - print("\033[3m", end='') - if new_belief_state[d][s] != self.gt_belief_state[d][s]: - if self.gt_belief_state[d][s] == '': - print(" \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='') - else: - print(" \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]), end='') - is_printed = True - elif new_belief_state[d][s] != '': - print(" \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='') - is_printed = True - if is_updated: - print(" (%s)" % (self.config.dst_class_types[state_updates[d][s]])) - elif is_printed: - print() - - def __init__(self, model_type="roberta", model_name="roberta-base", model_path="", nlu_path=""): - super(TRIPPY, self).__init__() - - self.print_header() - - self.model_type = model_type.lower() - self.model_name = model_name.lower() - self.model_path = model_path - self.nlu_path = nlu_path - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] - self.config = self.config_class.from_pretrained(self.model_path) - # TODO: update config (parameters) - - # For debugging only - path = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - path = os.path.join(path, 'data/multiwoz/value_dict.json') - self.value_dict = json.load(open(path)) - - self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - - self.load_weights() - - def load_weights(self): - self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name) # TODO: do_lower_case=args.do_lower_case ? - self.model = self.model_class.from_pretrained(self.model_path, config=self.config) - self.model.to(self.device) - self.model.eval() - self.nlu = BERTNLU(model_file=self.nlu_path) # TODO: remove, once TripPy takes over its task - self.nlg_usr = TemplateNLG(is_user=True) - self.nlg_sys = TemplateNLG(is_user=False) - - def init_session(self): - self.state = default_state() # Initialise as empty state - self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE) - self.gt_belief_state = copy.deepcopy(TEMPLATE_STATE) - - def update_gt_belief_state(self, user_act): - print(user_act) - for intent, domain, slot, value in user_act: - if domain == 'police': - continue - if intent == 'inform': - if slot == 'none' or slot == '': - continue - domain_dic = self.gt_belief_state[domain] - if slot in domain_dic: - #nvalue = normalize_value(self.value_dict, domain, slot, value) - self.gt_belief_state[domain][slot] = value # nvalue - #elif slot != 'none' or slot != '': - # raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - # TODO: receive semantic, convert semantic -> text -> semantic for sanity check - # For TripPy: receive semantic, convert semantic -> text (with context) as input to DST - # - allows for accuracy estimates - # - allows isolating inform prediction from request prediction (as can be taken from input for sanity check) - def update(self, user_act=''): - prev_state = self.state - - user_act2 = self.get_acts(user_act) - - print("-" * 40) - - nlg_history = [] - #for h in prev_state['history'][-2:]: # TODO: make this an option? - for h in prev_state['history']: - nlg_history.append([h[0], self.get_text(h[1], is_user=(h[0]=='user'))]) - self.print_dialog(nlg_history) - - # --- Get inform memory and auxiliary features --- - - # If system_action is plain text, get acts using NLU - if isinstance(prev_state['system_action'], str): - acts, _ = self.get_acts(prev_state['system_action']) - elif isinstance(prev_state['system_action'], list): - acts = prev_state['system_action'] - else: - raise Exception('Unknown format for system action:', prev_state['system_action']) - inform_aux, inform_mem = self.get_inform_aux(acts) - self.print_inform_memory(inform_mem) - - # --- Tokenize dialogue context and feed DST model --- - - ##features = self.get_features(self.state['history'], ds_aux=self.ds_aux, inform_aux=inform_aux) - used_ds_aux = None if not self.config.dst_class_aux_feats_ds else self.ds_aux - used_inform_aux = None if not self.config.dst_class_aux_feats_inform else inform_aux - features = self.get_features(nlg_history, ds_aux=used_ds_aux, inform_aux=used_inform_aux) - pred_states, class_preds, cls_representation = self.predict(features, inform_mem) - - # --- Update ConvLab-style dialogue state --- - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - user_acts = [] - for state, value in pred_states.items(): - if value == 'none': - continue - domain, slot = state.split('-', 1) - # TODO: value normalizations? - if domain == 'hotel' and slot == 'type': - value = "hotel" if value == "yes" else "guesthouse" - orig_slot = slot - slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot) - if slot in new_belief_state[domain]: - new_belief_state[domain][slot] = value # TODO: value normalization? - user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value]) # TODO: value normalization? - else: - raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - self.update_gt_belief_state(user_act2) # For evaluation - - # BELIEF STATE UPDATE - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state # TripPy - #new_state['belief_state'] = self.gt_belief_state # Rule - - state_updates = {} - for cl in class_preds: - cl_d, cl_s = cl.split('-') - # Some reformatting for the evaluation further down - if cl_d not in state_updates: - state_updates[cl_d] = {} - state_updates[cl_d][SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s)] = class_preds[cl] - # We care only about the requestable slots here - if self.config.dst_class_types[class_preds[cl]] != 'request': - continue - if cl_d != 'general' and cl_s == 'none': - user_acts.append(['inform', cl_d, '', '']) - elif cl_d == 'general': - user_acts.append([SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), 'general', '', '']) - #user_acts.append(['bye', 'general', '', '']) # Map "thank" to "bye"? Mind "hello" as well! - else: - user_acts.append(['request', cl_d, SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), '']) - - # TODO: For debugging -> doesn't make a difference - #for e in user_act: - # nlu_a, nlu_d, nlu_s, nlu_v = e - # nlu_a = nlu_a.lower() - # nlu_d = nlu_d.lower() - # nlu_s = nlu_s.lower() - # nlu_v = nlu_v.lower() - # # Mostly requestables - # if nlu_a == 'inform' and nlu_d == 'train' and nlu_s == 'notbook': - # user_acts.append([nlu_a, nlu_d, 'NotBook', 'none']) - - # TODO: fix # TODO: still needed? - if 0: - domain = '' - is_inform = False - is_request = False - is_notbook = False - for act in user_act: - _, _, slot, _ = act - if slot == "NotBook": - is_notbook = True - for act in user_acts: - intent, domain, slot, value = act - if intent == 'inform': - is_inform = True - if intent == 'request': - is_request = True - if is_inform and not is_request and not is_notbook and domain != '' and domain != "general": - user_acts = [['inform', domain, '', '']] + user_acts - - # USER ACTS UPDATE - new_state['user_action'] = user_acts # TripPy - #new_state['user_action'] = user_act # Rule - - self.eval_user_acts(user_act2, user_acts) - self.eval_dialog_state(state_updates, new_belief_state) - - #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu? - - self.state = new_state - - # (Re)set internal states - if self.state['terminated']: - print("=" * 15, "End of dialog", "=" * 15) - self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - else: - self.ds_aux = self.update_ds_aux(self.state['belief_state'], pred_states) - #print("ds:", [self.ds_aux[s][0].item() for s in self.ds_aux]) - - return self.state - - def predict(self, features, inform_mem): - with torch.no_grad(): - outputs = self.model(input_ids=features['input_ids'], - input_mask=features['attention_mask'], - inform_slot_id=features['inform_slot_id'], - diag_state=features['diag_state']) - - input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! - - #total_loss = outputs[0] - #per_slot_per_example_loss = outputs[1] - per_slot_class_logits = outputs[2] - per_slot_start_logits = outputs[3] - per_slot_end_logits = outputs[4] - per_slot_refer_logits = outputs[5] - - cls_representation = outputs[6] - - # TODO: maybe add assert to check that batch=1 - - predictions = {slot: 'none' for slot in self.config.dst_slot_list} - class_predictions = {slot: 0 for slot in self.config.dst_slot_list} - - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - start_logits = per_slot_start_logits[slot][0] - end_logits = per_slot_end_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - start_prediction = int(start_logits.argmax()) - end_prediction = int(end_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if class_prediction == self.config.dst_class_types.index('dontcare'): - predictions[slot] = 'dontcare' - elif class_prediction == self.config.dst_class_types.index('copy_value'): - predictions[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) - predictions[slot] = re.sub("(^| )##", "", predictions[slot]) - if "\u0120" in predictions[slot]: - predictions[slot] = re.sub(" ", "", predictions[slot]) - predictions[slot] = re.sub("\u0120", " ", predictions[slot]) - predictions[slot] = predictions[slot].strip() - elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): - predictions[slot] = "yes" # 'true' - elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): - predictions[slot] = "no" # 'false' - elif class_prediction == self.config.dst_class_types.index('inform'): - #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot]) - predictions[slot] = inform_mem[slot] - # Referral case is handled below - - # Referral case. All other slot values need to be seen first in order - # to be able to do this correctly. - # TODO: right now, resolution is only attempted within single turn. Consider previous state instead! - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): - # Only slots that have been mentioned before can be referred to. - # One can think of a situation where one slot is referred to in the same utterance. - # This phenomenon is however currently not properly covered in the training data - # label generation process. - predictions[slot] = predictions[self.config.dst_slot_list[refer_prediction - 1]] - - class_predictions[slot] = class_prediction - #if class_prediction > 0: - # print(" ", slot, "->", class_prediction, ",", predictions[slot]) - - return predictions, class_predictions, cls_representation - - def get_features(self, context, ds_aux=None, inform_aux=None): - assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models - input_tokens = ['<s>'] - e_itr = 0 - for e_itr, e in enumerate(reversed(context)): - #input_tokens.append(e[1].lower() if e[1] != 'null' else ' ') # TODO: normalise text - input_tokens.append(e[1] if e[1] != 'null' else ' ') # TODO: normalise text - if e_itr < 2: - input_tokens.append('</s> </s>') - if e_itr == 0: - input_tokens.append('</s> </s>') - input_tokens.append('</s>') - input_tokens = ' '.join(input_tokens) - - # TODO: delex sys utt somehow, or refrain from using delex for sys utts? - features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length) - - input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device) - attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device) - features = {'input_ids': input_ids, - 'attention_mask': attention_mask, - 'inform_slot_id': inform_aux, - 'diag_state': ds_aux} - - return features - - def update_ds_aux(self, state, pred_states, terminated=False): - ds_aux = copy.deepcopy(self.ds_aux) # TODO: deepcopy necessary? just update class variable? - for slot in self.config.dst_slot_list: - d, s = slot.split('-') - if d in state and s in state[d]: - ds_aux[slot][0] = int(state[d][SLOT_MAP_TRIPPY_TO_UDF[d].get(s, s)] != '') - else: - # Requestable slots are not found in the DS - ds_aux[slot][0] = int(pred_states[slot] != 'none') - return ds_aux - - # TODO: consider "booked" values? - def get_inform_aux(self, state): - inform_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} - inform_mem = {slot: 'none' for slot in self.config.dst_slot_list} - for e in state: - #print(e) - #pdb.set_trace() - a, d, s, v = e - if a in ['inform', 'recommend', 'select', 'book', 'offerbook']: - #ds_d = d.lower() - #if s in REF_SYS_DA[d]: - # ds_s = REF_SYS_DA[d][s] - #elif s in REF_SYS_DA['Booking']: - # ds_s = "book_" + REF_SYS_DA['Booking'][s] - #else: - # ds_s = s.lower() - # #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d)) - slot = "%s-%s" % (d, s) - if slot in inform_aux: - inform_aux[slot][0] = 1 - inform_mem[slot] = v - return inform_aux, inform_mem - - # TODO: fix, still a mess... - def get_acts(self, user_act): - context = self.state['history'] - if context: - if context[-1][0] != 'sys': - system_act = '' - context = [t for s,t in context] - else: - system_act = context[-1][-1] - context = [t for s,t in context[:-1]] - else: - system_act = '' - context = [''] - - #print(" SYS:", system_act, context) - system_acts = [] # self.nlu.predict(system_act, context=context) - context = [''] - - context.append(system_act) - #print(" USR:", user_act, context) - user_acts = self.nlu.predict(user_act, context=context) - - return user_acts, system_acts - - def get_text(self, user_act, is_user=False): - if user_act == 'null': - return 'null' - if not isinstance(user_act, list): - return user_act - if is_user: - return self.nlg_usr.generate(user_act) - else: - return self.nlg_sys.generate(user_act) - - -# if __name__ == "__main__": -# tracker = TRIPPY(model_type='roberta', model_path='/path/to/model', -# nlu_path='/path/to/nlu') -# tracker.init_session() -# state = tracker.update('hey. I need a cheap restaurant.') -# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# state = tracker.update('If you have something Asian that would be great.') -# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# state = tracker.update('Great. Where are they located?') -# tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# print(tracker.state) diff --git a/convlab/dst/trippy/multiwoz/trippypublic b/convlab/dst/trippy/multiwoz/trippypublic deleted file mode 160000 index a4e3f675b1f075a697e774833aec72f8148000b0..0000000000000000000000000000000000000000 --- a/convlab/dst/trippy/multiwoz/trippypublic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a4e3f675b1f075a697e774833aec72f8148000b0 diff --git a/convlab/dst/trippyr/__init__.py b/convlab/dst/trippyr/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/convlab/dst/trippyr/multiwoz/__init__.py b/convlab/dst/trippyr/multiwoz/__init__.py deleted file mode 100644 index b2b9444a46ae879cbd3670b067e3712a4e34c43b..0000000000000000000000000000000000000000 --- a/convlab/dst/trippyr/multiwoz/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from convlab.dst.trippyr.multiwoz.trippyr import TRIPPYR diff --git a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py b/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py deleted file mode 100644 index 36e1a1970b0a46ffd26cf5d2a4752f3178d0c748..0000000000000000000000000000000000000000 --- a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py +++ /dev/null @@ -1,825 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss -from torch.nn import TripletMarginLoss -from torch.nn import PairwiseDistance -from torch.nn import MultiheadAttention -import torch.nn.functional as F - -from transformers import (RobertaModel, RobertaConfig, RobertaPreTrainedModel) - -import time - - -class RobertaForDST(RobertaPreTrainedModel): - def __init__(self, config): - super(RobertaForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.noncategorical = config.dst_noncategorical - self.categorical = config.dst_categorical - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.value_loss_for_nonpointable = config.dst_value_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.stack_token_logits = False # config.dst_stack_token_logits # TODO - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - self.slot_attention_heads = config.dst_slot_attention_heads - self.dropout_rate = config.dst_dropout_rate - self.heads_dropout_rate = config.dst_heads_dropout_rate - - self.debug_fix_slot_embs = config.debug_fix_slot_embs - self.debug_joint_slot_gate = config.debug_joint_slot_gate - self.debug_joint_refer_gate = config.debug_joint_refer_gate - self.debug_simple_joint_slot_gate = config.debug_simple_joint_slot_gate - self.debug_separate_seq_tagging = config.debug_separate_seq_tagging - self.debug_sigmoid_sequence_tagging = config.debug_sigmoid_sequence_tagging - self.debug_att_output = config.debug_att_output - self.debug_tanh_for_att_output = config.debug_tanh_for_att_output - self.debug_stack_att = config.debug_stack_att - self.debug_stack_rep = config.debug_stack_rep - self.debug_sigmoid_slot_gates = config.debug_sigmoid_slot_gates - self.debug_att_slot_gates = config.debug_att_slot_gates - self.debug_use_triplet_loss = config.debug_use_triplet_loss - self.debug_use_tlf = config.debug_use_tlf - self.debug_value_att_none_class = config.debug_value_att_none_class - self.debug_tag_none_target = config.debug_tag_none_target - self.triplet_loss_weight = config.triplet_loss_weight - self.none_weight = config.none_weight - self.pretrain_loss_function = config.pretrain_loss_function - self.token_loss_function = config.token_loss_function - self.value_loss_function = config.value_loss_function - self.class_loss_function = config.class_loss_function - self.sequence_tagging_dropout = -1 - if config.sequence_tagging_dropout > 0.0: - self.sequence_tagging_dropout = int(1 / config.sequence_tagging_dropout) - self.ignore_o_tags = config.ignore_o_tags - self.debug_slot_embs_per_step_nograd = config.debug_slot_embs_per_step_nograd - try: - self.debug_use_cross_attention = config.debug_use_cross_attention - except: - self.debug_use_cross_attention = False - - self.val_rep_mode = config.val_rep_mode # TODO: for debugging at the moment. - - config.output_hidden_states = True # TODO - - #config.dst_dropout_rate = 0.0 # TODO: for debugging - #config.dst_heads_dropout_rate = 0.0 # TODO: for debugging - - self.roberta = RobertaModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - # -- Dialogue state tracking functionality -- - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - # Slot specific gates - for slot in self.slot_list: - if not self.debug_joint_slot_gate: - if self.debug_stack_att: - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size * 2 + aux_dims, 1)) - elif self.debug_att_slot_gates: - self.add_module("class_" + slot, MultiheadAttention(config.hidden_size * 2 + aux_dims, self.slot_attention_heads)) - else: - self.add_module("class_" + slot, nn.Linear(config.hidden_size * 2 + aux_dims, self.class_labels)) - else: - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size + aux_dims, 1)) - elif self.debug_att_slot_gates: - self.add_module("class_" + slot, MultiheadAttention(config.hidden_size + aux_dims, self.slot_attention_heads)) - else: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - #self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1)) - if not self.debug_joint_refer_gate: - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - if self.debug_separate_seq_tagging: - if self.debug_sigmoid_sequence_tagging: - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1)) - else: - self.add_module("token_" + slot, MultiheadAttention(config.hidden_size, self.slot_attention_heads)) - - if self.debug_att_output: - self.class_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - # Conditioned sequence tagging - if not self.debug_separate_seq_tagging: - if self.debug_sigmoid_sequence_tagging: - self.h1t = nn.Linear(config.hidden_size, config.hidden_size) - self.h2t = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) - self.llt = nn.Linear(config.hidden_size * 2, 1) - else: - self.token_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - if self.debug_joint_refer_gate: - self.refer_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - self.value_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - #self.slot_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - - self.token_layer_norm = nn.LayerNorm(config.hidden_size) - self.token_layer_norm2 = nn.LayerNorm(config.hidden_size) - self.class_layer_norm = nn.LayerNorm(config.hidden_size) - self.class_layer_norm2 = nn.LayerNorm(config.hidden_size) - - self.tlinear = nn.Linear(config.hidden_size, config.hidden_size) - self.clinear = nn.Linear(config.hidden_size, config.hidden_size) - - # Conditioned slot gate - self.h1c = nn.Linear(config.hidden_size + aux_dims, config.hidden_size) - if self.debug_stack_att: - self.h0c = nn.Linear(config.hidden_size, config.hidden_size) - self.h2c = nn.Linear(config.hidden_size * 3, config.hidden_size * 3) - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 3, 1)) - elif self.debug_att_slot_gates: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads) - else: - self.llc = MultiheadAttention(config.hidden_size * 3, self.slot_attention_heads) - else: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = nn.Linear(config.hidden_size * 2, self.class_labels) - else: - self.llc = nn.Linear(config.hidden_size * 3, self.class_labels) - else: - if self.debug_att_slot_gates: - self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 1) - else: - self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 2, 1)) - elif self.debug_att_slot_gates: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = MultiheadAttention(config.hidden_size * 1, self.slot_attention_heads) - else: - self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads) - else: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = nn.Linear(config.hidden_size * 1, self.class_labels) - else: - self.llc = nn.Linear(config.hidden_size * 2, self.class_labels) - - # Conditioned refer gate - self.h2r = nn.Linear(config.hidden_size * 2, config.hidden_size * 1) - - # -- Spanless sequence tagging functionality -- - - self.dis = PairwiseDistance(p=2) - self.sigmoid = nn.Sigmoid() - self.tanh = nn.Tanh() - self.relu = nn.ReLU() - self.gelu = nn.GELU() - self.binary_cross_entropy = F.binary_cross_entropy - self.triplet_loss = TripletMarginLoss(margin=1.0, reduction="none") - self.mse = nn.MSELoss(reduction="none") - self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target - self.value_loss_fct = CrossEntropyLoss(reduction='none') - - if self.none_weight != 1.0: - none_weight = self.none_weight - weight_mass = none_weight + (self.class_labels - 1) - none_weight /= weight_mass - other_weights = 1 / weight_mass - #self.clweights = torch.tensor([none_weight] + [other_weights] * (self.class_labels - 1)) # .to(outputs[0].device) - self.clweights = torch.tensor([other_weights] * self.class_labels) # .to(outputs[0].device) - self.clweights[self.class_types.index('none')] = none_weight - if self.debug_sigmoid_slot_gates: - self.class_loss_fct = F.binary_cross_entropy # (reduction="none") # weights for classes are controlled when adding losses - else: - self.class_loss_fct = CrossEntropyLoss(weight=self.clweights, reduction='none') - else: - if self.debug_sigmoid_slot_gates: - self.class_loss_fct = F.binary_cross_entropy # (reduction="none") - else: - self.class_loss_fct = CrossEntropyLoss(reduction='none') - - self.init_weights() - - def compute_triplet_loss_att(self, att_output, pos_sampling_input, neg_sampling_input, slot): - loss = self.triplet_loss(att_output, pos_sampling_input[slot].squeeze(), neg_sampling_input[slot].squeeze()) - sample_mask = neg_sampling_input[slot].squeeze(1).sum(1) != 0 - loss *= sample_mask - return loss - - def forward(self, batch, step=None, epoch=None, t_slot=None, mode=None): - assert(mode in [None, "pretrain", "tag", "encode", "encode_vals", "represent"]) # TODO - - input_ids = batch['input_ids'] - input_mask = batch['input_mask'] - #segment_ids = batch['segment_ids'] - usr_mask = batch['usr_mask'] - start_pos = batch['start_pos'] - end_pos = batch['end_pos'] - refer_id = batch['refer_id'] - class_label_id = batch['class_label_id'] - inform_slot_id = batch['inform_slot_id'] - diag_state = batch['diag_state'] - slot_ids = batch['slot_ids'] if 'slot_ids' in batch else None # TODO: fix? - slot_mask = batch['slot_mask'] if 'slot_mask' in batch else None # TODO: fix? - cl_ids = batch['cl_ids'] if 'cl_ids' in batch else None # TODO: fix? - cl_mask = batch['cl_mask'] if 'cl_mask' in batch else None # TODO: fix? - pos_sampling_input = batch['pos_sampling_input'] - neg_sampling_input = batch['neg_sampling_input'] - value_labels = batch['value_labels'] if 'value_labels' in batch else None # TODO: fix? - - batch_input_mask = input_mask - if slot_ids is not None and slot_mask is not None: - if self.debug_slot_embs_per_step_nograd: - with torch.no_grad(): - outputs_slot = self.roberta( - slot_ids, - attention_mask=slot_mask, - token_type_ids=None, # segment_ids, - position_ids=None, - head_mask=None - ) - else: - input_ids = torch.cat((input_ids, slot_ids)) - input_mask = torch.cat((input_mask, slot_mask)) - if cl_ids is not None and cl_mask is not None: - input_ids = torch.cat((input_ids, cl_ids)) - input_mask = torch.cat((input_mask, cl_mask)) - - outputs = self.roberta( - input_ids, - attention_mask=input_mask, - token_type_ids=None, # segment_ids, - position_ids=None, - head_mask=None - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - if cl_ids is not None and cl_mask is not None: - sequence_output = sequence_output[:-1 * len(cl_ids), :, :] - encoded_classes = pooled_output[-1 * len(cl_ids):, :] - pooled_output = pooled_output[:-1 * len(cl_ids), :] - if slot_ids is not None and slot_mask is not None: - if self.debug_slot_embs_per_step_nograd: - encoded_slots_seq = outputs_slot[0] - encoded_slots_pooled = outputs_slot[1] - else: - encoded_slots_seq = sequence_output[-1 * len(slot_ids):, :, :] - sequence_output = sequence_output[:-1 * len(slot_ids), :, :] - encoded_slots_pooled = pooled_output[-1 * len(slot_ids):, :] - pooled_output = pooled_output[:-1 * len(slot_ids), :] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - #sequence_output11 = outputs[2][-2] - - inverted_input_mask = ~(batch_input_mask.bool()) - if usr_mask is None: - usr_mask = input_mask - inverted_usr_mask = ~(usr_mask.bool()) - - if mode == "encode": # Create vector representations only - return pooled_output, sequence_output, None - - if mode == "encode_vals": # Create vector representations only - no_seq_w = 1 / batch_input_mask.sum(1) - uniform_weights = batch_input_mask * no_seq_w.unsqueeze(1) - pooled_output_vals = torch.matmul(uniform_weights, sequence_output).squeeze(1) - pooled_output_vals = self.token_layer_norm(pooled_output_vals) - return pooled_output_vals, None, None - - if mode == "pretrain": - pos_vectors = {} - pos_weights = {} - pos_vectors, pos_weights = self.token_att( - query=encoded_slots_pooled.squeeze(1).unsqueeze(0), - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask, - need_weights=True) - pos_vectors = pos_vectors.squeeze(0) - pos_vectors = self.token_layer_norm(pos_vectors) - pos_weights = pos_weights.squeeze(1) - - #neg_vectors = {} - neg_weights = {} - #neg_vectors[slot], neg_weights[slot] = self.token_att( - # query=neg_sampling_input[slot].squeeze(1).unsqueeze(0), - # key=sequence_output.transpose(0, 1), - # value=sequence_output.transpose(0, 1), - # key_padding_mask=inverted_input_mask, - # need_weights=True) - #neg_vectors[slot] = neg_vectors[slot].squeeze(0) - #neg_vectors[slot] = self.token_layer_norm(neg_vectors[slot]) - #neg_weights[slot] = neg_weights[slot].squeeze(1) - - pos_labels_clipped = torch.clamp(start_pos.float(), min=0, max=1) - pos_labels_clipped_scaled = pos_labels_clipped / torch.clamp(pos_labels_clipped.sum(1).unsqueeze(1), min=1) # scaled - #no_seq_w = 1 / batch_input_mask.sum(1) - #neg_labels_clipped = batch_input_mask * no_seq_w.unsqueeze(1) - if self.pretrain_loss_function == "mse": - pos_token_loss = self.mse(pos_weights, pos_labels_clipped_scaled) # TODO: MSE might be better for scaled targets - else: - pos_token_loss = self.binary_cross_entropy(pos_weights, pos_labels_clipped_scaled, reduction="none") - if self.ignore_o_tags: - pos_token_loss *= pos_labels_clipped # TODO: good idea? - #mm = torch.clamp(start_pos[slot].float(), min=0, max=1) - #mm = torch.clamp(mm * 10, min=1) - #pos_token_loss *= mm - pos_token_loss = pos_token_loss.sum(1) - #neg_token_loss = self.mse(neg_weights[slot], neg_labels_clipped) - #neg_token_loss = neg_token_loss.sum(1) - - #triplet_loss = self.compute_triplet_loss_att(pos_vectors[slot], pos_sampling_input, neg_sampling_input, slot) - - per_example_loss = pos_token_loss # + triplet_loss # + neg_token_loss - total_loss = per_example_loss.sum() - - return (total_loss, pos_weights, neg_weights,) - - if mode == "tag": - query = torch.stack(list(batch['value_reps'].values())).transpose(1, 2).reshape(-1, pooled_output.size()[0], pooled_output.size()[1]) - _, weights = self.token_att(query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask + inverted_usr_mask, - need_weights=True) - return (weights,) - - aaa_time = time.time() - # Attention for sequence tagging - vectors = {} - weights = {} - for s_itr, slot in enumerate(self.slot_list): - if slot_ids is not None and slot_mask is not None: - encoded_slot_seq = encoded_slots_seq[s_itr] - encoded_slot_pooled = encoded_slots_pooled[s_itr] - else: - encoded_slot_seq = batch['encoded_slots_seq'][slot] - encoded_slot_pooled = batch['encoded_slots_pooled'][slot] - if self.debug_separate_seq_tagging: - if self.debug_sigmoid_sequence_tagging: - weights[slot] = self.sigmoid(getattr(self, "token_" + slot)(sequence_output)).squeeze(2) - vectors[slot] = torch.matmul(weights[slot], sequence_output) - vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1) - vectors[slot] = self.token_layer_norm(vectors[slot]) - else: - if self.debug_use_cross_attention: - query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1) - else: - query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) - vectors[slot], weights[slot] = getattr(self, "token_" + slot)( - query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask + inverted_usr_mask, - need_weights=True) # TODO: use usr_mask better or worse? - vectors[slot] = vectors[slot].squeeze(0) - #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize? - if self.debug_tanh_for_att_output: - #vectors[slot] = self.tanh(vectors[slot]) - vectors[slot] = self.token_layer_norm2(vectors[slot]) - vectors[slot] = self.tanh(self.tlinear(vectors[slot])) - #vectors[slot] += pooled_output # Residual - vectors[slot] = self.token_layer_norm(vectors[slot]) - #vectors[slot] = self.dropout_heads(vectors[slot]) - weights[slot] = weights[slot].squeeze(1) - if self.debug_stack_rep: - vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1) - else: - if self.debug_sigmoid_sequence_tagging: - # Conditioned sequence tagging - sequence_output_add = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(1).expand(sequence_output.size()) - xxt = self.gelu(self.h1t(sequence_output)) - yyt = self.gelu(self.h2t(torch.cat((sequence_output_add, xxt), 2))) - weights[slot] = self.sigmoid(self.llt(yyt)).squeeze(2) - vectors[slot] = torch.matmul(weights[slot], sequence_output) - vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1) - vectors[slot] = self.token_layer_norm(vectors[slot]) - else: - if self.debug_use_cross_attention: - query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1) - else: - query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) - vectors[slot], weights[slot] = self.token_att( - query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask + inverted_usr_mask, - need_weights=True) # TODO: use usr_mask better or worse? - vectors[slot] = vectors[slot].squeeze(0) - if self.debug_use_cross_attention: - vectors[slot] = torch.mean(vectors[slot] * (batch_input_mask + usr_mask).transpose(0, 1).unsqueeze(-1), dim=0) - #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize? - if self.debug_tanh_for_att_output: - #vectors[slot] = self.tanh(vectors[slot]) - vectors[slot] = self.token_layer_norm2(vectors[slot]) - vectors[slot] = self.tanh(self.tlinear(vectors[slot])) - #vectors[slot] += pooled_output # Residual - vectors[slot] = self.token_layer_norm(vectors[slot]) - #vectors[slot] = self.dropout_heads(vectors[slot]) - if self.debug_use_cross_attention: - weights[slot] = torch.mean(weights[slot] * (batch_input_mask + usr_mask).unsqueeze(-1), dim=1) - weights[slot] = weights[slot].squeeze(1) - if self.debug_stack_rep: - vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1) - - if mode == "represent": # Create vector representations only - return vectors, None, weights - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - bbb_time = time.time() - total_loss = 0 - total_cl_loss = 0 - total_tk_loss = 0 - total_tp_loss = 0 - per_slot_per_example_loss = {} - per_slot_per_example_cl_loss = {} - per_slot_per_example_tk_loss = {} - per_slot_per_example_tp_loss = {} - per_slot_att_weights = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_value_logits = {} - per_slot_refer_logits = {} - for s_itr, slot in enumerate(self.slot_list): - #if t_slot is not None and slot != t_slot: - # continue - if slot_ids is not None and slot_mask is not None: - encoded_slot_seq = encoded_slots_seq[s_itr] - encoded_slot_pooled = encoded_slots_pooled[s_itr] - else: - encoded_slot_seq = batch['encoded_slots_seq'][slot] - encoded_slot_pooled = batch['encoded_slots_pooled'][slot] - - # Attention for slot gates - if self.debug_att_output: - if self.debug_use_cross_attention: - query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1) - else: - query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) - att_output, c_weights = self.class_att( - query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask, - need_weights=True) - if self.debug_use_cross_attention: - att_output = torch.mean(att_output, dim=0) - c_weights = torch.mean(c_weights, dim=1) - if self.debug_tanh_for_att_output: - #att_output = self.tanh(att_output) - att_output = self.class_layer_norm2(att_output) - att_output = self.tanh(self.clinear(att_output)) - att_output = self.class_layer_norm(att_output) - att_output = self.dropout_heads(att_output) - per_slot_att_weights[slot] = c_weights.squeeze(1) - else: - per_slot_att_weights[slot] = None - - # Conditioned slot gate, or separate slot gates - if self.debug_joint_slot_gate: - if self.debug_att_output: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1))) - elif self.class_aux_feats_inform: - xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels)), 1))) - elif self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1))) - else: - xx = self.gelu(self.h1c(att_output.squeeze(0))) - if self.debug_stack_att: - x0 = self.gelu(self.h0c(pooled_output)) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1))) - elif self.class_aux_feats_inform: - xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels)), 1))) - elif self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1))) - else: - xx = self.gelu(self.h1c(pooled_output)) - if self.debug_att_output and not self.debug_simple_joint_slot_gate and self.debug_stack_att: - yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), x0, xx), 1))) - elif self.debug_att_output and self.debug_simple_joint_slot_gate and self.debug_stack_att: - yy = torch.cat((pooled_output, att_output.squeeze(0)), 1) - elif self.debug_att_output and self.debug_simple_joint_slot_gate: - yy = att_output.squeeze(0) - else: - yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1))) - slot_gate_input = yy - slot_gate_layer = "llc" - else: - if self.debug_att_output: - if self.debug_stack_att: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1) - else: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0)), 1) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels),), 1) - elif self.class_aux_feats_ds: - slot_gate_input = torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1) - else: - slot_gate_input = att_output.squeeze(0) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - slot_gate_input = pooled_output - slot_gate_layer = "class_" + slot - - # Conditioned refer gate, or separate refer gates - if self.debug_joint_slot_gate and self.debug_joint_refer_gate: - slot_refer_input = self.gelu(self.h2r(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1))) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - slot_refer_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - slot_refer_input = pooled_output - - # Slot gate classification - if self.debug_sigmoid_slot_gates: - class_logits = [] - for cl in range(self.class_labels): - class_logits.append(self.sigmoid(getattr(self, slot_gate_layer + "_" + str(cl))(slot_gate_input))) - class_logits = torch.stack(class_logits, 1).squeeze(-1) - elif self.debug_att_slot_gates: - # TODO: implement separate gates as well - if cl_ids is not None and cl_mask is not None: - bla = encoded_classes.unsqueeze(1).expand(-1, pooled_output.size()[0], -1) - else: - bla = torch.stack(list(batch['encoded_classes'].values())).expand(-1, pooled_output.size()[0], -1) # TODO - #_, class_logits = self.slot_att( - _, class_logits = getattr(self, slot_gate_layer)( - query=slot_gate_input.unsqueeze(0), - key=bla, - value=bla, - need_weights=True) - class_logits = class_logits.squeeze(1) - else: - class_logits = getattr(self, slot_gate_layer)(slot_gate_input) - - class_logits = self.dropout_heads(class_logits) - - #token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output).squeeze(-1)) - token_weights = weights[slot] - - # --- - - if self.triplet_loss_weight > 0.0 and not self.debug_use_triplet_loss: - slot_values = torch.stack(list(batch['encoded_slot_values'][slot].values())) # TODO: filter? - slot_values = slot_values.expand((-1, pooled_output.size(0), -1)) - #yy = (batch['pos_sampling_input'][slot].sum(2) > 0.0) - #if self.debug_value_att_none_class: - # ww = batch['value_labels'][slot][:, 1:] * yy - #else: - # ww = batch['value_labels'][slot] * yy - #wwx = ww == 0 - #bbb = slot_values * wwx.transpose(0, 1).unsqueeze(2) - #wwy = ww == 1 - #yyy = batch['pos_sampling_input'][slot].expand(-1, slot_values.size(0), -1) * wwy.unsqueeze(2) - #slot_values = bbb + yyy.transpose(0, 1) - if self.debug_value_att_none_class: - slot_values = torch.cat((vectors[slot].unsqueeze(0), slot_values)) - #slot_values = slot_values.expand(-1, pooled_output.size()[0], -1) - #slot_values = torch.stack((batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter? - #slot_values = torch.stack((vectors[slot].unsqueeze(1), batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter? - #slot_values = slot_values.squeeze(2) - _, value_weights = self.value_att( - query=vectors[slot].unsqueeze(0), - key=slot_values, - value=slot_values, - need_weights=True) - ##vectors[slot] = vectors[slot].squeeze(0) - #vectors[slot] = torch.matmul(weights[slot], sequence_output).squeeze(1) - #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize? - #if self.debug_tanh_for_att_output: - #vectors[slot] = self.tanh(vectors[slot]) - #vectors[slot] = self.token_layer_norm2(vectors[slot]) - #vectors[slot] = self.tanh(self.tlinear(vectors[slot])) - #vectors[slot] += pooled_output # Residual - #vectors[slot] = self.token_layer_norm(vectors[slot]) - #vectors[slot] = self.dropout_heads(vectors[slot]) - value_weights = value_weights.squeeze(1) - else: - value_weights = None - - # --- - - # TODO: implement choice between joint_refer_gate and individual refer gates, analogous to - # slot gates. Use same input as slot gate, i.e., for joint case, use yy, for individual - # use pooled. This is stored in slot_gate_input. - - # --- - - if self.debug_joint_refer_gate: - if slot_ids is not None and slot_mask is not None: - refer_slots = encoded_slots_pooled.unsqueeze(1).expand(-1, pooled_output.size()[0], -1) - else: - #refer_slots = torch.stack((list(self.encoded_slots.values()))).expand(-1, pooled_output.size()[0], -1) - refer_slots = torch.stack(list(batch['encoded_slots_pooled'].values())).expand(-1, pooled_output.size()[0], -1) - _, refer_weights = self.refer_att( - query=slot_refer_input.unsqueeze(0), - key=refer_slots, - value=refer_slots, - need_weights=True) - refer_weights = refer_weights.squeeze(1) - refer_logits = refer_weights - else: - refer_logits = getattr(self, "refer_" + slot)(slot_refer_input) - - refer_logits = self.dropout_heads(refer_logits) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = token_weights # TODO - per_slot_value_logits[slot] = value_weights - #per_slot_start_logits[slot] = self.sigmoid(token_logits) # TODO # this is for sigmoid approach - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - - # TODO: solve this using the sequence_tagging def? - labels_clipped = torch.clamp(start_pos[slot].float(), min=0, max=1) - labels_clipped_scaled = labels_clipped / torch.clamp(labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets? - no_seq_mask = labels_clipped_scaled.sum(1) == 0 - no_seq_w = 1 / batch_input_mask.sum(1) - labels_clipped_scaled += batch_input_mask * (no_seq_mask * no_seq_w).unsqueeze(1) - #token_weights = self.sigmoid(token_logits) # TODO # this is for sigmoid approach - if self.token_loss_function == "mse": - token_loss = self.mse(token_weights, labels_clipped_scaled) # TODO: MSE might be better for scaled targets - else: - token_loss = self.binary_cross_entropy(token_weights, labels_clipped_scaled, reduction="none") - if self.ignore_o_tags: - token_loss *= labels_clipped # TODO: good idea? - - # TODO: do negative examples have to be balanced due to their large number? - token_loss = token_loss.sum(1) - token_is_pointable = (start_pos[slot].sum(1) > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - value_loss = torch.zeros(token_is_pointable.size(), device=token_is_pointable.device) - if self.triplet_loss_weight > 0.0: - if self.debug_use_triplet_loss: - value_loss = self.compute_triplet_loss_att(vectors[slot], pos_sampling_input, neg_sampling_input, slot) - #triplet_loss = torch.clamp(triplet_loss, max=1) # TODO: parameterize # Not the best idea I think... - else: - value_labels_clipped = torch.clamp(value_labels[slot].float(), min=0, max=1) - value_labels_clipped /= torch.clamp(value_labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets? - value_no_seq_mask = value_labels_clipped.sum(1) == 0 - value_no_seq_w = 1 / value_labels_clipped.size(1) - value_labels_clipped += (value_no_seq_mask * value_no_seq_w).unsqueeze(1) - if self.value_loss_function == "mse": - value_loss = self.mse(value_weights, value_labels_clipped) # TODO: scale value_labels to also cover nonpointable cases and multitarget cases - else: - value_loss = self.binary_cross_entropy(value_weights, value_labels_clipped, reduction="none") - #print(slot) - #print(value_labels_clipped) - #print(value_weights) - #print(value_loss) - value_loss = value_loss.sum(1) - token_is_matchable = token_is_pointable - if self.debug_tag_none_target: - token_is_matchable *= (start_pos[slot][:, 1] == 0).float() - if not self.value_loss_for_nonpointable: - value_loss *= token_is_matchable - - # Re-definition necessary to make slot-independent prediction possible - self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target - refer_loss = self.refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - if self.debug_sigmoid_slot_gates: - class_loss = [] - for cl in range(self.class_labels): - #class_loss.append(self.binary_cross_entropy(class_logits[:,cl], labels, reduction="none")) - class_loss.append(self.class_loss_fct(class_logits[:, cl], (class_label_id[slot] == cl).float(), reduction="none")) - class_loss = torch.stack(class_loss, 1) - if self.none_weight != 1.0: - class_loss *= self.clweights.to(outputs[0].device) - class_loss = class_loss.sum(1) - elif self.debug_att_slot_gates: - if self.class_loss_function == "mse": - class_loss = self.mse(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float()) - else: - class_loss = self.binary_cross_entropy(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float(), reduction="none") - if self.none_weight != 1.0: - class_loss *= self.clweights.to(outputs[0].device) - class_loss = class_loss.sum(1) - else: - class_loss = self.class_loss_fct(class_logits, class_label_id[slot]) - - #print("%15s, class loss: %.3f, token loss: %.3f, triplet loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (triplet_loss.sum()).item())) - #print("%15s, class loss: %.3f, token loss: %.3f, value loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (value_loss.sum()).item())) - - st_switch = int(not (self.sequence_tagging_dropout >= 1 and step is not None and step % self.sequence_tagging_dropout == 0)) - - if self.refer_index > -1: - #per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * triplet_loss - per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * value_loss - #per_example_loss = class_loss - else: - #per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss - per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * value_loss - #if epoch is not None and epoch > 20: - # per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss - #else: - # per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - #per_example_loss = class_loss - if self.debug_use_tlf: - per_example_loss *= 1.0 + ((batch['diag_len'] - batch['turn_id']) * 0.05) - - total_loss += per_example_loss.sum() - total_cl_loss += class_loss.sum() - total_tk_loss += token_loss.sum() - #total_tp_loss += triplet_loss.sum() - total_tp_loss += value_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - per_slot_per_example_cl_loss[slot] = class_loss - per_slot_per_example_tk_loss[slot] = token_loss - #per_slot_per_example_tp_loss[slot] = triplet_loss - per_slot_per_example_tp_loss[slot] = value_loss - ccc_time = time.time() - #print("TIME:", bbb_time - aaa_time, ccc_time - bbb_time) # 0.028620243072509766 - - # add hidden states and attention if they are here - outputs = (total_loss, - total_cl_loss, - total_tk_loss, - total_tp_loss, - per_slot_per_example_loss, - per_slot_per_example_cl_loss, - per_slot_per_example_tk_loss, - per_slot_per_example_tp_loss, - per_slot_class_logits, - per_slot_start_logits, - per_slot_end_logits, - per_slot_value_logits, - per_slot_refer_logits, - per_slot_att_weights,) + (vectors, weights,) + (pooled_output,) # + outputs[2:] - #outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, per_slot_att_weights,) + (vectors, weights,) # + outputs[2:] - - return outputs diff --git a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py~ b/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py~ deleted file mode 100644 index 2e4dec8bfea07922e25261bdb07798df559c5c4e..0000000000000000000000000000000000000000 --- a/convlab/dst/trippyr/multiwoz/modeling_roberta_dst.py~ +++ /dev/null @@ -1,854 +0,0 @@ -# coding=utf-8 -# -# Copyright 2020 Heinrich Heine University Duesseldorf -# -# Part of this code is based on the source code of BERT-DST -# (arXiv:1907.03040) -# Part of this code is based on the source code of Transformers -# (arXiv:1910.03771) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss -from torch.nn import TripletMarginLoss -from torch.nn import PairwiseDistance -from torch.nn import MultiheadAttention -import torch.nn.functional as F - -from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable) -from transformers.modeling_utils import (PreTrainedModel) -from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm) - -import time - - -class RobertaPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - config_class = RobertaConfig - pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP - base_model_prefix = "roberta" - - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -@add_start_docstrings( - """RoBERTa Model with classification heads for the DST task. """, - ROBERTA_START_DOCSTRING, -) -class RobertaForDST(RobertaPreTrainedModel): - def __init__(self, config): - super(RobertaForDST, self).__init__(config) - self.slot_list = config.dst_slot_list - self.noncategorical = config.dst_noncategorical - self.categorical = config.dst_categorical - self.class_types = config.dst_class_types - self.class_labels = config.dst_class_labels - self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable - self.value_loss_for_nonpointable = config.dst_value_loss_for_nonpointable - self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable - self.stack_token_logits = False # config.dst_stack_token_logits # TODO - self.class_aux_feats_inform = config.dst_class_aux_feats_inform - self.class_aux_feats_ds = config.dst_class_aux_feats_ds - self.class_loss_ratio = config.dst_class_loss_ratio - self.slot_attention_heads = config.dst_slot_attention_heads - self.dropout_rate = config.dst_dropout_rate - self.heads_dropout_rate = config.dst_heads_dropout_rate - - self.debug_fix_slot_embs = config.debug_fix_slot_embs - self.debug_joint_slot_gate = config.debug_joint_slot_gate - self.debug_joint_refer_gate = config.debug_joint_refer_gate - self.debug_simple_joint_slot_gate = config.debug_simple_joint_slot_gate - self.debug_separate_seq_tagging = config.debug_separate_seq_tagging - self.debug_sigmoid_sequence_tagging = config.debug_sigmoid_sequence_tagging - self.debug_att_output = config.debug_att_output - self.debug_tanh_for_att_output = config.debug_tanh_for_att_output - self.debug_stack_att = config.debug_stack_att - self.debug_stack_rep = config.debug_stack_rep - self.debug_sigmoid_slot_gates = config.debug_sigmoid_slot_gates - self.debug_att_slot_gates = config.debug_att_slot_gates - self.debug_use_triplet_loss = config.debug_use_triplet_loss - self.debug_use_tlf = config.debug_use_tlf - self.debug_value_att_none_class = config.debug_value_att_none_class - self.debug_tag_none_target = config.debug_tag_none_target - self.triplet_loss_weight = config.triplet_loss_weight - self.none_weight = config.none_weight - self.pretrain_loss_function = config.pretrain_loss_function - self.token_loss_function = config.token_loss_function - self.value_loss_function = config.value_loss_function - self.class_loss_function = config.class_loss_function - self.sequence_tagging_dropout = -1 - if config.sequence_tagging_dropout > 0.0: - self.sequence_tagging_dropout = int(1 / config.sequence_tagging_dropout) - self.ignore_o_tags = config.ignore_o_tags - self.debug_slot_embs_per_step_nograd = config.debug_slot_embs_per_step_nograd - try: - self.debug_use_cross_attention = config.debug_use_cross_attention - except: - self.debug_use_cross_attention = False - - self.val_rep_mode = config.val_rep_mode # TODO: for debugging at the moment. - - config.output_hidden_states = True # TODO - - #config.dst_dropout_rate = 0.0 # TODO: for debugging - #config.dst_heads_dropout_rate = 0.0 # TODO: for debugging - - self.roberta = RobertaModel(config) - self.dropout = nn.Dropout(config.dst_dropout_rate) - self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) - - # -- Dialogue state tracking functionality -- - - # Only use refer loss if refer class is present in dataset. - if 'refer' in self.class_types: - self.refer_index = self.class_types.index('refer') - else: - self.refer_index = -1 - - if self.class_aux_feats_inform: - self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - if self.class_aux_feats_ds: - self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) - - aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 - - # Slot specific gates - for slot in self.slot_list: - if not self.debug_joint_slot_gate: - if self.debug_stack_att: - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size * 2 + aux_dims, 1)) - elif self.debug_att_slot_gates: - self.add_module("class_" + slot, MultiheadAttention(config.hidden_size * 2 + aux_dims, self.slot_attention_heads)) - else: - self.add_module("class_" + slot, nn.Linear(config.hidden_size * 2 + aux_dims, self.class_labels)) - else: - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("class_" + slot + "_" + str(cl), nn.Linear(config.hidden_size + aux_dims, 1)) - elif self.debug_att_slot_gates: - self.add_module("class_" + slot, MultiheadAttention(config.hidden_size + aux_dims, self.slot_attention_heads)) - else: - self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) - #self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1)) - if not self.debug_joint_refer_gate: - self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) - if self.debug_separate_seq_tagging: - if self.debug_sigmoid_sequence_tagging: - self.add_module("token_" + slot, nn.Linear(config.hidden_size, 1)) - else: - self.add_module("token_" + slot, MultiheadAttention(config.hidden_size, self.slot_attention_heads)) - - if self.debug_att_output: - self.class_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - # Conditioned sequence tagging - if not self.debug_separate_seq_tagging: - if self.debug_sigmoid_sequence_tagging: - self.h1t = nn.Linear(config.hidden_size, config.hidden_size) - self.h2t = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) - self.llt = nn.Linear(config.hidden_size * 2, 1) - else: - self.token_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - if self.debug_joint_refer_gate: - self.refer_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - self.value_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - #self.slot_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) - - self.token_layer_norm = nn.LayerNorm(config.hidden_size) - self.token_layer_norm2 = nn.LayerNorm(config.hidden_size) - self.class_layer_norm = nn.LayerNorm(config.hidden_size) - self.class_layer_norm2 = nn.LayerNorm(config.hidden_size) - - self.tlinear = nn.Linear(config.hidden_size, config.hidden_size) - self.clinear = nn.Linear(config.hidden_size, config.hidden_size) - - # Conditioned slot gate - self.h1c = nn.Linear(config.hidden_size + aux_dims, config.hidden_size) - if self.debug_stack_att: - self.h0c = nn.Linear(config.hidden_size, config.hidden_size) - self.h2c = nn.Linear(config.hidden_size * 3, config.hidden_size * 3) - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 3, 1)) - elif self.debug_att_slot_gates: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads) - else: - self.llc = MultiheadAttention(config.hidden_size * 3, self.slot_attention_heads) - else: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = nn.Linear(config.hidden_size * 2, self.class_labels) - else: - self.llc = nn.Linear(config.hidden_size * 3, self.class_labels) - else: - if self.debug_att_slot_gates: - self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 1) - else: - self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) - if self.debug_sigmoid_slot_gates: - for cl in range(self.class_labels): - self.add_module("llc_" + str(cl), nn.Linear(config.hidden_size * 2, 1)) - elif self.debug_att_slot_gates: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = MultiheadAttention(config.hidden_size * 1, self.slot_attention_heads) - else: - self.llc = MultiheadAttention(config.hidden_size * 2, self.slot_attention_heads) - else: - if self.debug_att_output and self.debug_simple_joint_slot_gate: - self.llc = nn.Linear(config.hidden_size * 1, self.class_labels) - else: - self.llc = nn.Linear(config.hidden_size * 2, self.class_labels) - - # Conditioned refer gate - self.h2r = nn.Linear(config.hidden_size * 2, config.hidden_size * 1) - - # -- Spanless sequence tagging functionality -- - - self.dis = PairwiseDistance(p=2) - self.sigmoid = nn.Sigmoid() - self.tanh = nn.Tanh() - self.relu = nn.ReLU() - self.gelu = nn.GELU() - self.binary_cross_entropy = F.binary_cross_entropy - self.triplet_loss = TripletMarginLoss(margin=1.0, reduction="none") - self.mse = nn.MSELoss(reduction="none") - self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target - self.value_loss_fct = CrossEntropyLoss(reduction='none') - - if self.none_weight != 1.0: - none_weight = self.none_weight - weight_mass = none_weight + (self.class_labels - 1) - none_weight /= weight_mass - other_weights = 1 / weight_mass - #self.clweights = torch.tensor([none_weight] + [other_weights] * (self.class_labels - 1)) # .to(outputs[0].device) - self.clweights = torch.tensor([other_weights] * self.class_labels) # .to(outputs[0].device) - self.clweights[self.class_types.index('none')] = none_weight - if self.debug_sigmoid_slot_gates: - self.class_loss_fct = F.binary_cross_entropy # (reduction="none") # weights for classes are controlled when adding losses - else: - self.class_loss_fct = CrossEntropyLoss(weight=self.clweights, reduction='none') - else: - if self.debug_sigmoid_slot_gates: - self.class_loss_fct = F.binary_cross_entropy # (reduction="none") - else: - self.class_loss_fct = CrossEntropyLoss(reduction='none') - - self.init_weights() - - def compute_triplet_loss_att(self, att_output, pos_sampling_input, neg_sampling_input, slot): - loss = self.triplet_loss(att_output, pos_sampling_input[slot].squeeze(), neg_sampling_input[slot].squeeze()) - sample_mask = neg_sampling_input[slot].squeeze(1).sum(1) != 0 - loss *= sample_mask - return loss - - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) - def forward(self, batch, step=None, epoch=None, t_slot=None, mode=None): - assert(mode in [None, "pretrain", "tag", "encode", "encode_vals", "represent"]) # TODO - - input_ids = batch['input_ids'] - input_mask = batch['input_mask'] - #segment_ids = batch['segment_ids'] - usr_mask = batch['usr_mask'] - start_pos = batch['start_pos'] - end_pos = batch['end_pos'] - refer_id = batch['refer_id'] - class_label_id = batch['class_label_id'] - inform_slot_id = batch['inform_slot_id'] - diag_state = batch['diag_state'] - slot_ids = batch['slot_ids'] if 'slot_ids' in batch else None # TODO: fix? - slot_mask = batch['slot_mask'] if 'slot_mask' in batch else None # TODO: fix? - cl_ids = batch['cl_ids'] if 'cl_ids' in batch else None # TODO: fix? - cl_mask = batch['cl_mask'] if 'cl_mask' in batch else None # TODO: fix? - pos_sampling_input = batch['pos_sampling_input'] - neg_sampling_input = batch['neg_sampling_input'] - value_labels = batch['value_labels'] if 'value_labels' in batch else None # TODO: fix? - - batch_input_mask = input_mask - if slot_ids is not None and slot_mask is not None: - if self.debug_slot_embs_per_step_nograd: - with torch.no_grad(): - outputs_slot = self.roberta( - slot_ids, - attention_mask=slot_mask, - token_type_ids=None, # segment_ids, - position_ids=None, - head_mask=None - ) - else: - input_ids = torch.cat((input_ids, slot_ids)) - input_mask = torch.cat((input_mask, slot_mask)) - if cl_ids is not None and cl_mask is not None: - input_ids = torch.cat((input_ids, cl_ids)) - input_mask = torch.cat((input_mask, cl_mask)) - - outputs = self.roberta( - input_ids, - attention_mask=input_mask, - token_type_ids=None, # segment_ids, - position_ids=None, - head_mask=None - ) - - sequence_output = outputs[0] - pooled_output = outputs[1] - - if cl_ids is not None and cl_mask is not None: - sequence_output = sequence_output[:-1 * len(cl_ids), :, :] - encoded_classes = pooled_output[-1 * len(cl_ids):, :] - pooled_output = pooled_output[:-1 * len(cl_ids), :] - if slot_ids is not None and slot_mask is not None: - if self.debug_slot_embs_per_step_nograd: - encoded_slots_seq = outputs_slot[0] - encoded_slots_pooled = outputs_slot[1] - else: - encoded_slots_seq = sequence_output[-1 * len(slot_ids):, :, :] - sequence_output = sequence_output[:-1 * len(slot_ids), :, :] - encoded_slots_pooled = pooled_output[-1 * len(slot_ids):, :] - pooled_output = pooled_output[:-1 * len(slot_ids), :] - - sequence_output = self.dropout(sequence_output) - pooled_output = self.dropout(pooled_output) - - #sequence_output11 = outputs[2][-2] - - inverted_input_mask = ~(batch_input_mask.bool()) - if usr_mask is None: - usr_mask = input_mask - inverted_usr_mask = ~(usr_mask.bool()) - - if mode == "encode": # Create vector representations only - return pooled_output, sequence_output, None - - if mode == "encode_vals": # Create vector representations only - no_seq_w = 1 / batch_input_mask.sum(1) - uniform_weights = batch_input_mask * no_seq_w.unsqueeze(1) - pooled_output_vals = torch.matmul(uniform_weights, sequence_output).squeeze(1) - pooled_output_vals = self.token_layer_norm(pooled_output_vals) - return pooled_output_vals, None, None - - if mode == "pretrain": - pos_vectors = {} - pos_weights = {} - pos_vectors, pos_weights = self.token_att( - query=encoded_slots_pooled.squeeze(1).unsqueeze(0), - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask, - need_weights=True) - pos_vectors = pos_vectors.squeeze(0) - pos_vectors = self.token_layer_norm(pos_vectors) - pos_weights = pos_weights.squeeze(1) - - #neg_vectors = {} - neg_weights = {} - #neg_vectors[slot], neg_weights[slot] = self.token_att( - # query=neg_sampling_input[slot].squeeze(1).unsqueeze(0), - # key=sequence_output.transpose(0, 1), - # value=sequence_output.transpose(0, 1), - # key_padding_mask=inverted_input_mask, - # need_weights=True) - #neg_vectors[slot] = neg_vectors[slot].squeeze(0) - #neg_vectors[slot] = self.token_layer_norm(neg_vectors[slot]) - #neg_weights[slot] = neg_weights[slot].squeeze(1) - - pos_labels_clipped = torch.clamp(start_pos.float(), min=0, max=1) - pos_labels_clipped_scaled = pos_labels_clipped / torch.clamp(pos_labels_clipped.sum(1).unsqueeze(1), min=1) # scaled - #no_seq_w = 1 / batch_input_mask.sum(1) - #neg_labels_clipped = batch_input_mask * no_seq_w.unsqueeze(1) - if self.pretrain_loss_function == "mse": - pos_token_loss = self.mse(pos_weights, pos_labels_clipped_scaled) # TODO: MSE might be better for scaled targets - else: - pos_token_loss = self.binary_cross_entropy(pos_weights, pos_labels_clipped_scaled, reduction="none") - if self.ignore_o_tags: - pos_token_loss *= pos_labels_clipped # TODO: good idea? - #mm = torch.clamp(start_pos[slot].float(), min=0, max=1) - #mm = torch.clamp(mm * 10, min=1) - #pos_token_loss *= mm - pos_token_loss = pos_token_loss.sum(1) - #neg_token_loss = self.mse(neg_weights[slot], neg_labels_clipped) - #neg_token_loss = neg_token_loss.sum(1) - - #triplet_loss = self.compute_triplet_loss_att(pos_vectors[slot], pos_sampling_input, neg_sampling_input, slot) - - per_example_loss = pos_token_loss # + triplet_loss # + neg_token_loss - total_loss = per_example_loss.sum() - - return (total_loss, pos_weights, neg_weights,) - - if mode == "tag": - query = torch.stack(list(batch['value_reps'].values())).transpose(1, 2).reshape(-1, pooled_output.size()[0], pooled_output.size()[1]) - _, weights = self.token_att(query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask + inverted_usr_mask, - need_weights=True) - return (weights,) - - aaa_time = time.time() - # Attention for sequence tagging - vectors = {} - weights = {} - for s_itr, slot in enumerate(self.slot_list): - if slot_ids is not None and slot_mask is not None: - encoded_slot_seq = encoded_slots_seq[s_itr] - encoded_slot_pooled = encoded_slots_pooled[s_itr] - else: - encoded_slot_seq = batch['encoded_slots_seq'][slot] - encoded_slot_pooled = batch['encoded_slots_pooled'][slot] - if self.debug_separate_seq_tagging: - if self.debug_sigmoid_sequence_tagging: - weights[slot] = self.sigmoid(getattr(self, "token_" + slot)(sequence_output)).squeeze(2) - vectors[slot] = torch.matmul(weights[slot], sequence_output) - vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1) - vectors[slot] = self.token_layer_norm(vectors[slot]) - else: - if self.debug_use_cross_attention: - query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1) - else: - query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) - vectors[slot], weights[slot] = getattr(self, "token_" + slot)( - query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask + inverted_usr_mask, - need_weights=True) # TODO: use usr_mask better or worse? - vectors[slot] = vectors[slot].squeeze(0) - #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize? - if self.debug_tanh_for_att_output: - #vectors[slot] = self.tanh(vectors[slot]) - vectors[slot] = self.token_layer_norm2(vectors[slot]) - vectors[slot] = self.tanh(self.tlinear(vectors[slot])) - #vectors[slot] += pooled_output # Residual - vectors[slot] = self.token_layer_norm(vectors[slot]) - #vectors[slot] = self.dropout_heads(vectors[slot]) - weights[slot] = weights[slot].squeeze(1) - if self.debug_stack_rep: - vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1) - else: - if self.debug_sigmoid_sequence_tagging: - # Conditioned sequence tagging - sequence_output_add = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(1).expand(sequence_output.size()) - xxt = self.gelu(self.h1t(sequence_output)) - yyt = self.gelu(self.h2t(torch.cat((sequence_output_add, xxt), 2))) - weights[slot] = self.sigmoid(self.llt(yyt)).squeeze(2) - vectors[slot] = torch.matmul(weights[slot], sequence_output) - vectors[slot] = torch.diagonal(vectors[slot]).transpose(0, 1) - vectors[slot] = self.token_layer_norm(vectors[slot]) - else: - if self.debug_use_cross_attention: - query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1) - else: - query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) - vectors[slot], weights[slot] = self.token_att( - query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask + inverted_usr_mask, - need_weights=True) # TODO: use usr_mask better or worse? - vectors[slot] = vectors[slot].squeeze(0) - if self.debug_use_cross_attention: - vectors[slot] = torch.mean(vectors[slot] * (batch_input_mask + usr_mask).transpose(0, 1).unsqueeze(-1), dim=0) - #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize? - if self.debug_tanh_for_att_output: - #vectors[slot] = self.tanh(vectors[slot]) - vectors[slot] = self.token_layer_norm2(vectors[slot]) - vectors[slot] = self.tanh(self.tlinear(vectors[slot])) - #vectors[slot] += pooled_output # Residual - vectors[slot] = self.token_layer_norm(vectors[slot]) - #vectors[slot] = self.dropout_heads(vectors[slot]) - if self.debug_use_cross_attention: - weights[slot] = torch.mean(weights[slot] * (batch_input_mask + usr_mask).unsqueeze(-1), dim=1) - weights[slot] = weights[slot].squeeze(1) - if self.debug_stack_rep: - vectors[slot] = torch.cat((pooled_output, vectors[slot]), 1) - - if mode == "represent": # Create vector representations only - return vectors, None, weights - - # TODO: establish proper format in labels already? - if inform_slot_id is not None: - inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() - if diag_state is not None: - diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) - - bbb_time = time.time() - total_loss = 0 - total_cl_loss = 0 - total_tk_loss = 0 - total_tp_loss = 0 - per_slot_per_example_loss = {} - per_slot_per_example_cl_loss = {} - per_slot_per_example_tk_loss = {} - per_slot_per_example_tp_loss = {} - per_slot_att_weights = {} - per_slot_class_logits = {} - per_slot_start_logits = {} - per_slot_end_logits = {} - per_slot_value_logits = {} - per_slot_refer_logits = {} - for s_itr, slot in enumerate(self.slot_list): - #if t_slot is not None and slot != t_slot: - # continue - if slot_ids is not None and slot_mask is not None: - encoded_slot_seq = encoded_slots_seq[s_itr] - encoded_slot_pooled = encoded_slots_pooled[s_itr] - else: - encoded_slot_seq = batch['encoded_slots_seq'][slot] - encoded_slot_pooled = batch['encoded_slots_pooled'][slot] - - # Attention for slot gates - if self.debug_att_output: - if self.debug_use_cross_attention: - query = encoded_slot_seq.squeeze().unsqueeze(0).expand(sequence_output.size()).transpose(0, 1) - else: - query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) - att_output, c_weights = self.class_att( - query=query, - key=sequence_output.transpose(0, 1), - value=sequence_output.transpose(0, 1), - key_padding_mask=inverted_input_mask, - need_weights=True) - if self.debug_use_cross_attention: - att_output = torch.mean(att_output, dim=0) - c_weights = torch.mean(c_weights, dim=1) - if self.debug_tanh_for_att_output: - #att_output = self.tanh(att_output) - att_output = self.class_layer_norm2(att_output) - att_output = self.tanh(self.clinear(att_output)) - att_output = self.class_layer_norm(att_output) - att_output = self.dropout_heads(att_output) - per_slot_att_weights[slot] = c_weights.squeeze(1) - else: - per_slot_att_weights[slot] = None - - # Conditioned slot gate, or separate slot gates - if self.debug_joint_slot_gate: - if self.debug_att_output: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1))) - elif self.class_aux_feats_inform: - xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels)), 1))) - elif self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1))) - else: - xx = self.gelu(self.h1c(att_output.squeeze(0))) - if self.debug_stack_att: - x0 = self.gelu(self.h0c(pooled_output)) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1))) - elif self.class_aux_feats_inform: - xx = self.gelu(self.h1c(torch.cat((pooled_output, self.inform_projection(inform_labels)), 1))) - elif self.class_aux_feats_ds: - xx = self.gelu(self.h1c(torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1))) - else: - xx = self.gelu(self.h1c(pooled_output)) - if self.debug_att_output and not self.debug_simple_joint_slot_gate and self.debug_stack_att: - yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), x0, xx), 1))) - elif self.debug_att_output and self.debug_simple_joint_slot_gate and self.debug_stack_att: - yy = torch.cat((pooled_output, att_output.squeeze(0)), 1) - elif self.debug_att_output and self.debug_simple_joint_slot_gate: - yy = att_output.squeeze(0) - else: - yy = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1))) - slot_gate_input = yy - slot_gate_layer = "llc" - else: - if self.debug_att_output: - if self.debug_stack_att: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1) - else: - slot_gate_input = torch.cat((pooled_output, att_output.squeeze(0)), 1) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_gate_input = torch.cat((att_output.squeeze(0), self.inform_projection(inform_labels),), 1) - elif self.class_aux_feats_ds: - slot_gate_input = torch.cat((att_output.squeeze(0), self.ds_projection(diag_state_labels)), 1) - else: - slot_gate_input = att_output.squeeze(0) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_gate_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - slot_gate_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - slot_gate_input = pooled_output - slot_gate_layer = "class_" + slot - - # Conditioned refer gate, or separate refer gates - if self.debug_joint_slot_gate and self.debug_joint_refer_gate: - slot_refer_input = self.gelu(self.h2r(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), xx), 1))) - else: - if self.class_aux_feats_inform and self.class_aux_feats_ds: - slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) - elif self.class_aux_feats_inform: - slot_refer_input = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) - elif self.class_aux_feats_ds: - slot_refer_input = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) - else: - slot_refer_input = pooled_output - - # Slot gate classification - if self.debug_sigmoid_slot_gates: - class_logits = [] - for cl in range(self.class_labels): - class_logits.append(self.sigmoid(getattr(self, slot_gate_layer + "_" + str(cl))(slot_gate_input))) - class_logits = torch.stack(class_logits, 1).squeeze(-1) - elif self.debug_att_slot_gates: - # TODO: implement separate gates as well - if cl_ids is not None and cl_mask is not None: - bla = encoded_classes.unsqueeze(1).expand(-1, pooled_output.size()[0], -1) - else: - bla = torch.stack(list(batch['encoded_classes'].values())).expand(-1, pooled_output.size()[0], -1) # TODO - #_, class_logits = self.slot_att( - _, class_logits = getattr(self, slot_gate_layer)( - query=slot_gate_input.unsqueeze(0), - key=bla, - value=bla, - need_weights=True) - class_logits = class_logits.squeeze(1) - else: - class_logits = getattr(self, slot_gate_layer)(slot_gate_input) - - class_logits = self.dropout_heads(class_logits) - - #token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output).squeeze(-1)) - token_weights = weights[slot] - - # --- - - if self.triplet_loss_weight > 0.0 and not self.debug_use_triplet_loss: - slot_values = torch.stack(list(batch['encoded_slot_values'][slot].values())) # TODO: filter? - slot_values = slot_values.expand((-1, pooled_output.size(0), -1)) - #yy = (batch['pos_sampling_input'][slot].sum(2) > 0.0) - #if self.debug_value_att_none_class: - # ww = batch['value_labels'][slot][:, 1:] * yy - #else: - # ww = batch['value_labels'][slot] * yy - #wwx = ww == 0 - #bbb = slot_values * wwx.transpose(0, 1).unsqueeze(2) - #wwy = ww == 1 - #yyy = batch['pos_sampling_input'][slot].expand(-1, slot_values.size(0), -1) * wwy.unsqueeze(2) - #slot_values = bbb + yyy.transpose(0, 1) - if self.debug_value_att_none_class: - slot_values = torch.cat((vectors[slot].unsqueeze(0), slot_values)) - #slot_values = slot_values.expand(-1, pooled_output.size()[0], -1) - #slot_values = torch.stack((batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter? - #slot_values = torch.stack((vectors[slot].unsqueeze(1), batch['pos_sampling_input'][slot], batch['neg_sampling_input'][slot])) # TODO: filter? - #slot_values = slot_values.squeeze(2) - _, value_weights = self.value_att( - query=vectors[slot].unsqueeze(0), - key=slot_values, - value=slot_values, - need_weights=True) - ##vectors[slot] = vectors[slot].squeeze(0) - #vectors[slot] = torch.matmul(weights[slot], sequence_output).squeeze(1) - #vectors[slot] = vectors[slot] / torch.norm(vectors[slot], dim=1, keepdim=True) # Normalize? - #if self.debug_tanh_for_att_output: - #vectors[slot] = self.tanh(vectors[slot]) - #vectors[slot] = self.token_layer_norm2(vectors[slot]) - #vectors[slot] = self.tanh(self.tlinear(vectors[slot])) - #vectors[slot] += pooled_output # Residual - #vectors[slot] = self.token_layer_norm(vectors[slot]) - #vectors[slot] = self.dropout_heads(vectors[slot]) - value_weights = value_weights.squeeze(1) - else: - value_weights = None - - # --- - - # TODO: implement choice between joint_refer_gate and individual refer gates, analogous to - # slot gates. Use same input as slot gate, i.e., for joint case, use yy, for individual - # use pooled. This is stored in slot_gate_input. - - # --- - - if self.debug_joint_refer_gate: - if slot_ids is not None and slot_mask is not None: - refer_slots = encoded_slots_pooled.unsqueeze(1).expand(-1, pooled_output.size()[0], -1) - else: - #refer_slots = torch.stack((list(self.encoded_slots.values()))).expand(-1, pooled_output.size()[0], -1) - refer_slots = torch.stack(list(batch['encoded_slots_pooled'].values())).expand(-1, pooled_output.size()[0], -1) - _, refer_weights = self.refer_att( - query=slot_refer_input.unsqueeze(0), - key=refer_slots, - value=refer_slots, - need_weights=True) - refer_weights = refer_weights.squeeze(1) - refer_logits = refer_weights - else: - refer_logits = getattr(self, "refer_" + slot)(slot_refer_input) - - refer_logits = self.dropout_heads(refer_logits) - - per_slot_class_logits[slot] = class_logits - per_slot_start_logits[slot] = token_weights # TODO - per_slot_value_logits[slot] = value_weights - #per_slot_start_logits[slot] = self.sigmoid(token_logits) # TODO # this is for sigmoid approach - per_slot_refer_logits[slot] = refer_logits - - # If there are no labels, don't compute loss - if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: - # If we are on multi-GPU, split add a dimension - if len(start_pos[slot].size()) > 1: - start_pos[slot] = start_pos[slot].squeeze(-1) - - # TODO: solve this using the sequence_tagging def? - labels_clipped = torch.clamp(start_pos[slot].float(), min=0, max=1) - labels_clipped_scaled = labels_clipped / torch.clamp(labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets? - no_seq_mask = labels_clipped_scaled.sum(1) == 0 - no_seq_w = 1 / batch_input_mask.sum(1) - labels_clipped_scaled += batch_input_mask * (no_seq_mask * no_seq_w).unsqueeze(1) - #token_weights = self.sigmoid(token_logits) # TODO # this is for sigmoid approach - if self.token_loss_function == "mse": - token_loss = self.mse(token_weights, labels_clipped_scaled) # TODO: MSE might be better for scaled targets - else: - token_loss = self.binary_cross_entropy(token_weights, labels_clipped_scaled, reduction="none") - if self.ignore_o_tags: - token_loss *= labels_clipped # TODO: good idea? - - # TODO: do negative examples have to be balanced due to their large number? - token_loss = token_loss.sum(1) - token_is_pointable = (start_pos[slot].sum(1) > 0).float() - if not self.token_loss_for_nonpointable: - token_loss *= token_is_pointable - - value_loss = torch.zeros(token_is_pointable.size(), device=token_is_pointable.device) - if self.triplet_loss_weight > 0.0: - if self.debug_use_triplet_loss: - value_loss = self.compute_triplet_loss_att(vectors[slot], pos_sampling_input, neg_sampling_input, slot) - #triplet_loss = torch.clamp(triplet_loss, max=1) # TODO: parameterize # Not the best idea I think... - else: - value_labels_clipped = torch.clamp(value_labels[slot].float(), min=0, max=1) - value_labels_clipped /= torch.clamp(value_labels_clipped.sum(1).unsqueeze(1), min=1) # Scale targets? - value_no_seq_mask = value_labels_clipped.sum(1) == 0 - value_no_seq_w = 1 / value_labels_clipped.size(1) - value_labels_clipped += (value_no_seq_mask * value_no_seq_w).unsqueeze(1) - if self.value_loss_function == "mse": - value_loss = self.mse(value_weights, value_labels_clipped) # TODO: scale value_labels to also cover nonpointable cases and multitarget cases - else: - value_loss = self.binary_cross_entropy(value_weights, value_labels_clipped, reduction="none") - #print(slot) - #print(value_labels_clipped) - #print(value_weights) - #print(value_loss) - value_loss = value_loss.sum(1) - token_is_matchable = token_is_pointable - if self.debug_tag_none_target: - token_is_matchable *= (start_pos[slot][:, 1] == 0).float() - if not self.value_loss_for_nonpointable: - value_loss *= token_is_matchable - - # Re-definition necessary to make slot-independent prediction possible - self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target - refer_loss = self.refer_loss_fct(refer_logits, refer_id[slot]) - token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() - if not self.refer_loss_for_nonpointable: - refer_loss *= token_is_referrable - - if self.debug_sigmoid_slot_gates: - class_loss = [] - for cl in range(self.class_labels): - #class_loss.append(self.binary_cross_entropy(class_logits[:,cl], labels, reduction="none")) - class_loss.append(self.class_loss_fct(class_logits[:, cl], (class_label_id[slot] == cl).float(), reduction="none")) - class_loss = torch.stack(class_loss, 1) - if self.none_weight != 1.0: - class_loss *= self.clweights.to(outputs[0].device) - class_loss = class_loss.sum(1) - elif self.debug_att_slot_gates: - if self.class_loss_function == "mse": - class_loss = self.mse(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float()) - else: - class_loss = self.binary_cross_entropy(class_logits, torch.nn.functional.one_hot(class_label_id[slot], self.class_labels).float(), reduction="none") - if self.none_weight != 1.0: - class_loss *= self.clweights.to(outputs[0].device) - class_loss = class_loss.sum(1) - else: - class_loss = self.class_loss_fct(class_logits, class_label_id[slot]) - - #print("%15s, class loss: %.3f, token loss: %.3f, triplet loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (triplet_loss.sum()).item())) - #print("%15s, class loss: %.3f, token loss: %.3f, value loss: %.3f" % (slot, (class_loss.sum()).item(), (token_loss.sum()).item(), (value_loss.sum()).item())) - - st_switch = int(not (self.sequence_tagging_dropout >= 1 and step is not None and step % self.sequence_tagging_dropout == 0)) - - if self.refer_index > -1: - #per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * triplet_loss - per_example_loss = (self.class_loss_ratio) * class_loss + st_switch * ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.triplet_loss_weight * value_loss - #per_example_loss = class_loss - else: - #per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss - per_example_loss = self.class_loss_ratio * class_loss + st_switch * (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * value_loss - #if epoch is not None and epoch > 20: - # per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + self.triplet_loss_weight * triplet_loss - #else: - # per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss - #per_example_loss = class_loss - if self.debug_use_tlf: - per_example_loss *= 1.0 + ((batch['diag_len'] - batch['turn_id']) * 0.05) - - total_loss += per_example_loss.sum() - total_cl_loss += class_loss.sum() - total_tk_loss += token_loss.sum() - #total_tp_loss += triplet_loss.sum() - total_tp_loss += value_loss.sum() - per_slot_per_example_loss[slot] = per_example_loss - per_slot_per_example_cl_loss[slot] = class_loss - per_slot_per_example_tk_loss[slot] = token_loss - #per_slot_per_example_tp_loss[slot] = triplet_loss - per_slot_per_example_tp_loss[slot] = value_loss - ccc_time = time.time() - #print(bbb_time - aaa_time, ccc_time - bbb_time) # 0.028620243072509766 - - # add hidden states and attention if they are here - outputs = (total_loss, - total_cl_loss, - total_tk_loss, - total_tp_loss, - per_slot_per_example_loss, - per_slot_per_example_cl_loss, - per_slot_per_example_tk_loss, - per_slot_per_example_tp_loss, - per_slot_class_logits, - per_slot_start_logits, - per_slot_end_logits, - per_slot_value_logits, - per_slot_refer_logits, - per_slot_att_weights,) + (vectors, weights,) + (pooled_output,) # + outputs[2:] - #outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, per_slot_att_weights,) + (vectors, weights,) # + outputs[2:] - - return outputs diff --git a/convlab/dst/trippyr/multiwoz/trippyr.py b/convlab/dst/trippyr/multiwoz/trippyr.py deleted file mode 100644 index f89de78cc1024d5fce60767addc3ce803f3e23ad..0000000000000000000000000000000000000000 --- a/convlab/dst/trippyr/multiwoz/trippyr.py +++ /dev/null @@ -1,793 +0,0 @@ -# Copyright 2021 Heinrich Heine University Duesseldorf -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re -import json -import copy -import pickle - -import torch -from transformers import (RobertaConfig, RobertaTokenizer) - -from convlab.dst.trippyr.multiwoz.modeling_roberta_dst import (RobertaForDST) - -from convlab.dst.dst import DST -from convlab.util.multiwoz.state import default_state -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from convlab.nlu.jointBERT.multiwoz import BERTNLU -from convlab.nlg.template.multiwoz import TemplateNLG -from convlab.dst.rule.multiwoz import normalize_value - -from convlab.util import relative_import_module_from_unified_datasets -ONTOLOGY = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'ontology') -TEMPLATE_STATE = ONTOLOGY['state'] - -import pdb -import time - - -MODEL_CLASSES = { - 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), -} - - -SLOT_MAP_TRIPPY_TO_UDF = { - 'hotel': { - 'pricerange': 'price range', - 'book_stay': 'book stay', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode', - 'price': 'price range', - 'people': 'book people' - }, - 'restaurant': { - 'pricerange': 'price range', - 'book_time': 'book time', - 'book_day': 'book day', - 'book_people': 'book people', - 'addr': 'address', - 'post': 'postcode', - 'price': 'price range', - 'people': 'book people' - }, - 'taxi': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'car': 'type', - 'car type': 'type', - 'depart': 'departure', - 'dest': 'destination' - }, - 'train': { - 'arriveBy': 'arrive by', - 'leaveAt': 'leave at', - 'book_people': 'book people', - 'arrive': 'arrive by', - 'leave': 'leave at', - 'depart': 'departure', - 'dest': 'destination', - 'id': 'train id', - 'people': 'book people', - 'time': 'duration', - 'ticket': 'price', - 'trainid': 'train id' - }, - 'attraction': { - 'post': 'postcode', - 'addr': 'address', - 'fee': 'entrance fee', - 'price': 'entrance fee' - }, - 'general': {}, - 'hospital': { - 'post': 'postcode', - 'addr': 'address' - }, - 'police': { - 'post': 'postcode', - 'addr': 'address' - } -} - - -class TRIPPYR(DST): - def print_header(self): - print(" _________ ________ ___ ________ ________ ___ ___ ________ ") - print("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /| |\ __ \ ") - print("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /______\ \ \|\ \ ") - print(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / /\_______\ \ _ _\ ") - print(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / /\|_______|\ \ \\\ \| ") - print(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / \ \__\\\ _\ ") - print(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / \|__|\|__|") - print(" (c) 2022 Heinrich Heine University \|___|/ ") - print() - - def print_dialog(self, hst): - #print("Dialogue %s, turn %s:" % (self.global_diag_cnt, int(len(hst) / 2) - 1)) - print("Dialogue %s, turn %s:" % (self.global_diag_cnt, self.global_turn_cnt)) - for utt in hst[:-2]: - print(" \033[92m%s\033[0m" % (utt)) - if len(hst) > 1: - print(" ", hst[-2]) - print(" ", hst[-1]) - - def print_inform_memory(self, inform_mem): - print("Inform memory:") - is_all_none = True - for s in inform_mem: - if inform_mem[s] != 'none': - print(" %s = %s" % (s, inform_mem[s])) - is_all_none = False - if is_all_none: - print(" -") - - def eval_user_acts(self, user_act, user_acts): - print("User acts:") - for ua in user_acts: - if ua not in user_act: - print(" \033[33m%s\033[0m" % (ua)) - else: - print(" \033[92m%s\033[0m" % (ua)) - for ua in user_act: - if ua not in user_acts: - print(" \033[91m%s\033[0m" % (ua)) - - def eval_dialog_state(self, state_updates, new_belief_state): - print("Dialogue state:") - for d in self.gt_belief_state: - print(" %s:" % (d)) - for s in new_belief_state[d]: - is_printed = False - is_updated = False - if state_updates[d][s] > 0: - is_updated = True - if is_updated: - print("\033[3m", end='') - if new_belief_state[d][s] != self.gt_belief_state[d][s]: - self.global_eval_stats[d][s]['FP'] += 1 - if self.gt_belief_state[d][s] == '': - print(" \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='') - else: - print(" \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]), end='') - self.global_eval_stats[d][s]['FN'] += 1 - is_printed = True - elif new_belief_state[d][s] != '': - print(" \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]), end='') - self.global_eval_stats[d][s]['TP'] += 1 - is_printed = True - if is_updated: - print(" (%s)" % (self.config.dst_class_types[state_updates[d][s]])) - elif is_printed: - print() - - def eval_print_stats(self): - print("Statistics:") - for d in self.global_eval_stats: - for s in self.global_eval_stats[d]: - TP = self.global_eval_stats[d][s]['TP'] - FP = self.global_eval_stats[d][s]['FP'] - FN = self.global_eval_stats[d][s]['FN'] - prec = TP / ( TP + FP + 1e-8) - rec = TP / ( TP + FN + 1e-8) - f1 = 2 * ((prec * rec) / (prec + rec + 1e-8)) - print(" %s %s Recall: %.2f, Precision: %.2f, F1: %.2f" % (d, s, rec, prec, f1)) - - def __init__(self, model_type="roberta", - model_name="roberta-base", - model_path="", - nlu_path="", - emb_path="", - no_eval=False, - no_history=False, - no_normalize_value=False, - no_smoothing=False, - gt_user_acts=False, - gt_ds=False, - gt_request_acts=False, - fp16=False): - super(TRIPPYR, self).__init__() - - self.print_header() - - self.model_type = model_type.lower() - self.model_name = model_name.lower() - self.model_path = model_path - self.nlu_path = nlu_path - self.emb_path = emb_path - self.no_eval = no_eval - self.no_history = no_history - self.no_normalize_value = no_normalize_value - self.no_smoothing = no_smoothing - self.gt_user_acts = gt_user_acts - self.gt_ds = gt_ds - self.gt_request_acts = gt_request_acts - self.fp16 = fp16 - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] - self.config = self.config_class.from_pretrained(self.model_path, local_files_only=True) # TODO: parameterize - # TODO: update config (parameters) - - path = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - path = os.path.join(path, 'data/multiwoz/value_dict.json') - self.value_dict = json.load(open(path)) - - self.global_eval_stats = copy.deepcopy(TEMPLATE_STATE) - for d in self.global_eval_stats: - for s in self.global_eval_stats[d]: - self.global_eval_stats[d][s] = {'TP': 0, 'FP': 0, 'FN': 0} - self.global_diag_cnt = -3 - self.global_turn_cnt = -1 - - self.load_weights() - self.load_embeddings(self.emb_path, self.fp16) - - def load_weights(self): - self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name, local_files_only=True) # TODO: do_lower_case=args.do_lower_case ? # TODO: parameterize - self.model = self.model_class.from_pretrained(self.model_path, config=self.config, local_files_only=True) # TODO: parameterize - self.model.to(self.device) - self.model.eval() - self.nlu = BERTNLU(model_file=self.nlu_path) # This is used for internal evaluation - self.nlg_usr = TemplateNLG(is_user=True) - self.nlg_sys = TemplateNLG(is_user=False) - - def load_embeddings(self, emb_path, fp16=False): - self.encoded_slots_pooled = pickle.load(open(os.path.join(emb_path, "encoded_slots_pooled.pickle"), "rb")) - self.encoded_slots_seq = pickle.load(open(os.path.join(emb_path, "encoded_slots_seq.pickle"), "rb")) - self.encoded_slot_values = pickle.load(open(os.path.join(emb_path, "encoded_slot_values_test.pickle"), "rb")) # TODO: maybe load a clean list in accordance with ontology? - if fp16: - for e in self.encoded_slots_pooled: - self.encoded_slots_pooled[e] = self.encoded_slots_pooled[e].type(torch.float32) - for e in self.encoded_slots_seq: - self.encoded_slots_seq[e] = self.encoded_slots_seq[e].type(torch.float32) - for e in self.encoded_slot_values: - for f in self.encoded_slot_values[e]: - self.encoded_slot_values[e][f] = self.encoded_slot_values[e][f].type(torch.float32) - - def init_session(self): - self.state = default_state() # Initialise as empty state - self.state['belief_state'] = copy.deepcopy(TEMPLATE_STATE) - self.nlg_history = [] - self.gt_belief_state = copy.deepcopy(TEMPLATE_STATE) - self.global_diag_cnt += 1 - self.global_turn_cnt = -1 - - def update_gt_belief_state(self, user_act): - for intent, domain, slot, value in user_act: - if domain == 'police': - continue - if intent == 'inform': - if slot == 'none' or slot == '': - continue - domain_dic = self.gt_belief_state[domain] - if slot in domain_dic: - #nvalue = normalize_value(self.value_dict, domain, slot, value) - self.gt_belief_state[domain][slot] = value # nvalue - #elif slot != 'none' or slot != '': - # raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - # TODO: receive semantic, convert semantic -> text -> semantic for sanity check - # For TripPy: receive semantic, convert semantic -> text (with context) as input to DST - # - allows for accuracy estimates - # - allows isolating inform prediction from request prediction (as can be taken from input for sanity check) - def update(self, user_act=''): - def normalize_values(text): - text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"} - #text = re.sub("^(\d{2}) : (\d{2})$", r"\1:\2", text) # Times - #text = re.sub(" ?' ?s", "s", text) # Genitive - text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters - text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe - if text in text_to_num: - text = text_to_num[text] - return text - - def filter_sequences(seqs, mode="first"): - if mode == "first": - return seqs[0][0][0] - elif mode == "max_first": - max_conf = 0 - max_idx = 0 - for e_itr, e in enumerate(seqs[0]): - if e[1] > max_conf: - max_conf = e[1] - max_idx = e_itr - return seqs[0][max_idx][0] - elif mode == "max": - max_conf = 0 - max_t_idx = 0 - for t_itr, t in enumerate(seqs): - for e_itr, e in enumerate(t): - if e[1] > max_conf: - max_conf = e[1] - max_t_idx = t_itr - max_idx = e_itr - return seqs[max_t_idx][max_idx][0] - else: - print("WARN: mode %s unknown. Aborting." % mode) - exit() - - prev_state = self.state - - if not self.no_eval: - print("-" * 40) - - #nlg_history = [] - ##for h in prev_state['history'][-2:]: # TODO: make this an option? - #for h in prev_state['history']: - # nlg_history.append([h[0], self.get_text(h[1], is_user=(h[0]=='user'), normalize=True)]) - ## Special case: at the beginning of the dialog the history might be empty (depending on policy) - #if len(nlg_history) == 0: - # nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False, normalize=True)]) - # nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True, normalize=True)]) - if self.no_history: - self.nlg_history = [] - self.nlg_history.append(['sys', self.get_text(prev_state['system_action'], is_user=False, normalize=True)]) - self.nlg_history.append(['user', self.get_text(prev_state['user_action'], is_user=True, normalize=True)]) - self.global_turn_cnt += 1 - if not self.no_eval: - self.print_dialog(self.nlg_history) - - # --- Get inform memory and auxiliary features --- - - # If system_action is plain text, get acts using NLU - if isinstance(prev_state['user_action'], str): - u_acts, s_acts = self.get_acts() - elif isinstance(prev_state['user_action'], list): - u_acts = user_act # same as prev_state['user_action'] - s_acts = prev_state['system_action'] - else: - raise Exception('Unknown format for user action:', prev_state['user_action']) - inform_mem = self.get_inform_mem(s_acts) - if not self.no_eval: - self.print_inform_memory(inform_mem) - #printed_inform_mem = False - #for s in inform_mem: - # if inform_mem[s] != 'none': - # if not printed_inform_mem: - # print("DST: inform_mem:") - # print(s, ':', inform_mem[s]) - # printed_inform_mem = True - - # --- Tokenize dialogue context and feed DST model --- - - features = self.get_features(self.nlg_history) - pred_states, pred_classes, cls_representation = self.predict(features, inform_mem) - - # --- Update ConvLab-style dialogue state --- - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - user_acts = [] - for state, value in pred_states.items(): - if isinstance(value, list): - value = filter_sequences(value, mode="max") - value = normalize_values(value) - if value == 'none': - continue - domain, slot = state.split('-', 1) - # Value normalization # TODO: according to trippy rules? - if domain == 'hotel' and slot == 'type': - value = "hotel" if value == "yes" else "guesthouse" - if not self.no_normalize_value: - value = normalize_value(self.value_dict, domain, slot, value) - slot = SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot) - if slot in new_belief_state[domain]: - new_belief_state[domain][slot] = value - user_acts.append(['inform', domain, SLOT_MAP_TRIPPY_TO_UDF[domain].get(slot, slot), value]) - else: - raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - self.update_gt_belief_state(u_acts) # For evaluation - - # BELIEF STATE UPDATE - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state # TripPy-R - if self.gt_ds: - new_state['belief_state'] = self.gt_belief_state # Rule - - state_updates = {} - for cl in pred_classes: - cl_d, cl_s = cl.split('-') - # Some reformatting for the evaluation further down - if cl_d not in state_updates: - state_updates[cl_d] = {} - state_updates[cl_d][SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s)] = pred_classes[cl] - # We care only about the requestable slots here - if self.config.dst_class_types[pred_classes[cl]] != 'request': - continue - if cl_d != 'general' and cl_s == 'none': - user_acts.append(['inform', cl_d, '', '']) - elif cl_d == 'general': - user_acts.append([SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), 'general', '', '']) - #user_acts.append(['bye', 'general', '', '']) # Map "thank" to "bye"? Mind "hello" as well! - elif not self.gt_request_acts: - user_acts.append(['request', cl_d, SLOT_MAP_TRIPPY_TO_UDF[cl_d].get(cl_s, cl_s), '']) - - # USER ACTS UPDATE - new_state['user_action'] = user_acts # TripPy-R - # ONLY FOR DEBUGGING - if self.gt_user_acts: - new_state['user_action'] = u_acts # Rule - elif self.gt_request_acts: - for e in u_acts: - ea, _, _, _ = e - if ea == 'request': - user_acts.append(e) - - if not self.no_eval: - self.eval_user_acts(u_acts, user_acts) - self.eval_dialog_state(state_updates, new_belief_state) - - #new_state['cls_representation'] = cls_representation # TODO: needed by Nunu? - - self.state = new_state - - # Print eval statistics - if self.state['terminated'] and not self.no_eval: - print("Booked:", self.state['booked']) - self.eval_print_stats() - print("=" * 10, "End of the dialogue", "=" * 10) - - return self.state - - def predict(self, features, inform_mem): - def _tokenize(text): - if "\u0120" in text: - text = re.sub(" ", "", text) - text = re.sub("\u0120", " ", text) - text = text.strip() - return text - #return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) - - def get_spans(pred, norm_logits, input_tokens, usr_utt_spans): - span_indices = [i for i in range(len(pred)) if pred[i]] - prev_si = None - spans = [] - #confs = [] - for si in span_indices: - if prev_si is None or si - prev_si > 1: - spans.append(([], [], [])) - #confs.append([]) - #spans[-1].append(input_tokens[si]) - spans[-1][0].append(si) - spans[-1][1].append(input_tokens[si]) - spans[-1][2].append(norm_logits[si]) - #confs[-1].append(norm_logits[si]) - prev_si = si - #spans = [' '.join(t for t in s) for s in spans] - spans = [(min(i), max(i), ' '.join(t for t in s), (sum(c) / len(c)).item()) for (i, s, c) in spans] - #confs = [(sum(c) / len(c)).item() for c in confs] - final_spans = {} - for s in spans: - for us_itr, us in enumerate(usr_utt_spans): - if s[0] >= us[0] and s[1] <= us[1]: - if us_itr not in final_spans: - final_spans[us_itr] = [] - final_spans[us_itr].append(s[2:]) - break - final_spans = list(final_spans.values()) - return final_spans # , confs - - def get_usr_utt_spans(usr_mask): - span_indices = [i for i in range(len(usr_mask)) if usr_mask[i]] - prev_si = None - spans = [] - for si in span_indices: - if prev_si is None or si - prev_si > 1: - spans.append([]) - spans[-1].append(si) - prev_si = si - spans = [[min(s), max(s)] for s in spans] - return spans - - def smooth_roberta_predictions(pred, input_tokens): - smoothed_pred = pred.detach().clone() - # Forward - span = False - for i in range(len(pred)): - if pred[i] > 0: - span = True - elif span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<": - smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens - elif span and (input_tokens[i][0] == "\u0120" or input_tokens[i][0] == "<"): - span = False - # Backward - span = False - for i in range(len(pred) - 1, 0, -1): - if pred[i] > 0: - span = True - if span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<": - smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens - elif span and input_tokens[i][0] == "\u0120": - smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens - span = False - #if pred != smoothed_pred: - # print(get_spans(pred, input_tokens)) - # print(get_spans(smoothed_pred, input_tokens)) - return smoothed_pred - - #aaa_time = time.time() - with torch.no_grad(): - outputs = self.model(features) # TODO: mode etc - #bbb_time = time.time() - - input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! - - # assign identified spans to their respective usr turns (simply append spans as list of lists) - usr_utt_spans = get_usr_utt_spans(features['usr_mask'][0][1:]) - - per_slot_class_logits = outputs[8] # [2] - per_slot_start_logits = outputs[9] # [3] - per_slot_end_logits = outputs[10] # [4] - per_slot_value_logits = outputs[11] # [4] - per_slot_refer_logits = outputs[12] # [5] - - cls_representation = outputs[16] - - # TODO: maybe add assert to check that batch=1 - - predictions = {slot: 'none' for slot in self.config.dst_slot_list} - class_predictions = {slot: 'none' for slot in self.config.dst_slot_list} - - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - start_logits = per_slot_start_logits[slot][0] - end_logits = per_slot_end_logits[slot][0] if slot in per_slot_end_logits else None - value_logits = per_slot_value_logits[slot][0] if per_slot_value_logits[slot] is not None else None - refer_logits = per_slot_refer_logits[slot][0] - - weights = start_logits[:len(features['input_ids'][0])] - norm_logits = torch.div(torch.clamp(weights - torch.mean(weights), min=0), torch.max(weights)) - - class_prediction = int(class_logits.argmax()) - start_prediction = norm_logits > 0.0 - if not self.no_smoothing: - start_prediction = smooth_roberta_predictions(start_prediction, input_tokens) # Slow! - start_prediction[0] = False # Ignore <s> - end_prediction = int(end_logits.argmax()) if end_logits is not None else None - refer_prediction = int(refer_logits.argmax()) - - if class_prediction == self.config.dst_class_types.index('dontcare'): - predictions[slot] = 'dontcare' - elif class_prediction == self.config.dst_class_types.index('copy_value'): - spans = get_spans(start_prediction[1:], norm_logits[1:], input_tokens[1:], usr_utt_spans) - if len(spans) > 0: - for e_itr in range(len(spans)): - for ee_itr in range(len(spans[e_itr])): - tmp = list(spans[e_itr][ee_itr]) - tmp[0] = _tokenize(tmp[0]) - spans[e_itr][ee_itr] = tuple(tmp) - predictions[slot] = spans - else: - predictions[slot] = "none" - elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): - predictions[slot] = "yes" # 'true' - elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): - predictions[slot] = "no" # 'false' - elif class_prediction == self.config.dst_class_types.index('inform'): - #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot]) - predictions[slot] = inform_mem[slot] - #elif class_prediction == self.config.dst_class_types.index('request'): - # if slot in ["hotel-internet", "hotel-parking"]: - # predictions[slot] = "yes" # 'true' - # Referral case is handled below - - # Referral case. All other slot values need to be seen first in order - # to be able to do this correctly. - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): - # Only slots that have been mentioned before can be referred to. - # First try to resolve a reference within the same turn. (One can think of a situation - # where one slot is referred to in the same utterance. This phenomenon is however - # currently not properly covered in the training data label generation process) - # Then try to resolve a reference given the current dialogue state. - predictions[slot] = predictions[list(self.config.dst_slot_list.keys())[refer_prediction]] - if predictions[slot] == 'none': - referred_slot = list(self.config.dst_slot_list.keys())[refer_prediction] - referred_slot_d, referred_slot_s = referred_slot.split('-') - referred_slot_s = SLOT_MAP_TRIPPY_TO_UDF[referred_slot_d].get(referred_slot_s, referred_slot_s) - #pdb.set_trace() - if self.state['belief_state'][referred_slot_d][referred_slot_s] != '': - predictions[slot] = self.state['belief_state'][referred_slot_d][referred_slot_s] - if predictions[slot] == 'none': - ref_slot = list(self.config.dst_slot_list.keys())[refer_prediction] - if ref_slot == 'hotel-name': - predictions[slot] = 'the hotel' - elif ref_slot == 'restaurant-name': - predictions[slot] = 'the restaurant' - elif ref_slot == 'attraction-name': - predictions[slot] = 'the attraction' - elif ref_slot == 'hotel-area': - predictions[slot] = 'same area as the hotel' - elif ref_slot == 'restaurant-area': - predictions[slot] = 'same area as the restaurant' - elif ref_slot == 'attraction-area': - predictions[slot] = 'same area as the attraction' - elif ref_slot == 'hotel-pricerange': - predictions[slot] = 'in the same price range as the hotel' - elif ref_slot == 'restaurant-pricerange': - predictions[slot] = 'in the same price range as the restaurant' - - class_predictions[slot] = class_prediction - - # TODO: value normalization - # TODO: value matching - #ccc_time = time.time() - #print("TIME:", bbb_time - aaa_time, ccc_time - bbb_time) - - return predictions, class_predictions, cls_representation - - def get_features(self, context): - def to_device(batch, device): - if isinstance(batch, tuple): - batch_on_device = tuple([to_device(element, device) for element in batch]) - if isinstance(batch, dict): - batch_on_device = {k: to_device(v, device) for k, v in batch.items()} - else: - batch_on_device = batch.to(device) if batch is not None else batch - return batch_on_device - - assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models - input_tokens = ['<s>'] - e_itr = 0 - for e_itr, e in enumerate(reversed(context)): - input_tokens.append(e[1] if e[1] != 'null' else ' ') - if e_itr < 2: - input_tokens.append('</s> </s>') - else: - input_tokens.append('</s>') - # Ignore history for now - #if e_itr == 1: - # break - if e_itr == 0: - input_tokens.append('</s> </s>') - input_tokens.append('</s>') - input_tokens = ' '.join(input_tokens) - - # TODO: delex sys utt somehow, or refrain from using delex for sys utts? - features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length) - - input_ids = torch.tensor(features['input_ids']).reshape(1,-1) - input_mask = torch.tensor(features['attention_mask']).reshape(1,-1) - usr_mask = torch.zeros(input_ids.size()) - usr_seen = False - sys_seen = False - usr_sep = 0 - sys_sep = 0 - hst_cnt = 0 - for i_itr, i in enumerate(input_ids[0,:]): - if i_itr == 0: - continue - is_usr = True - if i == 1: - is_usr = False - if i == 2: - is_usr = False - if not usr_seen: - usr_sep += 1 - if usr_sep == 2: - usr_seen = True - elif not sys_seen: - sys_sep += 1 - if sys_sep == 2: - sys_seen = True - else: - hst_cnt += 1 - if usr_seen and not sys_seen: - is_usr = False - elif usr_seen and sys_seen and hst_cnt % 2 == 1: - is_usr = False - if is_usr: - usr_mask[0,i_itr] = 1 - #usr_mask = torch.tensor(features['attention_mask']).reshape(1,-1) # TODO - features = {'input_ids': input_ids, - 'input_mask': input_mask, - 'usr_mask': usr_mask, - 'start_pos': None, - 'end_pos': None, - 'refer_id': None, - 'class_label_id': None, - 'inform_slot_id': None, - 'diag_state': None, - 'pos_sampling_input': None, - 'neg_sampling_input': None, - 'encoded_slots_pooled': self.encoded_slots_pooled, - 'encoded_slots_seq': self.encoded_slots_seq, - 'encoded_slot_values': self.encoded_slot_values} - - return to_device(features, self.device) - - # TODO: consider "booked" values? - def get_inform_mem(self, state): - inform_mem = {slot: 'none' for slot in self.config.dst_slot_list} - for e in state: - a, d, s, v = e - if a in ['inform', 'recommend', 'select', 'book', 'offerbook']: - #ds_d = d.lower() - #if s in REF_SYS_DA[d]: - # ds_s = REF_SYS_DA[d][s] - #elif s in REF_SYS_DA['Booking']: - # ds_s = "book_" + REF_SYS_DA['Booking'][s] - #else: - # ds_s = s.lower() - # #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d)) - slot = "%s-%s" % (d, s) - if slot in inform_mem: - inform_mem[slot] = v - return inform_mem - - def get_acts(self): - context = self.state['history'] - if context[-1][0] != 'user': - raise Exception("Wrong order of utterances, check your input.") - system_act = context[-2][-1] - user_act = context[-1][-1] - system_context = [t for s,t in context[:-2]] - user_context = [t for s,t in context[:-1]] - - #print(" SYS:", system_act, system_context) - system_acts = self.nlu.predict(system_act, context=system_context) - - #print(" USR:", user_act, user_context) - user_acts = self.nlu.predict(user_act, context=user_context) - - return user_acts, system_acts - - def get_text(self, act, is_user=False, normalize=False): - if act == 'null': - return 'null' - if not isinstance(act, list): - result = act - elif is_user: - result = self.nlg_usr.generate(act) - else: - result = self.nlg_sys.generate(act) - if normalize: - return self.normalize_text(result) - else: - return result - - def normalize_text(self, text): - norm_text = text.lower() - #norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference - #norm_text = re.sub("ca not", "cannot", norm_text) - norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0]) - return norm_text - - -# if __name__ == "__main__": -# tracker = TRIPPY(model_type='roberta', model_path='/path/to/model', -# nlu_path='/path/to/nlu') -# tracker.init_session() -# state = tracker.update('hey. I need a cheap restaurant.') -# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# state = tracker.update('If you have something Asian that would be great.') -# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# state = tracker.update('Great. Where are they located?') -# tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# print(tracker.state) diff --git a/convlab/dst/trippyr/multiwoz/trippyr.py~ b/convlab/dst/trippyr/multiwoz/trippyr.py~ deleted file mode 100644 index fc5696227ffbd6de639af9eda92983b77c41059e..0000000000000000000000000000000000000000 --- a/convlab/dst/trippyr/multiwoz/trippyr.py~ +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright 2021 Heinrich Heine University Duesseldorf -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re -import json -import copy -import pickle - -import torch -from transformers import (RobertaConfig, RobertaTokenizer) - -from convlab2.dst.trippyr.multiwoz.modeling_roberta_dst import (RobertaForDST) - -from convlab2.dst.dst import DST -from convlab2.util.multiwoz.state import default_state -from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from convlab2.nlu.jointBERT.multiwoz import BERTNLU -from convlab2.dst.rule.multiwoz import normalize_value - -MODEL_CLASSES = { - 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), -} - - -class TRIPPYR(DST): - def __init__(self, model_type="roberta", model_name="roberta-base", model_path="", nlu_path="", emb_path="", fp16=False): - super(TRIPPYR, self).__init__() - - self.model_type = model_type.lower() - self.model_name = model_name.lower() - self.model_path = model_path - self.nlu_path = nlu_path - - path = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - path = os.path.join(path, 'data/multiwoz/value_dict.json') - self.value_dict = json.load(open(path)) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] - self.config = self.config_class.from_pretrained(self.model_path) - # TODO: update config (parameters) - - self.load_weights() - self.load_embeddings(emb_path, fp16) - - def load_weights(self): - self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name) # TODO: do_lower_case=args.do_lower_case ? - self.model = self.model_class.from_pretrained(self.model_path, config=self.config) - self.model.to(self.device) - self.model.eval() - self.nlu = BERTNLU(model_file=self.nlu_path) # TODO: remove, once TripPy takes over its task - - def load_embeddings(self, emb_path, fp16=False): - self.encoded_slots_pooled = pickle.load(open(os.path.join(emb_path, "encoded_slots_pooled.pickle"), "rb")) - self.encoded_slots_seq = pickle.load(open(os.path.join(emb_path, "encoded_slots_seq.pickle"), "rb")) - self.encoded_slot_values = pickle.load(open(os.path.join(emb_path, "encoded_slot_values_test.pickle"), "rb")) - if fp16: - for e in self.encoded_slots_pooled: - self.encoded_slots_pooled[e] = self.encoded_slots_pooled[e].type(torch.float32) - for e in self.encoded_slots_seq: - self.encoded_slots_seq[e] = self.encoded_slots_seq[e].type(torch.float32) - for e in self.encoded_slot_values: - for f in self.encoded_slot_values[e]: - self.encoded_slot_values[e][f] = self.encoded_slot_values[e][f].type(torch.float32) - - def init_session(self): - self.state = default_state() - self.hidden_states = None - - def update(self, user_act=''): - def filter_sequences(seqs, mode="first"): - if mode == "first": - return tokenize(seqs[0][0][0]) - elif mode == "max_first": - max_conf = 0 - max_idx = 0 - for e_itr, e in enumerate(seqs[0]): - if e[1] > max_conf: - max_conf = e[1] - max_idx = e_itr - return tokenize(seqs[0][max_idx][0]) - elif mode == "max": - max_conf = 0 - max_t_idx = 0 - for t_itr, t in enumerate(seqs): - for e_itr, e in enumerate(t): - if e[1] > max_conf: - max_conf = e[1] - max_t_idx = t_itr - max_idx = e_itr - return tokenize(seqs[max_t_idx][max_idx][0]) - else: - print("WARN: mode %s unknown. Aborting." % mode) - exit() - - def tokenize(text): - if "\u0120" in text: - text = re.sub(" ", "", text) - text = re.sub("\u0120", " ", text) - text = text.strip() - return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) - - prev_state = self.state - #print("--") - - # --- Get inform memory --- - - # If system_action is plain text, get acts using NLU - if isinstance(prev_state['system_action'], str): - acts, _ = self.get_acts(prev_state['system_action']) - elif isinstance(prev_state['system_action'], list): - acts = prev_state['system_action'] - else: - raise Exception('Unknown format for system action:', prev_state['system_action']) - inform_mem = self.get_inform_mem(acts) - #print("inform_mem:") - #for s in inform_mem: - # if inform_mem[s] != 'none': - # print(s, ':', inform_mem[s]) - - # --- Tokenize dialogue context and feed DST model --- - - features = self.get_features(self.state['history']) - pred_states, pred_classes, cls_representation = self.predict(features, inform_mem) - #print(pred_states) - #print(pred_classes) - #import pdb - #pdb.set_trace() - - # --- Update ConvLab-style dialogue state --- - - new_belief_state = copy.deepcopy(prev_state['belief_state']) - user_acts = [] - for state, value in pred_states.items(): - if isinstance(value, list): - value = filter_sequences(value, mode="max") - else: - value = tokenize(value) - #if pred_classes[state] > 0: - # print(pred_classes[state], state, value) - - domain, slot = state.split('-', 1) - - # 1.) Domain prediction - if slot == "none": # and pred_classes[state] == 3: - continue # for now, continue - - # 2.) Requests and greetings - if pred_classes[state] == 7: - if domain == "general": - user_acts.append([slot, 'general', 'none', 'none']) - else: - user_acts.append(['Request', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(slot, slot.capitalize()), '?']) - - # 3.) Informable slots - if value == 'none': - continue - # Value normalization # TODO: according to trippy rules? - if domain == 'hotel' and slot == 'type': - value = "hotel" if value == "yes" else "guesthouse" - value = normalize_value(self.value_dict, domain, slot, value) - if slot not in ['name', 'book']: - if domain not in new_belief_state: - if domain == 'bus': - continue - else: - raise Exception('Domain <{}> not in belief state'.format(domain)) - slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) if domain.capitalize() in REF_SYS_DA else slot - assert 'semi' in new_belief_state[domain] - assert 'book' in new_belief_state[domain] - if 'book' in slot: - assert slot.startswith('book_') - s = slot.split('_')[1] - if s in new_belief_state[domain]['book']: - new_belief_state[domain]['book'][s] = value - user_acts.append(['Inform', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(s, s.capitalize()), value]) - elif slot in new_belief_state[domain]['semi']: - new_belief_state[domain]['semi'][slot] = value - user_acts.append(['Inform', domain.capitalize(), REF_USR_DA[domain.capitalize()].get(slot, slot.capitalize()), value]) - else: - raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) - - # Update request_state - #new_request_state = copy.deepcopy(prev_state['request_state']) - - new_state = copy.deepcopy(dict(prev_state)) - new_state['belief_state'] = new_belief_state - #new_state['request_state'] = new_request_state - - nlu_user_acts, nlu_system_acts = self.get_acts(user_act) - user_acts = [] - for e in nlu_user_acts: - if e[0] != 'Inform': - user_acts.append(e) - #new_state['system_action'] = nlu_system_acts # Empty when DST for user -> needed? - new_state['user_action'] = user_acts - - new_state['cls_representation'] = cls_representation - - self.state = new_state - - #print("--") - return self.state - - def predict(self, features, inform_mem): - def _tokenize(text): - if "\u0120" in text: - text = re.sub(" ", "", text) - text = re.sub("\u0120", " ", text) - text = text.strip() - return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) - - def get_spans(pred, norm_logits, input_tokens, usr_utt_spans): - span_indices = [i for i in range(len(pred)) if pred[i]] - prev_si = None - spans = [] - #confs = [] - for si in span_indices: - if prev_si is None or si - prev_si > 1: - spans.append(([], [], [])) - #confs.append([]) - #spans[-1].append(input_tokens[si]) - spans[-1][0].append(si) - spans[-1][1].append(input_tokens[si]) - spans[-1][2].append(norm_logits[si]) - #confs[-1].append(norm_logits[si]) - prev_si = si - #spans = [' '.join(t for t in s) for s in spans] - spans = [(min(i), max(i), ' '.join(t for t in s), (sum(c) / len(c)).item()) for (i, s, c) in spans] - #confs = [(sum(c) / len(c)).item() for c in confs] - final_spans = {} - for s in spans: - for us_itr, us in enumerate(usr_utt_spans): - if s[0] >= us[0] and s[1] <= us[1]: - if us_itr not in final_spans: - final_spans[us_itr] = [] - final_spans[us_itr].append(s[2:]) - break - final_spans = list(final_spans.values()) - return final_spans # , confs - - def get_usr_utt_spans(usr_mask): - span_indices = [i for i in range(len(usr_mask)) if usr_mask[i]] - prev_si = None - spans = [] - for si in span_indices: - if prev_si is None or si - prev_si > 1: - spans.append([]) - spans[-1].append(si) - prev_si = si - spans = [[min(s), max(s)] for s in spans] - return spans - - def smooth_roberta_predictions(pred, input_tokens): - smoothed_pred = pred.detach().clone() - # Forward - span = False - i = 0 - while i < len(pred): - if pred[i] > 0: - span = True - elif span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<": - smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens - elif span and (input_tokens[i][0] == "\u0120" or input_tokens[i][0] == "<"): - span = False - i += 1 - # Backward - span = False - i = len(pred) - 1 - while i >= 0: - if pred[i] > 0: - span = True - if span and input_tokens[i][0] != "\u0120" and input_tokens[i][0] != "<": - smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens - elif span and input_tokens[i][0] == "\u0120": - smoothed_pred[i] = 1 # TODO: make sure to use label for in-span tokens - span = False - i -= 1 - #if pred != smoothed_pred: - # print(get_spans(pred, input_tokens)) - # print(get_spans(smoothed_pred, input_tokens)) - return smoothed_pred - - with torch.no_grad(): - outputs = self.model(features) # TODO: mode etc - - input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! - - # assign identified spans to their respective usr turns (simply append spans as list of lists) - usr_utt_spans = get_usr_utt_spans(features['usr_mask'][0][1:]) - - per_slot_class_logits = outputs[8] # [2] - per_slot_start_logits = outputs[9] # [3] - per_slot_end_logits = outputs[10] # [4] - per_slot_value_logits = outputs[11] # [4] - per_slot_refer_logits = outputs[12] # [5] - - cls_representation = outputs[16] - - # TODO: maybe add assert to check that batch=1 - - predictions = {slot: 'none' for slot in self.config.dst_slot_list} - class_predictions = {slot: 'none' for slot in self.config.dst_slot_list} - - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - start_logits = per_slot_start_logits[slot][0] - end_logits = per_slot_end_logits[slot][0] if slot in per_slot_end_logits else None - value_logits = per_slot_value_logits[slot][0] if per_slot_value_logits[slot] is not None else None - refer_logits = per_slot_refer_logits[slot][0] - - weights = start_logits[:len(features['input_ids'][0])] - norm_logits = torch.clamp(weights - torch.mean(weights), min=0) / max(weights) - - class_prediction = int(class_logits.argmax()) - start_prediction = norm_logits > 0.0 - start_prediction = smooth_roberta_predictions(start_prediction, input_tokens) - start_prediction[0] = False # Ignore <s> - end_prediction = int(end_logits.argmax()) if end_logits is not None else None - refer_prediction = int(refer_logits.argmax()) - - if class_prediction == self.config.dst_class_types.index('dontcare'): - predictions[slot] = 'dontcare' - elif class_prediction == self.config.dst_class_types.index('copy_value'): - spans = get_spans(start_prediction[1:], norm_logits[1:], input_tokens[1:], usr_utt_spans) - if len(spans) > 0: - for e_itr in range(len(spans)): - for ee_itr in range(len(spans[e_itr])): - tmp = list(spans[e_itr][ee_itr]) - tmp[0] = _tokenize(tmp[0]) - spans[e_itr][ee_itr] = tuple(tmp) - predictions[slot] = spans - else: - predictions[slot] = "none" - elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): - predictions[slot] = "yes" # 'true' - elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): - predictions[slot] = "no" # 'false' - elif class_prediction == self.config.dst_class_types.index('inform'): - #print("INFORM:", slot, ",", predictions[slot], "->", inform_mem[slot]) - predictions[slot] = inform_mem[slot] - #elif class_prediction == self.config.dst_class_types.index('request'): - # if slot in ["hotel-internet", "hotel-parking"]: - # predictions[slot] = "yes" # 'true' - # Referral case is handled below - - class_predictions[slot] = class_prediction - - # Referral case. All other slot values need to be seen first in order - # to be able to do this correctly. - for slot in self.config.dst_slot_list: - class_logits = per_slot_class_logits[slot][0] - refer_logits = per_slot_refer_logits[slot][0] - - class_prediction = int(class_logits.argmax()) - refer_prediction = int(refer_logits.argmax()) - - if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): - # Only slots that have been mentioned before can be referred to. - # One can think of a situation where one slot is referred to in the same utterance. - # This phenomenon is however currently not properly covered in the training data - # label generation process. - predictions[slot] = predictions[list(self.config.dst_slot_list.keys())[refer_prediction]] - - # TODO: value normalization - # TODO: value matching - - #if class_prediction > 0: - # print(" ", slot, "->", class_prediction, ",", predictions[slot]) - - return predictions, class_predictions, cls_representation - - def get_features(self, context): - def to_device(batch, device): - if isinstance(batch, tuple): - batch_on_device = tuple([to_device(element, device) for element in batch]) - if isinstance(batch, dict): - batch_on_device = {k: to_device(v, device) for k, v in batch.items()} - else: - batch_on_device = batch.to(device) if batch is not None else batch - return batch_on_device - - assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models - input_tokens = [] # ['<s>'] - for e_itr, e in enumerate(reversed(context)): - input_tokens.append(e[1] if e[1] != 'null' else ' ') - if e_itr < 2: - input_tokens.append('</s> </s>') - else: - input_tokens.append('</s>') - # Ignore history for now - if e_itr == 1: - break - if e_itr == 0: - input_tokens.append('</s> </s>') - #input_tokens.append('</s>') - input_tokens = ' '.join(input_tokens) - - # TODO: delex sys utt somehow, or refrain from using delex for sys utts? - features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=True, max_length=self.config.dst_max_seq_length) - - input_ids = torch.tensor(features['input_ids']).reshape(1,-1) - input_mask = torch.tensor(features['attention_mask']).reshape(1,-1) - usr_mask = torch.zeros(input_ids.size()) - usr_seen = False - sys_seen = False - usr_sep = 0 - sys_sep = 0 - hst_cnt = 0 - for i_itr, i in enumerate(input_ids[0,:]): - if i_itr == 0: - continue - is_usr = True - if i == 1: - is_usr = False - if i == 2: - is_usr = False - if not usr_seen: - usr_sep += 1 - if usr_sep == 2: - usr_seen = True - elif not sys_seen: - sys_sep += 1 - if sys_sep == 2: - sys_seen = True - else: - hst_cnt += 1 - if usr_seen and not sys_seen: - is_usr = False - elif usr_seen and sys_seen and hst_cnt % 2 == 1: - is_usr = False - if is_usr: - usr_mask[0,i_itr] = 1 - #usr_mask = torch.tensor(features['attention_mask']).reshape(1,-1) # TODO - features = {'input_ids': input_ids, - 'input_mask': input_mask, - 'usr_mask': usr_mask, - 'start_pos': None, - 'end_pos': None, - 'refer_id': None, - 'class_label_id': None, - 'inform_slot_id': None, - 'diag_state': None, - 'pos_sampling_input': None, - 'neg_sampling_input': None, - 'encoded_slots_pooled': self.encoded_slots_pooled, - 'encoded_slots_seq': self.encoded_slots_seq, - 'encoded_slot_values': self.encoded_slot_values} - - return to_device(features, self.device) - - # TODO: consider "booked" values? - def get_inform_mem(self, state): - inform_mem = {slot: 'none' for slot in self.config.dst_slot_list} - for e in state: - a, d, s, v = e - if a in ['Inform', 'Recommend', 'Select', 'Book', 'OfferBook']: - ds_d = d.lower() - if s in REF_SYS_DA[d]: - ds_s = REF_SYS_DA[d][s] - elif s in REF_SYS_DA['Booking']: - ds_s = "book_" + REF_SYS_DA['Booking'][s] - else: - ds_s = s.lower() - #raise Exception('Slot <{}> of domain <{}> unknown'.format(s, d)) - slot = "%s-%s" % (ds_d, ds_s) - if slot in inform_mem: - inform_mem[slot] = v - return inform_mem - - # TODO: fix, still a mess... - def get_acts(self, user_act): - context = self.state['history'] - if context: - if context[-1][0] != 'sys': - system_act = '' - context = [t for s,t in context] - else: - system_act = context[-1][-1] - context = [t for s,t in context[:-1]] - else: - system_act = '' - context = [''] - - #print(" SYS:", system_act, context) - system_acts = self.nlu.predict(system_act, context=context) - - context.append(system_act) - #print(" USR:", user_act, context) - user_acts = self.nlu.predict(user_act, context=context) - - return user_acts, system_acts - - -# if __name__ == "__main__": -# tracker = TRIPPY(model_type='roberta', model_path='/path/to/model', -# nlu_path='/path/to/nlu') -# tracker.init_session() -# state = tracker.update('hey. I need a cheap restaurant.') -# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.']) -# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?']) -# state = tracker.update('If you have something Asian that would be great.') -# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.']) -# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.']) -# state = tracker.update('Great. Where are they located?') -# tracker.state['history'].append(['usr', 'Great. Where are they located?']) -# print(tracker.state) diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index a29af17b6b1da3be4e4347326f4d72e4187fd4c0..c89361b2a84824198e516941b71806440a9ba3a5 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -5,16 +5,29 @@ import re import numpy as np from copy import deepcopy +from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map +from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.evaluator.evaluator import Evaluator from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple from convlab.util.multiwoz.dbquery import Database -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.util import relative_import_module_from_unified_datasets -DEBUG = False -reverse_da = relative_import_module_from_unified_datasets( - 'multiwoz21', 'preprocess.py', 'reverse_da') +# import reflect table +REF_SYS_DA_M = {} +for dom, ref_slots in REF_SYS_DA.items(): + dom = dom.lower() + REF_SYS_DA_M[dom] = {} + for slot_a, slot_b in ref_slots.items(): + if slot_a == 'Ref': + slot_b = 'ref' + REF_SYS_DA_M[dom][slot_a.lower()] = slot_b + REF_SYS_DA_M[dom]['none'] = 'none' +REF_SYS_DA_M['taxi']['phone'] = 'phone' +REF_SYS_DA_M['taxi']['car'] = 'car type' + +reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') + requestable = \ {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'], @@ -28,20 +41,15 @@ requestable = \ belief_domains = requestable.keys() mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'price': 'pricerange', 'ref': 'ref', - 'price range': 'pricerange', 'address': 'address', 'postcode': 'postcode'}, + 'post': 'postcode', 'price': 'pricerange', 'ref': 'ref'}, 'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name', - 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type', 'ref': 'ref', - 'price range': 'pricerange', 'address': 'address', 'postcode': 'postcode'}, + 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type', 'ref': 'ref'}, 'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'type': 'type', 'entrance fee': 'entrance fee'}, - 'train': {'id': 'trainID', 'arrive': 'arrive by', 'day': 'day', 'depart': 'departure', 'dest': 'destination', - 'time': 'duration', 'leave': 'leave at', 'ticket': 'price', 'ref': 'ref', - 'arrive by': 'arrive by', 'leave at': 'leave at', 'departure': 'departure', 'destination': "destination", "duration": "duration", "price": "price"}, - 'taxi': {'car': 'car type', 'phone': 'phone', - 'car type': 'car type'}, - 'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department', - 'postcode': 'postcode', 'address': 'address'}, + 'post': 'postcode', 'type': 'type'}, + 'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination', + 'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price', 'ref': 'ref'}, + 'taxi': {'car': 'car type', 'phone': 'phone'}, + 'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'}, 'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}} @@ -59,7 +67,6 @@ for dom, ref_slots in REF_SYS_DA.items(): REF_SYS_DA_M[dom]['none'] = 'none' REF_SYS_DA_M['taxi']['phone'] = 'phone' REF_SYS_DA_M['taxi']['car'] = 'car type' -REF_SYS_DA_M['train']['id'] = 'train id' DEF_VAL_UNK = '?' # Unknown DEF_VAL_DNC = 'dontcare' # Do not care DEF_VAL_NUL = 'none' # for none @@ -68,6 +75,14 @@ DEF_VAL_NOBOOK = 'no' # for booked NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK] +# Not sure values in inform +DEF_VAL_UNK = '?' # Unknown +DEF_VAL_DNC = 'dontcare' # Do not care +DEF_VAL_NUL = 'none' # for none +DEF_VAL_BOOKED = 'yes' # for booked +DEF_VAL_NOBOOK = 'no' # for booked +NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK] + class MultiWozEvaluator(Evaluator): def __init__(self, check_book_constraints=True, check_domain_success=False): @@ -277,8 +292,7 @@ class MultiWozEvaluator(Evaluator): if value == value_predicted: match += 1 except Exception as e: - if DEBUG: - print("Tracker probably does not track that slot.", e) + print("Tracker probably does not track that slot.", e) # if tracker does not track it, it trivially matches since policy has no chance otherwise match += 1 @@ -319,7 +333,6 @@ class MultiWozEvaluator(Evaluator): else: # print('FN + 1') reqt_not_inform.add(('request', domain, k)) - print("reqt_not_inform.add(('request', domain, k))", domain, k) FN += 1 for k in inform_slot[domain]: # exclude slots that are informed by users @@ -329,13 +342,12 @@ class MultiWozEvaluator(Evaluator): # print('FP + 1 @2', k) inform_not_reqt.add(('inform', domain, k,)) FP += 1 - return TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt def _check_value(self, domain, key, value): if key == "area": return value.lower() in ["centre", "east", "south", "west", "north"] - elif key == "arriveBy" or key == "leaveAt" or key == "arrive by" or key == "leave at": + elif key == "arriveBy" or key == "leaveAt": return time_re.match(value) elif key == "day": return value.lower() in ["monday", "tuesday", "wednesday", "thursday", "friday", @@ -348,13 +360,13 @@ class MultiWozEvaluator(Evaluator): return re.match(r'^\d{11}$', value) or domain == "restaurant" elif key == "price": return 'pound' in value - elif key == "pricerange" or key == "price range": + elif key == "pricerange": return value in ["cheap", "expensive", "moderate", "free"] or domain == "attraction" elif key == "postcode": return re.match(r'^cb\d{1,3}[a-z]{2,3}$', value) or value == 'pe296fl' elif key == "stars": return re.match(r'^\d$', value) - elif key == "trainID" or key == "train id": + elif key == "trainID": return re.match(r'^tr\d{4}$', value.lower()) else: return True @@ -426,9 +438,6 @@ class MultiWozEvaluator(Evaluator): TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt = self._inform_F1_goal( goal, self.sys_da_array) - if len(reqt_not_inform) > 0: - print("bad_inform", bad_inform) - print("reqt_not_inform", reqt_not_inform) if aggregate: try: rec = TP / (TP + FN) @@ -463,14 +472,6 @@ class MultiWozEvaluator(Evaluator): book_constraint_sess == 1 or book_constraint_sess is None) else 0 return self.success if not self.check_book_constraints else self.success_strict else: - print("===== fail reason =====") - if goal_sess != 1: - print("goal_sess", goal_sess) - if book_sess is not None and book_sess < 1: - print("book_sess", book_sess) - if inform_sess[1] is not None and inform_sess[1] < 1: - print("inform_sess", inform_sess) - self.complete = 1 if booking_done and ( inform_sess[1] == 1 or inform_sess[1] is None) else 0 self.success = 0 @@ -668,12 +669,11 @@ class MultiWozEvaluator(Evaluator): slot = reverse_da_slot_name_map[domain][slot] else: slot = slot.capitalize() - if intent.lower() in ['inform', 'recommend']: if domain.lower() in goal: if 'reqt' in goal[domain.lower()]: - if REF_SYS_DA_M.get(domain.lower(), {}).get(slot.lower(), slot.lower()) in goal[domain.lower()][ - 'reqt']: + if REF_SYS_DA_M.get(domain.lower(), {}).get(slot.lower(), slot.lower()) \ + in goal[domain.lower()]['reqt']: if val in NOT_SURE_VALS: val = '\"' + val + '\"' goal[domain.lower()]['reqt'][ diff --git a/convlab/nlg/template/multiwoz/manual_system_template_nlg.bck.json b/convlab/nlg/template/multiwoz/manual_system_template_nlg.bck.json deleted file mode 100755 index aa5f6b69615092d5c2bb64e645e048ef081c20fe..0000000000000000000000000000000000000000 --- a/convlab/nlg/template/multiwoz/manual_system_template_nlg.bck.json +++ /dev/null @@ -1,1452 +0,0 @@ -{ - "Attraction-Inform": { - "Addr": [ - "it is located in #ATTRACTION-INFORM-ADDR#", - "adress is #ATTRACTION-INFORM-ADDR#", - "It is on #ATTRACTION-INFORM-ADDR# .", - "their address in our system is listed as #ATTRACTION-INFORM-ADDR# .", - "The address is #ATTRACTION-INFORM-ADDR# .", - "it 's located at #ATTRACTION-INFORM-ADDR# .", - "#ATTRACTION-INFORM-ADDR# is the address", - "They are located at #ATTRACTION-INFORM-ADDR# ." - ], - "none": [ - "what information would you like about it today ?", - "i have their info , what would you like to know ?" - ], - "Choice": [ - "I ' ve found #ATTRACTION-INFORM-CHOICE# places for you to go . Do you have any specific ideas in mind ?", - "sorry about that , there are actually #ATTRACTION-INFORM-CHOICE# .", - "sure , there are #ATTRACTION-INFORM-CHOICE# for you to choose from .", - "There are #ATTRACTION-INFORM-CHOICE# . Would you like me to recommend one for you ?", - "We have #ATTRACTION-INFORM-CHOICE# of those ! Anything specific you need or just a recommendation ?", - "sure , there are #ATTRACTION-INFORM-CHOICE# options for you", - "sure , there are #ATTRACTION-INFORM-CHOICE# in that area .", - "we have #ATTRACTION-INFORM-CHOICE# options , can i reccomend for you ?", - "there are #ATTRACTION-INFORM-CHOICE# , anything in particular you are looking for ?", - "We have #ATTRACTION-INFORM-CHOICE# such location ." - ], - "Post": [ - "The postcode of the attraction is #ATTRACTION-INFORM-POST# .", - "The post code is #ATTRACTION-INFORM-POST# .", - "Its postcode is #ATTRACTION-INFORM-POST# .", - "Their postcode is #ATTRACTION-INFORM-POST# ." - ], - "Fee": [ - "Its entrance fee is #ATTRACTION-INFORM-FEE# .", - "The entry fee is #ATTRACTION-INFORM-FEE# .", - "their entrance fee is #ATTRACTION-INFORM-FEE# by our system currently ." - ], - "Name": [ - "I think a fun place to visit is #ATTRACTION-INFORM-NAME# .", - "#ATTRACTION-INFORM-NAME# looks good .", - "#ATTRACTION-INFORM-NAME# is available , would that work for you ?", - "we have #ATTRACTION-INFORM-NAME# .", - "#ATTRACTION-INFORM-NAME# is popular among visitors .", - "How about #ATTRACTION-INFORM-NAME# ?", - "What about #ATTRACTION-INFORM-NAME# ?", - "you might want to try the #ATTRACTION-INFORM-NAME# ." - ], - "Area": [ - "That one is located in the #ATTRACTION-INFORM-AREA# .", - "it is located in the #ATTRACTION-INFORM-AREA# .", - "They are located within the #ATTRACTION-INFORM-AREA# .", - "it 's located in the #ATTRACTION-INFORM-AREA# .", - "That is in the #ATTRACTION-INFORM-AREA# .", - "It will be located in the #ATTRACTION-INFORM-AREA# .", - "it is in the #ATTRACTION-INFORM-AREA# of the city", - "It is in the #ATTRACTION-INFORM-AREA# ." - ], - "Phone": [ - "The attraction phone number is #ATTRACTION-INFORM-PHONE# .", - "Here is the attraction phone number , #ATTRACTION-INFORM-PHONE# ." - ], - "Type": [ - "It is listed as a #ATTRACTION-INFORM-TYPE# attraction .", - "it is a #ATTRACTION-INFORM-TYPE# .", - "There are some wonderful #ATTRACTION-INFORM-TYPE# in that area .", - "it 's considered a #ATTRACTION-INFORM-TYPE# .", - "Would you be interested in visiting a #ATTRACTION-INFORM-TYPE# ?", - "It 's a #ATTRACTION-INFORM-TYPE# attraction .", - "It is listed as #ATTRACTION-INFORM-TYPE# ." - ], - "Open": [ - "#ATTRACTION-INFORM-OPEN# in our database .", - "#ATTRACTION-INFORM-OPEN# ." - ], - "Price": [ - "it is in the #ATTRACTION-INFORM-PRICE# price range", - "The fee is #ATTRACTION-INFORM-PRICE# ." - ] - }, - "Attraction-NoOffer": { - "none": [ - "There are no attractions matching that description .", - "I ' m sorry but I have not found any matches .", - "I ' m sorry there are no matches .", - "There are none available at this time .", - "I ' m sorry . I ' m not finding any attractions that meet your criteria .", - "we do nt have any in that area .", - "I do n't have anything meeting that criteria ." - ], - "Type": [ - "There are no #ATTRACTION-NOOFFER-TYPE# close to the area you are requesting", - "No , I ' m sorry , I am not finding anything with #ATTRACTION-NOOFFER-TYPE# .", - "I ' m sorry , but it does n't look like we have a #ATTRACTION-NOOFFER-TYPE# that matches your criteria .", - "Unfortunately there are no #ATTRACTION-NOOFFER-TYPE# venues in that location .", - "I ' m sorry , I do n't see any #ATTRACTION-NOOFFER-TYPE# attractions in that area of town . Is there anything else you 'd be interested in seeing ?", - "There are no #ATTRACTION-NOOFFER-TYPE# in this area .", - "I ' m sorry . There are no #ATTRACTION-NOOFFER-TYPE# listed in that area .", - "Unfortunately I can not find anything strictly categorized as #ATTRACTION-NOOFFER-TYPE# in that area can you provide more specifications ?", - "There are no #ATTRACTION-NOOFFER-TYPE# in that area ." - ], - "Name": [ - "There is no listing for #ATTRACTION-NOOFFER-NAME#", - "i am sorry but i actually am not finding any information for #ATTRACTION-NOOFFER-NAME# ." - ], - "Area": [ - "no such attractions in #ATTRACTION-NOOFFER-AREA#", - "I ' m sorry , I could n't find anything like that in the #ATTRACTION-NOOFFER-AREA# .", - "sorry , i could n't find anything in the #ATTRACTION-NOOFFER-AREA# .", - "I am sorry , I am unable to locate an attraction in #ATTRACTION-NOOFFER-AREA# ?" - ], - "Fee": [ - "sorry , i could n't find anything with #ATTRACTION-NOOFFER-Fee# .", - "There are no attractions with #ATTRACTION-NOOFFER-Fee# ." - ], - "Addr": [ - "I ' m sorry , but I do n't see any attractions at #ATTRACTION-NOOFFER-ADDR# ." - ] - }, - "Attraction-Recommend": { - "Name": [ - "we have #ATTRACTION-RECOMMEND-NAME#", - "#ATTRACTION-RECOMMEND-NAME# looks good , would you like to head there ?", - "Would you like #ATTRACTION-RECOMMEND-NAME# ?", - "you would love #ATTRACTION-RECOMMEND-NAME#", - "I recommend #ATTRACTION-RECOMMEND-NAME# ", - "I would suggest #ATTRACTION-RECOMMEND-NAME# .", - "I 'd recommend #ATTRACTION-RECOMMEND-NAME# . Would you like some information on it ?", - "You would love #ATTRACTION-RECOMMEND-NAME#", - "#ATTRACTION-RECOMMEND-NAME# meets your requirements .", - "how about #ATTRACTION-RECOMMEND-NAME# ? they 're pretty fun ." - ], - "Type": [ - "I would suggest visiting one of the famous #ATTRACTION-RECOMMEND-TYPE# .", - "How about a #ATTRACTION-RECOMMEND-TYPE# ?", - "Would a #ATTRACTION-RECOMMEND-TYPE# work for you ?", - "It 's an #ATTRACTION-RECOMMEND-TYPE# . Great for the whole family , especially the younger ones !" - ], - "none": [ - "It 's a really fun attraction with lots of interesting things to do in it ." - ], - "Fee": [ - "Its entrance fee is #ATTRACTION-RECOMMEND-FEE# .", - "The entry fee is #ATTRACTION-RECOMMEND-FEE# .", - "their entrance fee is #ATTRACTION-RECOMMEND-FEE# by our system currently ." - ], - "Addr": [ - "it is located in #ATTRACTION-RECOMMEND-ADDR#", - "adress is #ATTRACTION-RECOMMEND-ADDR#", - "It is on #ATTRACTION-RECOMMEND-ADDR# .", - "their address in our system is listed as #ATTRACTION-RECOMMEND-ADDR# .", - "The address is #ATTRACTION-RECOMMEND-ADDR# .", - "it 's located at #ATTRACTION-RECOMMEND-ADDR# .", - "#ATTRACTION-RECOMMEND-ADDR# is the address", - "They are located at #ATTRACTION-RECOMMEND-ADDR# ." - ], - "Post": [ - "The postcode is #ATTRACTION-RECOMMEND-POST# .", - "The post code is #ATTRACTION-RECOMMEND-POST# .", - "Its postcode is #ATTRACTION-RECOMMEND-POST# .", - "Their postcode is #ATTRACTION-RECOMMEND-POST# ." - ], - "Phone": [ - "The attraction phone number is #ATTRACTION-RECOMMEND-PHONE# .", - "Here is the attraction phone number , #ATTRACTION-RECOMMEND-PHONE# ." - ], - "Area": [ - "That one is located in the #ATTRACTION-RECOMMEND-AREA# .", - "it is located in the #ATTRACTION-RECOMMEND-AREA# .", - "They are located within the #ATTRACTION-RECOMMEND-AREA# .", - "it 's located in the #ATTRACTION-RECOMMEND-AREA# .", - "That is in the #ATTRACTION-RECOMMEND-AREA# .", - "It will be located in the #ATTRACTION-RECOMMEND-AREA# .", - "it is in the #ATTRACTION-RECOMMEND-AREA# of the city", - "It is in the #ATTRACTION-RECOMMEND-AREA# ." - ], - "Price": [ - "it is in the #ATTRACTION-RECOMMEND-PRICE# price range", - "The fee is #ATTRACTION-RECOMMEND-PRICE# ." - ], - "Choice": [ - "I ' ve found #ATTRACTION-RECOMMEND-CHOICE# places for you to go . Do you have any specific ideas in mind ?", - "sure , there are #ATTRACTION-RECOMMEND-CHOICE# for you to choose from .", - "There are #ATTRACTION-RECOMMEND-CHOICE# . Would you like me to recommend one for you ?", - "We have #ATTRACTION-RECOMMEND-CHOICE# of those ! Anything specific you need or just a recommendation ?", - "sure , there are #ATTRACTION-RECOMMEND-CHOICE# options for you", - "sure , there are #ATTRACTION-RECOMMEND-CHOICE# in that area .", - "we have #ATTRACTION-RECOMMEND-CHOICE# options , can i reccomend for you ?", - "there are #ATTRACTION-RECOMMEND-CHOICE# , anything in particular you are looking for ?", - "We have #ATTRACTION-RECOMMEND-CHOICE# such location ." - ], - "Open": [ - "#ATTRACTION-RECOMMEND-OPEN# in our database .", - "#ATTRACTION-RECOMMEND-OPEN# ." - ] - }, - "Attraction-Request": { - "Type": [ - "What type of attraction are you looking for ?", - "Please specify the type of attraction you 're interested in .", - "What type of attraction are you interested in ?", - "What kind of attraction are you interested in ?", - "What kind of attraction were you looking for ?", - "you have any particular attraction in mind ?", - "what type of attractions are you interested in ?", - "What sort of attraction would you like it to be ?" - ], - "Area": [ - "Any particular area ?", - "is there a certain area of town you would prefer ?", - "I have various attractions all over town . Is there a specific area you are wanting to find something to do ?", - "Do you have a part of town you prefer ?", - "What part of town would you like it", - "Do you have a preference for the area of town you wish to visit ?", - "Where in town would you like to go ?", - "Which part of town would you prefer ?", - "Is there a specific area you are looking for ?" - ], - "Name": [ - "What is the name of the attraction ?", - "What attraction are you thinking about ?", - "I ' m sorry for the confusion , what attraction are you interested in ?", - "What attraction were you thinking of ?", - "Do you know the name of it ?", - "can you give me the name of it ?" - ], - "Price": [ - "any specific price range to help narrow down available options ?", - "What price range would you like ?", - "what is your price range for that ?", - "What price range are you looking for ?", - "What price point is good for you ?", - "Does a entrance fee make any difference ?", - "Would you like a free entrance fee or paid ?", - "Do you need free admission or pay to get in ?", - "What price point would you like ?" - ] - }, - "Booking-Book": { - "Ref": [ - "Booking was successful . Reference number is : #BOOKING-BOOK-REF# .", - "your reference number is #BOOKING-BOOK-REF# .", - "Here is the booking information : Booking was successful . Reference number is : #BOOKING-BOOK-REF#", - "Reference number is : #BOOKING-BOOK-REF# .", - "All set . Your reference number is #BOOKING-BOOK-REF# ." - ], - "Name": [ - "the #BOOKING-BOOK-NAME# seems appropriate . i have booked it for you .", - "i have booked you #BOOKING-BOOK-NAME#", - "I booked you at #BOOKING-BOOK-NAME#", - "Your reservation is at #BOOKING-BOOK-NAME# ." - ], - "none": [ - "Thanks , booking has been completed .", - "Booking was successful .", - "Your booking was successful .", - "You 're all booked ." - ], - "Day": [ - "I was able to book you for #BOOKING-BOOK-DAY# .", - "I have reserved for #BOOKING-BOOK-DAY#", - "I was able to get you that reservation on #BOOKING-BOOK-DAY# ." - ], - "Time": [ - "Would #BOOKING-BOOK-TIME# be a convenient time for you ?" - ], - "Stay": [ - "I ' ve booked you for #BOOKING-BOOK-STAY# night .", - "I was able to book your reservation for #BOOKING-BOOK-STAY# days ." - ], - "People": [ - "I was able to book it for you for #BOOKING-BOOK-PEOPLE# people .", - "I have booked for #BOOKING-BOOK-PEOPLE# people .", - "I was able to book a reservation for #BOOKING-BOOK-PEOPLE# people ." - ] - }, - "Booking-Inform": { - "none": [ - "Shall I try to start and book you into one ?", - "I will book it for you and get a reference number ?", - "Would you like for me to try and make a reservation ?", - "I will go ahead and book that now .", - "Can I make a reservation for you ?", - "Would you like me to book it ?" - ], - "Ref": [ - "Booking was successful . Reference number is : #BOOKING-INFORM-REF#.", - "I was able to book it , reference number is #BOOKING-INFORM-REF# ." - ], - "Name": [ - "Did you need to book the #BOOKING-INFORM-NAME# ?", - "It is #BOOKING-INFORM-NAME# . Do you want a reservation ?", - "#BOOKING-INFORM-NAME# . Would you like to book a reservation ?", - "Would you like to book the #BOOKING-INFORM-NAME# ?", - "Want me to book #BOOKING-INFORM-NAME# ?", - "Would you like me to book the #BOOKING-INFORM-NAME# for you ?", - "I ' ve located #BOOKING-INFORM-NAME# , would you like me to assist you with booking ?" - ], - "People": [ - "Will you be booking for #BOOKING-INFORM-PEOPLE# people ?", - "Would you like to book for #BOOKING-INFORM-PEOPLE# people ?", - "that was #BOOKING-INFORM-PEOPLE# , correct ?", - "i want to confirm this , do i book for #BOOKING-INFORM-PEOPLE# person ?", - "So for #BOOKING-INFORM-PEOPLE# people in total ?", - "Ok I will book it for you for #BOOKING-INFORM-PEOPLE# people", - "I will book that for #BOOKING-INFORM-PEOPLE# people .", - "Do you want reservations for #BOOKING-INFORM-PEOPLE# people ?" - ], - "Day": [ - "Okay , so you would like the reservation for #BOOKING-INFORM-DAY# ?", - "Will you be coming in on #BOOKING-INFORM-DAY# ?", - "Would you like this reservation be for #BOOKING-INFORM-DAY# ?", - "Do you want the reservations to begin on #BOOKING-INFORM-DAY# ?" - ], - "Time": [ - "#BOOKING-INFORM-TIME# is available , would you like me to book that for you ?", - "I try to make the reservation for #BOOKING-INFORM-TIME# ." - ], - "Stay": [ - "#BOOKING-INFORM-STAY# . Would you like me to book that ?", - "For #BOOKING-INFORM-STAY# day ?" - ] - }, - "Booking-NoBook": { - "none": [ - "I am unable to book this for you . Do you have any other preferences ?", - "Booking was unsuccessful do you have any other preference ?", - "I ' m sorry , I was unable to reserve rooms . Would you like to try anything else ?", - "I ' m sorry those are not available .", - "Unfortunately the booking was not successful ." - ], - "Day": [ - "I apologize , but it looks like #BOOKING-NOBOOK-DAY# is not working .", - "I ' m sorry #BOOKING-NOBOOK-DAY# is n't working either .", - "I ' m sorry but i ' m unable to make the reservation on #BOOKING-NOBOOK-DAY# .", - "sorry , but #BOOKING-NOBOOK-DAY# is all booked", - "we are currently full on #BOOKING-NOBOOK-DAY# would you like to book at another hotel ?", - "I ' m sorry , but there 's nothing available starting on #BOOKING-NOBOOK-DAY# .", - "I am unable to book for #BOOKING-NOBOOK-DAY# .", - "#BOOKING-NOBOOK-DAY# is not available ." - ], - "Stay": [ - "Sorry , the hotel ca n't accommodate you for #BOOKING-NOBOOK-STAY# . want to change dates ?", - "Neither is available for #BOOKING-NOBOOK-STAY# nights .", - "They do n't have a room available for #BOOKING-NOBOOK-STAY# nights . Anything else you 'd like me to try ?", - "Unfortunately it can not be booked for #BOOKING-NOBOOK-STAY# days . Did you want to get information about a different hotel instead ?" - ], - "Ref": [ - "Great the booking was successful , your reference number is #BOOKING-NOBOOK-REF#.", - "Booking was successful . Reference number is : #BOOKING-NOBOOK-REF# .", - "Okay that booking was successful and your reference number is #BOOKING-NOBOOK-REF#." - ], - "Name": [ - "Let 's decide on #BOOKING-NOBOOK-NAME# . Unfortunately , that appears to already be booked . Do you want to try one of the others ?" - ], - "Time": [ - "I am sorry they do not have a table at #BOOKING-NOBOOK-TIME# , perhaps another restaurant ?" - ], - "People": [ - "I am sorry , I am unable to make a reservation for #BOOKING-NOBOOK-PEOPLE# people", - "I ' m unable to book for #BOOKING-NOBOOK-PEOPLE# people .", - "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ." - ] - }, - "Booking-Request": { - "Day": [ - "What day would you like your booking for ?", - "What day would you like that reservation ?", - "what day would you like the booking to be made for ?", - "What day would you like to book ?", - "Ok , what day would you like to make the reservation on ?" - ], - "Stay": [ - "How many nights will you be staying ?", - "And how many nights ?", - "for how many days ?", - "And for how many days ?", - "how many days would you like to stay ?", - "How many nights would you like to book it for ?", - "And what nights would you like me to reserve for you ?", - "How many nights are you wanting to stay ?", - "How many days will you be staying ?" - ], - "People": [ - "For how many people ?", - "How many people will be ?", - "How many people will be with you ?", - "How many people is the reservation for ?" - ], - "Time": [ - "Do you have a time preference ?", - "what time are you looking for a reservation at ?", - "For what time ?", - "What time would you like me to make your reservation ?", - "What time would you like the reservation for ?", - "what time should I make the reservation for ?", - "What time would you prefer ?", - "What time would you like the reservation for ?" - ] - }, - "Hotel-Inform": { - "Internet": [ - "it has wifi .", - "the place provides free wifi .", - "the wifi is included .", - "There is wifi available at the hotel .", - "internet is available .", - "it has free wifi ." - ], - "Stars": [ - "It is rated #HOTEL-INFORM-STARS# stars ,", - "It is rated #HOTEL-INFORM-STARS# stars , is that okay ?", - "It has #HOTEL-INFORM-STARS# stars .", - "It has a rating of #HOTEL-INFORM-STARS# stars .", - "They have a #HOTEL-INFORM-STARS# Star rating", - "The hotel is #HOTEL-INFORM-STARS# stars .", - "It is a #HOTEL-INFORM-STARS#-star rating .", - "it does have #HOTEL-INFORM-STARS# stars ." - ], - "Name": [ - "#HOTEL-INFORM-NAME# is available would you like to try that ?", - "Does the #HOTEL-INFORM-NAME# work ?", - "You can try #HOTEL-INFORM-NAME#", - "#HOTEL-INFORM-NAME# is a great place", - "how about #HOTEL-INFORM-NAME# ?", - "How about #HOTEL-INFORM-NAME# ?", - "Okay , how about #HOTEL-INFORM-NAME# ?", - "what about #HOTEL-INFORM-NAME# ?", - "How about #HOTEL-INFORM-NAME# ?" - ], - "Area": [ - "It is in the #HOTEL-INFORM-AREA# area .", - "They are located in the #HOTEL-INFORM-AREA# .", - "it is in the #HOTEL-INFORM-AREA# .", - "It 's located in the #HOTEL-INFORM-AREA# .", - "It is in the #HOTEL-INFORM-AREA# part of town .", - "It 's in #HOTEL-INFORM-AREA# .", - "It is indeed in the #HOTEL-INFORM-AREA# ." - ], - "Parking": [ - "It does include free parking .", - "they have free parking .", - "it offers free parking .", - "the parking is free ." - ], - "Phone": [ - "The hotel phone number is #HOTEL-INFORM-PHONE# .", - "The phone number of the hotel is #HOTEL-INFORM-PHONE# ." - ], - "Choice": [ - "i have #HOTEL-INFORM-CHOICE# options for you", - "There are #HOTEL-INFORM-CHOICE# of those .", - "We have #HOTEL-INFORM-CHOICE# such places .", - "I have #HOTEL-INFORM-CHOICE# different options for you !" - ], - "Addr": [ - "The hotel address is #HOTEL-INFORM-ADDR# .", - "They are located at #HOTEL-INFORM-ADDR#", - "It is located at #HOTEL-INFORM-ADDR#" - ], - "Post": [ - "The postal code for that hotel is #HOTEL-INFORM-POST# .", - "The postcode is #HOTEL-INFORM-POST# .", - "their postcode is #HOTEL-INFORM-POST#" - ], - "none": [ - "Yes , it fits all those needs .", - "Yes it does" - ], - "Type": [ - "It is a #HOTEL-INFORM-TYPE# ." - ], - "Price": [ - "Its listed as #HOTEL-INFORM-PRICE# .", - "It is in the #HOTEL-INFORM-PRICE# price range .", - "It is a #HOTEL-INFORM-PRICE# place .", - "It is #HOTEL-INFORM-PRICE# .", - "This is an #HOTEL-INFORM-PRICE# hotel .", - "It is #HOTEL-INFORM-PRICE# priced .", - "The price range is #HOTEL-INFORM-PRICE# ." - ], - "Ref": [ - "the reference number is #HOTEL-INFORM-REF# .", - "Your reference number is #HOTEL-INFORM-REF#.", - "You 're all set ! Your reference number is #HOTEL-INFORM-REF# .", - "Your reference number is #HOTEL-INFORM-REF#", - "The reference number is #HOTEL-INFORM-REF#", - "Reference number is : #HOTEL-INFORM-REF# .", - "The Reference number is : #HOTEL-INFORM-REF# ." - ] - }, - "Hotel-NoOffer": { - "Type": [ - "Sorry there is no #HOTEL-NOOFFER-TYPE# fitting the description you asked for", - "I was not able to find any #HOTEL-NOOFFER-TYPE# that met those requirements .", - "There are no #HOTEL-NOOFFER-TYPE# that meet that criteria , would you like information about the hotel options ?", - "I ' m sorry , I ' m afraid I do n't see any #HOTEL-NOOFFER-TYPE# matching that description . Do you want to try a different price range or star rating ?", - "i ca n't find any #HOTEL-NOOFFER-TYPE# that fit your criteria , i ' m sorry .", - "It is n't , and unfortunately I do n't have a #HOTEL-NOOFFER-TYPE# that matches that criteria .", - "I ' m sorry , there are no #HOTEL-NOOFFER-TYPE# that match your preferences .", - "no #HOTEL-NOOFFER-TYPE# meet your criteria ." - ], - "none": [ - "Sorry , my search did n't bring back any results .", - "I ' m sorry , I can not help you with hotels . Are you sure that 's what you 're looking for ?", - "I was unable to find any matching places for that .", - "I ' m sorry there are no matches .", - "There were no matches found .", - "There are no hotels meeting these requirements .", - "Nothing fits all of that criteria .", - "Sorry , I ' m not finding anything ." - ], - "Stars": [ - "I am sorry , but that hotel does not have a #HOTEL-NOOFFER-STARS# star rating , would you like another option ?", - "Unfortunately , I could n't find anything with #HOTEL-NOOFFER-STARS# stars .", - "I am sorry I have no listings for any with #HOTEL-NOOFFER-STARS# stars .", - "I am sorry , there are not #HOTEL-NOOFFER-STARS# stars available .", - "I have not found anything with a star of #HOTEL-NOOFFER-STARS# ." - ], - "Parking": [ - "I ' m not showing anything in that area of town with no parking ." - ], - "Area": [ - "Sorry there are none in the #HOTEL-NOOFFER-AREA# .", - "There are none in the #HOTEL-NOOFFER-AREA# . Perhaps another criteria change ?", - "I ' m sorry , no , none in the #HOTEL-NOOFFER-AREA# .", - "There are n't any that match your criteria in the #HOTEL-NOOFFER-AREA# . Any other suggestions ?", - "I have nothing in the #HOTEL-NOOFFER-AREA# . Can I try something else ?" - ], - "Name": [ - "I am not finding anything for #HOTEL-NOOFFER-NAME# that suit your needs", - "There is no #HOTEL-NOOFFER-NAME# in our system .", - "#HOTEL-NOOFFER-NAME# is not available ." - ], - "Price": [ - "I ' m sorry , I do n't have anything in the #HOTEL-NOOFFER-PRICE# price range , would you like to search for something else ?", - "There is none that is #HOTEL-NOOFFER-PRICE# . Would you like to change your criteria ?", - "I ' m sorry , there are no #HOTEL-NOOFFER-PRICE# hotels . Would you like to try searching for something else ?" - ], - "Internet": [ - "There does n't appear to be a hotel with wifi .", - "I ' m sorry there are no results for hotels with free internet .", - "There are not any with wifi ." - ] - }, - "Hotel-Recommend": { - "Name": [ - "How about #HOTEL-RECOMMEND-NAME# ? It has all the attributes you requested and a great name !", - "How about #HOTEL-RECOMMEND-NAME# ? Fits your request perfectly .", - "Would #HOTEL-RECOMMEND-NAME# work for you ?", - "Everyone seems to enjoy the #HOTEL-RECOMMEND-NAME# .", - "How about #HOTEL-RECOMMEND-NAME# ?", - "#HOTEL-RECOMMEND-NAME# looks like it would be a good choice .", - "Will #HOTEL-RECOMMEND-NAME# be alright ?", - "I would suggest #HOTEL-RECOMMEND-NAME#" - ], - "Type": [ - "would a #HOTEL-RECOMMEND-TYPE# be OK ?" - ], - "Price": [ - "Its listed as #HOTEL-RECOMMEND-PRICE# .", - "It is in the #HOTEL-RECOMMEND-PRICE# price range .", - "It is a #HOTEL-RECOMMEND-PRICE# place .", - "It is #HOTEL-RECOMMEND-PRICE# .", - "This is an #HOTEL-RECOMMEND-PRICE# hotel .", - "The price range is #HOTEL-RECOMMEND-PRICE# ." - ], - "Area": [ - "It is in the #HOTEL-RECOMMEND-AREA# area .", - "They are located in the #HOTEL-RECOMMEND-AREA# .", - "it is in the #HOTEL-RECOMMEND-AREA# .", - "It 's located in the #HOTEL-RECOMMEND-AREA# .", - "It is in the #HOTEL-RECOMMEND-AREA# part of town .", - "It 's in #HOTEL-RECOMMEND-AREA# .", - "It is indeed in the #HOTEL-RECOMMEND-AREA# ." - ], - "Addr": [ - "The address is #HOTEL-RECOMMEND-ADDR# .", - "They are located at #HOTEL-RECOMMEND-ADDR#", - "The address is #HOTEL-RECOMMEND-ADDR# .", - "It is located at #HOTEL-RECOMMEND-ADDR#" - ], - "Post": [ - "The postal code for that hotel is #HOTEL-RECOMMEND-POST# .", - "The postcode is #HOTEL-RECOMMEND-POST# .", - "their postcode is #HOTEL-RECOMMEND-POST#" - ], - "Internet": [ - "it has wifi .", - "the place provides free wifi .", - "the wifi is included .", - "There is wifi available at the hotel .", - "internet is available .", - "it has free wifi ." - ], - "Parking": [ - "It does include free parking .", - "they have free parking .", - "it offers free parking .", - "the parking is free ." - ], - "Stars": [ - "It is rated #HOTEL-RECOMMEND-STARS# stars ,", - "It is rated #HOTEL-RECOMMEND-STARS# stars , is that okay ?", - "It has #HOTEL-RECOMMEND-STARS# stars .", - "It has a rating of #HOTEL-RECOMMEND-STARS# stars .", - "it has #HOTEL-RECOMMEND-STARS# star rating .", - "They have a #HOTEL-RECOMMEND-STARS# Star rating", - "The hotel is #HOTEL-RECOMMEND-STARS# stars .", - "It is a #HOTEL-RECOMMEND-STARS#-star rating .", - "it does have #HOTEL-RECOMMEND-STARS# stars ." - ], - "none": [ - "I would suggest that place .", - "May I recommend something to you ?", - "would you like me to make a recommendation ?", - "would you like a recommendation ?" - ], - "Phone": [ - "The hotel phone number is #HOTEL-RECOMMEND-PHONE# .", - "The phone number of the hotel is #HOTEL-RECOMMEND-PHONE# ." - ], - "Choice": [ - "i have #HOTEL-RECOMMEND-CHOICE# options for you", - "There are #HOTEL-RECOMMEND-CHOICE# of those .", - "We have #HOTEL-RECOMMEND-CHOICE# such places .", - "I have #HOTEL-RECOMMEND-CHOICE# different options for you !" - ] - }, - "Hotel-Request": { - "Area": [ - "Okay , do you have a specific area you want to stay in ?", - "What area are you looking to stay in ?", - "Remind me of the area you need that in .", - "Do you have an idea on the location ?", - "What area would you like the hotel in ?", - "Is there a specific area of town you 're interested in ?", - "What area of town would you like to be in ?", - "What area would you like to stay in ?", - "Let me get some additional information . What area of town would you like to stay in ?" - ], - "Price": [ - "What is your price range ?", - "What price range were you thinking ?", - "What price range do you prefer ?", - "Okay , do you have any price range you 're looking for ?", - "Is there a price you are looking for ?", - "Do you have a price range preference ?", - "Is there a price range you prefer ?", - "do you have a price range preference ?", - "what price range are you looking for ?", - "What is the price range for you ?" - ], - "Type": [ - "Would you like a guesthouse or a hotel ?", - "Would you like to stay in a guesthouse , or in a hotel ?", - "would you like a guesthouse or hotel ?", - "Okay , were you looking for a hotel or a guesthouse ?", - "Do you have a preference for a hotel versus a guesthouse ?", - "Would you like to try a hotel ?", - "Were you looking for a hotel with a guesthouse ?", - "What type of hotel are you looking for ?" - ], - "Internet": [ - "Would you prefer one with free internet ?", - "do you need free internet ?", - "Would you prefer one with internet ?" - ], - "Name": [ - "do you have the name of the hotel ?", - "Which hotel is it ?", - "do you know what you 're looking for ?", - "Do you have the hotel name ?", - "What is the name of the hotel you are looking for ?", - "could you please give me the name of the Hotel you are looking for ?", - "Can you give me the name of the place ?", - "What is the name of the hotel you 'd like to book ?", - "Would you like to tell me the name of that hotel ?" - ], - "Stars": [ - "How many stars would you like ?", - "Is there a number of stars you prefer ?", - "Is there a certain star rating you would like it to have ?", - "Do you have a preference of star rating ?", - "Do you have a preference of number of stars ?", - "How many stars should the hotel be rated for ?", - "What star rating do you prefer ?", - "How many stars are you looking for ?" - ], - "Parking": [ - "Will you need parking ?", - "Do you need it to have free parking ?", - "Do you have a parking preference ?", - "Does the hotel needs to have free parking ?", - "Will you need free parking ?", - "Do you need free parking ?", - "Will you need parking while you 're there ?", - "Will you be needing free parking ?" - ] - }, - "Restaurant-Inform": { - "Addr": [ - "they are located at #RESTAURANT-INFORM-ADDR#", - "It is at #RESTAURANT-INFORM-ADDR#", - "The restaurant address is #RESTAURANT-INFORM-ADDR# .", - "Their address is #RESTAURANT-INFORM-ADDR# ." - ], - "Price": [ - "They are in the #RESTAURANT-INFORM-PRICE# price range .", - "It is a #RESTAURANT-INFORM-PRICE# restaurant .", - "This is a #RESTAURANT-INFORM-PRICE# one .", - "They are in the #RESTAURANT-INFORM-PRICE# price range .", - "They are #RESTAURANT-INFORM-PRICE#", - "It 's in the #RESTAURANT-INFORM-PRICE# price range .", - "This restaurant is in the #RESTAURANT-INFORM-PRICE# price range ." - ], - "Food": [ - "They serve #RESTAURANT-INFORM-FOOD# food .", - "They serve #RESTAURANT-INFORM-FOOD# .", - "It is #RESTAURANT-INFORM-FOOD# food .", - "That is a #RESTAURANT-INFORM-FOOD# restaurant ." - ], - "Name": [ - "How about #RESTAURANT-INFORM-NAME# ?", - "#RESTAURANT-INFORM-NAME# looks like a good place .", - "How does the #RESTAURANT-INFORM-NAME# sound ?", - "Okay , how about #RESTAURANT-INFORM-NAME# ?", - "Would you like to try #RESTAURANT-INFORM-NAME# ?", - "There is a restaurant called #RESTAURANT-INFORM-NAME# that meets your criteria .", - "How about #RESTAURANT-INFORM-NAME# ?", - "there 's a place called #RESTAURANT-INFORM-NAME#", - "Would you like to try #RESTAURANT-INFORM-NAME# ?" - ], - "Choice": [ - "I have #RESTAURANT-INFORM-CHOICE# different restaurants I can give you some information for . They are all pretty good .", - "there are #RESTAURANT-INFORM-CHOICE# different places that match your description .", - "There are #RESTAURANT-INFORM-CHOICE# restaurants in that area that fit that criteria .", - "i have #RESTAURANT-INFORM-CHOICE# options for you", - "We have #RESTAURANT-INFORM-CHOICE# such places .", - "I have #RESTAURANT-INFORM-CHOICE# options for you !", - "there are #RESTAURANT-INFORM-CHOICE# available restaurants ." - ], - "Post": [ - "The restaurant postcode is #RESTAURANT-INFORM-POST# .", - "Their postcode is #RESTAURANT-INFORM-POST#", - "The post code is #RESTAURANT-INFORM-POST# ." - ], - "Phone": [ - "The number of the restaurant is #RESTAURANT-INFORM-PHONE# .", - "The restaurant's phone number is #RESTAURANT-INFORM-PHONE# .", - "The phone number of the restaurant is #RESTAURANT-INFORM-PHONE# .", - "#RESTAURANT-INFORM-PHONE# is the restaurant phone number" - ], - "Area": [ - "it is in the #RESTAURANT-INFORM-AREA# area .", - "It is located in the #RESTAURANT-INFORM-AREA# ." - ], - "Ref": [ - "I was able to get you into that restaurant and your reference number is #RESTAURANT-INFORM-REF#.", - "The reference number is #RESTAURANT-INFORM-REF# .", - "I have a reference number for you . It is #RESTAURANT-INFORM-REF# .", - "the reference number is #RESTAURANT-INFORM-REF#.", - "The reference number is #RESTAURANT-INFORM-REF#", - "Reference number is : #RESTAURANT-INFORM-REF#.", - "Your reference number is #RESTAURANT-INFORM-REF# . You will likely need it .", - "Your reference code is #RESTAURANT-INFORM-REF# .", - "Your reference number is #RESTAURANT-INFORM-REF# ." - ], - "none": [ - "Is there anything else I can help you with ?", - "I have found that restaurant for you", - "It 's a perfect fit .", - "That 's a great place ." - ] - }, - "Restaurant-NoOffer": { - "Food": [ - "There are no #RESTAURANT-NOOFFER-FOOD# food places , shall I run another search ?", - "I do not have anything in that price range for #RESTAURANT-NOOFFER-FOOD# . Another criteria perhaps ?", - "I am sorry there are not #RESTAURANT-NOOFFER-FOOD# restaurants . Can I check for a Chinese restaurant ?", - "I am unable to find any #RESTAURANT-NOOFFER-FOOD# restaurants in town .", - "I have nothing with #RESTAURANT-NOOFFER-FOOD# . Do you have another preference ?", - "There are no #RESTAURANT-NOOFFER-FOOD# restaurants .", - "I did not find any #RESTAURANT-NOOFFER-FOOD# restaurants .", - "There no #RESTAURANT-NOOFFER-FOOD# restaurants that I can find right now . Would something else work ?", - "I ' m sorry I have no restaurants serving #RESTAURANT-NOOFFER-FOOD# food .", - "There are no #RESTAURANT-NOOFFER-FOOD# restaurants unfortunately ." - ], - "none": [ - "I do n't have anything meeting that criteria . Can I look for something else ?", - "we do nt have a place that matches those qualities . can you try something else ?", - "I am afraid there is none .", - "We do n't have any of those , sad to say . Want to broaden the search ?", - "There are no matching records found for that request .", - "No , I ' m sorry . The search did n't pull up any matches .", - "There is no listing for this restaurant" - ], - "Area": [ - "Sorry , there are no restaurants like that in the #RESTAURANT-NOOFFER-AREA# .", - "I did not find any restaurants in #RESTAURANT-NOOFFER-AREA# .", - "I am sorry there is none even in the #RESTAURANT-NOOFFER-AREA#", - "I am sorry there are no restaurants in #RESTAURANT-NOOFFER-AREA# that match that description .", - "There are none in #RESTAURANT-NOOFFER-AREA# of town .", - "I am sorry but there are no restaurants that fit that criteria in the #RESTAURANT-NOOFFER-AREA# .", - "i have n't found any in the #RESTAURANT-NOOFFER-AREA#", - "there no such restraunts in #RESTAURANT-NOOFFER-AREA#" - ], - "Price": [ - "I do n't have anything in the #RESTAURANT-NOOFFER-PRICE# range that fits that criteria .", - "There are none in #RESTAURANT-NOOFFER-PRICE# , perhaps something else ?", - "There are no #RESTAURANT-NOOFFER-PRICE# ones .", - "No #RESTAURANT-NOOFFER-PRICE# restaurant" - ], - "Name": [ - "i ' m sorry . i can not find details for #RESTAURANT-NOOFFER-NAME# ." - ] - }, - "Restaurant-Recommend": { - "Name": [ - "How about #RESTAURANT-RECOMMEND-NAME# ?", - "#RESTAURANT-RECOMMEND-NAME# has some great reviews .", - "I would suggest #RESTAURANT-RECOMMEND-NAME# .", - "The #RESTAURANT-RECOMMEND-NAME# is a nice place would you like to try that one ?", - "Excellent . #RESTAURANT-RECOMMEND-NAME# is just your thing .", - "I would suggest #RESTAURANT-RECOMMEND-NAME# .", - "#RESTAURANT-RECOMMEND-NAME# sounds like it might be what you are looking for .", - "#RESTAURANT-RECOMMEND-NAME# matches your description .", - "I have a place called #RESTAURANT-RECOMMEND-NAME# , does that sound like something you would enjoy ?", - "I recommend #RESTAURANT-RECOMMEND-NAME#" - ], - "Food": [ - "Would you like #RESTAURANT-RECOMMEND-FOOD# food ?", - "Would you like to try an #RESTAURANT-RECOMMEND-FOOD# restaurant ?", - "How about #RESTAURANT-RECOMMEND-FOOD# ?", - "Okay , may I suggest #RESTAURANT-RECOMMEND-FOOD# food ?" - ], - "Price": [ - "They are in the #RESTAURANT-RECOMMEND-PRICE# price range .", - "It is a #RESTAURANT-RECOMMEND-PRICE# restaurant .", - "This is a #RESTAURANT-RECOMMEND-PRICE# one .", - "They are in the #RESTAURANT-RECOMMEND-PRICE# price range .", - "They are #RESTAURANT-RECOMMEND-PRICE#", - "It 's in the #RESTAURANT-RECOMMEND-PRICE# price range .", - "This restaurant is in the #RESTAURANT-RECOMMEND-PRICE# price range ." - ], - "Area": [ - "it is in the #RESTAURANT-RECOMMEND-AREA# area .", - "It is located in the #RESTAURANT-RECOMMEND-AREA# ." - ], - "Addr": [ - "they are located at #RESTAURANT-RECOMMEND-ADDR#", - "It is at #RESTAURANT-RECOMMEND-ADDR#", - "The address is #RESTAURANT-RECOMMEND-ADDR# .", - "Their address is #RESTAURANT-RECOMMEND-ADDR# ." - ], - "Post": [ - "The postcode is #RESTAURANT-RECOMMEND-POST# .", - "Their postcode is #RESTAURANT-RECOMMEND-POST#", - "The post code is #RESTAURANT-RECOMMEND-POST# ." - ], - "Phone": [ - "The number of the restaurant is #RESTAURANT-RECOMMEND-PHONE# .", - "The restaurant's phone number is #RESTAURANT-RECOMMEND-PHONE#", - "The phone number of the restaurant is #RESTAURANT-RECOMMEND-PHONE# .", - "#RESTAURANT-RECOMMEND-PHONE# is the restaurant phone number" - ], - "none": [ - "Is there anything else I can help you with ?", - "I have found that restaurant for you", - "It 's a perfect fit .", - "That 's a great place ." - ] - }, - "Restaurant-Request": { - "Area": [ - "what location ?", - "What area of town would you prefer ?", - "Do you have a specific area in mind ?", - "We need some more information . Where are looking to eat ?", - "Do you have an area of town you prefer ?", - "Which side of town would you prefer ?", - "What area should the restaurant be in ?", - "Do you have an area preference ?", - "Do you have a location preference ?" - ], - "Food": [ - "Do you have any specific type of food you would like ?", - "What type of food are you looking for ?", - "Do you have a preference in food type ?", - "What cuisine are you interested in ?", - "What type of food would you like ?", - "What type of food do you want to eat ?", - "Is there a certain kind of food you would like ?", - "what type of food would you like ?", - "what type of food would you like to eat ?", - "did you have a specific kind of cuisine in mind ?" - ], - "Price": [ - "is there a price range that you prefer ?", - "Do you have a price range ?", - "what price range would you like to stay within ?", - "Do you have a preference for the price range ?", - "Do you have a certain price range you would like ?", - "Did you have a price range in mind ?", - "what is the price range you are looking for ?", - "what price range are you looking for ?", - "Do you have a price range in mind ?", - "Is there a price range you would prefer to stay within ?" - ], - "Name": [ - "Do you know the name ?", - "what is the name of the restaurant ?", - "What 's the name of the restaurant you 're looking for ?", - "what is the name of the restaurant ?", - "What is the name of the restaurant you have in mind ?", - "Are you looking for something in particular ?", - "what is the name of the restaurant you are needing information on ?", - "Do you know the name of the location ?", - "Is there a certain restaurant you 're looking for ?" - ] - }, - "Taxi-Inform": { - "Arrive": [ - "it will arrive by #TAXI-INFORM-ARRIVE#", - "the taxi is due to arrive at #TAXI-INFORM-ARRIVE# .", - "I can book one for the closest time to #TAXI-INFORM-ARRIVE# .", - "you will arrive by #TAXI-INFORM-ARRIVE# ." - ], - "none": [ - "Your taxi will be available and has been booked .", - "I have booked you a Taxi that fits your needs .", - "The Taxi has been booked as requested .", - "you have been assigned a specific car .", - "I have booked your taxi", - "Okay I completed a booking for you" - ], - "Car": [ - "The model of the car was #TAXI-INFORM-CAR# .", - "A #TAXI-INFORM-CAR# is booked for you .", - "The taxi is a #TAXI-INFORM-CAR#", - "a #TAXI-INFORM-CAR# is booked .", - "I was able to book a #TAXI-INFORM-CAR# for you !", - "it will be a #TAXI-INFORM-CAR# .", - "Your booking is complete , a #TAXI-INFORM-CAR# will be picking you up .", - "The car arriving will be a #TAXI-INFORM-CAR# .", - "You got the #TAXI-INFORM-CAR# enjoy the ride .", - "I have successfully booked you a #TAXI-INFORM-CAR# ." - ], - "Phone": [ - "The contact number is #TAXI-INFORM-PHONE# .", - "Contact number for the taxi is #TAXI-INFORM-PHONE# .", - "their contact number is #TAXI-INFORM-PHONE#", - "the contact is #TAXI-INFORM-PHONE# .", - "You can give them a call at #TAXI-INFORM-PHONE#", - "The contact number is #TAXI-INFORM-PHONE# .", - "you can reach them on #TAXI-INFORM-PHONE#" - ], - "Leave": [ - "It will leave at #TAXI-INFORM-LEAVE# ." - ], - "Dest": [ - "it 's taking you to the #TAXI-INFORM-DEST# in time for your reservation .", - "I have booked a taxi to take you to #TAXI-INFORM-DEST# for your reservation ." - ], - "Depart": [ - "It will pick you up at #TAXI-INFORM-DEPART# .", - "Okay , I ' ve booked you a taxi leaving #TAXI-INFORM-DEPART# .", - "I have booked you a taxi from #TAXI-INFORM-DEPART# .", - "I have booked a taxi to pick you up at #TAXI-INFORM-DEPART# ." - ] - }, - "Taxi-Request": { - "Depart": [ - "Where will you be departing from ?", - "Where will you leave from ?", - "I need to know where you 'd like picked up at .", - "Where are you departing from ?", - "Where are you departing from , please ?", - "where are you leaving from ?", - "Where will you be leaving from ?" - ], - "Dest": [ - "what will your destination be ?", - "Where will the taxi be taking you ?", - "I need to know where you are traveling to", - "what is your destination please", - "What is the destination ?", - "Where are you going ?", - "And where will you be going ?", - "Where will you be going to ?" - ], - "Leave": [ - "What time would you like the taxi to pick you up ?", - "What time would you like to leave ?", - "What time would you like to be picked up ?", - "what time do you want to leave by ?", - "What time would you like to be picked up ?", - "What time do you need to book a taxi for ?", - "When would you like to leave ?", - "what time will you be leaving .", - "Can you tell me what time you would like the taxi to pick you up ?", - "I need to know what time you need to leave ." - ], - "Arrive": [ - "at what time would you like to arrive by ?", - "when would you like to arrive ?", - "What time would you like to arrive ?", - "Do you have a arrival time in mind ?", - "When is your arrival time ?", - "when would you like to arrive ?", - "when would you like to arrive by ?", - "What time would you like to arrive there by ?", - "What time would you like to arrive at your destination ?" - ] - }, - "Train-Inform": { - "Ticket": [ - "The fare is #TRAIN-INFORM-TICKET# per ticket .", - "The price of those tickets are #TRAIN-INFORM-TICKET# .", - "It is #TRAIN-INFORM-TICKET#", - "The price is #TRAIN-INFORM-TICKET# per ticket .", - "The price is #TRAIN-INFORM-TICKET# .", - "The trip will cost #TRAIN-INFORM-TICKET# .", - "The fare is #TRAIN-INFORM-TICKET#", - "It would cost #TRAIN-INFORM-TICKET# .", - "The cost of the one way journey is #TRAIN-INFORM-TICKET# .", - "The price is #TRAIN-INFORM-TICKET# per ticket ." - ], - "Leave": [ - "it leaves at #TRAIN-INFORM-LEAVE#", - "I have a train leaving at #TRAIN-INFORM-LEAVE# would that be okay ?", - "There is a train that leaves at #TRAIN-INFORM-LEAVE# .", - "How about #TRAIN-INFORM-LEAVE# will that work for you ?", - "There is a train meeting your criteria and is leaving at #TRAIN-INFORM-LEAVE# .", - "I would say you should leave by #TRAIN-INFORM-LEAVE#" - ], - "Id": [ - "The train ID is #TRAIN-INFORM-ID# .", - "Their ID is #TRAIN-INFORM-ID# .", - "#TRAIN-INFORM-ID# would be your perfect fit ." - ], - "Ref": [ - "The reference number is #TRAIN-INFORM-REF# .", - "The reference number of your trip is #TRAIN-INFORM-REF# .", - "Here is the reference number #TRAIN-INFORM-REF#", - "your reference number is #TRAIN-INFORM-REF#" - ], - "Time": [ - "The travel time is #TRAIN-INFORM-TIME# .", - "That would be #TRAIN-INFORM-TIME# .", - "#TRAIN-INFORM-TIME# would be the total duration .", - "The trip will last #TRAIN-INFORM-TIME#", - "the travel will take #TRAIN-INFORM-TIME#", - "The trip is #TRAIN-INFORM-TIME# .", - "The travel time for the trip is #TRAIN-INFORM-TIME# one way ." - ], - "Arrive": [ - "It arrives at #TRAIN-INFORM-ARRIVE# .", - "it should arrive by #TRAIN-INFORM-ARRIVE#", - "The arrival time is #TRAIN-INFORM-ARRIVE# ." - ], - "Depart": [ - "that train is departing from #TRAIN-INFORM-DEPART# .", - "it departs from #TRAIN-INFORM-DEPART# .", - "the train will be departing from #TRAIN-INFORM-DEPART# ." - ], - "Choice": [ - "I have #TRAIN-INFORM-CHOICE# trains that meet your criteria .", - "There are #TRAIN-INFORM-CHOICE# .", - "There are #TRAIN-INFORM-CHOICE# options", - "There are #TRAIN-INFORM-CHOICE# trains available .", - "There are #TRAIN-INFORM-CHOICE# total trips available to you" - ], - "none": [ - "Yes it is .", - "Yes it does ." - ], - "Day": [ - "The train is for #TRAIN-INFORM-DAY# you are all set", - "that train leaves on #TRAIN-INFORM-DAY# ." - ], - "Dest": [ - "the booking is for arriving in #TRAIN-INFORM-DEST# .", - "the train stop is #TRAIN-INFORM-DEST# ." - ], - "People": [ - "I booked #TRAIN-INFORM-PEOPLE# tickets ." - ] - }, - "Train-NoOffer": { - "Depart": [ - "There is no train leaving #TRAIN-NOOFFER-DEPART# .", - "There are no trains leaving from #TRAIN-NOOFFER-DEPART# ." - ], - "none": [ - "I ' m sorry , unfortunately there are no trains available at that time .", - "I do not have any trains that match your request .", - "There is no train leaving at that time .", - "I ' m sorry , there are no trains that meet your criteria ." - ], - "Leave": [ - "There is not one that leaves at #TRAIN-NOOFFER-LEAVE# .", - "There is no train leaving at #TRAIN-NOOFFER-LEAVE# .", - "There are no trains that leave after #TRAIN-NOOFFER-LEAVE# .", - "There are no trains leaving at #TRAIN-NOOFFER-LEAVE# .", - "Unfortunately , there are no trains that leave after #TRAIN-NOOFFER-LEAVE# ." - ], - "Day": [ - "No I ' m sorry there are not on a #TRAIN-NOOFFER-DAY# .", - "There are no trains on #TRAIN-NOOFFER-DAY# .", - "Unfortunately , there is not a train on #TRAIN-NOOFFER-DAY# matching your criteria .", - "There are no trains on #TRAIN-NOOFFER-DAY# ." - ], - "Dest": [ - "I ' m sorry there are no trains to #TRAIN-NOOFFER-Dest# .", - "There are no trains going to #TRAIN-NOOFFER-Dest# . Do you have another destination in mind ?", - "There do n't seem to be any trains going to #TRAIN-NOOFFER-Dest# ." - ], - "Arrive": [ - "I am sorry there are not trains to arrive at #TRAIN-NOOFFER-ARRIVE# .", - "There are no trains arriving by #TRAIN-NOOFFER-ARRIVE# .", - "I ' m sorry , we do n't have any trains arriving by #TRAIN-NOOFFER-ARRIVE# .", - "no trains arrives by #TRAIN-NOOFFER-ARRIVE#", - "I do not have a train arriving at #TRAIN-NOOFFER-ARRIVE# ." - ], - "Id": [ - "Sorry it looks like #TRAIN-NOOFFER-ID# is no longer available .", - "I ' m sorry , train #TRAIN-NOOFFER-ARRIVE# is not available ." - ] - }, - "Train-OfferBook": { - "none": [ - "Ok I will book that for you and get you a confirmation number", - "Ok great . Would you like for me to go ahead and book the train for you ?", - "Would you like me to reserve your train tickets ?", - "I will book that for you now .", - "Great , I have a train that meets your criteria . Would you like me to book it for you ?", - "Would you like me to book this train for you ?", - "Well , I can book it for YOU if you would like . Would you like me to ?", - "Did you want reservations ?" - ], - "Depart": [ - "Would you like me to book a train from #TRAIN-OFFERBOOK-DEPART# for you ?" - ], - "Id": [ - "Train #TRAIN-OFFERBOOK-ID# would work for you .", - "Okay the trainID is #TRAIN-OFFERBOOK-ID# ", - "The #TRAIN-OFFERBOOK-ID# meets your criteria .", - "Shall I go ahead and book you for train #TRAIN-OFFERBOOK-ID# ?", - "Okay , would you like me to book #TRAIN-OFFERBOOK-ID# ?", - "#TRAIN-OFFERBOOK-ID# looks like it would work for you", - "I would suggest booking #TRAIN-OFFERBOOK-ID#", - "Train #TRAIN-OFFERBOOK-ID# meets your criteria" - ], - "Arrive": [ - "Will you train arriving at #TRAIN-OFFERBOOK-ARRIVE# work ok for you ?", - "I have a train arriving at #TRAIN-OFFERBOOK-ARRIVE# . Would that do ?", - "It arrives by #TRAIN-OFFERBOOK-ARRIVE# would that be okay to book for you ?", - "I can get you tickets for an arrival time at #TRAIN-OFFERBOOK-ARRIVE# , is that okay ?", - "Will an arrival time of #TRAIN-OFFERBOOK-ARRIVE# work for you ?", - "Would you like me to book the #TRAIN-OFFERBOOK-ARRIVE# train ?" - ], - "People": [ - "I will book it for you for #TRAIN-OFFERBOOK-PEOPLE# people", - "I will go ahead and book #TRAIN-OFFERBOOK-PEOPLE# tickets .", - "i just want to confirm if i am booking #TRAIN-OFFERBOOK-PEOPLE# ticket", - "I will book #TRAIN-OFFERBOOK-PEOPLE# tickets for you .", - "I will book it for #TRAIN-OFFERBOOK-PEOPLE# people ." - ], - "Ticket": [ - "The price is #TRAIN-OFFERBOOK-TICKET# per ticket ", - "The cost per seat is #TRAIN-OFFERBOOK-TICKET# .", - "The price of the ticket is #TRAIN-OFFERBOOK-TICKET# " - ], - "Leave": [ - "Would you like me to book you on the #TRAIN-OFFERBOOK-LEAVE# train ?", - "There is a train that leaves at #TRAIN-OFFERBOOK-LEAVE# would you like me to book that train for you ?", - "Okay ! How about the train that leaves at #TRAIN-OFFERBOOK-LEAVE# ?", - "I would recommend the train that leaves at #TRAIN-OFFERBOOK-LEAVE# . Would you like me to book that ?", - "There is a train arriving at #TRAIN-OFFERBOOK-LEAVE# would you like me to book tickets for that one ?", - "The earliest train is at #TRAIN-OFFERBOOK-LEAVE# , do you want me to book it ?", - "There is a train leaving at #TRAIN-OFFERBOOK-LEAVE# would you like me to book this ?", - "Great , I have a train leaving there at #TRAIN-OFFERBOOK-LEAVE# . Would you like to book that ?", - "I can book a #TRAIN-OFFERBOOK-LEAVE# for you .", - "We can book you for the train leaving at #TRAIN-OFFERBOOK-LEAVE# ." - ], - "Choice": [ - "I have #TRAIN-OFFERBOOK-CHOICE# trains available that meet all of your requirements , would you like me to book a ticket for you ?", - "There are #TRAIN-OFFERBOOK-CHOICE# trains available . Should I book a train for you ?" - ], - "Dest": [ - "Would you like me to book a train to #TRAIN-OFFERBOOK-DEST# for you ?" - ], - "Time": [ - "The travel time is #TRAIN-OFFERBOOK-TIME# , would you like me to book it for you ?" - ], - "Ref": [ - "Your reference number is #TRAIN-OFFERBOOK-REF# ." - ], - "Day": [ - "Would you like to take the train on #TRAIN-OFFERBOOK-DAY# ?", - "I can book you on #TRAIN-OFFERBOOK-DAY#", - "I can book your tickets for #TRAIN-OFFERBOOK-DAY# ." - ] - }, - "Train-OfferBooked": { - "Ref": [ - "I ' ve booked your train tickets , and your reference number is #TRAIN-OFFERBOOKED-REF#.", - "Your booking was successful . Your reference number is #TRAIN-OFFERBOOKED-REF# .", - "I have made those reservations and your reference number is #TRAIN-OFFERBOOKED-REF# ." - ], - "none": [ - "Wonderful , your train has been booked !", - "your reservation is booked", - "ok , i got that fixed for you .", - "Booking was successful", - "Great your booking is all set !", - "You have been booked !" - ], - "Ticket": [ - "The cost of your ticket will be #TRAIN-OFFERBOOKED-TICKET#", - "the total fee is #TRAIN-OFFERBOOKED-TICKET# payable at the station", - "It is #TRAIN-OFFERBOOKED-TICKET# .", - "Booking was successful , the total fee is #TRAIN-OFFERBOOKED-TICKET# payable at the station ." - ], - "Id": [ - "the train ID is #TRAIN-OFFERBOOKED-ID# .", - "Ok . You should be set . The booking was successful . The train number is #TRAIN-OFFERBOOKED-ID# ." - ], - "People": [ - "I have successfully make a booking for #TRAIN-OFFERBOOKED-PEOPLE# on that train .", - "I have booked you #TRAIN-OFFERBOOKED-PEOPLE# tickets ." - ], - "Leave": [ - "i have booked you one leaving at #TRAIN-OFFERBOOKED-LEAVE# ." - ], - "Time": [ - "The travel time is #TRAIN-OFFERBOOKED-TIME# .", - "That would be #TRAIN-OFFERBOOKED-TIME# .", - "#TRAIN-OFFERBOOKED-TIME# would be the total duration .", - "The trip will last #TRAIN-OFFERBOOKED-TIME#", - "the travel will take #TRAIN-OFFERBOOKED-TIME#", - "The trip is #TRAIN-OFFERBOOKED-TIME# .", - "The travel time for the trip is #TRAIN-OFFERBOOKED-TIME# one way ." - ], - "Day": [ - "The train is for #TRAIN-OFFERBOOKED-DAY# you are all set", - "that train leaves on #TRAIN-OFFERBOOKED-DAY# ." - ], - "Depart": [ - "that train is departing from #TRAIN-OFFERBOOKED-DEPART# .", - "it departs from #TRAIN-OFFERBOOKED-DEPART# .", - "the train will be departing from #TRAIN-OFFERBOOKED-DEPART# ." - ], - "Dest": [ - "the booking is for arriving in #TRAIN-OFFERBOOKED-DEST# .", - "the train stop is #TRAIN-OFFERBOOKED-DEST# ." - ], - "Arrive": [ - "It arrives at #TRAIN-OFFERBOOKED-ARRIVE# .", - "it should arrive by #TRAIN-OFFERBOOKED-ARRIVE#", - "The arrival time is #TRAIN-OFFERBOOKED-ARRIVE# ." - ], - "Choice": [ - "I have #TRAIN-OFFERBOOKED-CHOICE# trains that meet your criteria .", - "There are #TRAIN-OFFERBOOKED-CHOICE# .", - "There are #TRAIN-OFFERBOOKED-CHOICE# options", - "There are #TRAIN-OFFERBOOKED-CHOICE# trains available .", - "There are #TRAIN-OFFERBOOKED-CHOICE# total trips available to you" - ] - }, - "Train-Request": { - "Day": [ - "Can you confirm your desired travel day ?", - "Can you tell me what day you would like to travel , please ?", - "Can you tell me what day you would like to travel ?", - "can you tell me which day you 'd like to travel on ?", - "What day would you like ?", - "What day will you travel on ?", - "On what day will you be traveling ?", - "What day will you be traveling ?", - "what day did you have in mind ?", - "what day would you like to travel ?" - ], - "Dest": [ - "Where would you like to go to ?", - "Where is your destination ?", - "where would you like your train to take you ?", - "Where are you heading to ?", - "Where are you headed ?", - "Where will you be arriving at ?", - "What is your destination ?", - "What station would you like to arrive at ?" - ], - "People": [ - "How many tickets would you like ?", - "how many tickets would you like me to book ?", - "for how many tickets ?", - "For how many people ?", - "how many tickets do you need ?", - "No problem . How many seats would you like to book ?" - ], - "Depart": [ - "where will you be departing from ?", - "Where are you departing from ?", - "Where will you be leaving from ?", - "where did you want to depart from ?", - "Where are you departing from ?", - "Where will you be traveling from ?" - ], - "Leave": [ - "what time would you like to depart ?", - "when would you like to travel ?", - "Is there a certain time you are wanting to leave ?", - "When would you like to leave by ?", - "departure time in mind ?", - "When would you like the train to depart ?", - "what time do you want to depart ?", - "when would you like to leave by ?", - "what time would you like to leave ?" - ], - "Arrive": [ - "Is there a time you would prefer to arrive ?", - "What time would you like to arrive by ?", - "What time do you need to arrive ?", - "What time do you want to arrive by ?", - "Do you have an arrival time in mind ?", - "Is there a time you would like to arrive by ?", - "Is there a time you would like to get there by ?", - "Is there a time you need to arrive by ?" - ] - }, - "Police-Inform": { - "Addr": [ - "it is located in #POLICE-INFORM-ADDR#", - "adress is #POLICE-INFORM-ADDR#", - "It is on #POLICE-INFORM-ADDR# .", - "their address in our system is listed as #POLICE-INFORM-ADDR# .", - "The address is #POLICE-INFORM-ADDR# .", - "it 's located at #POLICE-INFORM-ADDR# .", - "#POLICE-INFORM-ADDR# is the address", - "They are located at #POLICE-INFORM-ADDR# ." - ], - "Post": [ - "The postcode of the police is #POLICE-INFORM-POST# .", - "The post code is #POLICE-INFORM-POST# .", - "Its postcode is #POLICE-INFORM-POST# .", - "Their postcode is #POLICE-INFORM-POST# ." - ], - "Name": [ - "I think a fun place to visit is #POLICE-INFORM-NAME# .", - "#POLICE-INFORM-NAME# looks good .", - "#POLICE-INFORM-NAME# is available , would that work for you ?", - "we have #POLICE-INFORM-NAME# .", - "#POLICE-INFORM-NAME# is popular among visitors .", - "How about #POLICE-INFORM-NAME# ?", - "What about #POLICE-INFORM-NAME# ?", - "you might want to try the #POLICE-INFORM-NAME# ." - ], - "Phone": [ - "The police phone number is #POLICE-INFORM-PHONE# .", - "Here is the police phone number , #POLICE-INFORM-PHONE# ." - ] - }, - "Hospital-Inform": { - "Addr": [ - "it is located in #HOSPITAL-INFORM-ADDR#", - "adress is #HOSPITAL-INFORM-ADDR#", - "It is on #HOSPITAL-INFORM-ADDR# .", - "their address in our system is listed as #HOSPITAL-INFORM-ADDR# .", - "The address is #HOSPITAL-INFORM-ADDR# .", - "it 's located at #HOSPITAL-INFORM-ADDR# .", - "#HOSPITAL-INFORM-ADDR# is the address", - "They are located at #HOSPITAL-INFORM-ADDR# ." - ], - "Post": [ - "The postcode of the hospital is #HOSPITAL-INFORM-POST# .", - "The post code is #HOSPITAL-INFORM-POST# .", - "Its postcode is #HOSPITAL-INFORM-POST# .", - "Their postcode is #HOSPITAL-INFORM-POST# ." - ], - "Department": [ - "The department of the hospital is #HOSPITAL-INFORM-POST# .", - "The department is #HOSPITAL-INFORM-POST# .", - "Its department is #HOSPITAL-INFORM-POST# .", - "Their department is #HOSPITAL-INFORM-POST# ." - - ], - "Phone": [ - "The hospital phone number is #HOSPITAL-INFORM-PHONE# .", - "Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ." - ] - }, - "Hospital-Request": { - "Department": [ - "What is the name of the hospital department ?", - "What hospital department are you thinking about ?", - "I ' m sorry for the confusion , what hospital department are you interested in ?", - "What hospital department were you thinking of ?", - "Do you know the department of it ?", - "can you give me the department of it ?" - ] - }, - "general-bye": { - "none": [ - "Thank you for using our services .", - "Goodbye . If you think of anything else you need do n't hesitate to contact us .", - "You are very welcome . Goodbye .", - "Thank you and enjoy your visit . Have a great day .", - "I ' m happy to help , and I hope you enjoy your stay !", - "Thank you and goodbye .", - "Thank you for using our system !" - ] - }, - "general-greet": { - "none": [ - "You are more than welcome !", - "Have a good day .", - "I ' m happy to have been able to help you today .", - "Glad to have been of help . Thank you for using the Cambridge TownInfo centre . Enjoy the rest of your day !", - "Thank you for contacting the help desk . Have a great day .", - "Thank you for contacting us and have a nice day .", - "Ok , thank you . Have a good day .", - "Thank you for using our services ." - ] - }, - "general-reqmore": { - "none": [ - "You are welcome . Is there anything else I can help you with today ?", - "is there anything else I can help you with ?", - "Is there anything else I can help you with today ?", - "Did you need any further assistance today ?" - ] - }, - "general-reqinfo": { - "none": [ - "Could you please provide me with more information ." - ] - }, - "general-welcome": { - "none": [ - "you are welcome", - "Welcome , it was a pleasure serving you .", - "You 're welcome ! I hope you have a wonderful trip !", - "Okay ! Glad I could help . Enjoy your stay .", - "You are welcome . Have a good day !", - "you are welcome", - "You 're welcome . Have a good day !" - ] - } -} diff --git a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json index 300369b525d3d5c4781ea00587f198e56657a37b..414f0b80b740d0c33568ba21bab5e13e822d95df 100755 --- a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json +++ b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json @@ -225,7 +225,7 @@ "Do you know the name of it ?", "can you give me the name of it ?" ], - "Fee": [ + "Price": [ "any specific price range to help narrow down available options ?", "What price range would you like ?", "what is your price range for that ?", @@ -363,6 +363,42 @@ "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ." ] }, + "Booking-Request": { + "Day": [ + "What day would you like your booking for ?", + "What day would you like that reservation ?", + "what day would you like the booking to be made for ?", + "What day would you like to book ?", + "Ok , what day would you like to make the reservation on ?" + ], + "Stay": [ + "How many nights will you be staying ?", + "And how many nights ?", + "for how many days ?", + "And for how many days ?", + "how many days would you like to stay ?", + "How many nights would you like to book it for ?", + "And what nights would you like me to reserve for you ?", + "How many nights are you wanting to stay ?", + "How many days will you be staying ?" + ], + "People": [ + "For how many people ?", + "How many people will be ?", + "How many people will be with you ?", + "How many people is the reservation for ?" + ], + "Time": [ + "Do you have a time preference ?", + "what time are you looking for a reservation at ?", + "For what time ?", + "What time would you like me to make your reservation ?", + "What time would you like the reservation for ?", + "what time should I make the reservation for ?", + "What time would you prefer ?", + "What time would you like the reservation for ?" + ] + }, "Hotel-Inform": { "Internet": [ "it has wifi .", @@ -661,30 +697,6 @@ "Do you need free parking ?", "Will you need parking while you 're there ?", "Will you be needing free parking ?" - ], - "Day": [ - "What day would you like your booking for ?", - "What day would you like that reservation ?", - "what day would you like the booking to be made for ?", - "What day would you like to book ?", - "Ok , what day would you like to make the reservation on ?" - ], - "Stay": [ - "How many nights will you be staying ?", - "And how many nights ?", - "for how many days ?", - "And for how many days ?", - "how many days would you like to stay ?", - "How many nights would you like to book it for ?", - "And what nights would you like me to reserve for you ?", - "How many nights are you wanting to stay ?", - "How many days will you be staying ?" - ], - "People": [ - "For how many people ?", - "How many people will be ?", - "How many people will be with you ?", - "How many people is the reservation for ?" ] }, "Restaurant-Inform": { @@ -906,29 +918,6 @@ "what is the name of the restaurant you are needing information on ?", "Do you know the name of the location ?", "Is there a certain restaurant you 're looking for ?" - ], - "Day": [ - "What day would you like your booking for ?", - "What day would you like that reservation ?", - "what day would you like the booking to be made for ?", - "What day would you like to book ?", - "Ok , what day would you like to make the reservation on ?" - ], - "People": [ - "For how many people ?", - "How many people will be ?", - "How many people will be with you ?", - "How many people is the reservation for ?" - ], - "Time": [ - "Do you have a time preference ?", - "what time are you looking for a reservation at ?", - "For what time ?", - "What time would you like me to make your reservation ?", - "What time would you like the reservation for ?", - "what time should I make the reservation for ?", - "What time would you prefer ?", - "What time would you like the reservation for ?" ] }, "Taxi-Inform": { @@ -1342,77 +1331,6 @@ "Is there a time you need to arrive by ?" ] }, - "Police-Inform": { - "Addr": [ - "it is located in #POLICE-INFORM-ADDR#", - "adress is #POLICE-INFORM-ADDR#", - "It is on #POLICE-INFORM-ADDR# .", - "their address in our system is listed as #POLICE-INFORM-ADDR# .", - "The address is #POLICE-INFORM-ADDR# .", - "it 's located at #POLICE-INFORM-ADDR# .", - "#POLICE-INFORM-ADDR# is the address", - "They are located at #POLICE-INFORM-ADDR# ." - ], - "Post": [ - "The postcode of the police is #POLICE-INFORM-POST# .", - "The post code is #POLICE-INFORM-POST# .", - "Its postcode is #POLICE-INFORM-POST# .", - "Their postcode is #POLICE-INFORM-POST# ." - ], - "Name": [ - "I think a fun place to visit is #POLICE-INFORM-NAME# .", - "#POLICE-INFORM-NAME# looks good .", - "#POLICE-INFORM-NAME# is available , would that work for you ?", - "we have #POLICE-INFORM-NAME# .", - "#POLICE-INFORM-NAME# is popular among visitors .", - "How about #POLICE-INFORM-NAME# ?", - "What about #POLICE-INFORM-NAME# ?", - "you might want to try the #POLICE-INFORM-NAME# ." - ], - "Phone": [ - "The police phone number is #POLICE-INFORM-PHONE# .", - "Here is the police phone number , #POLICE-INFORM-PHONE# ." - ] - }, - "Hospital-Inform": { - "Addr": [ - "it is located in #HOSPITAL-INFORM-ADDR#", - "adress is #HOSPITAL-INFORM-ADDR#", - "It is on #HOSPITAL-INFORM-ADDR# .", - "their address in our system is listed as #HOSPITAL-INFORM-ADDR# .", - "The address is #HOSPITAL-INFORM-ADDR# .", - "it 's located at #HOSPITAL-INFORM-ADDR# .", - "#HOSPITAL-INFORM-ADDR# is the address", - "They are located at #HOSPITAL-INFORM-ADDR# ." - ], - "Post": [ - "The postcode of the hospital is #HOSPITAL-INFORM-POST# .", - "The post code is #HOSPITAL-INFORM-POST# .", - "Its postcode is #HOSPITAL-INFORM-POST# .", - "Their postcode is #HOSPITAL-INFORM-POST# ." - ], - "Department": [ - "The department of the hospital is #HOSPITAL-INFORM-POST# .", - "The department is #HOSPITAL-INFORM-POST# .", - "Its department is #HOSPITAL-INFORM-POST# .", - "Their department is #HOSPITAL-INFORM-POST# ." - - ], - "Phone": [ - "The hospital phone number is #HOSPITAL-INFORM-PHONE# .", - "Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ." - ] - }, - "Hospital-Request": { - "Department": [ - "What is the name of the hospital department ?", - "What hospital department are you thinking about ?", - "I ' m sorry for the confusion , what hospital department are you interested in ?", - "What hospital department were you thinking of ?", - "Do you know the department of it ?", - "can you give me the department of it ?" - ] - }, "general-bye": { "none": [ "Thank you for using our services .", @@ -1460,4 +1378,4 @@ "You 're welcome . Have a good day !" ] } -} +} \ No newline at end of file diff --git a/convlab/nlg/template/multiwoz/nlg.bck.py b/convlab/nlg/template/multiwoz/nlg.bck.py deleted file mode 100755 index c657f13058e5aa273a7898de113b611be628cf8a..0000000000000000000000000000000000000000 --- a/convlab/nlg/template/multiwoz/nlg.bck.py +++ /dev/null @@ -1,502 +0,0 @@ -import json -import random -import os -from pprint import pprint -import collections -import logging - -import numpy as np -import re - -from convlab.nlg import NLG -from convlab.nlg.template.multiwoz.noise_functions import delete_random_token, random_token_permutation, spelling_noise -from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -from convlab.util import relative_import_module_from_unified_datasets -from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple - -reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') - -def lower_keys(x, d=None): - if isinstance(x, list): - return [lower_keys(v) for v in x] - elif isinstance(x, dict): - #return {k.lower(): lower_keys(v) for k, v in x.items()} - return {SLOT_MAP_TO_UDF.get(k.title(), k.lower()): lower_keys(v) for k, v in x.items()} - else: - #return x - y = re.sub("#([^\-]+)-([^\-]+)-([^\-]+)#", lambda x: "#" + x.group(1) + "-" + x.group(2) + "-" + SLOT_MAP_TO_UDF[d].get(x.group(3).title(), x.group(3)).upper() + "#", x) - return y - - -def read_json(filename): - raw_data = None - with open(filename, 'r') as f: - raw_data = json.load(f) - #return lower_keys(json.load(f), None) - data = {} - for di in raw_data: - new_di = di.lower() - data[new_di] = {} - d, _ = di.split('-') - for s in raw_data[di]: - new_s = SLOT_MAP_TO_UDF[d].get(s.title(), s.lower()) if d in SLOT_MAP_TO_UDF else s.lower() - data[new_di][new_s] = [] - for t in raw_data[di][s]: - new_t = re.sub("#([^\-]+)-([^\-]+)-([^\-]+)#", lambda x: "#" + x.group(1) + "-" + x.group(2) + "-" + SLOT_MAP_TO_UDF[d].get(x.group(3).title(), x.group(3)).upper() + "#" if d in SLOT_MAP_TO_UDF else x.group(3).upper() + "#", t) - data[new_di][new_s].append(new_t) - return data - - -# supported slot -Slot2word = { - 'Fee': 'fee', - 'Addr': 'address', - 'Area': 'area', - 'Stars': 'stars', - 'Internet': 'Internet', - 'Department': 'department', - 'Choice': 'choice', - 'Ref': 'reference number', - 'Food': 'food', - 'Type': 'type', - 'Price': 'price range', - 'Stay': 'stay', - 'Phone': 'phone number', - 'Post': 'postcode', - 'Day': 'day', - 'Name': 'name', - 'Car': 'car type', - 'Leave': 'leave', - 'Time': 'time', - 'Arrive': 'arrive', - 'Ticket': 'ticket', - 'Depart': 'departure', - 'People': 'people', - 'Dest': 'destination', - 'Parking': 'parking', - 'Open': 'open', - 'Id': 'Id', - # 'TrainID': 'TrainID' -} - - -SLOT_MAP_TO_UDF = { - 'Attraction': { - 'Addr': 'address', - 'Post': 'postcode', - 'Fee': 'entrance fee' - }, - 'Hospital': { - 'Post': 'postcode', - 'Addr': 'address' - }, - 'Hotel': { - 'Addr': 'address', - 'Post': 'postcode', - 'Price': 'price range', - 'Stay': 'book stay', - 'Day': 'book day', - 'People': 'book people' - }, - 'Police': { - 'Post': 'postcode', - 'Addr': 'address' - }, - 'Restaurant': { - 'Price': 'price range', - 'Time': 'book time', - 'People': 'book people', - 'Day': 'book day', - 'Addr': 'address', - 'Post': 'postcode' - }, - 'Taxi': { - 'Leave': 'leave at', - 'Arrive': 'arrive by', - 'Dest': 'destination', - 'Depart': 'departure', - 'Car': 'type' - }, - 'Train': { - 'Time': 'duration', - 'Ticket': 'price', - 'Leave': 'leave at', - 'Id': 'train id', - 'Arrive': 'arrive by', - 'Depart': 'departure', - 'People': 'book people', - 'Dest': 'destination' - } -} - - -SLOT_MAP_TO_UDF2 = { - 'Addr': 'address', - 'Area': 'area', - 'Arrive': 'arrive by', - 'Car': 'type', - 'Choice': 'choice', - 'Day': 'day', - 'Depart': 'departure', - 'Department': 'department', - 'Dest': 'destination', - 'Fee': 'entrance fee', - 'Food': 'food', - 'Id': 'train id', - 'Internet': 'internet', - 'Leave': 'leave at', - 'Name': 'name', - 'Open': 'open', - 'Parking': 'parking', - 'People': 'book people', - 'Phone': 'phone', - 'Post': 'postcode', - 'Price': 'price range', - 'Ref': 'ref', - 'Stars': 'stars', - 'Stay': 'book stay', - 'Ticket': 'price', - 'Time': 'time', - 'TrainID': 'train id', - 'Type': 'type' -} - -slot2word = dict((k.lower(), v.lower()) for k, v in Slot2word.items()) - - -class TemplateNLG(NLG): - - def __init__(self, is_user, mode="manual", label_noise=0.0, text_noise=0.0, seed=0): - """ - Args: - is_user: - if dialog_act from user or system - mode: - - `auto`: templates extracted from data without manual modification, may have no match; - - - `manual`: templates with manual modification, sometimes verbose; - - - `auto_manual`: use auto templates first. When fails, use manual templates. - - both template are dict, *_template[dialog_act][slot] is a list of templates. - """ - super().__init__() - self.is_user = is_user - self.mode = mode - self.label_noise = label_noise - self.text_noise = text_noise - if not is_user: - self.label_noise, self.text_noise = 0.0, 0.0 - - print("NLG seed " + str(seed)) - logging.info(f'Building {"user" if is_user else "system"} template NLG module using {mode} templates.') - if self.label_noise > 0.0 or self.text_noise > 0.0: - self.seed = seed - np.random.seed(seed) - random.seed(seed) - logging.info(f'Template NLG will generate {self.label_noise * 100}% noise in values and {self.text_noise * 100}% random text noise.') - - self.load_templates() - - def load_templates(self): - template_dir = os.path.dirname(os.path.abspath(__file__)) - if self.is_user: - if 'manual' in self.mode: - self.manual_user_template = read_json(os.path.join( - template_dir, 'manual_user_template_nlg.json')) - if 'auto' in self.mode: - self.auto_user_template = read_json(os.path.join( - template_dir, 'auto_user_template_nlg.json')) - else: - if 'manual' in self.mode: - self.manual_system_template = read_json(os.path.join( - template_dir, 'manual_system_template_nlg.json')) - if 'auto' in self.mode: - self.auto_system_template = read_json(os.path.join(template_dir, 'auto_system_template_nlg.json')) - logging.info('NLG templates loaded.') - - if self.label_noise > 0.0 and self.is_user: - self.label_map = read_json(os.path.join(template_dir, 'label_maps.json')) - logging.info('NLG value noise label map loaded.') - - def sorted_dialog_act(self, dialog_acts): - new_action_group = {} - for item in dialog_acts: - intent, domain, slot, value = item - if domain not in new_action_group: - new_action_group[domain] = { - 'nooffer': [], 'inform-name': [], 'inform-other': [], 'request': [], 'other': []} - if intent == 'NoOffer': - new_action_group[domain]['nooffer'].append(item) - elif intent == 'Inform' and slot == 'Name': - new_action_group[domain]['inform-name'].append(item) - elif intent == 'Inform': - new_action_group[domain]['inform-other'].append(item) - elif intent == 'request': - new_action_group[domain]['request'].append(item) - else: - new_action_group[domain]['other'].append(item) - - new_action = [] - if 'general' in new_action_group: - new_action += new_action_group['general']['other'] - del new_action_group['general'] - for domain in new_action_group: - for k in ['other', 'request', 'inform-other', 'inform-name', 'nooffer']: - new_action = new_action_group[domain][k] + new_action - return new_action - - def noisy_dialog_acts(self, dialog_acts): - if self.label_noise > 0.0: - noisy_acts = [] - for intent, domain, slot, value in dialog_acts: - if intent == 'Inform': - if value in self.label_map: - if np.random.uniform() < self.label_noise: - value = self.label_map[value] - value = np.random.choice(value) - noisy_acts.append([intent, domain, slot, value]) - return noisy_acts - return dialog_acts - - def generate(self, dialog_acts): - """NLG for Multiwoz dataset - - Args: - dialog_acts - Returns: - generated sentence - """ - print("dialog_acts0:", dialog_acts) - dialog_acts = unified_format(dialog_acts) - print("dialog_acts1:", dialog_acts) - dialog_acts = reverse_da(dialog_acts) - print("dialog_acts2:", dialog_acts) - dialog_acts = act_dict_to_flat_tuple(dialog_acts) - print("dialog_acts3:", dialog_acts) - dialog_acts = self.noisy_dialog_acts( - dialog_acts) if self.is_user else dialog_acts - dialog_acts = self.sorted_dialog_act(dialog_acts) - action = collections.OrderedDict() - for intent, domain, slot, value in dialog_acts: - k = '-'.join([domain.lower(), intent.lower()]) - action.setdefault(k, []) - action[k].append([slot.lower(), value]) - dialog_acts = action - print("dialog_acts4:", dialog_acts) - mode = self.mode - try: - is_user = self.is_user - if mode == 'manual': - if is_user: - template = self.manual_user_template - else: - template = self.manual_system_template - - return self._manual_generate(dialog_acts, template) - - elif mode == 'auto': - if is_user: - template = self.auto_user_template - else: - template = self.auto_system_template - - return self._auto_generate(dialog_acts, template) - - elif mode == 'auto_manual': - if is_user: - template1 = self.auto_user_template - template2 = self.manual_user_template - else: - template1 = self.auto_system_template - template2 = self.manual_system_template - - res = self._auto_generate(dialog_acts, template1) - if res == 'None': - res = self._manual_generate(dialog_acts, template2) - return res - - else: - raise Exception( - "Invalid mode! available mode: auto, manual, auto_manual") - except Exception as e: - print('Error in processing:') - pprint(dialog_acts) - raise e - - def _postprocess(self, sen): - sen_strip = sen.strip() - sen = ''.join([val.capitalize() if i == 0 else val for i, - val in enumerate(sen_strip)]) - if len(sen) > 0 and sen[-1] != '?' and sen[-1] != '.': - sen += '.' - sen += ' ' - return sen - - def _add_random_noise(self, sen): - if self.text_noise > 0.0: - end = sen[-3:] - sen = sen[:-3] - sen = random_token_permutation( - sen, probability=self.text_noise / 2) - sen = delete_random_token(sen, probability=self.text_noise) - sen += end - return sen - - def _add_random_noise(self, sen): - if self.text_noise > 0.0: - end = sen[-3:] - sen = sen[:-3] - sen = spelling_noise(sen, prob=self.text_noise / 2) - # sen = random_token_permutation(sen, probability=self.text_noise / 4) - sen = delete_random_token(sen, probability=self.text_noise / 2) - sen += end - return sen - - def _manual_generate(self, dialog_acts, template): - sentences = '' - for dialog_act, slot_value_pairs in dialog_acts.items(): - intent = dialog_act.split('-') - if 'select' == intent[1]: - slot2values = {} - for slot, value in slot_value_pairs: - slot2values.setdefault(slot, []) - slot2values[slot].append(value) - for slot, values in slot2values.items(): - if slot == 'none': - continue - sentence = 'Do you prefer ' + values[0] - for i, value in enumerate(values[1:]): - if i == (len(values) - 2): - sentence += ' or ' + value - else: - sentence += ' , ' + value - sentence += ' {} ? '.format(slot2word[slot]) - sentences += self._add_random_noise(sentence) - elif 'request' == intent[1]: - for slot, value in slot_value_pairs: - if dialog_act not in template or slot not in template[dialog_act]: - if dialog_act not in template: - print("dialog_act not in template:", dialog_act) - else: - print("slot not in template:", dialog_act, slot) - import pdb - pdb.set_trace() - sentence = 'What is the {} of {} ? '.format( - slot.lower(), dialog_act.split('-')[0].lower()) - sentences += self._add_random_noise(sentence) - else: - sentence = random.choice(template[dialog_act][slot]) - sentence = self._postprocess(sentence) - sentences += self._add_random_noise(sentence) - elif 'general' == intent[0] and dialog_act in template: - sentence = random.choice(template[dialog_act]['none']) - sentence = self._postprocess(sentence) - sentences += self._add_random_noise(sentence) - else: - for slot, value in slot_value_pairs: - if isinstance(value, str): - value_lower = value.lower() - if value in ["do nt care", "do n't care", "dontcare"]: - sentence = 'I don\'t care about the {} of the {}'.format( - slot, dialog_act.split('-')[0]) - elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any': - # user have no preference, any choice is ok - sentence = random.choice([ - "Please pick one for me. ", - "Anyone would be ok. ", - "Just select one for me. " - ]) - elif slot == 'price' and 'same price range' in value_lower: - sentence = random.choice([ - "it just needs to be {} .".format(value), - "Oh , I really need something {} .".format(value), - "I would prefer something that is {} .".format( - value), - "it needs to be {} .".format(value) - ]) - elif slot in ['internet', 'parking'] and value_lower == 'no': - sentence = random.choice([ - "It does n't need to have {} .".format(slot), - "I do n't need free {} .".format(slot), - ]) - elif dialog_act in template and slot in template[dialog_act]: - sentence = random.choice(template[dialog_act][slot]) - if 'not available' in value.lower(): - domain_ = dialog_act.split('-', 1)[0] - slot_ = slot2word.get(slot, None) - slot_ = REF_SYS_DA.get(domain_, {}).get( - slot.title(), None) if not slot_ else slot_ - if slot_: - sentence = f"Sorry, I do not know the {slot_.lower()} of the {domain_.lower()}." - else: - sentence = "Sorry, I do not have that information." - else: - sentence = sentence.replace( - '#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value)) - elif slot == 'notbook': - sentence = random.choice([ - "I do not need to book. ", - "I 'm not looking to make a booking at the moment." - ]) - else: - if slot in slot2word: - if 'not available' in value.lower(): - domain_ = dialog_act.split('-', 1)[0] - slot_ = slot2word[slot] - if slot_: - sentence = f"Sorry, I do not know the {slot_.lower()} of the {domain_.lower()}." - else: - sentence = "Sorry, I do not have that information." - else: - sentence = 'The {} is {} . '.format( - slot2word[slot], str(value)) - else: - sentence = '' - sentence = self._postprocess(sentence) - sentences += self._add_random_noise(sentence) - return sentences.strip() - - def _auto_generate(self, dialog_acts, template): - sentences = '' - for dialog_act, slot_value_pairs in dialog_acts.items(): - key = '' - for s, v in sorted(slot_value_pairs, key=lambda x: x[0]): - key += s + ';' - if dialog_act in template and key in template[dialog_act]: - sentence = random.choice(template[dialog_act][key]) - if 'request' in dialog_act or 'general' in dialog_act: - sentence = self._postprocess(sentence) - sentences += sentence - else: - for s, v in sorted(slot_value_pairs, key=lambda x: x[0]): - if v != 'none': - sentence = sentence.replace( - '#{}-{}#'.format(dialog_act.upper(), s.upper()), v, 1) - sentence = self._postprocess(sentence) - sentences += sentence - else: - return 'None' - return sentences.strip() - - -def example(): - # dialog act - dialog_acts = [['Inform', 'Hotel', 'Area', 'east'], [ - 'Inform', 'Hotel', 'Internet', 'no'], ['welcome', 'general', 'none', 'none']] - #dialog_acts = [['Inform', 'Restaurant', 'NotBook', 'none']] - print(dialog_acts) - - # system model for manual, auto, auto_manual - nlg_sys_manual = TemplateNLG(is_user=False, mode='manual') - nlg_sys_auto = TemplateNLG(is_user=False, mode='auto') - nlg_sys_auto_manual = TemplateNLG(is_user=False, mode='auto_manual') - - # generate - print('manual : ', nlg_sys_manual.generate(dialog_acts)) - print('auto : ', nlg_sys_auto.generate(dialog_acts)) - print('auto_manual : ', nlg_sys_auto_manual.generate(dialog_acts)) - - -if __name__ == '__main__': - example() diff --git a/convlab/nlg/template/multiwoz/nlg.py b/convlab/nlg/template/multiwoz/nlg.py index 5f362ebbabd68ee0b6fab62c692caf0e8da436e9..f83a6db4e2ad6f77f9bdc154cfda7bf2db0ff2c5 100755 --- a/convlab/nlg/template/multiwoz/nlg.py +++ b/convlab/nlg/template/multiwoz/nlg.py @@ -31,33 +31,33 @@ def read_json(filename): # supported slot Slot2word = { - 'Fee': 'entrance fee', + 'Fee': 'fee', 'Addr': 'address', 'Area': 'area', - 'Stars': 'number of stars', - 'Internet': 'internet', + 'Stars': 'stars', + 'Internet': 'Internet', 'Department': 'department', 'Choice': 'choice', 'Ref': 'reference number', 'Food': 'food', 'Type': 'type', 'Price': 'price range', - 'Stay': 'length of the stay', + 'Stay': 'stay', 'Phone': 'phone number', 'Post': 'postcode', 'Day': 'day', 'Name': 'name', 'Car': 'car type', - 'Leave': 'departure time', + 'Leave': 'leave', 'Time': 'time', - 'Arrive': 'arrival time', - 'Ticket': 'ticket price', + 'Arrive': 'arrive', + 'Ticket': 'ticket', 'Depart': 'departure', - 'People': 'number of people', + 'People': 'people', 'Dest': 'destination', 'Parking': 'parking', - 'Open': 'opening hours', - 'Id': 'id', + 'Open': 'open', + 'Id': 'Id', # 'TrainID': 'TrainID' } @@ -271,10 +271,6 @@ class TemplateNLG(NLG): elif 'request' == intent[1]: for slot, value in slot_value_pairs: if dialog_act not in template or slot not in template[dialog_act]: - if dialog_act not in template: - print("WARNING (nlg.py): (User?: %s) dialog_act '%s' not in template!" % (self.is_user, dialog_act)) - else: - print("WARNING (nlg.py): (User?: %s) slot '%s' of dialog_act '%s' not in template!" % (self.is_user, slot, dialog_act)) sentence = 'What is the {} of {} ? '.format( slot.lower(), dialog_act.split('-')[0].lower()) sentences += self._add_random_noise(sentence) @@ -292,7 +288,7 @@ class TemplateNLG(NLG): value_lower = value.lower() if value in ["do nt care", "do n't care", "dontcare"]: sentence = 'I don\'t care about the {} of the {}'.format( - slot2word.get(slot, slot), dialog_act.split('-')[0]) + slot, dialog_act.split('-')[0]) elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any': # user have no preference, any choice is ok sentence = random.choice([ diff --git a/convlab/nlu/jointBERT/dataloader.py b/convlab/nlu/jointBERT/dataloader.py index 2aa04ae3cb2735e55717037524543b9f1b7a8039..d1fcbc7a4864211a9956cedacc7c3479c195733a 100755 --- a/convlab/nlu/jointBERT/dataloader.py +++ b/convlab/nlu/jointBERT/dataloader.py @@ -21,7 +21,7 @@ class Dataloader: self.intent2id = dict([(x, i) for i, x in enumerate(intent_vocab)]) self.id2tag = dict([(i, x) for i, x in enumerate(tag_vocab)]) self.tag2id = dict([(x, i) for i, x in enumerate(tag_vocab)]) - self.tokenizer = BertTokenizer.from_pretrained(pretrained_weights, local_files_only=True) + self.tokenizer = BertTokenizer.from_pretrained(pretrained_weights) self.data = {} self.intent_weight = [1] * len(self.intent2id) diff --git a/convlab/nlu/jointBERT/jointBERT.py b/convlab/nlu/jointBERT/jointBERT.py index 9550c50b51707e29d8cadd0d67d125ee7716f92a..5f73c9aba6808ff92e15d7967bf6a2aa43419c75 100755 --- a/convlab/nlu/jointBERT/jointBERT.py +++ b/convlab/nlu/jointBERT/jointBERT.py @@ -12,7 +12,7 @@ class JointBERT(nn.Module): self.intent_weight = intent_weight if intent_weight is not None else torch.tensor([1.]*intent_dim) print(model_config['pretrained_weights']) - self.bert = BertModel.from_pretrained(model_config['pretrained_weights'], local_files_only=True) + self.bert = BertModel.from_pretrained(model_config['pretrained_weights']) self.dropout = nn.Dropout(model_config['dropout']) self.context = model_config['context'] self.finetune = model_config['finetune'] diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py index 1373919e5861156c87a2ba14d6506d13e0842204..e25fbad1227c4b1f85ae2ae42a8ac899fa61d7b8 100755 --- a/convlab/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab/nlu/jointBERT/multiwoz/nlu.py @@ -74,8 +74,7 @@ class BERTNLU(NLU): for token in token_list: token = token.strip() self.nlp.tokenizer.add_special_case( - #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) - token, [{ORTH: token}]) + token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) logging.info("BERTNLU loaded") def predict(self, utterance, context=list()): diff --git a/convlab/policy/evaluate.py b/convlab/policy/evaluate.py index 4d89faff2a9dd133dffec76e8964d7e427dc27f2..7a692261869f35e587c34a26e425d1489abdcf56 100755 --- a/convlab/policy/evaluate.py +++ b/convlab/policy/evaluate.py @@ -14,7 +14,6 @@ from convlab.policy.rule.multiwoz import RulePolicy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.custom_util import set_seed, get_config, env_config, create_goals, data_goals from tqdm import tqdm -from pprint import pprint def init_logging(log_dir_path, path_suffix=None): @@ -60,10 +59,7 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d policy_sys = GDPL(vectorizer=conf['vectorizer_sys_activated']) elif model_name == "DDPT": from convlab.policy.vtrace_DPT import VTRACE - policy_sys = VTRACE( - is_train=False, vectorizer=conf['vectorizer_sys_activated']) - else: - print("Unknown model name", model_name) + policy_sys = VTRACE(is_train=False, vectorizer=conf['vectorizer_sys_activated']) try: if model_path: @@ -83,7 +79,6 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d if goals_from_data: logging.info("read goals from dataset...") goals = data_goals(dialogues, dataset="multiwoz21", dial_ids_order=0) - else: logging.info("create goals from goal_generator...") goals = create_goals(goal_generator, num_goals=dialogues, @@ -100,21 +95,18 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d task_succ_strict = 0 complete = 0 - # if verbose: - # logging.info("NEW EPISODE!!!!" + "-" * 80) - # logging.info(f"\n Seed: {seed}") - # logging.info(f"GOAL: {sess.evaluator.goal}") - # logging.info("\n") - dialog = [] + if verbose: + logging.info("NEW EPISODE!!!!" + "-" * 80) + logging.info(f"\n Seed: {seed}") + logging.info(f"GOAL: {sess.evaluator.goal}") + logging.info("\n") for i in range(40): sys_response, user_response, session_over, reward = sess.next_turn( sys_response) if verbose: - dialog.append({"usr": user_response}) - dialog.append({"sys": sys_response}) - logging.info(f"usr {user_response}") - logging.info(f"sys {sys_response}") + logging.info(f"USER RESPONSE: {user_response}") + logging.info(f"SYS RESPONSE: {sys_response}") actions += len(sys_response) length = len(sys_response) @@ -133,19 +125,18 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d task_succ = sess.evaluator.task_success() task_succ = sess.evaluator.success task_succ_strict = sess.evaluator.success_strict - complete = sess.evaluator.complete - # TODO check the definision of complete rate - # complete = sess.user_agent.policy.policy.goal.task_complete() - - # if goals_from_data: - # complete = sess.user_agent.policy.policy.goal.task_complete() - # else: - # complete = sess.evaluator.complete + if goals_from_data: + complete = sess.user_agent.policy.policy.goal.task_complete() + else: + complete = sess.evaluator.complete break if verbose: logging.info(f"Complete: {complete}") logging.info(f"Success: {task_succ}") + logging.info(f"Success strict: {task_succ_strict}") + logging.info(f"Return: {total_return}") + logging.info(f"Average actions: {actions / turns}") task_success['Complete'].append(complete) task_success['Success'].append(task_succ) diff --git a/convlab/policy/evaluate_distributed.py b/convlab/policy/evaluate_distributed.py index d58b7dee7ae242292f303d6442fad74e40566e97..1f7b3ffe93c040e6e18aa0ccd88b8c21fbcd178c 100644 --- a/convlab/policy/evaluate_distributed.py +++ b/convlab/policy/evaluate_distributed.py @@ -73,9 +73,7 @@ def sampler(pid, queue, evt, sess, seed_range, goals): if session_over is True: success = sess.evaluator.task_success() - # TODO check the differenct between complete and success - # complete = sess.evaluator.complete - complete = sess.user_agent.policy.policy.goal.task_complete() + complete = sess.evaluator.complete success = sess.evaluator.success success_strict = sess.evaluator.success_strict break diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py deleted file mode 100644 index 87de854970d2701900ba180d2bf15736071e0c1a..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/evaluate.py +++ /dev/null @@ -1,257 +0,0 @@ -import json -import os -import sys -from argparse import ArgumentParser -from pprint import pprint - -import torch -from convlab.nlg.evaluate import fine_SER -from datasets import load_metric - -# from convlab.policy.genTUS.pg.stepGenTUSagent import \ -# stepGenTUSPG as UserPolicy -from convlab.policy.genTUS.stepGenTUS import UserActionPolicy -from tqdm import tqdm - -sys.path.append(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))) - - -def arg_parser(): - parser = ArgumentParser() - parser.add_argument("--model-checkpoint", type=str, help="the model path") - parser.add_argument("--model-weight", type=str, - help="the model weight", default="") - parser.add_argument("--input-file", type=str, help="the testing input file", - default="") - parser.add_argument("--generated-file", type=str, help="the generated results", - default="") - parser.add_argument("--only-action", action="store_true") - parser.add_argument("--dataset", default="multiwoz") - parser.add_argument("--do-semantic", action="store_true", - help="do semantic evaluation") - parser.add_argument("--do-nlg", action="store_true", - help="do nlg generation") - parser.add_argument("--do-golden-nlg", action="store_true", - help="do golden nlg generation") - return parser.parse_args() - - -class Evaluator: - def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False): - self.dataset = dataset - self.model_checkpoint = model_checkpoint - self.model_weight = model_weight - # if model_weight: - # self.usr_policy = UserPolicy( - # self.model_checkpoint, only_action=only_action) - # self.usr_policy.load(model_weight) - # self.usr = self.usr_policy.usr - # else: - self.usr = UserActionPolicy( - model_checkpoint, only_action=only_action, dataset=self.dataset) - self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) - - def generate_results(self, f_eval, golden=False): - in_file = json.load(open(f_eval)) - r = { - "input": [], - "golden_acts": [], - "golden_utts": [], - "gen_acts": [], - "gen_utts": [] - } - for dialog in tqdm(in_file['dialog']): - inputs = dialog["in"] - labels = self.usr._parse_output(dialog["out"]) - if golden: - usr_act = labels["action"] - usr_utt = self.usr.generate_text_from_give_semantic( - inputs, usr_act) - - else: - output = self.usr._parse_output( - self.usr._generate_action(inputs)) - usr_act = self.usr._remove_illegal_action(output["action"]) - usr_utt = output["text"] - r["input"].append(inputs) - r["golden_acts"].append(labels["action"]) - r["golden_utts"].append(labels["text"]) - r["gen_acts"].append(usr_act) - r["gen_utts"].append(usr_utt) - - return r - - def read_generated_result(self, f_eval): - in_file = json.load(open(f_eval)) - r = { - "input": [], - "golden_acts": [], - "golden_utts": [], - "gen_acts": [], - "gen_utts": [] - } - for dialog in tqdm(in_file['dialog']): - for x in dialog: - r[x].append(dialog[x]) - - return r - - def nlg_evaluation(self, input_file=None, generated_file=None, golden=False): - if input_file: - print("Force generation") - gen_r = self.generate_results(input_file, golden) - - elif generated_file: - gen_r = self.read_generated_result(generated_file) - else: - print("You must specify the input_file or the generated_file") - - nlg_eval = { - "golden": golden, - "metrics": {}, - "dialog": [] - } - for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]): - nlg_eval["dialog"].append({ - "input": input, - "golden_acts": golden_act, - "golden_utts": golden_utt, - "gen_acts": gen_act, - "gen_utts": gen_utt - }) - - if golden: - print("Calculate BLEU") - bleu_metric = load_metric("sacrebleu") - labels = [[utt] for utt in gen_r["golden_utts"]] - - bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"], - references=labels, - force=True) - print("bleu_metric", bleu_score) - nlg_eval["metrics"]["bleu"] = bleu_score - - else: - print("Calculate SER") - missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( - gen_r["gen_acts"], gen_r["gen_utts"]) - - print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format( - "genTUSNLG", missing, total, hallucinate, missing/total)) - nlg_eval["metrics"]["SER"] = missing/total - - dir_name = self.model_checkpoint - json.dump(nlg_eval, - open(os.path.join(dir_name, "nlg_eval.json"), 'w'), - indent=2) - return os.path.join(dir_name, "nlg_eval.json") - - def evaluation(self, input_file=None, generated_file=None): - force_prediction = True - if generated_file: - gen_file = json.load(open(generated_file)) - force_prediction = False - if gen_file["golden"]: - force_prediction = True - - if force_prediction: - in_file = json.load(open(input_file)) - dialog_result = [] - gen_acts, golden_acts = [], [] - # scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []} - for dialog in tqdm(in_file['dialog']): - inputs = dialog["in"] - labels = self.usr._parse_output(dialog["out"]) - ans_action = self.usr._remove_illegal_action(labels["action"]) - preds = self.usr._generate_action(inputs) - preds = self.usr._parse_output(preds) - usr_action = self.usr._remove_illegal_action(preds["action"]) - - gen_acts.append(usr_action) - golden_acts.append(ans_action) - - d = {"input": inputs, - "golden_acts": ans_action, - "gen_acts": usr_action} - if "text" in preds: - d["golden_utts"] = labels["text"] - d["gen_utts"] = preds["text"] - # print("pred text", preds["text"]) - - dialog_result.append(d) - else: - gen_acts, golden_acts = [], [] - for dialog in gen_file['dialog']: - gen_acts.append(dialog["gen_acts"]) - golden_acts.append(dialog["golden_acts"]) - dialog_result = gen_file['dialog'] - - scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []} - - for gen_act, golden_act in zip(gen_acts, golden_acts): - s = f1_measure(preds=gen_act, labels=golden_act) - for metric in scores: - scores[metric].append(s[metric]) - - result = {} - for metric in scores: - result[metric] = sum(scores[metric])/len(scores[metric]) - print(f"{metric}: {result[metric]}") - - result["dialog"] = dialog_result - basename = "semantic_evaluation_result" - json.dump(result, open(os.path.join( - self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) - # if self.model_weight: - # json.dump(result, open(os.path.join( - # 'results', f"{basename}.json"), 'w')) - # else: - # json.dump(result, open(os.path.join( - # self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) - - -def f1_measure(preds, labels): - tp = 0 - score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0} - for p in preds: - if p in labels: - tp += 1.0 - if preds: - score["precision"] = tp/len(preds) - if labels: - score["recall"] = tp/len(labels) - if (score["precision"] + score["recall"]) > 0: - score["f1"] = 2*(score["precision"]*score["recall"]) / \ - (score["precision"]+score["recall"]) - if tp == len(preds) and tp == len(labels): - score["turn_acc"] = 1 - return score - - -def main(): - args = arg_parser() - eval = Evaluator(args.model_checkpoint, - args.dataset, - args.model_weight, - args.only_action) - print("model checkpoint", args.model_checkpoint) - print("generated_file", args.generated_file) - print("input_file", args.input_file) - with torch.no_grad(): - if args.do_semantic: - eval.evaluation(args.input_file) - if args.do_nlg: - nlg_result = eval.nlg_evaluation(input_file=args.input_file, - generated_file=args.generated_file, - golden=args.do_golden_nlg) - if args.generated_file: - generated_file = args.generated_file - else: - generated_file = nlg_result - eval.evaluation(args.input_file, - generated_file) - - -if __name__ == '__main__': - main() diff --git a/convlab/policy/genTUS/ppo/vector.py b/convlab/policy/genTUS/ppo/vector.py deleted file mode 100644 index 4c502a46f87582008ff49219f8a14844378b9ed2..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/ppo/vector.py +++ /dev/null @@ -1,148 +0,0 @@ -import json - -import torch -from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph -from convlab.policy.genTUS.token_map import tokenMap -from convlab.policy.tus.unify.Goal import Goal -from transformers import BartTokenizer - - -class stepGenTUSVector: - def __init__(self, model_checkpoint, max_in_len=400, max_out_len=80, allow_general_intent=True): - self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint) - self.vocab = len(self.tokenizer) - self.max_in_len = max_in_len - self.max_out_len = max_out_len - self.token_map = tokenMap(tokenizer=self.tokenizer) - self.token_map.default(only_action=True) - self.kg = KnowledgeGraph(self.tokenizer) - self.mentioned_domain = [] - self.allow_general_intent = allow_general_intent - self.candidate_num = 5 - if self.allow_general_intent: - print("---> allow_general_intent") - - def init_session(self, goal: Goal): - self.goal = goal - self.mentioned_domain = [] - - def encode(self, raw_inputs, max_length, return_tensors="pt", truncation=True): - model_input = self.tokenizer(raw_inputs, - max_length=max_length, - return_tensors=return_tensors, - truncation=truncation, - padding="max_length") - return model_input - - def decode(self, generated_so_far, skip_special_tokens=True): - output = self.tokenizer.decode( - generated_so_far, skip_special_tokens=skip_special_tokens) - return output - - def state_vectorize(self, action, history, turn): - self.goal.update_user_goal(action=action) - inputs = json.dumps({"system": action, - "goal": self.goal.get_goal_list(), - "history": history, - "turn": str(turn)}) - inputs = self.encode(inputs, self.max_in_len) - s_vec, action_mask = inputs["input_ids"][0], inputs["attention_mask"][0] - - return s_vec, action_mask - - def action_vectorize(self, action, s=None): - # action: [[intent, domain, slot, value], ...] - vec = {"vector": torch.tensor([]), "mask": torch.tensor([])} - if s is not None: - raw_inputs = self.decode(s[0]) - self.kg.parse_input(raw_inputs) - - self._append(vec, self._get_id("<s>")) - self._append(vec, self.token_map.get_id('start_json')) - self._append(vec, self.token_map.get_id('start_act')) - - act_len = len(action) - for i, (intent, domain, slot, value) in enumerate(action): - if value == '?': - value = '<?>' - c_idx = {x: None for x in ["intent", "domain", "slot", "value"]} - - if s is not None: - c_idx["intent"] = self._candidate_id(self.kg.candidate( - "intent", allow_general_intent=self.allow_general_intent)) - c_idx["domain"] = self._candidate_id(self.kg.candidate( - "domain", intent=intent)) - c_idx["slot"] = self._candidate_id(self.kg.candidate( - "slot", intent=intent, domain=domain, is_mentioned=self.is_mentioned(domain))) - c_idx["value"] = self._candidate_id(self.kg.candidate( - "value", intent=intent, domain=domain, slot=slot)) - - self._append(vec, self._get_id(intent), c_idx["intent"]) - self._append(vec, self.token_map.get_id('sep_token')) - self._append(vec, self._get_id(domain), c_idx["domain"]) - self._append(vec, self.token_map.get_id('sep_token')) - self._append(vec, self._get_id(slot), c_idx["slot"]) - self._append(vec, self.token_map.get_id('sep_token')) - self._append(vec, self._get_id(value), c_idx["value"]) - - c_idx = [0]*self.candidate_num - c_idx[0] = self.token_map.get_id('end_act')[0] - c_idx[1] = self.token_map.get_id('sep_act')[0] - if i == act_len - 1: - x = self.token_map.get_id('end_act') - else: - x = self.token_map.get_id('sep_act') - - self._append(vec, x, c_idx) - - self._append(vec, self._get_id("</s>")) - - # pad - if len(vec["vector"]) < self.max_out_len: - pad_len = self.max_out_len-len(vec["vector"]) - self._append(vec, x=torch.tensor([1]*pad_len)) - for vec_type in vec: - vec[vec_type] = vec[vec_type].to(torch.int64) - - return vec - - def _append(self, vec, x, candidate=None): - if type(x) is list: - x = torch.tensor(x) - mask = self._mask(x, candidate) - vec["vector"] = torch.cat((vec["vector"], x), dim=-1) - vec["mask"] = torch.cat((vec["mask"], mask), dim=0) - - def _mask(self, idx, c_idx=None): - mask = torch.zeros(len(idx), self.candidate_num) - mask[:, 0] = idx - if c_idx is not None and len(c_idx) > 1: - mask[0, :] = torch.tensor(c_idx) - - return mask - - def _candidate_id(self, candidate): - if len(candidate) > self.candidate_num: - print(f"too many candidates. Max = {self.candidate_num}") - c_idx = [0]*self.candidate_num - for i, idx in enumerate([self._get_id(c)[0] for c in candidate[:self.candidate_num]]): - c_idx[i] = idx - return c_idx - - def _get_id(self, value): - token_id = self.tokenizer(value, add_special_tokens=False) - return token_id["input_ids"] - - def action_devectorize(self, action_id): - return self.decode(action_id) - - def update_mentioned_domain(self, semantic_act): - for act in semantic_act: - domain = act[1] - if domain not in self.mentioned_domain: - self.mentioned_domain.append(domain) - - def is_mentioned(self, domain): - if domain in self.mentioned_domain: - return True - return False diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py deleted file mode 100644 index 902a54068446741b39fc93f264d9a4b06245afe0..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/stepGenTUS.py +++ /dev/null @@ -1,656 +0,0 @@ -import json -import os - -import torch -from transformers import BartTokenizer - -from convlab.policy.genTUS.ppo.vector import stepGenTUSVector -from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel -from convlab.policy.genTUS.token_map import tokenMap -from convlab.policy.genTUS.unify.Goal import Goal -from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph -from convlab.policy.policy import Policy -from convlab.task.multiwoz.goal_generator import GoalGenerator - -DEBUG = False - - -class UserActionPolicy(Policy): - def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs): - self.mode = mode - # if mode == "semantic" and only_action: - # # only generate semantic action in prediction - print("model_checkpoint", model_checkpoint) - self.only_action = only_action - if self.only_action: - print("change mode to semantic because only_action=True") - self.mode = "semantic" - self.max_in_len = 500 - self.max_out_len = 50 if only_action else 200 - max_act_len = kwargs.get("max_act_len", 2) - print("max_act_len", max_act_len) - self.max_action_len = max_act_len - if "max_act_len" in kwargs: - self.max_out_len = 30 * self.max_action_len - print("max_act_len", self.max_out_len) - self.max_turn = max_turn - if mode not in ["semantic", "language"]: - print("Unknown user mode") - - self.reward = {"success": self.max_turn*2, - "fail": self.max_turn*-1} - self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint) - self.device = "cuda" if torch.cuda.is_available() else "cpu" - train_whole_model = kwargs.get("whole_model", True) - self.model = stepGenTUSmodel( - model_checkpoint, train_whole_model=train_whole_model) - self.model.eval() - self.model.to(self.device) - self.model.share_memory() - - self.turn_level_reward = kwargs.get("turn_level_reward", True) - self.cooperative = kwargs.get("cooperative", True) - - dataset = kwargs.get("dataset", "") - self.kg = KnowledgeGraph( - tokenizer=self.tokenizer, - dataset=dataset) - - self.goal_gen = GoalGenerator() - - self.vector = stepGenTUSVector( - model_checkpoint, self.max_in_len, self.max_out_len) - self.norm_reward = False - - self.action_penalty = kwargs.get("action_penalty", False) - self.usr_act_penalize = kwargs.get("usr_act_penalize", 0) - self.goal_list_type = kwargs.get("goal_list_type", "normal") - self.update_mode = kwargs.get("update_mode", "normal") - self.max_history = kwargs.get("max_history", 3) - self.init_session() - - def _update_seq(self, sub_seq: list, pos: int): - for x in sub_seq: - self.seq[0, pos] = x - pos += 1 - - return pos - - def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True): - # TODO no duplicate - self.kg.parse_input(raw_inputs) - model_input = self.vector.encode(raw_inputs, self.max_in_len) - # start token - self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() - pos = self._update_seq([0], 0) - pos = self._update_seq(self.token_map.get_id('start_json'), pos) - pos = self._update_seq(self.token_map.get_id('start_act'), pos) - - # get semantic actions - for act_len in range(self.max_action_len): - pos = self._get_semantic_action( - model_input, pos, mode, allow_general_intent) - - terminate, token_name = self._stop_semantic( - model_input, pos, act_len) - pos = self._update_seq(self.token_map.get_id(token_name), pos) - - if terminate: - break - - if self.only_action: - # return semantic action. Don't need to generate text - return self.vector.decode(self.seq[0, :pos]) - - # TODO remove illegal action here? - - # get text output - pos = self._update_seq(self.token_map.get_id("start_text"), pos) - - text = self._get_text(model_input, pos) - - return text - - def generate_text_from_give_semantic(self, raw_inputs, semantic_action): - self.kg.parse_input(raw_inputs) - model_input = self.vector.encode(raw_inputs, self.max_in_len) - self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() - pos = self._update_seq([0], 0) - pos = self._update_seq(self.token_map.get_id('start_json'), pos) - pos = self._update_seq(self.token_map.get_id('start_act'), pos) - - if len(semantic_action) == 0: - pos = self._update_seq(self.token_map.get_id("end_act"), pos) - - for act_id, (intent, domain, slot, value) in enumerate(semantic_action): - pos = self._update_seq(self.kg._get_token_id(intent), pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - pos = self._update_seq(self.kg._get_token_id(domain), pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - pos = self._update_seq(self.kg._get_token_id(slot), pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - pos = self._update_seq(self.kg._get_token_id(value), pos) - - if act_id == len(semantic_action) - 1: - token_name = "end_act" - else: - token_name = "sep_act" - pos = self._update_seq(self.token_map.get_id(token_name), pos) - pos = self._update_seq(self.token_map.get_id("start_text"), pos) - - raw_output = self._get_text(model_input, pos) - return self._parse_output(raw_output)["text"] - - def _get_text(self, model_input, pos): - s_pos = pos - for i in range(s_pos, self.max_out_len): - next_token_logits = self.model.get_next_token_logits( - model_input, self.seq[:1, :pos]) - next_token = torch.argmax(next_token_logits, dim=-1) - - if self._stop_text(next_token): - # text = self.vector.decode(self.seq[0, s_pos:pos]) - # text = self._norm_str(text) - # return self.vector.decode(self.seq[0, :s_pos]) + text + '"}' - break - - pos = self._update_seq([next_token], pos) - text = self.vector.decode(self.seq[0, s_pos:pos]) - text = self._norm_str(text) - return self.vector.decode(self.seq[0, :s_pos]) + text + '"}' - # TODO return None - - def _stop_text(self, next_token): - if next_token == self.token_map.get_id("end_json")[0]: - return True - elif next_token == self.token_map.get_id("end_json_2")[0]: - return True - - return False - - @staticmethod - def _norm_str(text: str): - text = text.strip('"') - text = text.replace('"', "'") - text = text.replace('\\', "") - return text - - def _stop_semantic(self, model_input, pos, act_length=0): - - outputs = self.model.get_next_token_logits( - model_input, self.seq[:1, :pos]) - tokens = {} - for token_name in ['sep_act', 'end_act']: - tokens[token_name] = { - "token_id": self.token_map.get_id(token_name)} - hash_id = tokens[token_name]["token_id"][0] - tokens[token_name]["score"] = outputs[:, hash_id].item() - - if tokens['end_act']["score"] > tokens['sep_act']["score"]: - terminate = True - else: - terminate = False - - if act_length >= self.max_action_len - 1: - terminate = True - - token_name = "end_act" if terminate else "sep_act" - - return terminate, token_name - - def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True): - - intent = self._get_intent( - model_input, self.seq[:1, :pos], mode, allow_general_intent) - pos = self._update_seq(intent["token_id"], pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - - # get domain - domain = self._get_domain( - model_input, self.seq[:1, :pos], intent["token_name"], mode) - pos = self._update_seq(domain["token_id"], pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - - # get slot - slot = self._get_slot( - model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) - pos = self._update_seq(slot["token_id"], pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - - # get value - - value = self._get_value( - model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode) - pos = self._update_seq(value["token_id"], pos) - - return pos - - def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True): - next_token_logits = self.model.get_next_token_logits( - model_input, generated_so_far) - - return self.kg.get_intent(next_token_logits, mode, allow_general_intent) - - def _get_domain(self, model_input, generated_so_far, intent, mode="max"): - next_token_logits = self.model.get_next_token_logits( - model_input, generated_so_far) - - return self.kg.get_domain(next_token_logits, intent, mode) - - def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"): - next_token_logits = self.model.get_next_token_logits( - model_input, generated_so_far) - is_mentioned = self.vector.is_mentioned(domain) - return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned) - - def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"): - next_token_logits = self.model.get_next_token_logits( - model_input, generated_so_far) - - return self.kg.get_value(next_token_logits, intent, domain, slot, mode) - - def _remove_illegal_action(self, action): - # Transform illegal action to legal action - new_action = [] - for act in action: - if len(act) == 4: - if "<?>" in act[-1]: - act = [act[0], act[1], act[2], "?"] - if act not in new_action: - new_action.append(act) - else: - print("illegal action:", action) - return new_action - - def _parse_output(self, in_str): - in_str = str(in_str) - in_str = in_str.replace('<s>', '').replace( - '<\\s>', '').replace('o"clock', "o'clock") - action = {"action": [], "text": ""} - try: - action = json.loads(in_str) - except: - print("invalid action:", in_str) - print("-"*20) - return action - - def predict(self, sys_act, mode="max", allow_general_intent=True): - # raw_sys_act = sys_act - # sys_act = sys_act[:5] - # update goal - # TODO - allow_general_intent = False - self.model.eval() - - if not self.add_sys_from_reward: - self.goal.update_user_goal(action=sys_act, char="sys") - self.sys_acts.append(sys_act) # for terminate conversation - - # update constraint - self.time_step += 2 - - history = [] - if self.usr_acts: - if self.max_history == 1: - history = self.usr_acts[-1] - else: - history = self.usr_acts[-1*self.max_history:] - inputs = json.dumps({"system": sys_act, - "goal": self.goal.get_goal_list(), - "history": history, - "turn": str(int(self.time_step/2))}) - with torch.no_grad(): - raw_output = self._generate_action( - raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) - output = self._parse_output(raw_output) - self.semantic_action = self._remove_illegal_action(output["action"]) - if not self.only_action: - self.utterance = output["text"] - - # TODO - if self.is_finish(): - self.semantic_action, self.utterance = self._good_bye() - - # if self.is_finish(): - # print("terminated") - - # if self.is_finish(): - # good_bye = self._good_bye() - # self.goal.add_usr_da(good_bye) - # return good_bye - - self.goal.update_user_goal(action=self.semantic_action, char="usr") - self.vector.update_mentioned_domain(self.semantic_action) - self.usr_acts.append(self.semantic_action) - - # if self._usr_terminate(usr_action): - # print("terminated by user") - # self.terminated = True - - del inputs - - if self.mode == "language": - # print("in", sys_act) - # print("out", self.utterance) - return self.utterance - else: - return self.semantic_action - - def init_session(self, goal=None): - self.token_map = tokenMap(tokenizer=self.tokenizer) - self.token_map.default(only_action=self.only_action) - self.time_step = 0 - remove_domain = "police" # remove police domain in inference - - if not goal: - self._new_goal(remove_domain=remove_domain) - else: - self._read_goal(goal) - - self.vector.init_session(goal=self.goal) - print("="*20) - print("goal for GenTUS", self.goal) - - self.terminated = False - self.add_sys_from_reward = False - self.sys_acts = [] - self.usr_acts = [] - self.semantic_action = [] - self.utterance = "" - - def _read_goal(self, data_goal): - self.goal = Goal(goal=data_goal) - - def _new_goal(self, remove_domain="police", domain_len=None): - goal = self.goal_gen.get_user_goal() - self.goal = Goal(goal) - # keep_generate_goal = True - # # domain_len = 1 - # while keep_generate_goal: - # self.goal = Goal(goal_generator=self.goal_gen, - # goal_list_type=self.goal_list_type, - # update_mode=self.update_mode) - # if (domain_len and len(self.goal.domains) != domain_len) or \ - # (remove_domain and remove_domain in self.goal.domains): - # keep_generate_goal = True - # else: - # keep_generate_goal = False - - def load(self, model_path): - self.model.load_state_dict(torch.load( - model_path, map_location=self.device)) - # self.model = BartForConditionalGeneration.from_pretrained( - # model_checkpoint) - - def get_goal(self): - if self.goal.raw_goal is not None: - return self.goal.raw_goal - goal = {} - for domain in self.goal.domain_goals: - if domain not in goal: - goal[domain] = {} - for intent in self.goal.domain_goals[domain]: - if intent == "inform": - slot_type = "info" - elif intent == "request": - slot_type = "reqt" - elif intent == "book": - slot_type = "book" - else: - print("unknown slot type") - if slot_type not in goal[domain]: - goal[domain][slot_type] = {} - for slot, value in self.goal.domain_goals[domain][intent].items(): - goal[domain][slot_type][slot] = value - return goal - - def get_reward(self, sys_response=None): - self.add_sys_from_reward = False if sys_response is None else True - - if self.add_sys_from_reward: - self.goal.update_user_goal(action=sys_response, char="sys") - self.goal.add_sys_da(sys_response) # for evaluation - self.sys_acts.append(sys_response) # for terminate conversation - - if self.is_finish(): - if self.is_success(): - reward = self.reward["success"] - self.success = True - else: - reward = self.reward["fail"] - self.success = False - - else: - reward = -1 - if self.turn_level_reward: - reward += self.turn_reward() - - self.success = None - # if self.action_penalty: - # reward += self._system_action_penalty() - - if self.norm_reward: - reward = (reward - 20)/60 - return reward - - def _system_action_penalty(self): - free_action_len = 3 - if len(self.sys_acts) < 1: - return 0 - # TODO only penalize the slots not in user goal - # else: - # penlaty = 0 - # for i in range(len(self.sys_acts[-1])): - # penlaty += -1*i - # return penlaty - if len(self.sys_acts[-1]) > 3: - return -1*(len(self.sys_acts[-1])-free_action_len) - return 0 - - def turn_reward(self): - r = 0 - r += self._new_act_reward() - r += self._reply_reward() - r += self._usr_act_len() - return r - - def _usr_act_len(self): - last_act = self.usr_acts[-1] - penalty = 0 - if len(last_act) > 2: - penalty = (2-len(last_act))*self.usr_act_penalize - return penalty - - def _new_act_reward(self): - last_act = self.usr_acts[-1] - if last_act != self.semantic_action: - print(f"---> why? last {last_act} usr {self.semantic_action}") - new_act = [] - for act in last_act: - if len(self.usr_acts) < 2: - break - if act[1].lower() == "general": - new_act.append(0) - elif act in self.usr_acts[-2]: - new_act.append(-1) - elif act not in self.usr_acts[-2]: - new_act.append(1) - - return sum(new_act) - - def _reply_reward(self): - if self.cooperative: - return self._cooperative_reply_reward() - else: - return self._non_cooperative_reply_reward() - - def _non_cooperative_reply_reward(self): - r = [] - reqts = [] - infos = [] - reply_len = 0 - max_len = 1 - for act in self.sys_acts[-1]: - if act[0] == "request": - reqts.append([act[1], act[2]]) - for act in self.usr_acts[-1]: - if act[0] == "inform": - infos.append([act[1], act[2]]) - for req in reqts: - if req in infos: - if reply_len < max_len: - r.append(1) - elif reply_len == max_len: - r.append(0) - else: - r.append(-5) - - if r: - return sum(r) - return 0 - - def _cooperative_reply_reward(self): - r = [] - reqts = [] - infos = [] - for act in self.sys_acts[-1]: - if act[0] == "request": - reqts.append([act[1], act[2]]) - for act in self.usr_acts[-1]: - if act[0] == "inform": - infos.append([act[1], act[2]]) - for req in reqts: - if req in infos: - r.append(1) - else: - r.append(-1) - if r: - return sum(r) - return 0 - - def _usr_terminate(self): - for act in self.semantic_action: - if act[0] in ['thank', 'bye']: - return True - return False - - def is_finish(self): - # stop by model generation? - if self._finish_conversation_rule(): - self.terminated = True - return True - elif self._usr_terminate(): - self.terminated = True - return True - self.terminated = False - return False - - def is_success(self): - task_complete = self.goal.task_complete() - # goal_status = self.goal.all_mentioned() - # should mentioned all slots - if task_complete: # and goal_status["complete"] > 0.6: - return True - return False - - def _good_bye(self): - if self.is_success(): - return [['thank', 'general', 'none', 'none']], "thank you. bye" - # if self.mode == "semantic": - # return [['thank', 'general', 'none', 'none']] - # else: - # return "bye" - else: - return [["bye", "general", "None", "None"]], "bye" - if self.mode == "semantic": - return [["bye", "general", "None", "None"]] - return "bye" - - def _finish_conversation_rule(self): - if self.is_success(): - return True - - if self.time_step > self.max_turn: - return True - - if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]): - return True - return False - - def is_terminated(self): - # Is there any action to say? - self.is_finish() - return self.terminated - - -class UserPolicy(Policy): - def __init__(self, - model_checkpoint, - mode="semantic", - only_action=True, - sample=False, - action_penalty=False, - **kwargs): - # self.config = config - # if not os.path.exists(self.config["model_dir"]): - # os.mkdir(self.config["model_dir"]) - # model_downloader(self.config["model_dir"], - # "https://zenodo.org/record/5779832/files/default.zip") - - self.policy = UserActionPolicy( - model_checkpoint, - mode=mode, - only_action=only_action, - action_penalty=action_penalty, - **kwargs) - self.policy.load(os.path.join( - model_checkpoint, "pytorch_model.bin")) - self.sample = sample - - def predict(self, sys_act, mode="max"): - if self.sample: - mode = "sample" - else: - mode = "max" - response = self.policy.predict(sys_act, mode) - return response - - def init_session(self, goal=None): - self.policy.init_session(goal) - - def is_terminated(self): - return self.policy.is_terminated() - - def get_reward(self, sys_response=None): - return self.policy.get_reward(sys_response) - - def get_goal(self): - if hasattr(self.policy, 'get_goal'): - return self.policy.get_goal() - return None - - -if __name__ == "__main__": - import os - - from convlab.dialog_agent import PipelineAgent - # from convlab.nlu.jointBERT.multiwoz import BERTNLU - from convlab.util.custom_util import set_seed - - set_seed(20220220) - # Test semantic level behaviour - model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0' - usr_policy = UserPolicy( - model_checkpoint, - mode="semantic") - usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin")) - usr_nlu = None # BERTNLU() - usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') - print(usr.policy.get_goal()) - - print(usr.response([])) - print(usr.policy.policy.goal.status) - print(usr.response([["request", "attraction", "area", "?"]])) - print(usr.policy.policy.goal.status) - print(usr.response([["request", "attraction", "area", "?"]])) - print(usr.policy.policy.goal.status) diff --git a/convlab/policy/genTUS/stepGenTUSmodel.py b/convlab/policy/genTUS/stepGenTUSmodel.py deleted file mode 100644 index e2eaf7bc5064262808120aac4a9cbe2eb007d863..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/stepGenTUSmodel.py +++ /dev/null @@ -1,114 +0,0 @@ - -import json - -import torch -from torch.nn.functional import softmax, one_hot, cross_entropy - -from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph -from convlab.policy.genTUS.token_map import tokenMap -from convlab.policy.genTUS.utils import append_tokens -from transformers import (BartConfig, BartForConditionalGeneration, - BartTokenizer) - - -class stepGenTUSmodel(BartForConditionalGeneration): - def __init__(self, model_checkpoint, train_whole_model=True, **kwargs): - config = BartConfig.from_pretrained(model_checkpoint) - super().__init__(config, **kwargs) - - self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint) - self.vocab = len(self.tokenizer) - self.kg = KnowledgeGraph(self.tokenizer) - self.action_kg = KnowledgeGraph(self.tokenizer) - self.token_map = tokenMap(self.tokenizer) - # only_action doesn't matter. it is only used for get_log_prob - self.token_map.default(only_action=True) - - if not train_whole_model: - for param in self.parameters(): - param.requires_grad = False - - for param in self.model.decoder.layers[-1].fc1.parameters(): - param.requires_grad = True - for param in self.model.decoder.layers[-1].fc2.parameters(): - param.requires_grad = True - - def get_trainable_param(self): - - return filter( - lambda p: p.requires_grad, self.parameters()) - - def get_next_token_logits(self, model_input, generated_so_far): - input_ids = model_input["input_ids"].to(self.device) - attention_mask = model_input["attention_mask"].to(self.device) - outputs = self.forward( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=generated_so_far, - return_dict=True) - return outputs.logits[:, -1, :] - - def get_log_prob(self, s, a, action_mask, prob_mask): - output = self.forward(input_ids=s, - attention_mask=action_mask, - decoder_input_ids=a) - prob = self._norm_prob(a[:, 1:].long(), - output.logits[:, :-1, :], - prob_mask[:, 1:, :].long()) - return prob - - def _norm_prob(self, a, prob, mask): - prob = softmax(prob, -1) - base = self._base(prob, mask).to(self.device) # [b, seq_len] - prob = (prob*one_hot(a, num_classes=self.vocab)).sum(-1) - prob = torch.log(prob / base) - pad_mask = a != 1 - prob = prob*pad_mask.float() - return prob.sum(-1) - - @staticmethod - def _base(prob, mask): - batch_size, seq_len, dim = prob.shape - base = torch.zeros(batch_size, seq_len) - for b in range(batch_size): - for s in range(seq_len): - temp = [prob[b, s, c] for c in mask[b, s, :] if c > 0] - base[b, s] = torch.sum(torch.tensor(temp)) - return base - - -if __name__ == "__main__": - import os - from convlab.util.custom_util import set_seed - from convlab.policy.genTUS.stepGenTUS import UserActionPolicy - set_seed(0) - device = "cuda" if torch.cuda.is_available() else "cpu" - - model_checkpoint = 'results/genTUS-22-01-31-09-21/' - usr = UserActionPolicy(model_checkpoint=model_checkpoint) - usr.model.load_state_dict(torch.load( - os.path.join(model_checkpoint, "pytorch_model.bin"), map_location=device)) - usr.model.eval() - - test_file = "convlab/policy/genTUS/data/goal_status_validation_v1.json" - data = json.load(open(test_file)) - test_id = 20 - inputs = usr.tokenizer(data["dialog"][test_id]["in"], - max_length=400, - return_tensors="pt", - truncation=True) - - actions = [data["dialog"][test_id]["out"], - data["dialog"][test_id+100]["out"]] - - for action in actions: - action = json.loads(action) - vec = usr.vector.action_vectorize( - action["action"], s=inputs["input_ids"]) - - print({"action": action["action"]}) - print("get_log_prob", usr.model.get_log_prob( - inputs["input_ids"], - torch.unsqueeze(vec["vector"], 0), - inputs["attention_mask"], - torch.unsqueeze(vec["mask"], 0))) diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py deleted file mode 100644 index 7825c2880928c40f68284b0c3199932cd1cfc477..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/token_map.py +++ /dev/null @@ -1,64 +0,0 @@ -import json - - -class tokenMap: - def __init__(self, tokenizer): - self.tokenizer = tokenizer - self.token_name = {} - self.hash_map = {} - self.debug = False - self.default() - - def default(self, only_action=False): - self.format_tokens = { - 'start_json': '{"action": [', # 49643, 10845, 7862, 646 - 'start_act': '["', # 49329 - 'sep_token': '", "', # 1297('",'), 22 - 'sep_act': '"], ["', # 49177 - 'end_act': '"]], "', # 42248, 7479, 22 - 'start_text': 'text": "', # 29015, 7862, 22 - 'end_json': '}', # 24303 - 'end_json_2': '"}' # 48805 - } - if only_action: - self.format_tokens['end_act'] = '"]]}' - for token_name in self.format_tokens: - self.add_token( - token_name, self.format_tokens[token_name]) - - def add_token(self, token_name, value): - if token_name in self.token_name and self.debug: - print(f"---> duplicate token: {token_name}({value})!!!!!!!") - - token_id = self.tokenizer(str(value), add_special_tokens=False)[ - "input_ids"] - self.token_name[token_name] = {"value": value, "token_id": token_id} - # print(token_id) - hash_id = token_id[0] - if hash_id in self.hash_map and self.debug: - print( - f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}") - self.hash_map[hash_id] = { - "name": token_name, "value": value, "token_id": token_id} - - def get_info(self, hash_id): - return self.hash_map[hash_id] - - def get_id(self, token_name): - # workaround - # if token_name not in self.token_name[token_name]: - # self.add_token(token_name, token_name) - return self.token_name[token_name]["token_id"] - - def get_token_value(self, token_name): - return self.token_name[token_name]["value"] - - def token_name_is_in(self, token_name): - if token_name in self.token_name: - return True - return False - - def hash_id_is_in(self, hash_id): - if hash_id in self.hash_map: - return True - return False diff --git a/convlab/policy/genTUS/train_model.py b/convlab/policy/genTUS/train_model.py deleted file mode 100644 index 2162417461d692514e3b27742dbfd477491fc24e..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/train_model.py +++ /dev/null @@ -1,258 +0,0 @@ -import json -import os -import sys -from argparse import ArgumentParser -from datetime import datetime -from pprint import pprint -import numpy as np -import torch -import transformers -from datasets import Dataset, load_metric -from tqdm import tqdm -from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, - BartForConditionalGeneration, BartTokenizer, - DataCollatorForSeq2Seq, Seq2SeqTrainer, - Seq2SeqTrainingArguments) - -sys.path.append(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))) - -os.environ["WANDB_DISABLED"] = "true" - -METRIC = load_metric("sacrebleu") -TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base") -TOKENIZER.add_tokens(["<?>"]) -MAX_IN_LEN = 500 -MAX_OUT_LEN = 500 - - -def arg_parser(): - parser = ArgumentParser() - # data_name, dial_ids_order, split2ratio - parser.add_argument("--model-type", type=str, default="unify", - help="unify or multiwoz") - parser.add_argument("--data-name", type=str, default="multiwoz21", - help="multiwoz21, sgd, tm1, tm2, tm3, sgd+tm, or all") - parser.add_argument("--dial-ids-order", type=int, default=0) - parser.add_argument("--split2ratio", type=float, default=1) - parser.add_argument("--batch-size", type=int, default=16) - parser.add_argument("--model-checkpoint", type=str, - default="facebook/bart-base") - return parser.parse_args() - - -def gentus_compute_metrics(eval_preds): - preds, labels = eval_preds - if isinstance(preds, tuple): - preds = preds[0] - decoded_preds = TOKENIZER.batch_decode( - preds, skip_special_tokens=True, max_length=MAX_OUT_LEN) - - # Replace -100 in the labels as we can't decode them. - labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id) - decoded_labels = TOKENIZER.batch_decode( - labels, skip_special_tokens=True, max_length=MAX_OUT_LEN) - - act, text = postprocess_text(decoded_preds, decoded_labels) - - result = METRIC.compute( - # predictions=decoded_preds, references=decoded_labels) - predictions=text["preds"], references=text["labels"]) - result = {"bleu": result["score"]} - f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"]) - for s in f1_scores: - result[s] = f1_scores[s] - - result = {k: round(v, 4) for k, v in result.items()} - return result - - -def postprocess_text(preds, labels): - act = {"preds": [], "labels": []} - text = {"preds": [], "labels": []} - - for pred, label in zip(preds, labels): - model_output = parse_output(pred.strip()) - label_output = parse_output(label.strip()) - if len(label_output["text"]) < 1: - continue - act["preds"].append(model_output.get("action", [])) - text["preds"].append(model_output.get("text", pred.strip())) - act["labels"].append(label_output["action"]) - text["labels"].append([label_output["text"]]) - - return act, text - - -def parse_output(in_str): - in_str = in_str.replace('<s>', '').replace('<\\s>', '') - try: - output = json.loads(in_str) - except: - # print(f"invalid action {in_str}") - output = {"action": [], "text": ""} - return output - - -def f1_measure(pred_acts, label_acts): - result = {"precision": [], "recall": [], "f1": []} - for pred, label in zip(pred_acts, label_acts): - r = tp_fn_fp(pred, label) - for m in result: - result[m].append(r[m]) - for m in result: - result[m] = sum(result[m])/len(result[m]) - - return result - - -def tp_fn_fp(pred, label): - tp, fn, fp = 0.0, 0.0, 0.0 - precision, recall, f1 = 0, 0, 0 - for p in pred: - if p in label: - tp += 1 - else: - fp += 1 - for l in label: - if l not in pred: - fn += 1 - if (tp+fp) > 0: - precision = tp / (tp+fp) - if (tp+fn) > 0: - recall = tp/(tp+fn) - if (precision + recall) > 0: - f1 = (2*precision*recall)/(precision+recall) - - return {"precision": precision, "recall": recall, "f1": f1} - - -class TrainerHelper: - def __init__(self, tokenizer, max_input_length=500, max_target_length=500): - print("transformers version is: ", transformers.__version__) - self.tokenizer = tokenizer - self.max_input_length = max_input_length - self.max_target_length = max_target_length - self.base_name = "convlab/policy/genTUS" - self.dir_name = "" - - def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1): - # base_name = "convlab/policy/genTUS/unify/data" - if model_type not in ["unify", "multiwoz"]: - print("Unknown model type. Currently only support unify and multiwoz") - self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}" - return os.path.join(self.base_name, model_type, 'data', self.dir_name) - - def get_model_folder(self, model_type): - folder_name = os.path.join( - self.base_name, model_type, "experiments", self.dir_name) - if not os.path.exists(folder_name): - os.makedirs(folder_name) - return folder_name - - def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1): - data_folder = self._get_data_folder( - model_type, data_name, dial_ids_order, split2ratio) - - raw_data = {} - for d_type in ["train", "validation", "test"]: - f_name = os.path.join(data_folder, f"{d_type}.json") - raw_data[d_type] = json.load(open(f_name)) - - tokenized_datasets = {} - for data_type, data in raw_data.items(): - tokenized_datasets[data_type] = Dataset.from_dict( - self._preprocess(data["dialog"])) - - return tokenized_datasets - - def _preprocess(self, examples): - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if isinstance(examples, dict): - examples = [examples] - for example in tqdm(examples): - inputs = self.tokenizer(example["in"], - max_length=self.max_input_length, - truncation=True) - - # Setup the tokenizer for targets - with self.tokenizer.as_target_tokenizer(): - labels = self.tokenizer(example["out"], - max_length=self.max_target_length, - truncation=True) - for key in ["input_ids", "attention_mask"]: - model_inputs[key].append(inputs[key]) - model_inputs["labels"].append(labels["input_ids"]) - - return model_inputs - - -def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"): - tokenizer = TOKENIZER - - train_helper = TrainerHelper( - tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length) - data = train_helper.parse_data(model_type=model_type, - data_name=data_name, - dial_ids_order=dial_ids_order, - split2ratio=split2ratio) - - model = BartForConditionalGeneration.from_pretrained(model_checkpoint) - model.resize_token_embeddings(len(tokenizer)) - fp16 = False - if torch.cuda.is_available(): - fp16 = True - - model_dir = os.path.join( - train_helper.get_model_folder(model_type), - f"{datetime.now().strftime('%y-%m-%d-%H-%M')}") - - args = Seq2SeqTrainingArguments( - model_dir, - evaluation_strategy="epoch", - learning_rate=2e-5, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - weight_decay=0.01, - save_total_limit=2, - num_train_epochs=5, - predict_with_generate=True, - fp16=fp16, - push_to_hub=False, - generation_max_length=max_target_length, - logging_dir=os.path.join(model_dir, 'log') - ) - data_collator = DataCollatorForSeq2Seq( - tokenizer, model=model, padding=True) - - # customize this trainer - trainer = Seq2SeqTrainer( - model=model, - args=args, - train_dataset=data["train"], - eval_dataset=data["test"], - data_collator=data_collator, - tokenizer=tokenizer, - compute_metrics=gentus_compute_metrics) - print("start training...") - trainer.train() - print("saving model...") - trainer.save_model() - - -def main(): - args = arg_parser() - print("---> data_name", args.data_name) - train(model_type=args.model_type, - data_name=args.data_name, - dial_ids_order=args.dial_ids_order, - split2ratio=args.split2ratio, - batch_size=args.batch_size, - max_input_length=MAX_IN_LEN, - max_target_length=MAX_OUT_LEN, - model_checkpoint=args.model_checkpoint) - - -if __name__ == "__main__": - main() - # sgd+tm: 46000 diff --git a/convlab/policy/genTUS/unify/Goal.py b/convlab/policy/genTUS/unify/Goal.py deleted file mode 100644 index f257fa19d2ee5f87b9fc8d16b806705516c864a4..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/unify/Goal.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -The user goal for unify data format -""" -import json -from convlab.policy.tus.unify.Goal import old_goal2list -from convlab.task.multiwoz.goal_generator import GoalGenerator -from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal -from convlab.util.custom_util import slot_mapping -DEF_VAL_UNK = '?' # Unknown -DEF_VAL_DNC = 'dontcare' # Do not care -DEF_VAL_NUL = 'none' # for none -NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, ""] - -NOT_MENTIONED = "not mentioned" -FULFILLED = "fulfilled" -REQUESTED = "requested" -CONFLICT = "conflict" - - -class Goal: - """ User Goal Model Class. """ - - def __init__(self, goal=None): - """ - create new Goal from a dialog or from goal_generator - Args: - goal: can be a list (create from a dialog), an abus goal, or none - """ - self.domains = [] - self.domain_goals = {} - self.status = {} - self.invert_slot_mapping = {v: k for k, v in slot_mapping.items()} - self.raw_goal = None - - self._init_goal_from_data(goal) - self._init_status() - - def __str__(self): - return '-----Goal-----\n' + \ - json.dumps(self.domain_goals, indent=4) + \ - '\n-----Goal-----' - - def _init_goal_from_data(self, goal=None): - if not goal: - goal_gen = GoalGenerator() - old_goal = goal_gen.get_user_goal() - self.raw_goal = old_goal - goal = old_goal2list(old_goal) - - elif isinstance(goal, dict): - self.raw_goal = goal - goal = old_goal2list(goal) - - elif isinstance(goal, ABUS_Goal): - self.raw_goal = goal.domain_goals - goal = old_goal2list(goal.domain_goals) - - # be careful of this order - for domain, intent, slot, value in goal: - if domain == "none": - continue - if domain not in self.domains: - self.domains.append(domain) - self.domain_goals[domain] = {} - if intent not in self.domain_goals[domain]: - self.domain_goals[domain][intent] = {} - - if not value: - if intent == "request": - self.domain_goals[domain][intent][slot] = DEF_VAL_UNK - else: - print( - f"unknown no value intent {domain}, {intent}, {slot}") - else: - self.domain_goals[domain][intent][slot] = value - - def _init_status(self): - for domain, domain_goal in self.domain_goals.items(): - if domain not in self.status: - self.status[domain] = {} - for slot_type, sub_domain_goal in domain_goal.items(): - if slot_type not in self.status[domain]: - self.status[domain][slot_type] = {} - for slot in sub_domain_goal: - if slot not in self.status[domain][slot_type]: - self.status[domain][slot_type][slot] = {} - self.status[domain][slot_type][slot] = { - "value": str(sub_domain_goal[slot]), - "status": NOT_MENTIONED} - - def get_goal_list(self, data_goal=None): - goal_list = [] - if data_goal: - # make sure the order!!! - for domain, intent, slot, _ in data_goal: - status = self._get_status(domain, intent, slot) - value = self.domain_goals[domain][intent][slot] - goal_list.append([intent, domain, slot, value, status]) - return goal_list - else: - for domain, domain_goal in self.domain_goals.items(): - for intent, sub_goal in domain_goal.items(): - for slot, value in sub_goal.items(): - status = self._get_status(domain, intent, slot) - goal_list.append([intent, domain, slot, value, status]) - - return goal_list - - def _get_status(self, domain, intent, slot): - if domain not in self.status: - return NOT_MENTIONED - if intent not in self.status[domain]: - return NOT_MENTIONED - if slot not in self.status[domain][intent]: - return NOT_MENTIONED - return self.status[domain][intent][slot]["status"] - - def task_complete(self): - """ - Check that all requests have been met - Returns: - (boolean): True to accomplish. - """ - for domain, domain_goal in self.status.items(): - if domain not in self.domain_goals: - continue - for slot_type, sub_domain_goal in domain_goal.items(): - if slot_type not in self.domain_goals[domain]: - continue - for slot, status in sub_domain_goal.items(): - if slot not in self.domain_goals[domain][slot_type]: - continue - # for strict success, turn this on - if status["status"] in [NOT_MENTIONED, CONFLICT]: - if status["status"] == CONFLICT and slot in ["arrive by", "leave at"]: - continue - return False - if "?" in status["value"]: - return False - - return True - - # TODO change to update()? - def update_user_goal(self, action, char="usr"): - # update request and booked - if char == "usr": - self._user_action_update(action) - elif char == "sys": - self._system_action_update(action) - else: - print("!!!UNKNOWN CHAR!!!") - - def _user_action_update(self, action): - # no need to update user goal - for intent, domain, slot, _ in action: - goal_intent = self._check_slot_and_intent(domain, slot) - if not goal_intent: - continue - # fulfilled by user - if is_inform(intent): - self._set_status(goal_intent, domain, slot, FULFILLED) - # requested by user - if is_request(intent): - self._set_status(goal_intent, domain, slot, REQUESTED) - - def _system_action_update(self, action): - for intent, domain, slot, value in action: - goal_intent = self._check_slot_and_intent(domain, slot) - if not goal_intent: - continue - # fulfill request by system - if is_inform(intent) and is_request(goal_intent): - self._set_status(goal_intent, domain, slot, FULFILLED) - self._set_goal(goal_intent, domain, slot, value) - - if is_inform(intent) and is_inform(goal_intent): - # fulfill infrom by system - if value == self.domain_goals[domain][goal_intent][slot]: - self._set_status(goal_intent, domain, slot, FULFILLED) - # conflict system inform - else: - self._set_status(goal_intent, domain, slot, CONFLICT) - # requested by system - if is_request(intent) and is_inform(goal_intent): - self._set_status(goal_intent, domain, slot, REQUESTED) - - def _set_status(self, intent, domain, slot, status): - self.status[domain][intent][slot]["status"] = status - - def _set_goal(self, intent, domain, slot, value): - # old_value = self.domain_goals[domain][intent][slot] - self.domain_goals[domain][intent][slot] = value - self.status[domain][intent][slot]["value"] = value - # print( - # f"updating user goal {intent}-{domain}-{slot} {old_value}-> {value}") - - def _check_slot_and_intent(self, domain, slot): - not_found = "" - if domain not in self.domain_goals: - return not_found - for intent in self.domain_goals[domain]: - if slot in self.domain_goals[domain][intent]: - return intent - return not_found - - -def is_inform(intent): - if "inform" in intent: - return True - return False - - -def is_request(intent): - if "request" in intent: - return True - return False - - -def transform_data_act(data_action): - action_list = [] - for _, dialog_act in data_action.items(): - for act in dialog_act: - value = act.get("value", "") - if not value: - if "request" in act["intent"]: - value = "?" - else: - value = "none" - action_list.append( - [act["intent"], act["domain"], act["slot"], value]) - return action_list diff --git a/convlab/policy/genTUS/unify/build_data.py b/convlab/policy/genTUS/unify/build_data.py deleted file mode 100644 index 50873a1d4b6ffaa8ee49c84a4b088ae56ad13554..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/unify/build_data.py +++ /dev/null @@ -1,211 +0,0 @@ -import json -import os -import sys -from argparse import ArgumentParser - -from tqdm import tqdm - -from convlab.policy.genTUS.unify.Goal import Goal, transform_data_act -from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset - - -sys.path.append(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))) - - -def arg_parser(): - parser = ArgumentParser() - parser.add_argument("--dataset", type=str, default="multiwoz21", - help="the dataset, such as multiwoz21, sgd, tm1, tm2, and tm3.") - parser.add_argument("--dial-ids-order", type=int, default=0) - parser.add_argument("--split2ratio", type=float, default=1) - parser.add_argument("--random-order", action="store_true") - parser.add_argument("--no-status", action="store_true") - parser.add_argument("--add-history", action="store_true") - parser.add_argument("--remove-domain", type=str, default="") - - return parser.parse_args() - -class DataBuilder: - def __init__(self, dataset='multiwoz21'): - self.dataset = dataset - - def setup_data(self, - raw_data, - random_order=False, - no_status=False, - add_history=False, - remove_domain=None): - examples = {data_split: {"dialog": []} for data_split in raw_data} - - for data_split, dialogs in raw_data.items(): - for dialog in tqdm(dialogs, ascii=True): - example = self._one_dialog(dialog=dialog, - add_history=add_history, - random_order=random_order, - no_status=no_status) - examples[data_split]["dialog"] += example - - return examples - - def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False): - example = [] - history = [] - - data_goal = self.norm_domain_goal(create_goal(dialog)) - if not data_goal: - return example - user_goal = Goal(goal=data_goal) - - for turn_id in range(0, len(dialog["turns"]), 2): - sys_act = self._get_sys_act(dialog, turn_id) - - user_goal.update_user_goal(action=sys_act, char="sys") - usr_goal_str = self._user_goal_str(user_goal, data_goal, random_order, no_status) - - usr_act = self.norm_domain(transform_data_act( - dialog["turns"][turn_id]["dialogue_acts"])) - user_goal.update_user_goal(action=usr_act, char="usr") - - # change value "?" to "<?>" - usr_act = self._modify_act(usr_act) - - in_str = self._dump_in_str(sys_act, usr_goal_str, history, turn_id, add_history) - out_str = self._dump_out_str(usr_act, dialog["turns"][turn_id]["utterance"]) - - history.append(usr_act) - if usr_act: - example.append({"in": in_str, "out": out_str}) - - return example - - def _get_sys_act(self, dialog, turn_id): - sys_act = [] - if turn_id > 0: - sys_act = self.norm_domain(transform_data_act( - dialog["turns"][turn_id - 1]["dialogue_acts"])) - return sys_act - - def _user_goal_str(self, user_goal, data_goal, random_order, no_status): - if random_order: - usr_goal_str = user_goal.get_goal_list() - else: - usr_goal_str = user_goal.get_goal_list(data_goal=data_goal) - - if no_status: - usr_goal_str = self._remove_status(usr_goal_str) - return usr_goal_str - - def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history): - in_str = {} - in_str["system"] = self._modify_act(sys_act) - in_str["goal"] = usr_goal_str - if add_history: - h = [] - if history: - h = history[-3:] - in_str["history"] = h - in_str["turn"] = str(int(turn_id/2)) - - return json.dumps(in_str) - - def _dump_out_str(self, usr_act, text): - out_str = {"action": usr_act, "text": text} - return json.dumps(out_str) - - @staticmethod - def _norm_intent(intent): - if intent in ["inform_intent", "negate_intent", "affirm_intent", "request_alts"]: - return f"_{intent}" - return intent - - def norm_domain(self, x): - if not x: - return x - norm_result = [] - # print(x) - for intent, domain, slot, value in x: - if "_" in domain: - domain = domain.split('_')[0] - if not domain: - domain = "none" - if not slot: - slot = "none" - if not value: - if intent == "request": - value = "<?>" - else: - value = "none" - norm_result.append([self._norm_intent(intent), domain, slot, value]) - return norm_result - - def norm_domain_goal(self, x): - if not x: - return x - norm_result = [] - # take care of the order! - for domain, intent, slot, value in x: - if "_" in domain: - domain = domain.split('_')[0] - if not domain: - domain = "none" - if not slot: - slot = "none" - if not value: - if intent == "request": - value = "<?>" - else: - value = "none" - norm_result.append([domain, self._norm_intent(intent), slot, value]) - return norm_result - - @staticmethod - def _remove_status(goal_list): - new_list = [[goal[0], goal[1], goal[2], goal[3]] - for goal in goal_list] - return new_list - - @staticmethod - def _modify_act(act): - new_act = [] - for i, d, s, value in act: - if value == "?": - new_act.append([i, d, s, "<?>"]) - else: - new_act.append([i, d, s, value]) - return new_act - - -if __name__ == "__main__": - args = arg_parser() - - base_name = "convlab/policy/genTUS/unify/data" - dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}" - folder_name = os.path.join(base_name, dir_name) - remove_domain = args.remove_domain - - if not os.path.exists(folder_name): - os.makedirs(folder_name) - - dataset = load_experiment_dataset( - data_name=args.dataset, - dial_ids_order=args.dial_ids_order, - split2ratio=args.split2ratio) - data_builder = DataBuilder(dataset=args.dataset) - data = data_builder.setup_data( - raw_data=dataset, - random_order=args.random_order, - no_status=args.no_status, - add_history=args.add_history, - remove_domain=remove_domain) - - for data_type in data: - if remove_domain: - file_name = os.path.join( - folder_name, - f"no{remove_domain}_{data_type}.json") - else: - file_name = os.path.join( - folder_name, - f"{data_type}.json") - json.dump(data[data_type], open(file_name, 'w'), indent=2) diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py deleted file mode 100644 index 68af13e481fe4799dfc2a6f3763b526611eabd9c..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/unify/knowledge_graph.py +++ /dev/null @@ -1,252 +0,0 @@ -import json -from random import choices - -from convlab.policy.genTUS.token_map import tokenMap - -from transformers import BartTokenizer - -DEBUG = False -DATASET = "unify" - - -class KnowledgeGraph: - def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"): - print("dataset", dataset) - self.debug = DEBUG - self.tokenizer = tokenizer - - if "multiwoz" in dataset: - self.domain_intent = ["inform", "request"] - self.general_intent = ["thank", "bye"] - # use sgd dataset intents as default - else: - self.domain_intent = ["_inform_intent", - "_negate_intent", - "_affirm_intent", - "inform", - "request", - "affirm", - "negate", - "select", - "_request_alts"] - self.general_intent = ["thank_you", "goodbye"] - - self.general_domain = "none" - self.kg_map = {"intent": tokenMap(tokenizer=self.tokenizer)} - - for intent in self.domain_intent + self.general_intent: - self.kg_map["intent"].add_token(intent, intent) - - self.init() - - def init(self): - for map_type in ["domain", "slot", "value"]: - self.kg_map[map_type] = tokenMap(tokenizer=self.tokenizer) - self.add_token("<?>", "value") - - def parse_input(self, in_str): - self.init() - inputs = json.loads(in_str) - self.sys_act = inputs["system"] - self.user_goal = {} - self._add_none_domain() - for intent, domain, slot, value, _ in inputs["goal"]: - self._update_user_goal(intent, domain, slot, value, source="goal") - - for intent, domain, slot, value in self.sys_act: - self._update_user_goal(intent, domain, slot, value, source="sys") - - def _add_none_domain(self): - self.user_goal["none"] = {"none": "none"} - # add slot - self.add_token("none", "domain") - self.add_token("none", "slot") - self.add_token("none", "value") - - def _update_user_goal(self, intent, domain, slot, value, source="goal"): - - if value == "?": - value = "<?>" - - if intent == "request" and source == "sys": - value = "dontcare" # user can "dontcare" system request - - if source == "sys" and intent != "request": - return - - if domain not in self.user_goal: - self.user_goal[domain] = {} - self.user_goal[domain]["none"] = ["none"] - self.add_token(domain, "domain") - self.add_token("none", "slot") - self.add_token("none", "value") - - if slot not in self.user_goal[domain]: - self.user_goal[domain][slot] = [] - self.add_token(domain, "slot") - - if value not in self.user_goal[domain][slot]: - value = json.dumps(str(value))[1:-1] - self.user_goal[domain][slot].append(value) - value = value.replace('"', "'") - self.add_token(value, "value") - - def add_token(self, token_name, map_type): - if map_type == "value": - token_name = token_name.replace('"', "'") - if not self.kg_map[map_type].token_name_is_in(token_name): - self.kg_map[map_type].add_token(token_name, token_name) - - def _get_max_score(self, outputs, candidate_list, map_type): - score = {} - if not candidate_list: - print(f"ERROR: empty candidate list for {map_type}") - score[1] = {"token_id": self._get_token_id( - "none"), "token_name": "none"} - - for x in candidate_list: - hash_id = self._get_token_id(x)[0] - s = outputs[:, hash_id].item() - score[s] = {"token_id": self._get_token_id(x), - "token_name": x} - return score - - def _select(self, score, mode="max"): - probs = [s for s in score] - if mode == "max": - s = max(probs) - elif mode == "sample": - s = choices(probs, weights=probs, k=1) - s = s[0] - - else: - print("unknown select mode") - - return s - - def _get_max_domain_token(self, outputs, candidates, map_type, mode="max"): - score = self._get_max_score(outputs, candidates, map_type) - s = self._select(score, mode) - token_id = score[s]["token_id"] - token_name = score[s]["token_name"] - - return {"token_id": token_id, "token_name": token_name} - - def candidate(self, candidate_type, **kwargs): - if "intent" in kwargs: - intent = kwargs["intent"] - if candidate_type == "intent": - allow_general_intent = kwargs.get("allow_general_intent", True) - if allow_general_intent: - return self.domain_intent + self.general_intent - else: - return self.domain_intent - elif candidate_type == "domain": - if intent in self.general_intent: - return [self.general_domain] - else: - return [d for d in self.user_goal] - elif candidate_type == "slot": - if intent in self.general_intent: - return ["none"] - else: - return self._filter_slot(intent, kwargs["domain"], kwargs["is_mentioned"]) - else: - if intent in self.general_intent: - return ["none"] - elif intent.lower() == "request": - return ["<?>"] - else: - return self._filter_value(intent, kwargs["domain"], kwargs["slot"]) - - def get_intent(self, outputs, mode="max", allow_general_intent=True): - # return intent, token_id_list - # TODO request? - canidate_list = self.candidate( - "intent", allow_general_intent=allow_general_intent) - score = self._get_max_score(outputs, canidate_list, "intent") - s = self._select(score, mode) - - return score[s] - - def get_domain(self, outputs, intent, mode="max"): - if intent in self.general_intent: - token_name = self.general_domain - token_id = self.tokenizer(token_name, add_special_tokens=False) - token_map = {"token_id": token_id['input_ids'], - "token_name": token_name} - - elif intent in self.domain_intent: - # [d for d in self.user_goal] - domain_list = self.candidate("domain", intent=intent) - token_map = self._get_max_domain_token( - outputs=outputs, candidates=domain_list, map_type="domain", mode=mode) - else: - if self.debug: - print("unknown intent", intent) - - return token_map - - def get_slot(self, outputs, intent, domain, mode="max", is_mentioned=False): - if intent in self.general_intent: - token_name = "none" - token_id = self.tokenizer(token_name, add_special_tokens=False) - token_map = {"token_id": token_id['input_ids'], - "token_name": token_name} - - elif intent in self.domain_intent: - slot_list = self.candidate( - candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned) - token_map = self._get_max_domain_token( - outputs=outputs, candidates=slot_list, map_type="slot", mode=mode) - - return token_map - - def get_value(self, outputs, intent, domain, slot, mode="max"): - if intent in self.general_intent or slot.lower() == "none": - token_name = "none" - token_id = self.tokenizer(token_name, add_special_tokens=False) - token_map = {"token_id": token_id['input_ids'], - "token_name": token_name} - - elif intent.lower() == "request": - token_name = "<?>" - token_id = self.tokenizer(token_name, add_special_tokens=False) - token_map = {"token_id": token_id['input_ids'], - "token_name": token_name} - - elif intent in self.domain_intent: - # TODO should not none ? - # value_list = [v for v in self.user_goal[domain][slot]] - value_list = self.candidate( - candidate_type="value", intent=intent, domain=domain, slot=slot) - - token_map = self._get_max_domain_token( - outputs=outputs, candidates=value_list, map_type="value", mode=mode) - - return token_map - - def _filter_slot(self, intent, domain, is_mentioned=True): - slot_list = [] - for slot in self.user_goal[domain]: - value_list = self._filter_value(intent, domain, slot) - if len(value_list) > 0: - slot_list.append(slot) - if not is_mentioned and intent.lower() != "request": - slot_list.append("none") - return slot_list - - def _filter_value(self, intent, domain, slot): - value_list = [v for v in self.user_goal[domain][slot]] - if "none" in value_list: - value_list.remove("none") - if intent.lower() != "request": - if "?" in value_list: - value_list.remove("?") - if "<?>" in value_list: - value_list.remove("<?>") - # print(f"{intent}-{domain}-{slot}= {value_list}") - return value_list - - def _get_token_id(self, token): - return self.tokenizer(token, add_special_tokens=False)["input_ids"] diff --git a/convlab/policy/genTUS/utils.py b/convlab/policy/genTUS/utils.py deleted file mode 100644 index 39c822dd35f985790d53a2834e8b6fe437864f24..0000000000000000000000000000000000000000 --- a/convlab/policy/genTUS/utils.py +++ /dev/null @@ -1,5 +0,0 @@ -import torch - - -def append_tokens(tokens, new_token, device): - return torch.cat((tokens, torch.tensor([new_token]).to(device)), dim=1) diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json index 98aad8ddbcf50060c0d38a640f1ac1bb2d7a5c7e..a4a24598d7eac283bcc05d21b0f695cad90b6433 100644 --- a/convlab/policy/ppo/semantic_level_config.json +++ b/convlab/policy/ppo/semantic_level_config.json @@ -1,6 +1,6 @@ { "model": { - "load_path": "convlab/policy/ppo/pretrained_models/supervised", + "load_path": "", "use_pretrained_initialisation": false, "pretrained_load_path": "", "batchsz": 1000, @@ -40,4 +40,4 @@ } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/ppo/semantic_level_config_eval.json b/convlab/policy/ppo/semantic_level_config_eval.json deleted file mode 100644 index be03421a84f5cd48f04ef255901942225a84e6ac..0000000000000000000000000000000000000000 --- a/convlab/policy/ppo/semantic_level_config_eval.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "model": { - "load_path": "", - "use_pretrained_initialisation": false, - "pretrained_load_path": "", - "batchsz": 1000, - "seed": 0, - "epoch": 0, - "eval_frequency": 5, - "process_num": 4, - "sys_semantic_to_usr": false, - "num_eval_dialogues": 500 - }, - "vectorizer_sys": { - "uncertainty_vector_mul": { - "class_path": "convlab.policy.vector.vector_binary.VectorBinary", - "ini_params": { - "use_masking": true, - "manually_add_entity_names": false, - "seed": 0 - } - } - }, - "nlu_sys": {}, - "dst_sys": { - "RuleDST": { - "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", - "ini_params": {} - } - }, - "sys_nlg": {}, - "nlu_usr": {}, - "dst_usr": {}, - "policy_usr": { - "RulePolicy": { - "class_path": "convlab.policy.rule.multiwoz.RulePolicy", - "ini_params": { - "character": "usr" - } - } - }, - "usr_nlg": {} -} diff --git a/convlab/policy/ppo/setsumbt_end_baseline_config_eval.json b/convlab/policy/ppo/setsumbt_end_baseline_config_eval.json deleted file mode 100644 index b4fd62aa810a67e1a2219f70ba630bedab249431..0000000000000000000000000000000000000000 --- a/convlab/policy/ppo/setsumbt_end_baseline_config_eval.json +++ /dev/null @@ -1,65 +0,0 @@ -{ - "model": { - "load_path": "supervised", - "pretrained_load_path": "", - "use_pretrained_initialisation": false, - "batchsz": 1000, - "seed": 0, - "epoch": 0, - "eval_frequency": 5, - "process_num": 2, - "num_eval_dialogues": 500, - "sys_semantic_to_usr": false - }, - "vectorizer_sys": { - "uncertainty_vector_mul": { - "class_path": "convlab.policy.vector.vector_uncertainty.VectorUncertainty", - "ini_params": { - "use_masking": false, - "manually_add_entity_names": false, - "seed": 0, - "use_confidence_scores": true, - "use_state_total_uncertainty": true - } - } - }, - "nlu_sys": {}, - "dst_sys": { - "setsumbt-mul": { - "class_path": "convlab.dst.setsumbt.SetSUMBTTracker", - "ini_params": { - "model_path": "https://zenodo.org/record/5497808/files/setsumbt_end.zip", - "return_confidence_scores": true, - "return_belief_state_entropy": true - } - } - }, - "sys_nlg": { - "TemplateNLG": { - "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", - "ini_params": { - "is_user": false - } - } - }, - "nlu_usr": {}, - "dst_usr": {}, - "policy_usr": { - "RulePolicy": { - "class_path": "convlab.policy.rule.multiwoz.RulePolicy", - "ini_params": { - "character": "usr" - } - } - }, - "usr_nlg": { - "TemplateNLG": { - "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", - "ini_params": { - "is_user": true, - "label_noise": 0.0, - "text_noise": 0.0 - } - } - } -} diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index 1d453c9ac395d981a0adebe912be5266ec6bd0e1..9f00d3ddfd635e52c0084b0989ae46763aec4de1 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -34,6 +34,7 @@ except RuntimeError: def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): + """ This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple processes. @@ -108,6 +109,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): def sample(env, policy, batchsz, process_num, seed): + """ Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return @@ -135,8 +137,7 @@ def sample(env, policy, batchsz, process_num, seed): evt = mp.Event() processes = [] for i in range(process_num): - process_args = (i, queue, evt, env, policy, - process_batchsz, train_seeds[i]) + process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i]) processes.append(mp.Process(target=sampler, args=process_args)) for p in processes: # set the process as daemon, and it will be killed once the main process is stoped. @@ -189,41 +190,31 @@ if __name__ == '__main__': help="Set level for logger") parser.add_argument("--save_eval_dials", type=bool, default=False, help="Flag for saving dialogue_info during evaluation") - parser.add_argument("--save_path", type=str, default=None, - help="Custom save path other than the path of this script") path = parser.parse_args().path seed = parser.parse_args().seed mode = parser.parse_args().mode save_eval = parser.parse_args().save_eval_dials - custom_save_path = parser.parse_args().save_path - if custom_save_path: - logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ - init_logging(custom_save_path, mode) - else: - logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ - init_logging(os.path.dirname(os.path.abspath(__file__)), mode) + logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ + init_logging(os.path.dirname(os.path.abspath(__file__)), mode) args = [('model', 'seed', seed)] if seed is not None else list() environment_config = load_config_file(path) - save_config(vars(parser.parse_args()), - environment_config, config_save_path) + save_config(vars(parser.parse_args()), environment_config, config_save_path) conf = get_config(path, args) seed = conf['model']['seed'] logging.info('Train seed is ' + str(seed)) set_seed(seed) - policy_sys = PPO(True, seed=conf['model']['seed'], - vectorizer=conf['vectorizer_sys_activated']) + policy_sys = PPO(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated']) # Load model if conf['model']['use_pretrained_initialisation']: logging.info("Loading supervised model checkpoint.") - policy_sys.load_from_pretrained( - conf['model'].get('pretrained_load_path', "")) + policy_sys.load_from_pretrained(conf['model'].get('pretrained_load_path', "")) elif conf['model']['load_path']: try: policy_sys.load(conf['model']['load_path']) @@ -251,8 +242,7 @@ if __name__ == '__main__': logging.info(f"Evaluating at start - {time_now}" + '-'*60) time_now = time.time() - eval_dict = eval_policy(conf, policy_sys, env, sess, - save_eval, log_save_path) + eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path) logging.info(f"Finished evaluating, time spent: {time.time() - time_now}") for key in eval_dict: @@ -267,15 +257,13 @@ if __name__ == '__main__': for i in range(conf['model']['epoch']): idx = i + 1 # print("Epoch :{}".format(str(idx))) - update(env, policy_sys, conf['model']['batchsz'], - idx, conf['model']['process_num'], seed=seed) + update(env, policy_sys, conf['model']['batchsz'], idx, conf['model']['process_num'], seed=seed) if idx % conf['model']['eval_frequency'] == 0 and idx != 0: time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) logging.info(f"Evaluating at Epoch: {idx} - {time_now}" + '-'*60) - eval_dict = eval_policy( - conf, policy_sys, env, sess, save_eval, log_save_path) + eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path) best_complete_rate, best_success_rate, best_return = \ save_best(policy_sys, best_complete_rate, best_success_rate, best_return, @@ -283,8 +271,7 @@ if __name__ == '__main__': eval_dict["avg_return"], save_path) policy_sys.save(save_path, "last") for key in eval_dict: - tb_writer.add_scalar( - key, eval_dict[key], idx * conf['model']['batchsz']) + tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz']) logging.info("End of Training: " + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) @@ -293,9 +280,5 @@ if __name__ == '__main__': f.write(str(datetime.now() - begin_time)) f.close() - if custom_save_path: - move_finished_training(dir_path, os.path.join( - custom_save_path, "finished_experiments")) - else: - move_finished_training(dir_path, os.path.join( - os.path.dirname(os.path.abspath(__file__)), "finished_experiments")) + move_finished_training(dir_path, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "finished_experiments")) diff --git a/convlab/policy/ppo/trippy_eval.json b/convlab/policy/ppo/trippy_eval.json deleted file mode 100644 index d7f6ff5dba432bf519fc1bc1e13c425d8f897dfb..0000000000000000000000000000000000000000 --- a/convlab/policy/ppo/trippy_eval.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "model": { - "load_path": "finished_experiments/experiment_2022-09-08-14-53-25/save/best_ppo", - "pretrained_load_path": "", - "use_pretrained_initialisation": false, - "batchsz": 1000, - "seed": 0, - "epoch": 0, - "eval_frequency": 5, - "process_num": 1, - "num_eval_dialogues": 500, - "sys_semantic_to_usr": false - }, - "vectorizer_sys": { - "uncertainty_vector_mul": { - "class_path": "convlab.policy.vector.vector_binary.VectorBinary", - "ini_params": { - "use_masking": true, - "manually_add_entity_names": true, - "seed": 0 - } - } - }, - "nlu_sys": {}, - "dst_sys": { - "TripPy": { - "class_path": "convlab.dst.trippy.multiwoz.TRIPPY", - "ini_params": { - "model_type": "roberta", - "model_path": "/gpfs/project/heckmi/trippy/multiwoz21/trippy.convlab3/results.42/checkpoint-28392", - "nlu_path": "/home/heckmi/models/bert_multiwoz_all_context.zip" - } - } - }, - "sys_nlg": {}, - "nlu_usr": {}, - "dst_usr": {}, - "policy_usr": { - "RulePolicy": { - "class_path": "convlab.policy.rule.multiwoz.RulePolicy", - "ini_params": { - "character": "usr" - } - } - }, - "usr_nlg": {} -} diff --git a/convlab/policy/ppo/trippy_train.json b/convlab/policy/ppo/trippy_train.json deleted file mode 100644 index 248811acb2c22c18dd66b54787740968067a703c..0000000000000000000000000000000000000000 --- a/convlab/policy/ppo/trippy_train.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "model": { - "load_path": "/home/heckmi/models/supervised", - "pretrained_load_path": "", - "use_pretrained_initialisation": false, - "batchsz": 1000, - "seed": 0, - "epoch": 50, - "eval_frequency": 5, - "process_num": 1, - "num_eval_dialogues": 500, - "sys_semantic_to_usr": false - }, - "vectorizer_sys": { - "uncertainty_vector_mul": { - "class_path": "convlab.policy.vector.vector_binary.VectorBinary", - "ini_params": { - "use_masking": true, - "manually_add_entity_names": true, - "seed": 0 - } - } - }, - "nlu_sys": {}, - "dst_sys": { - "TripPy": { - "class_path": "convlab.dst.trippy.multiwoz.TRIPPY", - "ini_params": { - "model_type": "roberta", - "model_path": "/gpfs/project/heckmi/trippy/multiwoz21/trippy.convlab3/results.42/checkpoint-28392", - "nlu_path": "/home/heckmi/models/bert_multiwoz_all_context.zip" - } - } - }, - "sys_nlg": {}, - "nlu_usr": {}, - "dst_usr": {}, - "policy_usr": { - "RulePolicy": { - "class_path": "convlab.policy.rule.multiwoz.RulePolicy", - "ini_params": { - "character": "usr" - } - } - }, - "usr_nlg": {} -} diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json index 445fa9278552a9353529d8eca6dc8b2c6c169fd9..cc0216122dd473ab270fe52617621732d5b2034c 100644 --- a/convlab/policy/ppo/tus_semantic_level_config.json +++ b/convlab/policy/ppo/tus_semantic_level_config.json @@ -1,11 +1,11 @@ { "model": { - "load_path": "convlab/policy/ppo/pretrained_models/supervised", + "load_path": "convlab/policy/mle/experiments/experiment_2022-05-23-14-08-43/save/supervised", "use_pretrained_initialisation": false, "pretrained_load_path": "", "batchsz": 1000, "seed": 0, - "epoch": 200, + "epoch": 50, "eval_frequency": 5, "process_num": 4, "sys_semantic_to_usr": false, @@ -35,9 +35,9 @@ "TUSPolicy": { "class_path": "convlab.policy.tus.unify.TUS.UserPolicy", "ini_params": { - "config": "convlab/policy/tus/unify/exp/multiwoz.json" + "config": "convlab/policy/tus/unify/exp/all.json" } } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/tus/unify/Goal.py b/convlab/policy/tus/unify/Goal.py index 610469bb4246cf6c16ef4271c89827cab6421437..82a6e7a9799b4537495cf35e57f6f52711e78af8 100644 --- a/convlab/policy/tus/unify/Goal.py +++ b/convlab/policy/tus/unify/Goal.py @@ -1,9 +1,6 @@ import time import json -from convlab.policy.tus.unify.util import split_slot_name, slot_name_map -from convlab.util.custom_util import slot_mapping - -from random import sample, shuffle +from convlab.policy.tus.unify.util import split_slot_name from pprint import pprint DEF_VAL_UNK = '?' # Unknown DEF_VAL_DNC = 'dontcare' # Do not care @@ -30,35 +27,6 @@ def isTimeFormat(input): return False -def old_goal2list(goal: dict, reorder=False) -> list: - goal_list = [] - for domain in goal: - for slot_type in ['info', 'book', 'reqt']: - if slot_type not in goal[domain]: - continue - temp = [] - for slot in goal[domain][slot_type]: - s = slot - if slot in slot_name_map: - s = slot_name_map[slot] - elif slot in slot_name_map[domain]: - s = slot_name_map[domain][slot] - # domain, intent, slot, value - if slot_type in ['info', 'book']: - i = "inform" - v = goal[domain][slot_type][slot] - else: - i = "request" - v = DEF_VAL_UNK - s = slot_mapping.get(s, s) - temp.append([domain, i, s, v]) - shuffle(temp) - goal_list = goal_list + temp - # shuffle_goal = goal_list[:1] + sample(goal_list[1:], len(goal_list)-1) - # return shuffle_goal - return goal_list - - class Goal(object): """ User Goal Model Class. """ @@ -133,7 +101,6 @@ class Goal(object): if self.domain_goals[domain]["reqt"][slot] == DEF_VAL_UNK: # print(f"not fulfilled request{domain}-{slot}") return False - return True def init_local_id(self): @@ -202,10 +169,6 @@ class Goal(object): def _update_status(self, action: list, char: str): for intent, domain, slot, value in action: - if slot == "arrive by": - slot = "arriveBy" - elif slot == "leave at": - slot = "leaveAt" if domain not in self.status: self.status[domain] = {} # update info @@ -217,10 +180,6 @@ class Goal(object): def _update_goal(self, action: list, char: str): # update requt slots in goal for intent, domain, slot, value in action: - if slot == "arrive by": - slot = "arriveBy" - elif slot == "leave at": - slot = "leaveAt" if "info" not in intent: continue if self._check_update_request(domain, slot) and value != "?": diff --git a/convlab/policy/tus/unify/TUS.py b/convlab/policy/tus/unify/TUS.py index de98e6d1dfea7d4486d56a60dee5662f745c2df9..09c50672fb889ffd3e965ec0e90d2dee30b2c360 100644 --- a/convlab/policy/tus/unify/TUS.py +++ b/convlab/policy/tus/unify/TUS.py @@ -5,18 +5,19 @@ from copy import deepcopy import torch from convlab.policy.policy import Policy +from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import ( + act_dict_to_flat_tuple, unified_format) from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction from convlab.policy.tus.unify.Goal import Goal from convlab.policy.tus.unify.usermanager import BinaryFeature -from convlab.policy.tus.unify.util import create_goal, split_slot_name +from convlab.policy.tus.unify.util import (create_goal, int2onehot, + metadata2state, parse_dialogue_act, + parse_user_goal, split_slot_name) from convlab.util import (load_dataset, relative_import_module_from_unified_datasets) from convlab.util.custom_util import model_downloader -from convlab.task.multiwoz.goal_generator import GoalGenerator -from convlab.policy.tus.unify.Goal import old_goal2list -from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal - - +from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA +from pprint import pprint reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets( 'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) @@ -50,7 +51,7 @@ class UserActionPolicy(Policy): self.user = TransformerActionPrediction(self.config).to(device=DEVICE) if pretrain: model_path = os.path.join( - self.config["model_dir"], "model-non-zero")# self.config["model_name"]) + self.config["model_dir"], self.config["model_name"]) print(f"loading model from {model_path}...") self.load(model_path) self.user.eval() @@ -60,7 +61,6 @@ class UserActionPolicy(Policy): self.reward = {"success": 40, "fail": -20} self.sys_acts = [] - self.goal_gen = GoalGenerator() def _no_offer(self, system_in): for intent, domain, slot, value in system_in: @@ -127,14 +127,13 @@ class UserActionPolicy(Policy): self.topic = 'NONE' remove_domain = "police" # remove police domain in inference + # if not goal: + # self.new_goal(remove_domain=remove_domain) + # else: + # self.read_goal(goal) if not goal: - old_goal = self.goal_gen.get_user_goal() - goal_list = old_goal2list(old_goal) - goal = Goal(goal_list) - - elif type(goal) == ABUS_Goal: - goal_list = old_goal2list(goal.domain_goals) - goal = Goal(goal_list) + data = load_dataset(self.dataset, 0) + goal = Goal(create_goal(data["test"][0])) self.read_goal(goal) self.feat_handler.initFeatureHandeler(self.goal) @@ -156,15 +155,15 @@ class UserActionPolicy(Policy): else: self.goal = Goal(goal=data_goal) - # def new_goal(self, remove_domain="police", domain_len=None): - # keep_generate_goal = True - # while keep_generate_goal: - # self.goal = Goal(goal_generator=self.goal_gen) - # if (domain_len and len(self.goal.domains) != domain_len) or \ - # (remove_domain and remove_domain in self.goal.domains): - # keep_generate_goal = True - # else: - # keep_generate_goal = False + def new_goal(self, remove_domain="police", domain_len=None): + keep_generate_goal = True + while keep_generate_goal: + self.goal = Goal(goal_generator=self.goal_gen) + if (domain_len and len(self.goal.domains) != domain_len) or \ + (remove_domain and remove_domain in self.goal.domains): + keep_generate_goal = True + else: + keep_generate_goal = False def load(self, model_path=None): self.user.load_state_dict(torch.load(model_path, map_location=DEVICE)) @@ -404,32 +403,16 @@ class UserPolicy(Policy): self.config = json.load(open(config)) else: self.config = config - self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz' + self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}' if not os.path.exists(self.config["model_dir"]): # os.mkdir(self.config["model_dir"]) model_downloader(os.path.dirname(self.config["model_dir"]), "https://zenodo.org/record/5779832/files/default.zip") - self.slot2dbattr = { - 'open hours': 'openhours', - 'price range': 'pricerange', - 'arrive by': 'arriveBy', - 'leave at': 'leaveAt', - 'train id': 'trainID' - } - self.dbattr2slot = {} - for k,v in self.slot2dbattr.items(): - self.dbattr2slot[v] = k self.policy = UserActionPolicy(self.config) def predict(self, state): - raw_act = self.policy.predict(state) - act = [] - for intent, domain, slot, value in raw_act: - if slot in self.dbattr2slot: - slot = self.dbattr2slot[slot] - act.append([intent, domain, slot, value]) - return act + return self.policy.predict(state) def init_session(self, goal=None): self.policy.init_session(goal) @@ -441,6 +424,13 @@ class UserPolicy(Policy): return self.policy.get_reward() def get_goal(self): + slot2dbattr = { + 'open hours': 'openhours', + 'price range': 'pricerange', + 'arrive by': 'arriveBy', + 'leave at': 'leaveAt', + 'train id': 'trainID' + } if hasattr(self.policy, 'get_goal'): # workaround: convert goal to old format multiwoz_goal = {} @@ -459,8 +449,8 @@ class UserPolicy(Policy): multiwoz_goal[domain]["book"] = {} norm_slot = slot.split(' ')[-1] multiwoz_goal[domain]["book"][norm_slot] = value - elif slot in self.slot2dbattr: - norm_slot = self.slot2dbattr[slot] + elif slot in slot2dbattr: + norm_slot = slot2dbattr[slot] multiwoz_goal[domain][slot_type][norm_slot] = value else: multiwoz_goal[domain][slot_type][slot] = value diff --git a/convlab/policy/tus/unify/usermanager.py b/convlab/policy/tus/unify/usermanager.py index 3192d3d578ca3698424b16f47933d6863208f77c..640da7b9ee01414f04700e6aaa2976f9c916914f 100644 --- a/convlab/policy/tus/unify/usermanager.py +++ b/convlab/policy/tus/unify/usermanager.py @@ -97,7 +97,7 @@ class TUSDataManager(Dataset): action_list, user_goal, cur_state, usr_act) domain_label = self.feature_handler.domain_label( user_goal, usr_act) - # pre_state = user_goal.update(action=usr_act, char="user") # trick? + pre_state = user_goal.update(action=usr_act, char="user") feature["id"].append(dialog["dialogue_id"]) feature["input"].append(input_feature) feature["mask"].append(mask) diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py index d65f72a06e181e66bfe0d7ac0c60f0c03a56ad43..1978f489534f74839b11dee333e232cbc601f270 100644 --- a/convlab/policy/tus/unify/util.py +++ b/convlab/policy/tus/unify/util.py @@ -1,49 +1,9 @@ from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal -from convlab.util import load_dataset - import json NOT_MENTIONED = "not mentioned" -def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1): - ratio = {'train': split2ratio, 'validation': split2ratio} - if data_name == "all" or data_name == "sgd+tm" or data_name == "tm": - print("merge all datasets...") - if data_name == "all": - all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"] - if data_name == "sgd+tm": - all_dataset = ["sgd", "tm1", "tm2", "tm3"] - if data_name == "tm": - all_dataset = ["tm1", "tm2", "tm3"] - - datasets = {} - for name in all_dataset: - datasets[name] = load_dataset( - name, - dial_ids_order=dial_ids_order, - split2ratio=ratio) - raw_data = merge_dataset(datasets, all_dataset[0]) - - else: - print(f"load single dataset {data_name}/{split2ratio}") - raw_data = load_dataset(data_name, - dial_ids_order=dial_ids_order, - split2ratio=ratio) - return raw_data - - -def merge_dataset(datasets, data_name): - data_split = [x for x in datasets[data_name]] - raw_data = {} - for data_type in data_split: - raw_data[data_type] = [] - for dataname, dataset in datasets.items(): - print(f"merge {dataname}...") - raw_data[data_type] += dataset[data_type] - return raw_data - - def int2onehot(index, output_dim=6, remove_zero=False): one_hot = [0] * output_dim if remove_zero: @@ -129,6 +89,50 @@ def get_booking_domain(slot, value, all_values, domain_list): return found +def act2slot(intent, domain, slot, value, all_values): + + if domain not in UsrDa2Goal: + # print(f"Not handle domain {domain}") + return "" + + if domain == "booking": + slot = SysDa2Goal[domain][slot] + domain = get_booking_domain(slot, value, all_values) + return f"{domain}-{slot}" + + elif domain in UsrDa2Goal: + if slot in SysDa2Goal[domain]: + slot = SysDa2Goal[domain][slot] + elif slot in UsrDa2Goal[domain]: + slot = UsrDa2Goal[domain][slot] + elif slot in SysDa2Goal["booking"]: + slot = SysDa2Goal["booking"][slot] + # else: + # print( + # f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}") + + return f"{domain}-{slot}" + + print("strange!!!") + print(intent, domain, slot, value) + + return "" + + +def get_user_history(dialog, all_values): + turn_num = len(dialog) + mentioned_slot = [] + for turn_id in range(0, turn_num, 2): + usr_act = parse_dialogue_act( + dialog[turn_id]["dialog_act"]) + for intent, domain, slot, value in usr_act: + slot_name = act2slot( + intent, domain.lower(), slot.lower(), value.lower(), all_values) + if slot_name not in mentioned_slot: + mentioned_slot.append(slot_name) + return mentioned_slot + + def update_config_file(file_name, attribute, value): with open(file_name, 'r') as config_file: config = json.load(config_file) @@ -143,7 +147,7 @@ def update_config_file(file_name, attribute, value): def create_goal(dialog) -> list: # a list of {'intent': ..., 'domain': ..., 'slot': ..., 'value': ...} dicts = [] - for turn in dialog['turns']: + for i, turn in enumerate(dialog['turns']): # print(turn['speaker']) # assert (i % 2 == 0) == (turn['speaker'] == 'user') # if i % 2 == 0: @@ -201,45 +205,6 @@ def split_slot_name(slot_name): return tokens[0], '-'.join(tokens[1:]) -# copy from data.unified_datasets.multiwoz21 -slot_name_map = { - 'addr': "address", - 'post': "postcode", - 'pricerange': "price range", - 'arrive': "arrive by", - 'arriveby': "arrive by", - 'leave': "leave at", - 'leaveat': "leave at", - 'depart': "departure", - 'dest': "destination", - 'fee': "entrance fee", - 'open': 'open hours', - 'car': "type", - 'car type': "type", - 'ticket': 'price', - 'trainid': 'train id', - 'id': 'train id', - 'people': 'book people', - 'stay': 'book stay', - 'none': '', - 'attraction': { - 'price': 'entrance fee' - }, - 'hospital': {}, - 'hotel': { - 'day': 'book day', 'price': "price range" - }, - 'restaurant': { - 'day': 'book day', 'time': 'book time', 'price': "price range" - }, - 'taxi': {}, - 'train': { - 'day': 'day', 'time': "duration" - }, - 'police': {}, - 'booking': {} -} - if __name__ == "__main__": print(split_slot_name("restaurant-search-location")) print(split_slot_name("sports-day.match")) diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py deleted file mode 100644 index ee4582467925b76b2c1472f445e057808cfda9fc..0000000000000000000000000000000000000000 --- a/convlab/policy/vector/vector_uncertainty.py +++ /dev/null @@ -1,166 +0,0 @@ -# -*- coding: utf-8 -*- -import sys -import numpy as np -import logging - -from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da -from convlab.policy.vector.vector_binary import VectorBinary - - -class VectorUncertainty(VectorBinary): - """Vectorise state and state uncertainty predictions""" - - def __init__(self, - dataset_name: str = 'multiwoz21', - character: str = 'sys', - use_masking: bool = False, - manually_add_entity_names: bool = True, - seed: str = 0, - use_confidence_scores: bool = True, - confidence_thresholds: dict = None, - use_state_total_uncertainty: bool = False, - use_state_knowledge_uncertainty: bool = False): - """ - Args: - dataset_name: Name of environment dataset - character: Character of the agent (sys/usr) - use_masking: If true certain actions are masked during devectorisation - manually_add_entity_names: If true inform entity name actions are manually added - seed: Seed - use_confidence_scores: If true confidence scores are used in state vectorisation - confidence_thresholds: If true confidence thresholds are used in database querying - use_state_total_uncertainty: If true state entropy is added to the state vector - use_state_knowledge_uncertainty: If true state mutual information is added to the state vector - """ - - self.use_confidence_scores = use_confidence_scores - self.use_state_total_uncertainty = use_state_total_uncertainty - self.use_state_knowledge_uncertainty = use_state_knowledge_uncertainty - if confidence_thresholds is not None: - self.setup_uncertain_query(confidence_thresholds) - - super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) - - def get_state_dim(self): - self.belief_state_dim = 0 - - for domain in self.ontology['state']: - for slot in self.ontology['state'][domain]: - # Dim 1 - indicator/confidence score - # Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc) - slot_dim = 1 if not self.use_state_total_uncertainty else 2 - slot_dim += 1 if self.use_state_knowledge_uncertainty else 0 - self.belief_state_dim += slot_dim - - self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \ - len(self.db_domains) + 6 * len(self.db_domains) + 1 - - # Add thresholds for db_queries - def setup_uncertain_query(self, confidence_thresholds): - self.use_confidence_scores = True - self.confidence_thresholds = confidence_thresholds - logging.info('DB Search uncertainty activated.') - - def dbquery_domain(self, domain): - """ - query entities of specified domain - Args: - domain string: - domain to query - Returns: - entities list: - list of entities of the specified domain - """ - # Get all user constraints - constraints = {slot: value for slot, value in self.state[domain].items() - if slot and value not in ['dontcare', - "do n't care", "do not care"]} if domain in self.state else dict() - - # Remove constraints for which the uncertainty is high - if self.confidence_scores is not None and self.use_confidence_scores and self.confidence_thresholds is not None: - # Collect threshold values for each domain-slot pair - threshold = self.confidence_thresholds.get(domain, dict()) - threshold = {slot: threshold.get(slot, 0.05) for slot in constraints} - # Get confidence scores for each constraint - probs = self.confidence_scores.get(domain, dict()) - probs = {slot: probs.get(slot, {}).get('inform', 1.0) for slot in constraints} - - # Filter out constraints for which confidence is lower than threshold - constraints = {slot: value for slot, value in constraints.items() if probs[slot] >= threshold[slot]} - - return self.db.query(domain, constraints.items(), topk=10) - - def vectorize_user_act(self, state): - """Return confidence scores for the user actions""" - self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None - action = state['user_action'] if self.character == 'sys' else state['system_action'] - opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) - opp_act_vec = np.zeros(self.da_opp_dim) - for da in opp_action: - if da in self.opp2vec: - if 'belief_state_probs' in state and self.use_confidence_scores: - domain, intent, slot, value = da.split('-') - if domain in state['belief_state_probs']: - slot = slot if slot else 'none' - if slot in state['belief_state_probs'][domain]: - prob = state['belief_state_probs'][domain][slot] - elif slot.lower() in state['belief_state_probs'][domain]: - prob = state['belief_state_probs'][domain][slot.lower()] - else: - prob = dict() - - if intent in prob: - prob = float(prob[intent]) - else: - prob = 1.0 - else: - prob = 1.0 - else: - prob = 1.0 - opp_act_vec[self.opp2vec[da]] = prob - - return opp_act_vec - - def vectorize_belief_state(self, state, domain_active_dict): - belief_state = np.zeros(self.belief_state_dim) - i = 0 - for domain in self.belief_domains: - if self.use_confidence_scores and 'belief_state_probs' in state: - for slot in state['belief_state'][domain]: - prob = None - if slot in state['belief_state_probs'][domain]: - prob = state['belief_state_probs'][domain][slot] - prob = prob['inform'] if 'inform' in prob else None - if prob: - belief_state[i] = float(prob) - i += 1 - else: - for slot, value in state['belief_state'][domain].items(): - if value and value != 'not mentioned': - belief_state[i] = 1. - i += 1 - - if 'active_domains' in state: - domain_active = state['active_domains'][domain] - domain_active_dict[domain] = domain_active - else: - if [slot for slot, value in state['belief_state'][domain].items() if value]: - domain_active_dict[domain] = True - - # Add knowledge and/or total uncertainty to the belief state - if self.use_state_total_uncertainty and 'entropy' in state: - for domain in self.belief_domains: - for slot in state['belief_state'][domain]: - if slot in state['entropy'][domain]: - belief_state[i] = float(state['entropy'][domain][slot]) - i += 1 - - if self.use_state_knowledge_uncertainty and 'mutual_information' in state: - for domain in self.belief_domains: - for slot in state['belief_state'][domain]['semi']: - if slot in state['mutual_information'][domain]: - belief_state[i] = float(state['mutual_information'][domain][slot]) - i += 1 - - return belief_state, domain_active_dict diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f7e5e9cdf949b082b9ae752c09d2eafeb322bfb Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt new file mode 100644 index 0000000000000000000000000000000000000000..5c6193163b6a1fc788b959eff2d0153b3a4bf7cb Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt differ diff --git a/convlab/util/analysis_tool/analyzer.py b/convlab/util/analysis_tool/analyzer.py index 61a863ca49601344a5b3652d670ba40d217268ce..5163bee1e151b3d8855a481222d753a862f1f783 100755 --- a/convlab/util/analysis_tool/analyzer.py +++ b/convlab/util/analysis_tool/analyzer.py @@ -109,7 +109,6 @@ class Analyzer: step = 0 - print('Goal:', sess.evaluator.goal, file=flog) # print('init goal:',file=f) # # print(sess.evaluator.goal, file=f) # # pprint(sess.evaluator.goal) @@ -156,7 +155,7 @@ class Analyzer: if session_over: break - task_success = sess.evaluator.task_success(f=flog) + task_success = sess.evaluator.task_success() task_complete = sess.user_agent.policy.policy.goal.task_complete() book_rate = sess.evaluator.book_rate() stats = sess.evaluator.inform_F1() @@ -164,7 +163,6 @@ class Analyzer: if task_success: print('Dialogue succesfully completed!', file=flog) else: - print("> final goal:", sess.evaluator.goal, file=flog) print('Dialogue NOT completed succesfully!', file=flog) percentage = sess.evaluator.final_goal_analyze() diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py index 80fa115a2b1d6bb46f8366c9cbfc6d0911f75849..aad6c4cd9f2898a57ae59c47c9cf1bd5629c8fec 100644 --- a/convlab/util/custom_util.py +++ b/convlab/util/custom_util.py @@ -25,7 +25,8 @@ import signal slot_mapping = {"pricerange": "price range", "post": "postcode", "arriveBy": "arrive by", "leaveAt": "leave at", - "Id": "train id", "ref": "reference", "trainID": "train id"} + "Id": "trainid", "ref": "reference"} + sys.path.append(os.path.dirname(os.path.dirname( os.path.dirname(os.path.abspath(__file__))))) @@ -102,8 +103,7 @@ def load_config_file(filepath: str = None) -> dict: def save_config(terminal_args, config_file_args, config_save_path, policy_config=None): config_save_path = os.path.join(config_save_path, f'config_saved.json') - args_dict = {"args": terminal_args, - "config": config_file_args, "policy_config": policy_config} + args_dict = {"args": terminal_args, "config": config_file_args, "policy_config": policy_config} json.dump(args_dict, open(config_save_path, 'w')) @@ -165,29 +165,26 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do goals = [] for seed in range(1000, 1000 + conf['model']['num_eval_dialogues']): set_seed(seed) - goal = create_goals(goal_generator, 1, - single_domain_goals, allowed_domains) + goal = create_goals(goal_generator, 1, single_domain_goals, allowed_domains) goals.append(goal[0]) if conf['model']['process_num'] == 1: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ - select_acts, offer_acts, recommend_acts = evaluate(sess, - num_dialogues=conf['model']['num_eval_dialogues'], - sys_semantic_to_usr=conf['model'][ - 'sys_semantic_to_usr'], - save_flag=save_eval, save_path=log_save_path, goals=goals) - - total_acts = book_acts + inform_acts + request_acts + \ - select_acts + offer_acts + recommend_acts + select_acts, offer_acts, recommend_acts = evaluate(sess, + num_dialogues=conf['model']['num_eval_dialogues'], + sys_semantic_to_usr=conf['model'][ + 'sys_semantic_to_usr'], + save_flag=save_eval, save_path=log_save_path, goals=goals) + + total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts else: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ select_acts, offer_acts, recommend_acts = \ evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])), conf['model']['process_num'], goals) - total_acts = book_acts + inform_acts + request_acts + \ - select_acts + offer_acts + recommend_acts + total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts task_success_gathered = {} for task_dict in task_success: @@ -199,18 +196,12 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do policy_sys.is_train = True - mean_complete, err_complete = np.average(complete_rate), np.std( - complete_rate) / np.sqrt(len(complete_rate)) - mean_success, err_success = np.average(success_rate), np.std( - success_rate) / np.sqrt(len(success_rate)) - mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std( - success_rate_strict) / np.sqrt(len(success_rate_strict)) - mean_return, err_return = np.average(avg_return), np.std( - avg_return) / np.sqrt(len(avg_return)) - mean_turns, err_turns = np.average( - turns), np.std(turns) / np.sqrt(len(turns)) - mean_actions, err_actions = np.average(avg_actions), np.std( - avg_actions) / np.sqrt(len(avg_actions)) + mean_complete, err_complete = np.average(complete_rate), np.std(complete_rate) / np.sqrt(len(complete_rate)) + mean_success, err_success = np.average(success_rate), np.std(success_rate) / np.sqrt(len(success_rate)) + mean_success_strict, err_success_strict = np.average(success_rate_strict), np.std(success_rate_strict) / np.sqrt(len(success_rate_strict)) + mean_return, err_return = np.average(avg_return), np.std(avg_return) / np.sqrt(len(avg_return)) + mean_turns, err_turns = np.average(turns), np.std(turns) / np.sqrt(len(turns)) + mean_actions, err_actions = np.average(avg_actions), np.std(avg_actions) / np.sqrt(len(avg_actions)) logging.info(f"Complete: {mean_complete}+-{round(err_complete, 2)}, " f"Success: {mean_success}+-{round(err_success, 2)}, " @@ -381,9 +372,6 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False complete = sess.evaluator.complete task_succ = sess.evaluator.success task_succ_strict = sess.evaluator.success_strict - if not task_succ: - print("%s | %s %s | GOAL: %s" % - (seed - 1000, task_succ, complete, sess.evaluator.goal)) break else: complete = 0 @@ -428,11 +416,10 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False # save dialogue_info and clear mem return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \ - task_success['total_return'], task_success['turns'], task_success['avg_actions'], task_success, \ + task_success['total_return'], task_success['turns'], task_success['avg_actions'], task_success, \ np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \ np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \ - np.average(task_success['total_offer_acts']), np.average( - task_success['total_recommend_acts']) + np.average(task_success['total_offer_acts']), np.average(task_success['total_recommend_acts']) def model_downloader(download_dir, model_path): diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index aa71884703ffc51be528db177f02e3f22e2ee89d..726079d1c2b04c304bfd7055d48b1b4ae4905856 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from pprint import pprint from convlab.util.file_util import cached_path import shutil -# from sentence_transformers import SentenceTransformer, util +from sentence_transformers import SentenceTransformer, util import torch from tqdm import tqdm diff --git a/examples/agent_examples/test_heckmi_PPO.py b/examples/agent_examples/test_heckmi_PPO.py deleted file mode 100755 index 39b5a79f898a8be8178a9fbac504d674168a2cd9..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_PPO.py +++ /dev/null @@ -1,53 +0,0 @@ - -from convlab.dst.rule.multiwoz import RuleDST -from convlab.policy.ppo import PPO -from convlab.policy.rule.multiwoz import RulePolicy - -from convlab.dialog_agent import PipelineAgent -from convlab.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np -from datetime import datetime - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System - sys_nlu = None - sys_dst = RuleDST() - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - sys_nlg = None - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') - - # User Simulator - user_nlu = None - user_dst = None - user_policy = RulePolicy(character='usr') - user_nlg = None - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - now = datetime.now() - time = now.strftime("%Y%m%d%H%M%S") - name = f'PPO-Seed{seed}-{time}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_PPO_semantic.py b/examples/agent_examples/test_heckmi_PPO_semantic.py deleted file mode 100755 index 3fe3f536673016d8383ac598ff35dbb9c97b219f..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_PPO_semantic.py +++ /dev/null @@ -1,51 +0,0 @@ -# available NLU models - -from convlab2.dst.rule.multiwoz import RuleDST -from convlab2.policy.ppo import PPO -from convlab2.policy.rule.multiwoz import RulePolicy - -from convlab2.dialog_agent import PipelineAgent -from convlab2.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - sys_dst = RuleDST() - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - sys_nlg = None - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') - - # User Simulator, receiving system output - user_nlu = None - user_dst = None - user_policy = RulePolicy(character='usr') - user_nlg = None - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - name=f'TripPy-PPO-Seed{seed}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_SetSUMBT_PPO.py b/examples/agent_examples/test_heckmi_SetSUMBT_PPO.py deleted file mode 100755 index 0fd5232150e35833d26e780fdd0e3060f02bd0df..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_SetSUMBT_PPO.py +++ /dev/null @@ -1,60 +0,0 @@ -# available NLU models - -from convlab2.nlu.jointBERT.multiwoz import BERTNLU -from convlab2.dst.setsumbt.unified_format_data.Tracker import SetSUMBTTracker -from convlab2.policy.ppo import PPO -from convlab2.policy.rule.multiwoz import RulePolicy -from convlab2.nlg.template.multiwoz import TemplateNLG - -from convlab2.dialog_agent import PipelineAgent -from convlab2.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np -from datetime import datetime - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - sys_dst = SetSUMBTTracker(model_type='roberta', - model_path="https://cloud.cs.uni-duesseldorf.de/s/Yqkzz8NW3yoMWRk/download/setsumbt_end.zip") - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - sys_nlg = TemplateNLG(is_user=False) - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True) - - # User Simulator, receiving system output - user_nlu = None - #user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', - # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') - user_dst = None - user_policy = RulePolicy(character='usr') - #user_nlg = None - user_nlg = TemplateNLG(is_user=True) - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - now = datetime.now() - time = now.strftime("%Y%m%d%H%M%S") - name = f'SetSUMBT-PPO-Rule-Seed{seed}-{time}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_TripPyDummy_PPO.py b/examples/agent_examples/test_heckmi_TripPyDummy_PPO.py deleted file mode 100755 index 9be653273b1e96b4dd076614baa48532ca25752c..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_TripPyDummy_PPO.py +++ /dev/null @@ -1,58 +0,0 @@ -# available NLU models - -from convlab.nlu.jointBERT.multiwoz import BERTNLU -from convlab.dst.trippy.multiwoz import TRIPPY -from convlab.policy.ppo import PPO -from convlab.policy.rule.multiwoz import RulePolicy -from convlab.nlg.template.multiwoz import TemplateNLG - -from convlab.dialog_agent import PipelineAgent -from convlab.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np -from datetime import datetime - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - sys_dst = TRIPPY(model_type='roberta', - model_path='/gpfs/project/heckmi/trippy/multiwoz21/trippy.convlab3/results.42/checkpoint-28392', - nlu_path='/home/heckmi/models/bert_multiwoz_all_context.zip') - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/models/supervised') - sys_nlg = None - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True) - - # User Simulator, receiving system output - user_nlu = None - user_dst = None - user_policy = RulePolicy(character='usr') - user_nlg = None - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - now = datetime.now() - time = now.strftime("%Y%m%d%H%M%S") - name = f'TripPyDummy-PPO-Rule-Seed{seed}-{time}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_TripPyDummy_PPO2.py b/examples/agent_examples/test_heckmi_TripPyDummy_PPO2.py deleted file mode 100755 index 7a79225d4581d80abcb400bda1fbd94733d51c15..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_TripPyDummy_PPO2.py +++ /dev/null @@ -1,58 +0,0 @@ -# available NLU models - -from convlab2.nlu.jointBERT.multiwoz import BERTNLU -from convlab2.dst.trippy.multiwoz import TRIPPY -from convlab2.policy.ppo import PPO -from convlab2.policy.rule.multiwoz import RulePolicy -from convlab2.nlg.template.multiwoz import TemplateNLG - -from convlab2.dialog_agent import PipelineAgent -from convlab2.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np -from datetime import datetime - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - sys_dst = TRIPPY(model_type='roberta', - model_path='/home/heckmi/zim/checkpoints/trippy', - nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip') - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - sys_nlg = TemplateNLG(is_user=False) - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True) - - # User Simulator, receiving system output - user_nlu = None - user_dst = None - user_policy = RulePolicy(character='usr') - user_nlg = TemplateNLG(is_user=True) - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - now = datetime.now() - time = now.strftime("%Y%m%d%H%M%S") - name = f'TripPyDummy2-PPO-Rule-Seed{seed}-{time}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_TripPyR_PPO.py b/examples/agent_examples/test_heckmi_TripPyR_PPO.py deleted file mode 100755 index 97b7ea1ebb7c0aba4a78dd9ddac9dc0ba20c436a..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_TripPyR_PPO.py +++ /dev/null @@ -1,65 +0,0 @@ -# available NLU models - -from convlab2.nlu.jointBERT.multiwoz import BERTNLU -from convlab2.dst.trippyr.multiwoz import TRIPPYR -from convlab2.policy.ppo import PPO -from convlab2.policy.rule.multiwoz import RulePolicy -from convlab2.nlg.template.multiwoz import TemplateNLG - -from convlab2.dialog_agent import PipelineAgent -from convlab2.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np -from datetime import datetime - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - #sys_nlu = BERTNLU() - sys_dst = TRIPPYR(model_type='roberta', - model_path='/home/heckmi/zim/checkpoints/trippyr', - nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip', - emb_path='/home/heckmi/zim/checkpoints/trippyr', - fp16=True) - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - #sys_nlg = None - sys_nlg = TemplateNLG(is_user=False) - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True) - - # User Simulator, receiving system output - user_nlu = None - #user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', - # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') - user_dst = None - user_policy = RulePolicy(character='usr') - #user_nlg = None - user_nlg = TemplateNLG(is_user=True) - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - now = datetime.now() - time = now.strftime("%Y%m%d%H%M%S") - name = f'TripPyR-PPO-Rule-Seed{seed}-{time}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_TripPy_PPO.py b/examples/agent_examples/test_heckmi_TripPy_PPO.py deleted file mode 100755 index cf1bd546cb8ff8c354278e8be065556e7f18ac33..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_TripPy_PPO.py +++ /dev/null @@ -1,66 +0,0 @@ -# available NLU models - -from convlab.nlu.jointBERT.multiwoz import BERTNLU -from convlab.dst.trippy.multiwoz import TRIPPY -from convlab.policy.ppo import PPO -from convlab.policy.rule.multiwoz import RulePolicy -from convlab.nlg.template.multiwoz import TemplateNLG - -from convlab.dialog_agent import PipelineAgent -from convlab.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np -from datetime import datetime - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - #sys_nlu = BERTNLU() - #sys_dst = TRIPPY(model_type='roberta', - # model_path='/home/heckmi/gcp/roberta_corrected/20200805-155426803/results.42', - # nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip') - sys_dst = TRIPPY(model_type='roberta', - model_path='/home/heckmi/zim/checkpoints/trippy', - nlu_path='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip') - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - #sys_nlg = None - sys_nlg = TemplateNLG(is_user=False) - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True) - - # User Simulator, receiving system output - user_nlu = None - #user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', - # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') - user_dst = None - user_policy = RulePolicy(character='usr') - #user_nlg = None - user_nlg = TemplateNLG(is_user=True) - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - now = datetime.now() - time = now.strftime("%Y%m%d%H%M%S") - name = f'TripPy-PPO-Rule-Seed{seed}-{time}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues) diff --git a/examples/agent_examples/test_heckmi_TripPy_PPO_semantic.py b/examples/agent_examples/test_heckmi_TripPy_PPO_semantic.py deleted file mode 100755 index 3fe3f536673016d8383ac598ff35dbb9c97b219f..0000000000000000000000000000000000000000 --- a/examples/agent_examples/test_heckmi_TripPy_PPO_semantic.py +++ /dev/null @@ -1,51 +0,0 @@ -# available NLU models - -from convlab2.dst.rule.multiwoz import RuleDST -from convlab2.policy.ppo import PPO -from convlab2.policy.rule.multiwoz import RulePolicy - -from convlab2.dialog_agent import PipelineAgent -from convlab2.util.analysis_tool.analyzer import Analyzer - -import random -import numpy as np - -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - - -def set_seed(r_seed): - random.seed(r_seed) - np.random.seed(r_seed) - - -def test_end2end(seed=20200202, n_dialogues=1000): - - # Dialogue System, receiving user (simulator) output - sys_nlu = None - sys_dst = RuleDST() - sys_policy = PPO(False, seed=seed) - sys_policy.load('/home/heckmi/gcp/supervised') - sys_nlg = None - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') - - # User Simulator, receiving system output - user_nlu = None - user_dst = None - user_policy = RulePolicy(character='usr') - user_nlg = None - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') - - analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') - - set_seed(seed) - name=f'TripPy-PPO-Seed{seed}' - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=name, total_dialog=n_dialogues) - -if __name__ == '__main__': - # Get arguments - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--seed', help='Seed', default=20200202, type=int) - parser.add_argument('--n_dialogues', help='Number of eval dialogues', default=1000, type=int) - args = parser.parse_args() - - test_end2end(seed=args.seed, n_dialogues=args.n_dialogues)