diff --git a/.gitignore b/.gitignore index 364e64ff7f4e3e260b5ffd3f8843519d251edbff..6012819da84535f1108916153f3116b3a609ba07 100644 --- a/.gitignore +++ b/.gitignore @@ -66,7 +66,8 @@ convlab/nlu/jointBERT_new/**/output/ convlab/nlu/milu/09* convlab/nlu/jointBERT/multiwoz/configs/multiwoz_new_usr_context.json convlab/nlu/milu/multiwoz/configs/system_without_context.jsonnet -convlab/nlu/milu/multiwoz/configs/user_without_context.jsonnet +convlab/nlu/milu/multiwoz/configs/user_without_context.jsonnet\ +*.pkl # test script *_test.py @@ -87,7 +88,6 @@ dist convlab.egg-info # configs - *experiment* *pretrained_models* .ipynb_checkpoints diff --git a/convlab/dialog_agent/agent.py b/convlab/dialog_agent/agent.py index 0815fad0abe59fb72c0c0a3bfb59b0f6ac140d67..79f61e2b18f04e702c690a33cdf313656b2615c6 100755 --- a/convlab/dialog_agent/agent.py +++ b/convlab/dialog_agent/agent.py @@ -7,6 +7,7 @@ from convlab.policy import Policy from convlab.nlg import NLG from copy import deepcopy import time +import pdb from pprint import pprint @@ -63,7 +64,7 @@ class PipelineAgent(Agent): ===== ===== ====== === == === """ - def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str, return_semantic_acts=False): + def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str): """The constructor of PipelineAgent class. Here are some special combination cases: @@ -94,7 +95,7 @@ class PipelineAgent(Agent): self.dst = dst self.policy = policy self.nlg = nlg - self.return_semantic_acts = return_semantic_acts + self.init_session() self.agent_saves = [] self.history = [] @@ -151,6 +152,7 @@ class PipelineAgent(Agent): self.input_action = self.nlu.predict( observation, context=[x[1] for x in self.history[:-1]]) + # print("system semantic action: ", self.input_action) else: self.input_action = observation self.input_action_eval = observation @@ -186,7 +188,7 @@ class PipelineAgent(Agent): if type(self.output_action) == list: for intent, domain, slot, value in self.output_action: - if intent == "book": + if intent.lower() == "book": self.dst.state['booked'][domain] = [{slot: value}] else: self.dst.state['user_action'] = self.output_action @@ -196,8 +198,6 @@ class PipelineAgent(Agent): self.history.append([self.name, model_response]) self.turn += 1 - if self.return_semantic_acts: - return self.output_action self.agent_saves.append(self.save_info()) return model_response diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index 508d06b56080f72146d65a98286a0fd099f9b4c2..a8915301153f8e824e0a2ed91cd6eb9cd34e2605 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -49,6 +49,9 @@ class Environment(): dialog_act = self.sys_nlu.predict( observation) if self.sys_nlu else observation self.sys_dst.state['user_action'] = dialog_act + self.sys_dst.state['history'].append(["sys", model_response]) + self.sys_dst.state['history'].append(["user", observation]) + state = self.sys_dst.update(dialog_act) self.sys_dst.state['history'].append(["sys", model_response]) self.sys_dst.state['history'].append(["usr", observation]) diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index ea099442ddd18d0cd36a79db13b1f47788eb4fd4..d250f29e0186732b27fb284273d9f9ff4f166d2f 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -20,6 +20,7 @@ import os from shutil import copy2 as copy import json from copy import deepcopy +import pdb import torch import transformers @@ -34,6 +35,7 @@ from convlab.dst.setsumbt.modeling import training from convlab.dst.setsumbt.dataset import ontology as embeddings from convlab.dst.setsumbt.utils import get_args, update_args from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble +from convlab.util.custom_util import model_downloader # Available model @@ -55,6 +57,23 @@ def main(args=None, config=None): # Set up output directory OUTPUT_DIR = args.output_dir + + # Download model if needed + if not os.path.exists(OUTPUT_DIR): + # Get path /.../convlab/dst/setsumbt/multiwoz/models + download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + download_path = os.path.join(download_path, 'models') + if not os.path.exists(download_path): + os.mkdir(download_path) + model_downloader(download_path, OUTPUT_DIR) + # Downloadable model path format http://.../model_name.zip + OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '') + OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR) + + args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) + args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) + os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) + if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) os.mkdir(os.path.join(OUTPUT_DIR, 'database')) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index 6b620247fd4a36223fbed8c46c54615f7c69da98..eca7f1749369f9569d6b923312a93cd317e0701c 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -61,8 +61,8 @@ class SetSUMBTTracker(DST): if not os.path.exists(download_path): os.mkdir(download_path) model_downloader(download_path, self.model_path) - # Downloadable model path format http://.../setsumbt_model_name.zip - self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') + # Downloadable model path format http://.../model_name.zip + self.model_path = self.model_path.split('/')[-1].replace('.zip', '') self.model_path = os.path.join(download_path, self.model_path) # Select model type based on the encoder diff --git a/convlab/dst/trippy/README.md b/convlab/dst/trippy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..178c578287152c894c8727156a8db6e3c255b549 --- /dev/null +++ b/convlab/dst/trippy/README.md @@ -0,0 +1,59 @@ +# Introduction + +This is the TripPy DST module for ConvLab-3. + +## Supported encoders + +* RoBERTa +* BERT (full support w.i.p.) +* ELECTRA (full support w.i.p.) + +## Supported datasets + +* MultiWOZ 2.X +* Unified Data Format + +## Requirements + +transformers (tested: 4.18.0) +torch (tested: 1.8.0) + +# Parameters + +``` +model_type # Default: "roberta", Type of the model (Supported: "roberta", "bert", "electra") +model_name # Default: "roberta-base", Name of the model (Use -h to print a list of names) +model_path # Path to a model checkpoint +dataset_name # Default: "multiwoz21", Name of the dataset the model was trained on and/or is being applied to +local_files_only # Default: False, Set to True to load local files only. Useful for offline systems +nlu_usr_config # Path to a NLU config file. Only needed for internal evaluation +nlu_sys_config # Path to a NLU config file. Only needed for internal evaluation +nlu_usr_path # Path to a NLU model file. Only needed for internal evaluation +nlu_sys_path # Path to a NLU model file. Only needed for internal evaluation +no_eval # Default: True, Set to True if internal evaluation should be conducted +no_history # Default: False, Set to True if dialogue history should be omitted during inference +``` + +# Training + +TripPy can easily be trained for the abovementioned supported datasets using the original code in the official [TripPy repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). Simply clone the code and run the appropriate DO.* script to train a TripPy DST. After training, set model_path to the preferred checkpoint to use TripPy in ConvLab-3. + +# Training and evaluation with PPO policy + +Switch to the directory: +``` +cd ../../policy/ppo +``` + +Edit trippy_config.json and trippy_config_eval.json accordingly, e.g., edit paths to model checkpoints. + +For training, run +``` +train.py --path trippy_config.json +``` + +For evaluation, set training epochs to 0. + +# Paper + +[TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking](https://aclanthology.org/2020.sigdial-1.4/) diff --git a/convlab/dst/trippy/__init__.py b/convlab/dst/trippy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb36114aeabe0859fdae1c4230c147a7a40305ff --- /dev/null +++ b/convlab/dst/trippy/__init__.py @@ -0,0 +1 @@ +from convlab.dst.trippy.tracker import TRIPPY diff --git a/convlab/dst/trippy/dataset_interfacer.py b/convlab/dst/trippy/dataset_interfacer.py new file mode 100644 index 0000000000000000000000000000000000000000..51a692bbaa50349a70053f05933a0cea1f8246c3 --- /dev/null +++ b/convlab/dst/trippy/dataset_interfacer.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import logging + + +class DatasetInterfacer(object): + _domain_map_trippy_to_udf = {} + _slot_map_trippy_to_udf = {} + _generic_referral = {} + + def __init__(self): + pass + + def map_trippy_to_udf(self, domain, slot): + d = self._domain_map_trippy_to_udf.get(domain, domain) + s = slot + if d in self._slot_map_trippy_to_udf: + s = self._slot_map_trippy_to_udf[d].get(slot, slot) + return d, s + + def get_generic_referral(self, domain, slot): + d, s = self.map_trippy_to_udf(domain, slot) + ref = "the %s %s" % (d, s) + if d in self._generic_referral: + ref = self._generic_referral[d].get(s, s) + return ref + + def normalize_values(self, text): + return text + + def normalize_text(self, text): + return text + + def normalize_prediction(self, domain, slot, value, predictions=None, config=None): + return value + + +class MultiwozInterfacer(DatasetInterfacer): + _slot_map_trippy_to_udf = { + 'hotel': { + 'pricerange': 'price range', + 'book_stay': 'book stay', + 'book_day': 'book day', + 'book_people': 'book people', + 'addr': 'address', + 'post': 'postcode', + 'price': 'price range', + 'people': 'book people' + }, + 'restaurant': { + 'pricerange': 'price range', + 'book_time': 'book time', + 'book_day': 'book day', + 'book_people': 'book people', + 'addr': 'address', + 'post': 'postcode', + 'price': 'price range', + 'people': 'book people' + }, + 'taxi': { + 'arriveBy': 'arrive by', + 'leaveAt': 'leave at', + 'arrive': 'arrive by', + 'leave': 'leave at', + 'car': 'type', + 'car type': 'type', + 'depart': 'departure', + 'dest': 'destination' + }, + 'train': { + 'arriveBy': 'arrive by', + 'leaveAt': 'leave at', + 'book_people': 'book people', + 'arrive': 'arrive by', + 'leave': 'leave at', + 'depart': 'departure', + 'dest': 'destination', + 'id': 'train id', + 'people': 'book people', + 'time': 'duration', + 'ticket': 'price', + 'trainid': 'train id' + }, + 'attraction': { + 'post': 'postcode', + 'addr': 'address', + 'fee': 'entrance fee', + 'price': 'entrance fee' + }, + 'general': {}, + 'hospital': { + 'post': 'postcode', + 'addr': 'address' + }, + 'police': { + 'post': 'postcode', + 'addr': 'address' + } + } + + _generic_referral = { + 'hotel': { + 'name': 'the hotel', + 'area': 'same area as the hotel', + 'price range': 'in the same price range as the hotel' + }, + 'restaurant': { + 'name': 'the restaurant', + 'area': 'same area as the restaurant', + 'price range': 'in the same price range as the restaurant' + }, + 'attraction': { + 'name': 'the attraction', + 'area': 'same area as the attraction' + } + } + + def normalize_values(self, text): + text = text.lower() + text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"} + text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters + text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe + if text in text_to_num: + text = text_to_num[text] + return text + + def normalize_text(self, text): + norm_text = text.lower() + #norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference + norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0]) + return norm_text + + def normalize_prediction(self, domain, slot, value, predictions=None, class_predictions=None, config=None): + v = value + if domain == 'hotel' and slot == 'type': + # Map Boolean predictions to regular predictions. + v = "hotel" if value == "yes" else value + v = "guesthouse" if value == "no" else value + # HOTFIX: Avoid overprediction of hotel type caused by ambiguous rule based user simulator NLG. + if predictions['hotel-name'] != 'none': + v = 'none' + if config.dst_class_types[class_predictions['hotel-none']] == 'request': + v = 'none' + return v + + +DATASET_INTERFACERS = { + 'multiwoz21': MultiwozInterfacer() +} + + +def create_dataset_interfacer(dataset_name="multiwoz21"): + if dataset_name in DATASET_INTERFACERS: + return DATASET_INTERFACERS[dataset_name] + else: + logging.warn("You attempt to create a dataset interfacer for an unknown dataset '%s'. Creating generic dataset interfacer." % (dataset_name)) + return DatasetInterfacer() + + diff --git a/convlab/dst/trippy/modeling_dst.py b/convlab/dst/trippy/modeling_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..2828d17ed1e97ebb60b74ab999c28efb1e7bfa88 --- /dev/null +++ b/convlab/dst/trippy/modeling_dst.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import (BertModel, BertPreTrainedModel, + RobertaModel, RobertaPreTrainedModel, + ElectraModel, ElectraPreTrainedModel) + +PARENT_CLASSES = { + 'bert': BertPreTrainedModel, + 'roberta': RobertaPreTrainedModel, + 'electra': ElectraPreTrainedModel +} + +MODEL_CLASSES = { + BertPreTrainedModel: BertModel, + RobertaPreTrainedModel: RobertaModel, + ElectraPreTrainedModel: ElectraModel +} + + +class ElectraPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +def TransformerForDST(parent_name): + if parent_name not in PARENT_CLASSES: + raise ValueError("Unknown model %s" % (parent_name)) + + class TransformerForDST(PARENT_CLASSES[parent_name]): + def __init__(self, config): + assert config.model_type in PARENT_CLASSES + assert self.__class__.__bases__[0] in MODEL_CLASSES + super(TransformerForDST, self).__init__(config) + self.model_type = config.model_type + self.slot_list = config.dst_slot_list + self.class_types = config.dst_class_types + self.class_labels = config.dst_class_labels + self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable + self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable + self.stack_token_logits = config.dst_stack_token_logits + self.class_aux_feats_inform = config.dst_class_aux_feats_inform + self.class_aux_feats_ds = config.dst_class_aux_feats_ds + self.class_loss_ratio = config.dst_class_loss_ratio + + # Only use refer loss if refer class is present in dataset. + if 'refer' in self.class_types: + self.refer_index = self.class_types.index('refer') + else: + self.refer_index = -1 + + # Make sure this module has the same name as in the pretrained checkpoint you want to load! + self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config)) + if self.model_type == "electra": + self.pooler = ElectraPooler(config) + + self.dropout = nn.Dropout(config.dst_dropout_rate) + self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) + + if self.class_aux_feats_inform: + self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + if self.class_aux_feats_ds: + self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + + aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 + + for slot in self.slot_list: + self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) + self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) + self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) + + self.init_weights() + + def forward(self, + input_ids, + input_mask=None, + segment_ids=None, + position_ids=None, + head_mask=None, + start_pos=None, + end_pos=None, + inform_slot_id=None, + refer_id=None, + class_label_id=None, + diag_state=None): + outputs = getattr(self, self.model_type)( + input_ids, + attention_mask=input_mask, + token_type_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask + ) + + sequence_output = outputs[0] + if self.model_type == "electra": + pooled_output = self.pooler(sequence_output) + else: + pooled_output = outputs[1] + + sequence_output = self.dropout(sequence_output) + pooled_output = self.dropout(pooled_output) + + if inform_slot_id is not None: + inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() + if diag_state is not None: + diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) + + total_loss = 0 + per_slot_per_example_loss = {} + per_slot_class_logits = {} + per_slot_start_logits = {} + per_slot_end_logits = {} + per_slot_refer_logits = {} + for slot in self.slot_list: + if self.class_aux_feats_inform and self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) + elif self.class_aux_feats_inform: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) + elif self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) + else: + pooled_output_aux = pooled_output + class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) + + token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) + start_logits, end_logits = token_logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) + + per_slot_class_logits[slot] = class_logits + per_slot_start_logits[slot] = start_logits + per_slot_end_logits[slot] = end_logits + per_slot_refer_logits[slot] = refer_logits + + # If there are no labels, don't compute loss + if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: + # If we are on multi-GPU, split add a dimension + if len(start_pos[slot].size()) > 1: + start_pos[slot] = start_pos[slot].squeeze(-1) + if len(end_pos[slot].size()) > 1: + end_pos[slot] = end_pos[slot].squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) # This is a single index + start_pos[slot].clamp_(0, ignored_index) + end_pos[slot].clamp_(0, ignored_index) + + class_loss_fct = CrossEntropyLoss(reduction='none') + token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) + refer_loss_fct = CrossEntropyLoss(reduction='none') + + if not self.stack_token_logits: + start_loss = token_loss_fct(start_logits, start_pos[slot]) + end_loss = token_loss_fct(end_logits, end_pos[slot]) + else: + start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot]) + end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot]) + + token_loss = (start_loss + end_loss) / 2.0 + + token_is_pointable = (start_pos[slot] > 0).float() + if not self.token_loss_for_nonpointable: + token_loss *= token_is_pointable + + refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) + token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() + if not self.refer_loss_for_nonpointable: + refer_loss *= token_is_referrable + + class_loss = class_loss_fct(class_logits, class_label_id[slot]) + + if self.refer_index > -1: + per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + else: + per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + + total_loss += per_example_loss.sum() + per_slot_per_example_loss[slot] = per_example_loss + + # add hidden states and attention if they are here + outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:] + + return outputs + + return TransformerForDST diff --git a/convlab/dst/trippy/tracker.py b/convlab/dst/trippy/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..b0470266b2ce2c7d5cfba29d0270972b0c7cfa78 --- /dev/null +++ b/convlab/dst/trippy/tracker.py @@ -0,0 +1,544 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import json +import copy +import logging + +import torch +from transformers import (BertConfig, BertTokenizer, + RobertaConfig, RobertaTokenizer, + ElectraConfig, ElectraTokenizer) + +from convlab.dst.dst import DST +from convlab.dst.trippy.modeling_dst import (TransformerForDST) +from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer) +from convlab.util import relative_import_module_from_unified_datasets + +MODEL_CLASSES = { + 'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer), + 'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer), + 'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer), +} + + +class TRIPPY(DST): + def print_header(self): + logging.info(" _________ ________ ___ ________ ________ ___ ___ ") + logging.info("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /|") + logging.info("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /") + logging.info(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / / ") + logging.info(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / / ") + logging.info(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / ") + logging.info(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / ") + logging.info(" (c) 2022 Heinrich Heine University \|___|/ ") + logging.info("") + + def print_dialog(self, hst): + logging.info("Dialogue %s, turn %s:" % (self.global_diag_cnt, self.global_turn_cnt)) + for utt in hst[:-2]: + logging.info(" \033[92m%s\033[0m" % (utt)) + if len(hst) > 1: + logging.info(" %s" % (hst[-2])) + logging.info(" %s" % (hst[-1])) + + def print_inform_memory(self, inform_mem): + logging.info("Inform memory:") + is_all_none = True + for s in inform_mem: + if inform_mem[s] != 'none': + logging.info(" %s = %s" % (s, inform_mem[s])) + is_all_none = False + if is_all_none: + logging.info(" -") + + def eval_user_acts(self, user_act, user_acts): + logging.info("User acts:") + for ua in user_acts: + if ua not in user_act: + logging.info(" \033[33m%s\033[0m" % (ua)) + else: + logging.info(" \033[92m%s\033[0m" % (ua)) + for ua in user_act: + if ua not in user_acts: + logging.info(" \033[91m%s\033[0m" % (ua)) + + def eval_dialog_state(self, state_updates, new_belief_state): + logging.info("Dialogue state:") + for d in self.gt_belief_state: + logging.info(" %s:" % (d)) + for s in new_belief_state[d]: + is_printed = False + is_updated = False + if state_updates[d][s] > 0: + is_updated = True + log_str = "" + if is_updated: + log_str += "\033[3m" + if new_belief_state[d][s] != self.gt_belief_state[d][s]: + self.global_eval_stats[d][s]['FP'] += 1 + if self.gt_belief_state[d][s] == '': + log_str += " \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]) + else: + log_str += " \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]) + self.global_eval_stats[d][s]['FN'] += 1 + is_printed = True + elif new_belief_state[d][s] != '': + log_str += " \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]) + self.global_eval_stats[d][s]['TP'] += 1 + is_printed = True + if is_updated: + log_str += " (%s)" % (self.config.dst_class_types[state_updates[d][s]]) + logging.info(log_str) + elif is_printed: + logging.info(log_str) + + def eval_print_stats(self): + logging.info("Statistics:") + for d in self.global_eval_stats: + for s in self.global_eval_stats[d]: + TP = self.global_eval_stats[d][s]['TP'] + FP = self.global_eval_stats[d][s]['FP'] + FN = self.global_eval_stats[d][s]['FN'] + prec = TP / ( TP + FP + 1e-8) + rec = TP / ( TP + FN + 1e-8) + f1 = 2 * ((prec * rec) / (prec + rec + 1e-8)) + logging.info(" %s %s Recall: %.2f, Precision: %.2f, F1: %.2f" % (d, s, rec, prec, f1)) + + def __init__(self, model_type="roberta", + model_name="roberta-base", + model_path="", + dataset_name="multiwoz21", + local_files_only=False, + nlu_usr_config="", + nlu_sys_config="", + nlu_usr_path="", + nlu_sys_path="", + no_eval=True, + no_history=False): + super(TRIPPY, self).__init__() + + self.print_header() + + self.model_type = model_type.lower() + self.model_name = model_name.lower() + self.model_path = model_path + self.local_files_only = local_files_only + self.nlu_usr_config = nlu_usr_config + self.nlu_sys_config = nlu_sys_config + self.nlu_usr_path = nlu_usr_path + self.nlu_sys_path = nlu_sys_path + self.dataset_name = dataset_name + self.no_eval = no_eval + self.no_history = no_history + + assert self.model_type in ['roberta'] # TODO: ensure proper behavior for 'bert', 'electra' + assert self.dataset_name in ['multiwoz21', 'multiwoz22', 'multiwoz23'] # TODO: ensure proper behavior for other datasets + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + _ontology = relative_import_module_from_unified_datasets(self.dataset_name, 'preprocess.py', 'ontology') + self.template_state = _ontology['state'] + + self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] + self.config = self.config_class.from_pretrained(self.model_path, local_files_only=self.local_files_only) + + self.dataset_interfacer = create_dataset_interfacer(dataset_name) + + # For internal evaluation only + self.nlu_usr = None + self.nlu_sys = None + self.global_eval_stats = copy.deepcopy(self.template_state) + for d in self.global_eval_stats: + for s in self.global_eval_stats[d]: + self.global_eval_stats[d][s] = {'TP': 0, 'FP': 0, 'FN': 0} + self.global_diag_cnt = -3 + self.global_turn_cnt = -1 + if not self.no_eval: + global BERTNLU + from convlab.nlu.jointBERT.unified_datasets import BERTNLU + self.load_nlu() + + # For semantic action pipelines only + self.nlg_usr = None + self.nlg_sys = None + + logging.info("DST INIT VARS: %s" % (vars(self))) + + self.load_weights() + + def load_weights(self): + self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name, local_files_only=self.local_files_only) # TODO: do_lower_case? + self.model = self.model_class.from_pretrained(self.model_path, config=self.config, local_files_only=self.local_files_only) + self.model.to(self.device) + self.model.eval() + logging.info("DST model weights loaded from %s" % (self.model_path)) + + def load_nlu(self): + """ Loads NLUs for internal evaluation """ + # NLU for system utterances is used in case the policy does or can not provide semantic actions. + # The sole purpose of this is to fill the inform memory. + # NLU for user utterances is used in case the user simulator does or can not provide semantic actions. + # The sole purpose of this is to enable internal DST evaluation. + if self.nlu_usr_config == self.nlu_sys_config and \ + self.nlu_usr_path == self.nlu_sys_path: + self.nlu_usr = BERTNLU(mode="all", config_file=self.nlu_usr_config, model_file=self.nlu_usr_path) + self.nlu_sys = self.nlu_usr + else: + self.nlu_usr = BERTNLU(mode="user", config_file=self.nlu_usr_config, model_file=self.nlu_usr_path) + self.nlu_sys = BERTNLU(mode="sys", config_file=self.nlu_sys_config, model_file=self.nlu_sys_path) + logging.info("DST user NLU model weights loaded from %s" % (self.nlu_usr_path)) + logging.info("DST sys NLU model weights loaded from %s" % (self.nlu_sys_path)) + + def load_nlg(self): + if self.dataset_name in ['multiwoz21', 'multiwoz22', 'multiwoz23']: + from convlab.nlg.template.multiwoz import TemplateNLG + self.nlg_usr = TemplateNLG(is_user=True) + self.nlg_sys = TemplateNLG(is_user=False) + logging.info("DST template NLG loaded for dataset %s" % (self.dataset_name)) + else: + raise Exception("DST no NLG for dataset %s available." % (self.dataset_name)) + + def init_session(self): + # Initialise empty state + self.state = {'user_action': [], + 'system_action': [], + 'belief_state': {}, + 'booked': {}, + 'request_state': {}, + 'terminated': False, + 'history': []} + self.state['belief_state'] = copy.deepcopy(self.template_state) + self.history = [] + self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} + self.gt_belief_state = copy.deepcopy(self.template_state) + self.global_diag_cnt += 1 + self.global_turn_cnt = -1 + + def update_gt_belief_state(self, user_act): + for intent, domain, slot, value in user_act: + if domain == 'police': + continue + if intent == 'inform': + if slot == 'none' or slot == '': + continue + if slot in self.gt_belief_state[domain]: + self.gt_belief_state[domain][slot] = value + + def update(self, user_act=''): + prev_state = self.state + + if not self.no_eval: + logging.info("-" * 40) + + if self.no_history: + self.history = [] + self.history.append(['sys', self.get_text(prev_state['history'][-2][1], is_user=False, normalize=True)]) + self.history.append(['user', self.get_text(prev_state['history'][-1][1], is_user=True, normalize=True)]) + + self.global_turn_cnt += 1 + if not self.no_eval: + self.print_dialog(self.history) + + # --- Get inform memory and auxiliary features --- + + # system_action is a list of semantic system actions. + # TripPy uses system actions to fill the inform memory. + # End-to-end policies like Lava produce plain text instead. + # If system_action is plain text, get acts using NLU. + if isinstance(prev_state['system_action'], str): + s_acts = self.get_acts(prev_state['system_action']) + elif isinstance(prev_state['system_action'], list): + s_acts = prev_state['system_action'] + else: + raise Exception('Unknown format for system action:', prev_state['system_action']) + + if not self.no_eval: + # user_action is a list of semantic user actions if no NLG is used + # in the pipeline, otherwise user_action is plain text. + # TripPy uses user actions to perform internal DST evaluation. + # If user_action is plain text, get acts using NLU. + if isinstance(prev_state['user_action'], str): + u_acts = self.get_acts(prev_state['user_action'], is_user=True) + elif isinstance(prev_state['user_action'], list): + u_acts = prev_state['user_action'] # This is the same as user_act + else: + raise Exception('Unknown format for user action:', prev_state['user_action']) + + # Fill the inform memory. + inform_aux, inform_mem = self.get_inform_aux(s_acts) + if not self.no_eval: + self.print_inform_memory(inform_mem) + + # --- Tokenize dialogue context and feed DST model --- + + used_ds_aux = None if not self.config.dst_class_aux_feats_ds else self.ds_aux + used_inform_aux = None if not self.config.dst_class_aux_feats_inform else inform_aux + features = self.get_features(self.history, ds_aux=used_ds_aux, inform_aux=used_inform_aux) + pred_states, pred_classes = self.predict(features, inform_mem) + + # --- Update ConvLab-style dialogue state --- + + new_belief_state = copy.deepcopy(prev_state['belief_state']) + user_acts = [] + for state, value in pred_states.items(): + value = self.dataset_interfacer.normalize_values(value) + domain, slot = state.split('-', 1) + value = self.dataset_interfacer.normalize_prediction(domain, slot, value, + predictions=pred_states, + class_predictions=pred_classes, + config=self.config) + if value == 'none': + continue + if slot in new_belief_state[domain]: + new_belief_state[domain][slot] = value + user_acts.append(['inform', domain, slot, value]) + else: + raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) + + if not self.no_eval: + self.update_gt_belief_state(u_acts) # For evaluation + + # BELIEF STATE UPDATE + new_state = copy.deepcopy(dict(prev_state)) + new_state['belief_state'] = new_belief_state # TripPy + + state_updates = {} + for cl in pred_classes: + cl_d, cl_s = cl.split('-') + # Some reformatting for the evaluation further down + if cl_d not in state_updates: + state_updates[cl_d] = {} + state_updates[cl_d][cl_s] = pred_classes[cl] + # We care only about the requestable slots here + if self.config.dst_class_types[pred_classes[cl]] != 'request': + continue + if cl_d != 'general' and cl_s == 'none': + user_acts.append(['inform', cl_d, '', '']) + elif cl_d == 'general': + user_acts.append([cl_s, 'general', '', '']) + else: + user_acts.append(['request', cl_d, cl_s, '']) + + # USER ACTS UPDATE + new_state['user_action'] = user_acts # TripPy + + if not self.no_eval: + self.eval_user_acts(u_acts, user_acts) + self.eval_dialog_state(state_updates, new_belief_state) + + self.state = new_state + + # Print eval statistics + if self.state['terminated'] and not self.no_eval: + logging.info("Booked: %s" % self.state['booked']) + self.eval_print_stats() + logging.info("=" * 10 + "End of the dialogue" + "=" * 10) + self.ds_aux = self.update_ds_aux(self.state['belief_state'], pred_states) + + return self.state + + def predict(self, features, inform_mem): + with torch.no_grad(): + outputs = self.model(input_ids=features['input_ids'], + input_mask=features['attention_mask'], + inform_slot_id=features['inform_slot_id'], + diag_state=features['diag_state']) + + input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! + + per_slot_class_logits = outputs[2] + per_slot_start_logits = outputs[3] + per_slot_end_logits = outputs[4] + per_slot_refer_logits = outputs[5] + + # TODO: maybe add assert to check that batch=1 + + predictions = {} + class_predictions = {} + + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + slot_udf = "%s-%s" % (self.dataset_interfacer.map_trippy_to_udf(d, s)) + + predictions[slot_udf] = 'none' + class_predictions[slot_udf] = 0 + + class_logits = per_slot_class_logits[slot][0] + start_logits = per_slot_start_logits[slot][0] + end_logits = per_slot_end_logits[slot][0] + refer_logits = per_slot_refer_logits[slot][0] + + class_prediction = int(class_logits.argmax()) + start_prediction = int(start_logits.argmax()) + end_prediction = int(end_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if class_prediction == self.config.dst_class_types.index('dontcare'): + predictions[slot_udf] = 'dontcare' + elif class_prediction == self.config.dst_class_types.index('copy_value'): + predictions[slot_udf] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) + predictions[slot_udf] = re.sub("(^| )##", "", predictions[slot_udf]) + if "\u0120" in predictions[slot_udf]: + predictions[slot_udf] = re.sub(" ", "", predictions[slot_udf]) + predictions[slot_udf] = re.sub("\u0120", " ", predictions[slot_udf]) + predictions[slot_udf] = predictions[slot_udf].strip() + elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): + predictions[slot_udf] = "yes" # 'true' + elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): + predictions[slot_udf] = "no" # 'false' + elif class_prediction == self.config.dst_class_types.index('inform'): + predictions[slot_udf] = inform_mem[slot_udf] + # Referral case is handled below + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + slot_udf = "%s-%s" % (self.dataset_interfacer.map_trippy_to_udf(d, s)) + + class_logits = per_slot_class_logits[slot][0] + refer_logits = per_slot_refer_logits[slot][0] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): + # Only slots that have been mentioned before can be referred to. + # First try to resolve a reference within the same turn. (One can think of a situation + # where one slot is referred to in the same utterance. This phenomenon is however + # currently not properly covered in the training data label generation process) + # Then try to resolve a reference given the current dialogue state. + referred_slot = self.config.dst_slot_list[refer_prediction - 1] + referred_slot_d, referred_slot_s = referred_slot.split('-') + referred_slot_d, referred_slot_s = self.dataset_interfacer.map_trippy_to_udf(referred_slot_d, referred_slot_s) + referred_slot_udf = "%s-%s" % (referred_slot_d, referred_slot_s) + predictions[slot_udf] = predictions[referred_slot_udf] + if predictions[slot_udf] == 'none': + if self.state['belief_state'][referred_slot_d][referred_slot_s] != '': + predictions[slot_udf] = self.state['belief_state'][referred_slot_d][referred_slot_s] + if predictions[slot_udf] == 'none': + ref_slot = self.config.dst_slot_list[refer_prediction - 1] + ref_slot_d, ref_slot_s = ref_slot.split('-') + generic_ref = self.dataset_interfacer.get_generic_referral(ref_slot_d, ref_slot_s) + predictions[slot_udf] = generic_ref + + class_predictions[slot_udf] = class_prediction + + return predictions, class_predictions + + def get_features(self, context, ds_aux=None, inform_aux=None): + assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models + input_tokens = ['<s>'] # TODO: use tokenizer token names rather than strings + e_itr = 0 + for e_itr, e in enumerate(reversed(context)): + if e[1] not in ['null', '']: + input_tokens.append(e[1]) + if e_itr < 2: + input_tokens.append('</s> </s>') + if e_itr == 0: + input_tokens.append('</s> </s>') + input_tokens.append('</s>') + input_tokens = ' '.join(input_tokens) + + # TODO: delex sys utt currently not supported + features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length) + + input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device) + attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device) + features = {'input_ids': input_ids, + 'attention_mask': attention_mask, + 'inform_slot_id': inform_aux, + 'diag_state': ds_aux} + + return features + + def update_ds_aux(self, state, pred_states, terminated=False): + ds_aux = copy.deepcopy(self.ds_aux) + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + d_udf, s_udf = self.dataset_interfacer.map_trippy_to_udf(d, s) + slot_udf = "%s-%s" % (d_udf, s_udf) + if d_udf in state and s_udf in state[d_udf]: + ds_aux[slot][0] = int(state[d_udf][s_udf] != '') + else: + # Requestable slots are not found in the DS + ds_aux[slot][0] = int(pred_states[slot_udf] != 'none') + return ds_aux + + def get_inform_aux(self, state): + # Initialise auxiliary variables. + # For inform_aux, only the proper order of slots + # as defined in dst_slot_list is relevant, but not + # the actual slot names (as inform_aux will be + # converted into a simple binary list in the model) + inform_aux = {} + inform_mem = {} + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + d_udf, s_udf = self.dataset_interfacer.map_trippy_to_udf(d, s) + inform_aux["%s-%s" % (d_udf, s_udf)] = torch.tensor([0]).to(self.device) + inform_mem["%s-%s" % (d_udf, s_udf)] = 'none' + for e in state: + a, d, s, v = e + # TODO: offerbook needed? booked needed? + if a in ['inform', 'recommend', 'select', 'book', 'offerbook']: + slot = "%s-%s" % (d, s) + if slot in inform_aux: + inform_aux[slot][0] = 1 + inform_mem[slot] = self.dataset_interfacer.normalize_values(v) + return inform_aux, inform_mem + + def get_acts(self, act, is_user=False): + if isinstance(act, list): + return act + context = self.state['history'] + if context[-1][0] not in ['user', 'usr']: + raise Exception("Wrong order of utterances, check your input.") + system_context = [self.get_text(t) for s,t in context[:-2]] + user_context = [self.get_text(t, is_user=True) for s,t in context[:-1]] + if is_user: + if self.nlu_usr is None: + raise Exception("You attempt to convert semantic user actions into text, but no NLU module is loaded.") + acts = self.nlu_usr.predict(act, context=user_context) + else: + if self.nlu_sys is None: + raise Exception("You attempt to convert semantic system actions into text, but no NLU module is loaded.") + acts = self.nlu_sys.predict(act, context=system_context) + for act_itr in range(len(acts)): + acts[act_itr][-1] = self.dataset_interfacer.normalize_values(acts[act_itr][-1]) + return acts + + def get_text(self, act, is_user=False, normalize=False): + if act == 'null': + return 'null' + if not isinstance(act, list): + result = act + else: + if self.nlg_usr is None or self.nlg_sys is None: + logging.warn("You attempt to input semantic actions into TripPy, which expects text.") + logging.warn("Attempting to load NLG modules in order to convert actions into text.") + self.load_nlg() + if is_user: + result = self.nlg_usr.generate(act) + else: + result = self.nlg_sys.generate(act) + if normalize: + return self.dataset_interfacer.normalize_text(result) + else: + return result diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index c89361b2a84824198e516941b71806440a9ba3a5..cb6c8feb73aea6e4481a51fa0eb2466a6a07d1c6 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -3,6 +3,7 @@ import logging import re import numpy as np +import pdb from copy import deepcopy from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map @@ -158,6 +159,15 @@ class MultiWozEvaluator(Evaluator): list[intent, domain, slot, value] """ + new_acts = list() + for intent, domain, slot, value in da_turn: + if intent.lower() == 'book': + ref = [_value for _intent, _domain, _slot, _value in da_turn if _domain == domain and _intent.lower() == 'inform' and _slot.lower() == 'ref'] + ref = ref[0] if ref else '' + value = ref + new_acts.append([intent, domain, slot, value]) + da_turn = new_acts + da_turn = self._convert_action(da_turn) for intent, domain, slot, value in da_turn: diff --git a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json index 414f0b80b740d0c33568ba21bab5e13e822d95df..300369b525d3d5c4781ea00587f198e56657a37b 100755 --- a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json +++ b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json @@ -225,7 +225,7 @@ "Do you know the name of it ?", "can you give me the name of it ?" ], - "Price": [ + "Fee": [ "any specific price range to help narrow down available options ?", "What price range would you like ?", "what is your price range for that ?", @@ -363,42 +363,6 @@ "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ." ] }, - "Booking-Request": { - "Day": [ - "What day would you like your booking for ?", - "What day would you like that reservation ?", - "what day would you like the booking to be made for ?", - "What day would you like to book ?", - "Ok , what day would you like to make the reservation on ?" - ], - "Stay": [ - "How many nights will you be staying ?", - "And how many nights ?", - "for how many days ?", - "And for how many days ?", - "how many days would you like to stay ?", - "How many nights would you like to book it for ?", - "And what nights would you like me to reserve for you ?", - "How many nights are you wanting to stay ?", - "How many days will you be staying ?" - ], - "People": [ - "For how many people ?", - "How many people will be ?", - "How many people will be with you ?", - "How many people is the reservation for ?" - ], - "Time": [ - "Do you have a time preference ?", - "what time are you looking for a reservation at ?", - "For what time ?", - "What time would you like me to make your reservation ?", - "What time would you like the reservation for ?", - "what time should I make the reservation for ?", - "What time would you prefer ?", - "What time would you like the reservation for ?" - ] - }, "Hotel-Inform": { "Internet": [ "it has wifi .", @@ -697,6 +661,30 @@ "Do you need free parking ?", "Will you need parking while you 're there ?", "Will you be needing free parking ?" + ], + "Day": [ + "What day would you like your booking for ?", + "What day would you like that reservation ?", + "what day would you like the booking to be made for ?", + "What day would you like to book ?", + "Ok , what day would you like to make the reservation on ?" + ], + "Stay": [ + "How many nights will you be staying ?", + "And how many nights ?", + "for how many days ?", + "And for how many days ?", + "how many days would you like to stay ?", + "How many nights would you like to book it for ?", + "And what nights would you like me to reserve for you ?", + "How many nights are you wanting to stay ?", + "How many days will you be staying ?" + ], + "People": [ + "For how many people ?", + "How many people will be ?", + "How many people will be with you ?", + "How many people is the reservation for ?" ] }, "Restaurant-Inform": { @@ -918,6 +906,29 @@ "what is the name of the restaurant you are needing information on ?", "Do you know the name of the location ?", "Is there a certain restaurant you 're looking for ?" + ], + "Day": [ + "What day would you like your booking for ?", + "What day would you like that reservation ?", + "what day would you like the booking to be made for ?", + "What day would you like to book ?", + "Ok , what day would you like to make the reservation on ?" + ], + "People": [ + "For how many people ?", + "How many people will be ?", + "How many people will be with you ?", + "How many people is the reservation for ?" + ], + "Time": [ + "Do you have a time preference ?", + "what time are you looking for a reservation at ?", + "For what time ?", + "What time would you like me to make your reservation ?", + "What time would you like the reservation for ?", + "what time should I make the reservation for ?", + "What time would you prefer ?", + "What time would you like the reservation for ?" ] }, "Taxi-Inform": { @@ -1331,6 +1342,77 @@ "Is there a time you need to arrive by ?" ] }, + "Police-Inform": { + "Addr": [ + "it is located in #POLICE-INFORM-ADDR#", + "adress is #POLICE-INFORM-ADDR#", + "It is on #POLICE-INFORM-ADDR# .", + "their address in our system is listed as #POLICE-INFORM-ADDR# .", + "The address is #POLICE-INFORM-ADDR# .", + "it 's located at #POLICE-INFORM-ADDR# .", + "#POLICE-INFORM-ADDR# is the address", + "They are located at #POLICE-INFORM-ADDR# ." + ], + "Post": [ + "The postcode of the police is #POLICE-INFORM-POST# .", + "The post code is #POLICE-INFORM-POST# .", + "Its postcode is #POLICE-INFORM-POST# .", + "Their postcode is #POLICE-INFORM-POST# ." + ], + "Name": [ + "I think a fun place to visit is #POLICE-INFORM-NAME# .", + "#POLICE-INFORM-NAME# looks good .", + "#POLICE-INFORM-NAME# is available , would that work for you ?", + "we have #POLICE-INFORM-NAME# .", + "#POLICE-INFORM-NAME# is popular among visitors .", + "How about #POLICE-INFORM-NAME# ?", + "What about #POLICE-INFORM-NAME# ?", + "you might want to try the #POLICE-INFORM-NAME# ." + ], + "Phone": [ + "The police phone number is #POLICE-INFORM-PHONE# .", + "Here is the police phone number , #POLICE-INFORM-PHONE# ." + ] + }, + "Hospital-Inform": { + "Addr": [ + "it is located in #HOSPITAL-INFORM-ADDR#", + "adress is #HOSPITAL-INFORM-ADDR#", + "It is on #HOSPITAL-INFORM-ADDR# .", + "their address in our system is listed as #HOSPITAL-INFORM-ADDR# .", + "The address is #HOSPITAL-INFORM-ADDR# .", + "it 's located at #HOSPITAL-INFORM-ADDR# .", + "#HOSPITAL-INFORM-ADDR# is the address", + "They are located at #HOSPITAL-INFORM-ADDR# ." + ], + "Post": [ + "The postcode of the hospital is #HOSPITAL-INFORM-POST# .", + "The post code is #HOSPITAL-INFORM-POST# .", + "Its postcode is #HOSPITAL-INFORM-POST# .", + "Their postcode is #HOSPITAL-INFORM-POST# ." + ], + "Department": [ + "The department of the hospital is #HOSPITAL-INFORM-POST# .", + "The department is #HOSPITAL-INFORM-POST# .", + "Its department is #HOSPITAL-INFORM-POST# .", + "Their department is #HOSPITAL-INFORM-POST# ." + + ], + "Phone": [ + "The hospital phone number is #HOSPITAL-INFORM-PHONE# .", + "Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ." + ] + }, + "Hospital-Request": { + "Department": [ + "What is the name of the hospital department ?", + "What hospital department are you thinking about ?", + "I ' m sorry for the confusion , what hospital department are you interested in ?", + "What hospital department were you thinking of ?", + "Do you know the department of it ?", + "can you give me the department of it ?" + ] + }, "general-bye": { "none": [ "Thank you for using our services .", @@ -1378,4 +1460,4 @@ "You 're welcome . Have a good day !" ] } -} \ No newline at end of file +} diff --git a/convlab/nlg/template/multiwoz/nlg.py b/convlab/nlg/template/multiwoz/nlg.py index f83a6db4e2ad6f77f9bdc154cfda7bf2db0ff2c5..5f362ebbabd68ee0b6fab62c692caf0e8da436e9 100755 --- a/convlab/nlg/template/multiwoz/nlg.py +++ b/convlab/nlg/template/multiwoz/nlg.py @@ -31,33 +31,33 @@ def read_json(filename): # supported slot Slot2word = { - 'Fee': 'fee', + 'Fee': 'entrance fee', 'Addr': 'address', 'Area': 'area', - 'Stars': 'stars', - 'Internet': 'Internet', + 'Stars': 'number of stars', + 'Internet': 'internet', 'Department': 'department', 'Choice': 'choice', 'Ref': 'reference number', 'Food': 'food', 'Type': 'type', 'Price': 'price range', - 'Stay': 'stay', + 'Stay': 'length of the stay', 'Phone': 'phone number', 'Post': 'postcode', 'Day': 'day', 'Name': 'name', 'Car': 'car type', - 'Leave': 'leave', + 'Leave': 'departure time', 'Time': 'time', - 'Arrive': 'arrive', - 'Ticket': 'ticket', + 'Arrive': 'arrival time', + 'Ticket': 'ticket price', 'Depart': 'departure', - 'People': 'people', + 'People': 'number of people', 'Dest': 'destination', 'Parking': 'parking', - 'Open': 'open', - 'Id': 'Id', + 'Open': 'opening hours', + 'Id': 'id', # 'TrainID': 'TrainID' } @@ -271,6 +271,10 @@ class TemplateNLG(NLG): elif 'request' == intent[1]: for slot, value in slot_value_pairs: if dialog_act not in template or slot not in template[dialog_act]: + if dialog_act not in template: + print("WARNING (nlg.py): (User?: %s) dialog_act '%s' not in template!" % (self.is_user, dialog_act)) + else: + print("WARNING (nlg.py): (User?: %s) slot '%s' of dialog_act '%s' not in template!" % (self.is_user, slot, dialog_act)) sentence = 'What is the {} of {} ? '.format( slot.lower(), dialog_act.split('-')[0].lower()) sentences += self._add_random_noise(sentence) @@ -288,7 +292,7 @@ class TemplateNLG(NLG): value_lower = value.lower() if value in ["do nt care", "do n't care", "dontcare"]: sentence = 'I don\'t care about the {} of the {}'.format( - slot, dialog_act.split('-')[0]) + slot2word.get(slot, slot), dialog_act.split('-')[0]) elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any': # user have no preference, any choice is ok sentence = random.choice([ diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py index e25fbad1227c4b1f85ae2ae42a8ac899fa61d7b8..1373919e5861156c87a2ba14d6506d13e0842204 100755 --- a/convlab/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab/nlu/jointBERT/multiwoz/nlu.py @@ -74,7 +74,8 @@ class BERTNLU(NLU): for token in token_list: token = token.strip() self.nlp.tokenizer.add_special_case( - token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) + #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) + token, [{ORTH: token}]) logging.info("BERTNLU loaded") def predict(self, utterance, context=list()): diff --git a/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_sys_context3.json b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_sys_context3.json new file mode 100755 index 0000000000000000000000000000000000000000..dfbef5a39963f030ccf989276fcaf7efb8141cc1 --- /dev/null +++ b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_sys_context3.json @@ -0,0 +1,27 @@ +{ + "dataset_name": "multiwoz21", + "data_dir": "unified_datasets/data/multiwoz21/system/context_window_size_3", + "output_dir": "unified_datasets/output/multiwoz21/system/context_window_size_3", + "zipped_model_path": "unified_datasets/output/multiwoz21/system/context_window_size_3/bertnlu_unified_multiwoz21_system_context3.zip", + "log_dir": "unified_datasets/output/multiwoz21/system/context_window_size_3/log", + "DEVICE": "cuda:0", + "seed": 2019, + "cut_sen_len": 40, + "use_bert_tokenizer": true, + "context_window_size": 3, + "model": { + "finetune": true, + "context": true, + "context_grad": true, + "pretrained_weights": "bert-base-uncased", + "check_step": 1000, + "max_step": 10000, + "batch_size": 128, + "learning_rate": 1e-4, + "adam_epsilon": 1e-8, + "warmup_steps": 0, + "weight_decay": 0.0, + "dropout": 0.1, + "hidden_units": 1536 + } +} diff --git a/convlab/policy/lava/README.md b/convlab/policy/lava/README.md index bd06a73b9d4c2954338d448a7f7291652228fc59..9678d76bb7e41b9a8a6035dedfc7f40d1a73395a 100755 --- a/convlab/policy/lava/README.md +++ b/convlab/policy/lava/README.md @@ -1,74 +1,17 @@ ## LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization -Codebase for [LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization](https://), published as a long paper in COLING 2020. The code is developed based on the implementations of the [LaRL](https://arxiv.org/abs/1902.08858) paper. +ConvLab3 interface for [LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization](https://aclanthology.org/2020.coling-main.41/), published as a long paper in COLING 2020. -### Requirements - python 3 - pytorch - numpy - -### Data -The pre-processed MultiWoz 2.0 data is included in data.zip. Unzip the compressed file and access the data under **data/norm-multi-woz**. - -### Over structure: -The implementation of the models, as well as training and evaluation scripts are under **latent_dialog**. -The scripts for running the experiments are under **experiment_woz**. The trained models and evaluation results are under **experiment_woz/sys_config_log_model**. +To train a LAVA model, clone and follow instructions from the [original LAVA repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/lava-public). -There are 3 types of training to achieve the final model. +With a (pre-)trained LAVA model, it is possible to evaluate or perform online RL with ConvLab3 US by loading the lava module with -### Step 1: Unsupervised training (variational auto-encoding (VAE) task) -Given a dialogue response, the model is tasked to reproduce it via latent variables. With this task we aim to unsupervisedly capture generative factors of dialogue responses. +- from convlab.policy.lava.multiwoz import LAVA - - sl_cat_ae.py: train a VAE model using categorical latent variable - - sl_gauss_ae.py: train a VAE model using continuous (Gaussian) latent variable +and using it as the policy module in the ConvLab pipeline (NLG should be set to None). -### Step 2: Supervised training (response generation task) -The supervised training step of the variational encoder-decoder model could be done 4 different ways. -1. from scratch: +Code example can be found at +- ConvLab-3/examples/agent_examples/test_LAVA.py +A trained LAVA model can be found at https://huggingface.co/ConvLab/lava-policy-multiwoz21. - - sl_word: train a standard encoder decoder model using supervised learning (SL) - - sl_cat: train a latent action model with categorical latent variables using SL, - - sl_gauss: train a latent action model with continuous latent varaibles using SL, - -2. using the VAE models as pre-trained model (equivalent to LAVA_pt): - - - - finetune_cat_ae: use the VAE with categorical latent variables as weight initialization, and then fine-tune the model on response generation task - - finetune_gauss_ae: as above but with continuous latent variables - - Note: Fine-tuning can be set to be selective (only fine-tune encoder) or not (fine-tune the entire network) using the "selective_finetune" argument in config - -3. using the distribution of the VAE models to obtain informed prior (equivalent to LAVA_kl): - - - - actz_cat: initialized new encoder is combined with pre-trained VAE decoder and fine-tuned on response generation task. VAE encoder is used to obtain an informed prior of the target response and not trained further. - - actz_gauss: as above but with continuous latent variables - -4. or simultaneously from scrath with VAE task in a multi-task fashion (equivalent to LAVA_mt): - - - - mt_cat: train a model to optimize both auto-encoding and response generation in a multi-task fashion, using categorical latent variables - - mt_gauss: as above but with continuous latent variables - -No.1 and 4 can be directly trained without Step 1. No. 2 and 3 requires a pre-trained VAE model, given via a dictionary - - pretrained = {"2020-02-26-18-11-37-sl_cat_ae":100} - -### Step 3: Reinforcement Learning -The model can be further optimized with RL to maximize the dialogue success. - -Each script is used for: - - - reinforce_word: fine tune a pretrained model with word-level policy gradient (PG) - - reinforce_cat: fine tune a pretrained categorical latent action model with latent-level PG. - - reinforce_gauss: fine tune a pretrained gaussian latent action model with latent-level PG. - -The script takes a file containing list of test results from the SL step. - - f_in = "sys_config_log_model/test_files.lst" - - -### Checking the result -The evaluation result can be found at the bottom of the test_file.txt. We provide the best model in this repo. - -NOTE: when re-running the experiments some variance is to be expected in the numbers due to factors such as random seed and hardware specificiations. Some methods are more sensitive to this than others. diff --git a/convlab/policy/lava/multiwoz/latent_dialog/__init__.py b/convlab/policy/lava/multiwoz/latent_dialog/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cabc5d9c605965db3e69614595ab9bb891814c9 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/__init__.py @@ -0,0 +1,2 @@ +# @Time : 10/18/18 1:55 PM +# @Author : Tiancheng Zhao \ No newline at end of file diff --git a/convlab/policy/lava/multiwoz/latent_dialog/agent_task.py b/convlab/policy/lava/multiwoz/latent_dialog/agent_task.py new file mode 100644 index 0000000000000000000000000000000000000000..8da7af0e554a09436a80537efccacf5c2b7d0bd5 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/agent_task.py @@ -0,0 +1,115 @@ +import torch as th +import torch.nn as nn +import torch.optim as optim +import numpy as np +from convlab.policy.lava.multiwoz.latent_dialog.utils import LONG, FLOAT, Pack, get_detokenize +from convlab.policy.lava.multiwoz.latent_dialog.main import get_sent +from convlab.policy.lava.multiwoz.latent_dialog.data_loaders import BeliefDbDataLoaders +from collections import deque, namedtuple, defaultdict +import random +import pdb +import dill + + +class OfflineRlAgent(object): + def __init__(self, model, corpus, args, name, tune_pi_only): + self.model = model + self.corpus = corpus + self.args = args + self.name = name + self.raw_goal = None + self.vec_goals_list = None + self.logprobs = None + print("Do we only tune the policy: {}".format(tune_pi_only)) + self.opt = optim.SGD( + [p for n, p in self.model.named_parameters() if 'c2z' in n or not tune_pi_only], + lr=self.args.rl_lr, + momentum=self.args.momentum, + nesterov=(self.args.nesterov and self.args.momentum > 0)) + # self.opt = optim.Adam(self.model.parameters(), lr=0.01) + # self.opt = optim.RMSprop(self.model.parameters(), lr=0.0005) + self.all_rewards = [] + self.all_grads = [] + self.model.train() + + def print_dialog(self, dialog, reward, stats): + for t_id, turn in enumerate(dialog): + if t_id % 2 == 0: + print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>']))) + else: + print("Sys: {}".format(' '.join(turn))) + report = ['{}: {}'.format(k, v) for k, v in stats.items()] + print("Reward {}. {}".format(reward, report)) + + def run(self, batch, evaluator, max_words=None, temp=0.1): + self.logprobs = [] + self.dlg_history =[] + batch_size = len(batch['keys']) + logprobs, outs = self.model.forward_rl(batch, max_words, temp) + if batch_size == 1: + logprobs = [logprobs] + outs = [outs] + + key = batch['keys'][0] + sys_turns = [] + # construct the dialog history for printing + for turn_id, turn in enumerate(batch['contexts']): + user_input = self.corpus.id2sent(turn[-1]) + self.dlg_history.append(user_input) + sys_output = self.corpus.id2sent(outs[turn_id]) + self.dlg_history.append(sys_output) + sys_turns.append(' '.join(sys_output)) + + for log_prob in logprobs: + self.logprobs.extend(log_prob) + # compute reward here + generated_dialog = {key: sys_turns} + return evaluator.evaluateModel(generated_dialog, mode="offline_rl") + + def update(self, reward, stats): + self.all_rewards.append(reward) + # standardize the reward + r = (reward - np.mean(self.all_rewards)) / max(1e-4, np.std(self.all_rewards)) + # compute accumulated discounted reward + g = self.model.np2var(np.array([r]), FLOAT).view(1, 1) + rewards = [] + for _ in self.logprobs: + rewards.insert(0, g) + g = g * self.args.gamma + + loss = 0 + # estimate the loss using one MonteCarlo rollout + for lp, r in zip(self.logprobs, rewards): + loss -= lp * r + self.opt.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), self.args.rl_clip) + # for name, p in self.model.named_parameters(): + # print(name) + # print(p.grad) + self.opt.step() + +class OfflineLatentRlAgent(OfflineRlAgent): + def run(self, batch, evaluator, max_words=None, temp=0.1): + self.logprobs = [] + self.dlg_history =[] + batch_size = len(batch['keys']) + logprobs, outs, logprob_z, sample_z = self.model.forward_rl(batch, max_words, temp) + if batch_size == 1: + outs = [outs] + key = batch['keys'][0] + sys_turns = [] + # construct the dialog history for printing + for turn_id, turn in enumerate(batch['contexts']): + user_input = self.corpus.id2sent(turn[-1]) + self.dlg_history.append(user_input) + sys_output = self.corpus.id2sent(outs[turn_id]) + self.dlg_history.append(sys_output) + sys_turns.append(' '.join(sys_output)) + + for b_id in range(batch_size): + self.logprobs.append(logprob_z[b_id]) + # compute reward here + generated_dialog = {key: sys_turns} + return evaluator.evaluateModel(generated_dialog, mode="offline_rl") + diff --git a/convlab/policy/lava/multiwoz/latent_dialog/augpt_utils.py b/convlab/policy/lava/multiwoz/latent_dialog/augpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf317a08e4a1f9742bfd2f2141d9c9628344184 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/augpt_utils.py @@ -0,0 +1,558 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# +# Copyright © 2021 lubis <lubis@hilbert242> +# +# Distributed under terms of the MIT license. + +""" +utils from AuGPT codebase +""" +import re +import os +import sys +import types +import shutil +import logging +import requests +import torch +import zipfile +import bisect +import random +import copy +import json +from collections import OrderedDict, defaultdict +from typing import Callable, Union, Set, Optional, List, Dict, Any, Tuple, MutableMapping # noqa: 401 +from dataclasses import dataclass +import pdb + +DATASETS_PATH = os.path.join(os.path.expanduser(os.environ.get('DATASETS_PATH', '~/datasets')), 'augpt') +pricepat = re.compile("\d{1,3}[.]\d{1,2}") + +temp_path = os.path.dirname(os.path.abspath(__file__)) +fin = open(os.path.join("/home/lubis/datasets/augpt/mapping.pair")) +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + + + +class AutoDatabase: + @staticmethod + def load(pretrained_model_name_or_path): + database_file = os.path.join(pretrained_model_name_or_path, 'database.zip') + + with zipfile.ZipFile(database_file) as zipf: + def _build_database(): + module = types.ModuleType('database') + exec(zipf.read('database.py').decode('utf-8'), module.__dict__) + return module.Database(zipf) + + database = _build_database() + + + return database + +class BeliefParser: + def __init__(self): + self.slotval_re = re.compile(r"(\w[\w ]*\w) = ([\w\d: |']+)") + self.domain_re = re.compile(r"(\w+) {\s*([\w,= :\d|']*)\s*}", re.IGNORECASE) + + def __call__(self, raw_belief: str): + belief = OrderedDict() + for match in self.domain_re.finditer(raw_belief): + domain, domain_bs = match.group(1), match.group(2) + belief[domain] = {} + for slot_match in self.slotval_re.finditer(domain_bs): + slot, val = slot_match.group(1), slot_match.group(2) + belief[domain][slot] = val + return belief + +class AutoLexicalizer: + @staticmethod + def load(pretrained_model_name_or_path): + lexicalizer_file = os.path.join(pretrained_model_name_or_path, 'lexicalizer.zip') + + with zipfile.ZipFile(lexicalizer_file) as zipf: + def _build_lexicalizer(): + module = types.ModuleType('lexicalizer') + exec(zipf.read('lexicalizer.py').decode('utf-8'), module.__dict__) + # return module.Lexicalizer(zipf) + return Lexicalizer(zipf) + + lexicalizer = _build_lexicalizer() + + + return lexicalizer + +def build_blacklist(items, domains=None): + for i, (dialogue, items) in enumerate(items): + if domains is not None and set(dialogue['domains']).difference(domains): + yield i + elif items[-1]['speaker'] != 'system': + yield i + +class BlacklistItemsWrapper: + def __init__(self, items, blacklist): + self.items = items + self.key2idx = items.key2idx + self._indexmap = [] + blacklist_pointer = 0 + for i in range(len(items)): + if blacklist_pointer >= len(blacklist): + self._indexmap.append(i) + elif i < blacklist[blacklist_pointer]: + self._indexmap.append(i) + elif i == blacklist[blacklist_pointer]: + blacklist_pointer += 1 + assert len(self._indexmap) == len(items) - len(blacklist) + + def __getitem__(self, idx): + if isinstance(idx, str): + idx = self.key2idx[idx] + return self.items[self._indexmap[idx]] + + def __len__(self): + return len(self._indexmap) +def split_name(dataset_name: str): + split = dataset_name.rindex('/') + return dataset_name[:split], dataset_name[split + 1:] + +@dataclass +class DialogDatasetItem: + context: Union[List[str], str] + belief: Union[Dict[str, Dict[str, str]], str] = None + database: Union[List[Tuple[str, int]], List[Tuple[str, int, Any]], None, str] = None + response: str = None + positive: bool = True + raw_belief: Any = None + raw_response: str = None + key: str = None + + def __getattribute__(self, name): + val = object.__getattribute__(self, name) + if name == 'belief' and val is None and self.raw_belief is not None: + val = format_belief(self.raw_belief) + self.belief = val + return val + +@dataclass +class DialogDataset(torch.utils.data.Dataset): + items: List[any] + database: Any = None + domains: List[str] = None + lexicalizer: Any = None + transform: Callable[[Any], Any] = None + normalize_input: Callable[[str], str] = None + ontology: Dict[Tuple[str, str], Set[str]] = None + + @staticmethod + def build_dataset_without_database(items, *args, **kwargs): + return DialogDataset(items, FakeDatabase(), *args, **kwargs) + + def __getitem__(self, index): + item = self.items[index] + if self.transform is not None: + item = self.transform(item) + return item + + def __len__(self): + return len(self.items) + + def map(self, transformation): + def trans(x): + x = self.transform(x) + x = transformation(x) + return x + return dataclasses.replace(self, transform=trans) + + def finish(self, progressbar: Union[str, bool] = False): + if self.transform is None: + return self + + ontology = defaultdict(lambda: set()) + domains = set(self.domains) if self.domains else set() + + items = [] + for i in trange(len(self), + desc=progressbar if isinstance(progressbar, str) else 'loading dataset', + disable=not progressbar): + item = self[i] + for k, bs in item.raw_belief.items(): + domains.add(k) + for k2, val in bs.items(): + ontology[(k, k2)].add(val) + items.append(item) + if self.ontology: + ontology = merge_ontologies((self.ontology, ontology)) + return dataclasses.replace(self, items=items, transform=None, domains=domains, ontology=ontology) + +class DialogueItems: + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + r.append(e + s) + s += e + return r + + def __init__(self, dialogues): + lengths = [len(x['items']) for x in dialogues] + self.keys = [x['name'] for x in dialogues] + self.key2idx = {k:i for (i, k) in enumerate(self.keys)} + self.cumulative_sizes = DialogueItems.cumsum(lengths) + self.dialogues = dialogues + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dialogue_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dialogue_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dialogue_idx - 1] + return self.dialogues[dialogue_idx], self.dialogues[dialogue_idx]['items'][:sample_idx + 1] + + def __len__(self): + if not self.cumulative_sizes: + return 0 + return self.cumulative_sizes[-1] + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + +def augpt_normalize(text, delexicalize=True, online=False): + # lower case every word + text = text.lower() + + text = text.replace(" 1 ", " one ") + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # replace time and and price + if delexicalize: + text = re.sub(pricepat, ' [price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + if delexicalize and not online: + text = re.sub('[\"\:<>@\(\)]', '', text) + elif delexicalize and online: + text = re.sub('[\"\<>@\(\)]', '', text) + text = re.sub("(([^0-9]):([^0-9]|.*))|(([^0-9]|.*):([^0-9]))", "\\2\\3\\5\\6", text) #only replace colons if it's not surrounded by digits. this wasn't a problem in standalone LAVA because time is delexicalized before normalization + else: + text = re.sub('[\"\<>@\(\)]', '', text) + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + +def load_dataset(name, use_goal=False, context_window_size=15, domains=None, **kwargs) -> DialogDataset: + name, split = split_name(name) + path = os.path.join(DATASETS_PATH, name) + with open(os.path.join(path, f'{split}.json'), 'r') as f: + data = json.load(f, object_pairs_hook=OrderedDict) + dialogues = data['dialogues'] + items = DialogueItems(dialogues) + items = BlacklistItemsWrapper(items, list(build_blacklist(items, domains))) + + def transform(x): + dialogue, items = x + context = [s['text'] for s in items[:-1]] + if context_window_size is not None and context_window_size > 0: + context = context[-context_window_size:] + belief = items[-1]['belief'] + database = items[-1]['database'] + item = DialogDatasetItem(context, + raw_belief=belief, + database=database, + response=items[-1]['delexicalised_text'], + raw_response=items[-1]['text'], + key=dialogue['name']) + if use_goal: + setattr(item, 'goal', dialogue['goal']) + # MultiWOZ evaluation uses booked domains property + if 'booked_domains' in items[-1]: + setattr(item, 'booked_domains', items[-1]['booked_domains']) + setattr(item, 'dialogue_act', items[-1]['dialogue_act']) + setattr(item, 'active_domain', items[-1]['active_domain']) + return item + + dataset = DialogDataset(items, transform=transform, domains=data['domains']) + if os.path.exists(os.path.join(path, 'database.zip')): + dataset.database = AutoDatabase.load(path) + + if os.path.exists(os.path.join(path, 'lexicalizer.zip')): + dataset.lexicalizer = AutoLexicalizer.load(path) + + return dataset + +def format_belief(belief: OrderedDict) -> str: + assert isinstance(belief, OrderedDict) + str_bs = [] + for domain, domain_bs in belief.items(): + domain_bs = ', '.join([f'{slot} = {val}' for slot, val in sorted(domain_bs.items(), key=lambda x: x[0])]) + str_bs.extend([domain, '{' + domain_bs + '}']) + return ' '.join(str_bs) + +class Lexicalizer: + def __init__(self, zipf): + self.path = zipf.filename + + placeholder_re = re.compile(r'\[(\s*[\w_\s]+)\s*\]') + number_re = re.compile(r'.*(\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve)\s$') + time_re = re.compile(r'((?:\d{1,2}[:]\d{2,3})|(?:\d{1,2} (?:am|pm)))', re.IGNORECASE) + + @staticmethod + def ends_with_number(s): + return bool(Lexicalizer.number_re.match(s)) + + @staticmethod + def extend_database_results(database_results, belief): + # Augment database results from the belief state + database_results = OrderedDict(database_results) + if belief is not None: + for i, (domain, (num_results, results)) in enumerate(database_results.items()): + if domain not in belief: + continue + if num_results == 0: + database_results[domain] = (1, [belief[domain]]) + else: + new_results = [] + for r in results: + r = dict(**r) + for k, val in belief[domain].items(): + if k not in r: + r[k] = val + new_results.append(r) + database_results[domain] = (num_results, new_results) + return database_results + + @staticmethod + def extend_empty_database(database_results, belief): + # Augment database results from the belief state + database_results = OrderedDict(database_results) + if belief is not None: + for domain in belief.keys(): + if domain not in database_results.keys(): + if any([len(v) > 0 for v in belief[domain]["semi"].values()] + [len(v) > 0 for v in belief[domain]["book"].values()]): + database_results[domain] = (1, [belief[domain]]) + + return database_results + + def __call__(self, text, database_results, belief=None, context=None): + database_results = Lexicalizer.extend_database_results(database_results, belief) + database_results = Lexicalizer.extend_empty_database(database_results, belief) + result_index = 0 + last_assignment = defaultdict(set) + + def trans(label, span, force=False, loop=100): + nonlocal result_index + nonlocal last_assignment + result_str = None + + for domain, (count, results) in database_results.items(): + if domain in ["hotel", "attraction"] and label == "price": + label = "price range" + + # if count == 0: + # pdb.set_trace() + # # continue + # if label in result['semi']: + # result_str = result['semi'][label] + # elif label is result['book']: + # result_str = result['book'][label] + # else: + if domain == "train" and "arrive by" in results[0]["semi"]: + result = results[-1] + else: + result = results[result_index % len(results)] + # if domain == "train" and label == "id": + # label = "trainID" + if label in result: + result_str = result[label] + if result_str == '?': + result_str = 'unknown' + if label == 'price range' and result_str == 'moderate' and \ + not text[span[1]:].startswith(' price range') and \ + not text[span[1]:].startswith(' in price'): + result_str = 'moderately priced' + elif label in result['book']: + result_str = result['book'][label] + elif label in result['semi']: + result_str = result['semi'][label] + + # if label == 'type': + # pdb.set_trace() + # if text[:span[0]].endswith('no ') or text[:span[0]].endswith('any ') or \ + # text[:span[0]].endswith('some ') or Lexicalizer.ends_with_number(text[:span[0]]): + # if not result_str.endswith('s'): + # result_str += 's' + if label == 'time' and ('[leave at]' in text or '[arrive by]' in text) and \ + belief is not None and 'train' in belief and \ + any([k in belief['train'] for k in ('leave at', 'arrive by')]): + # this is a specific case in which additional [time] slot needs to be lexicalised + # directly from the belief state + # "The earliest train after [time] leaves at ... and arrives by ..." + if 'leave at' in belief['train']: + result_str = belief['train']['leave at'] + else: + result_str = belief['train']['arrive by'] + # elif label == 'time' and 'restaurant' in belief and 'book' in belief['restaurant']: + # result_str = belief['restaurant']['book']['time'] + elif label == 'count': + result_str = str(count) + elif label == 'price' and domain == "train" and "total" in text[:span[0]]: + try: + num_people = int(result['book']['people']) + except: + num_people = 1 + try: + result_str = str(float(result[label].split()[0]) * num_people) + " pounds" + except: + result_str = "" + elif force: + if label == 'time': + if 'leave at' in result or 'arrive by' in result: + if 'arrive' in text and 'arrive by' in result: + result_str = result['arrive by'].lstrip('0') + elif 'leave at' in result: + result_str = result['leave at'].lstrip('0') + elif context is not None and len(context) > 0: + last_utt = context[-1] + mtch = Lexicalizer.time_re.search(last_utt) + if mtch is not None: + result_str = mtch.group(1).lstrip('0') + elif label == 'name': + result_str = "the " + domain + # if result_str == "not mentioned": + # pdb.set_trace() + if result_str is not None: + break + if force and result_str is None: + # for domains with no database or cases with failed database search + # if domain == "hospital": + # if label == 'name': + # result_str = "Addenbrookes hospital" + # elif label == "postcode": + # result_str = "cb20qq" + # elif label == "address": + # result_str = "hills rd , cambridge" + # elif label == "phone": + # result_str = "01223216297" + # else: + if label == 'reference': + result_str = 'YF86GE4J' + elif label == 'phone': + result_str = '01223358966' + elif label == 'postcode': + result_str = 'cb11jg' + elif label == 'agent': + result_str = 'Cambridge Towninfo Centre' + elif label == 'stars': + result_str = '4' + elif label == 'car': + result_str = 'black honda taxi' + elif label == 'address': + result_str = 'Parkside, Cambridge' + elif label == 'name': + result_str = "it" + + if result_str is not None and result_str.lower() in last_assignment[label] and loop > 0: + result_index += 1 + return trans(label, force=force, loop=loop - 1, span=span) + + if result_str is not None: + last_assignment[label].add(result_str.lower()) + return result_str or f'[{label}]' + + text = Lexicalizer.placeholder_re.sub(lambda m: trans(m.group(1), span=m.span()), text) + text = Lexicalizer.placeholder_re.sub(lambda m: trans(m.group(1), force=True, span=m.span()), text) + return text, database_results + + def save(self, path): + shutil.copy(self.path, os.path.join(path, os.path.split(self.path)[-1])) diff --git a/convlab/policy/lava/multiwoz/latent_dialog/base_data_loaders.py b/convlab/policy/lava/multiwoz/latent_dialog/base_data_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..e96372acf9fcdbbf17bf393dc58158751212dacc --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/base_data_loaders.py @@ -0,0 +1,183 @@ +import numpy as np +import logging + + +class BaseDataLoaders(object): + def __init__(self, name): + self.data_size = None + self.indexes = None + self.name = name + + def _shuffle_indexes(self): + np.random.shuffle(self.indexes) + + def _shuffle_batch_indexes(self): + np.random.shuffle(self.batch_indexes) + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + self.ptr = 0 + self.batch_size = config.batch_size + self.num_batch = self.data_size // config.batch_size + + if verbose: + print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch)) + + if shuffle and not fix_batch: + self._shuffle_indexes() + + self.batch_indexes = [] + for i in range(self.num_batch): + self.batch_indexes.append(self.indexes[i*self.batch_size: (i+1)*self.batch_size]) + + if shuffle and fix_batch: + self._shuffle_batch_indexes() + + if verbose: + print('%s begins with %d batches' % (self.name, self.num_batch)) + + def next_batch(self): + if self.ptr < self.num_batch: + selected_ids = self.batch_indexes[self.ptr] + self.ptr += 1 + return self._prepare_batch(selected_index=selected_ids) + else: + return None + + def _prepare_batch(self, *args, **kwargs): + raise NotImplementedError('Have to override _prepare_batch()') + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + +class LongDataLoader(object): + """A special efficient data loader for TBPTT. Assume the data contains + N long sequences, each sequence has length k_i + + :ivar batch_size: the size of a minibatch + :ivar backward_size: how many steps in time to do BP + :ivar step_size: how fast we move the window + :ivar ptr: the current idx of batch + :ivar num_batch: the total number of batch + :ivar batch_indexes: a list of list. Each item is the IDs in this batch + :ivar grid_indexes: a list of (b_id, s_id, e_id). b_id is the index of + batch, s_id is the starting time id in that batch and e_id is the ending + time id. + :ivar indexes: a list, the ordered of sequences ID it should go through + :ivar data_size: the number of sequences, N. + :ivar data_lens: a list containing k_i + :ivar prev_alive_size: + :ivar name: the name of the this data loader + """ + logger = logging.getLogger() + + def __init__(self, name): + self.batch_size = 0 + self.backward_size = 0 + self.step_size = 0 + self.ptr = 0 + self.num_batch = None + self.batch_indexes = None # one batch is a dialog + self.grid_indexes = None # grid is the tokenized versiion + self.indexes = None + self.data_lens = None + self.data_size = None + self.name = name + + def _shuffle_batch_indexes(self): + np.random.shuffle(self.batch_indexes) + + def _shuffle_grid_indexes(self): + np.random.shuffle(self.grid_indexes) + + def _prepare_batch(self, cur_grid, prev_grid): + raise NotImplementedError("Have to override prepare batch") + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + + assert len(self.indexes) == self.data_size and \ + len(self.data_lens) == self.data_size + # make sure backward_size can be divided by step size + assert config.backward_size % config.step_size == 0 + + self.ptr = 0 + self.batch_size = config.batch_size + self.backward_size = config.backward_size + self.step_size = config.step_size + + # create batch indexes + temp_num_batch = self.data_size // config.batch_size + self.batch_indexes = [] + for i in range(temp_num_batch): + self.batch_indexes.append( + self.indexes[i * self.batch_size:(i + 1) * self.batch_size]) + + left_over = self.data_size - temp_num_batch * config.batch_size + if shuffle: + self._shuffle_batch_indexes() + + # create grid indexes + self.grid_indexes = [] + for idx, b_ids in enumerate(self.batch_indexes): + # assume the b_ids are sorted + all_lens = [self.data_lens[i] for i in b_ids] + max_len = self.data_lens[b_ids[0]] + min_len = self.data_lens[b_ids[-1]] + assert np.max(all_lens) == max_len + assert np.min(all_lens) == min_len + num_seg = (max_len - self.backward_size - self.step_size) // self.step_size + cut_start, cut_end = [], [] + if num_seg > 1: + cut_start = list(range(config.step_size, num_seg * config.step_size, config.step_size)) + cut_end = list(range(config.backward_size + config.step_size, + num_seg * config.step_size + config.backward_size, + config.step_size)) + assert cut_end[-1] < max_len + + actual_size = min(max_len, config.backward_size) + temp_end = list(range(2, actual_size, config.step_size)) + temp_start = [0] * len(temp_end) + + cut_start = temp_start + cut_start + cut_end = temp_end + cut_end + + assert len(cut_end) == len(cut_start) + new_grids = [(idx, s_id, e_id) for s_id, e_id in + zip(cut_start, cut_end) if s_id < min_len - 1] + + self.grid_indexes.extend(new_grids) + + # shuffle batch indexes + if shuffle: + self._shuffle_grid_indexes() + + self.num_batch = len(self.grid_indexes) + if verbose: + self.logger.info("%s init with %d batches with %d left over samples" % + (self.name, self.num_batch, left_over)) + + def next_batch(self): + if self.ptr < self.num_batch: + current_grid = self.grid_indexes[self.ptr] + if self.ptr > 0: + prev_grid = self.grid_indexes[self.ptr - 1] + else: + prev_grid = None + self.ptr += 1 + return self._prepare_batch(cur_grid=current_grid, + prev_grid=prev_grid) + else: + return None + + def pad_to(self, max_len, tokens, do_pad=True): + if len(tokens) >= max_len: + return tokens[0:max_len - 1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens diff --git a/convlab/policy/lava/multiwoz/latent_dialog/base_models.py b/convlab/policy/lava/multiwoz/latent_dialog/base_models.py new file mode 100644 index 0000000000000000000000000000000000000000..3aff3014678d9529dbcdae846825e02b7c8b2be8 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/base_models.py @@ -0,0 +1,110 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +import numpy as np +from convlab.policy.lava.multiwoz.latent_dialog.utils import INT, FLOAT, LONG, cast_type +import pdb + + +class BaseModel(nn.Module): + def __init__(self, config): + super(BaseModel, self).__init__() + self.use_gpu = config.use_gpu + self.config = config + self.kl_w = 0.0 + + def np2var(self, inputs, dtype): + if inputs is None: + return None + return cast_type(Variable(th.from_numpy(inputs)), + dtype, + self.use_gpu) + + def forward(self, *inputs): + raise NotImplementedError + + def backward(self, loss, batch_cnt): + total_loss = self.valid_loss(loss, batch_cnt) + total_loss.backward() + + def valid_loss(self, loss, batch_cnt=None): + total_loss = 0.0 + for k, l in loss.items(): + if l is not None: + total_loss += l + return total_loss + + def get_optimizer(self, config, verbose=True): + if config.op == 'adam': + if verbose: + print('Use Adam') + return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=config.init_lr, weight_decay=config.l2_norm) + elif config.op == 'sgd': + print('Use SGD') + return optim.SGD(self.parameters(), lr=config.init_lr, momentum=config.momentum) + elif config.op == 'rmsprop': + print('Use RMSProp') + return optim.RMSprop(self.parameters(), lr=config.init_lr, momentum=config.momentum) + + def get_clf_optimizer(self, config): + params = [] + params.extend(self.gru_attn_encoder.parameters()) + params.extend(self.feat_projecter.parameters()) + params.extend(self.sel_classifier.parameters()) + + if config.fine_tune_op == 'adam': + print('Use Adam') + return optim.Adam(params, lr=config.fine_tune_lr) + elif config.fine_tune_op == 'sgd': + print('Use SGD') + return optim.SGD(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum) + elif config.fine_tune_op == 'rmsprop': + print('Use RMSProp') + return optim.RMSprop(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum) + + + def model_sel_loss(self, loss, batch_cnt): + return self.valid_loss(loss, batch_cnt) + + + def extract_short_ctx(self, context, context_lens, backward_size=1): + utts = [] + for b_id in range(context.shape[0]): + utts.append(context[b_id, context_lens[b_id]-1]) + return np.array(utts) + + def flatten_context(self, context, context_lens, align_right=False): + utts = [] + temp_lens = [] + for b_id in range(context.shape[0]): + temp = [] + for t_id in range(context_lens[b_id]): + for token in context[b_id, t_id]: + if token != 0: + temp.append(token) + temp_lens.append(len(temp)) + utts.append(temp) + max_temp_len = np.max(temp_lens) + results = np.zeros((context.shape[0], max_temp_len)) + for b_id in range(context.shape[0]): + if align_right: + results[b_id, -temp_lens[b_id]:] = utts[b_id] + else: + results[b_id, 0:temp_lens[b_id]] = utts[b_id] + + return results + +def frange_cycle_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio=0.5): + L = np.ones(n_iter) * stop + period = n_iter/n_cycle + step = (stop-start)/(period*ratio) # linear schedule + + for c in range(n_cycle): + v, i = start, 0 + while v <= stop and (int(i+c*period) < n_iter): + L[int(i+c*period)] = v + v += step + i += 1 + return L diff --git a/convlab/policy/lava/multiwoz/latent_dialog/corpora.py b/convlab/policy/lava/multiwoz/latent_dialog/corpora.py new file mode 100644 index 0000000000000000000000000000000000000000..4d406f394a7d43f44a55d9c3f7030cb6fa1cd103 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/corpora.py @@ -0,0 +1,419 @@ +from __future__ import unicode_literals +import numpy as np +from collections import Counter +from convlab.policy.lava.multiwoz.latent_dialog.utils import Pack, get_tokenize, get_chat_tokenize, missingdict +import json +from nltk.tokenize import WordPunctTokenizer +import logging +from collections import defaultdict +import pdb + +PAD = '<pad>' +UNK = '<unk>' +USR = 'YOU:' +SYS = 'THEM:' +BOD = '<d>' +EOD = '</d>' +BOS = '<s>' +EOS = '<eos>' +SEL = '<selection>' +SEP = "|" +REQ = "<requestable>" +INF = "<informable>" +WILD = "%s" +SPECIAL_TOKENS = [PAD, UNK, USR, SYS, BOS, BOD, EOS, EOD] +STOP_TOKENS = [EOS, SEL] +DECODING_MASKED_TOKENS = [PAD, UNK, USR, SYS, BOD] + +REQ_TOKENS = {} +DOMAIN_REQ_TOKEN = ['restaurant', 'hospital', 'hotel','attraction', 'train', 'police', 'taxi'] +ACTIVE_BS_IDX = [13, 30, 35, 61, 72, 91, 93] #indexes in the BS indicating if domain is active +NO_MATCH_DB_IDX = [-1, 0, -1, 6, 12, 18, -1] # indexes in DB pointer indicating 0 match is found for that domain, -1 mean that domain has no DB +REQ_TOKENS['attraction'] = ["[attraction_address]", "[attraction_name]", "[attraction_phone]", "[attraction_postcode]", "[attraction_reference]", "[attraction_type]"] +REQ_TOKENS['hospital'] = ["[hospital_address]", "[hospital_department]", "[hospital_name]", "[hospital_phone]", "[hospital_postcode]"] #, "[hospital_reference]" +REQ_TOKENS['hotel'] = ["[hotel_address]", "[hotel_name]", "[hotel_phone]", "[hotel_postcode]", "[hotel_reference]", "[hotel_type]"] +REQ_TOKENS['restaurant'] = ["[restaurant_name]", "[restaurant_address]", "[restaurant_phone]", "[restaurant_postcode]", "[restaurant_reference]"] +REQ_TOKENS['train'] = ["[train_id]", "[train_reference]"] +REQ_TOKENS['police'] = ["[police_address]", "[police_phone]", "[police_postcode]"] #"[police_name]", +REQ_TOKENS['taxi'] = ["[taxi_phone]", "[taxi_type]"] + +GENERIC_TOKENS = ["[value_area]", "[value_count]", "[value_day]", "[value_food]", "[value_place]", "[value_price]", "[value_pricerange]", "[value_time]"] + + + +class NormMultiWozCorpus(object): + logger = logging.getLogger() + + def __init__(self, config): + self.bs_size = 94 + self.db_size = 30 + self.bs_types =['b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b'] + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + self.info_types = ['book', 'fail_book', 'fail_info', 'info', 'reqt'] + self.config = config + self.tokenize = lambda x: x.split() + self.train_corpus, self.val_corpus, self.test_corpus = self._read_file(self.config) + self._extract_vocab() + self._extract_goal_vocab() + self.logger.info('Loading corpus finished.') + + def _read_file(self, config): + train_data = json.load(open(config.train_path)) + valid_data = json.load(open(config.valid_path)) + test_data = json.load(open(config.test_path)) + + train_data = self._process_dialogue(train_data) + valid_data = self._process_dialogue(valid_data) + test_data = self._process_dialogue(test_data) + + return train_data, valid_data, test_data + + def _process_dialogue(self, data): + new_dlgs = [] + all_sent_lens = [] + all_dlg_lens = [] + + for key, raw_dlg in data.items(): + norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)] + for t_id in range(len(raw_dlg['db'])): + usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS] + sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS] + norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) + norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) + all_sent_lens.extend([len(usr_utt), len(sys_utt)]) + # To stop dialog + norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + # if self.config.to_learn == 'usr': + # norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + all_dlg_lens.append(len(raw_dlg['db'])) + processed_goal = self._process_goal(raw_dlg['goal']) + new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key)) + + self.logger.info('Max utt len = %d, mean utt len = %.2f' % ( + np.max(all_sent_lens), float(np.mean(all_sent_lens)))) + self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % ( + np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) + return new_dlgs + + def _extract_vocab(self): + all_words = [] + for dlg in self.train_corpus: + for turn in dlg.dlg: + all_words.extend(turn.utt) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + keep_vocab_size = min(self.config.max_vocab_size, raw_vocab_size) + oov_rate = np.sum([c for t, c in vocab_count[0:keep_vocab_size]]) / float(len(all_words)) + + self.logger.info('cut off at word {} with frequency={},\n'.format(vocab_count[keep_vocab_size - 1][0], + vocab_count[keep_vocab_size - 1][1]) + + 'OOV rate = {:.2f}%'.format(100.0 - oov_rate * 100)) + + vocab_count = vocab_count[0:keep_vocab_size] + self.vocab = SPECIAL_TOKENS + [t for t, cnt in vocab_count if t not in SPECIAL_TOKENS] + self.vocab_dict = {t: idx for idx, t in enumerate(self.vocab)} + self.unk_id = self.vocab_dict[UNK] + self.logger.info("Raw vocab size {} in train set and final vocab size {}".format(raw_vocab_size, len(self.vocab))) + + def _process_goal(self, raw_goal): + res = {} + for domain in self.domains: + all_words = [] + d_goal = raw_goal[domain] + if d_goal: + for info_type in self.info_types: + sv_info = d_goal.get(info_type, dict()) + if info_type == 'reqt' and isinstance(sv_info, list): + all_words.extend([info_type + '|' + item for item in sv_info]) + elif isinstance(sv_info, dict): + all_words.extend([info_type + '|' + k + '|' + str(v) for k, v in sv_info.items()]) + else: + print('Fatal Error!') + exit(-1) + res[domain] = all_words + return res + + def _extract_goal_vocab(self): + self.goal_vocab, self.goal_vocab_dict, self.goal_unk_id = {}, {}, {} + for domain in self.domains: + all_words = [] + for dlg in self.train_corpus: + all_words.extend(dlg.goal[domain]) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + discard_wc = np.sum([c for t, c in vocab_count]) + + self.logger.info('================= domain = {}, \n'.format(domain) + + 'goal vocab size of train set = %d, \n' % (raw_vocab_size,) + + 'cut off at word %s with frequency = %d, \n' % (vocab_count[-1][0], vocab_count[-1][1]) + + 'OOV rate = %.2f' % (1 - float(discard_wc) / len(all_words),)) + + self.goal_vocab[domain] = [UNK] + [g for g, cnt in vocab_count] + self.goal_vocab_dict[domain] = {t: idx for idx, t in enumerate(self.goal_vocab[domain])} + self.goal_unk_id[domain] = self.goal_vocab_dict[domain][UNK] + + def get_corpus(self): + id_train = self._to_id_corpus('Train', self.train_corpus) + id_val = self._to_id_corpus('Valid', self.val_corpus) + id_test = self._to_id_corpus('Test', self.test_corpus) + return id_train, id_val, id_test + + def _to_id_corpus(self, name, data): + results = [] + for dlg in data: + if len(dlg.dlg) < 1: + continue + id_dlg = [] + for turn in dlg.dlg: + id_turn = Pack(utt=self._sent2id(turn.utt), + speaker=turn.speaker, + db=turn.db, bs=turn.bs) + id_dlg.append(id_turn) + id_goal = self._goal2id(dlg.goal) + results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key)) + return results + + def _sent2id(self, sent): + return [self.vocab_dict.get(t, self.unk_id) for t in sent] + + def _goal2id(self, goal): + res = {} + for domain in self.domains: + d_bow = [0.0] * len(self.goal_vocab[domain]) + for word in goal[domain]: + word_id = self.goal_vocab_dict[domain].get(word, self.goal_unk_id[domain]) + d_bow[word_id] += 1.0 + res[domain] = d_bow + return res + + def id2sent(self, id_list): + return [self.vocab[i] for i in id_list] + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + +class NormMultiWozCorpusAE(object): + logger = logging.getLogger() + + def __init__(self, config): + self.bs_size = 94 + self.db_size = 30 + self.bs_types =['b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b'] + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + self.info_types = ['book', 'fail_book', 'fail_info', 'info', 'reqt'] + self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome'] + self.act2id = {a:i for i, a in enumerate(self.act_types)} + self.id2act = {i:a for i, a in enumerate(self.act_types)} + self.act_size = len(self.act_types) #domain agnostic act + # self.act_size = len(domain) * len(self.act_types) #domain dependent act + self.config = config + self.tokenize = lambda x: x.split() + self.train_corpus, self.val_corpus, self.test_corpus = self._read_file(self.config) + self._extract_vocab() + self._extract_goal_vocab() + self.logger.info('Loading corpus finished.') + + def _read_file(self, config): + train_data = json.load(open(config.train_path)) + valid_data = json.load(open(config.valid_path)) + test_data = json.load(open(config.test_path)) + dacts = json.load(open(config.dact_path)) + + train_data = self._process_dialogue(train_data, dacts) + valid_data = self._process_dialogue(valid_data, dacts) + test_data = self._process_dialogue(test_data, dacts) + + return train_data, valid_data, test_data + + def _process_dialogue(self, data, dacts): + new_dlgs = [] + all_sent_lens = [] + all_dlg_lens = [] + dact_skip_count = 0 + + for key, raw_dlg in data.items(): + norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size, act=[0.0]*self.act_size)] + if key.split(".")[0].lower() in dacts: + for t_id in range(len(raw_dlg['db'])): + usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS] + sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS] + # sys_act = self._process_multidomain_summary_acts(dacts[key.split(".")[0].lower()][str(t_id)]) + try: + sys_act = self._process_summary_acts(dacts[key.split(".")[0].lower()][str(t_id)]) + except: + sys_act = [float(0)] * self.act_size + norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id], act=sys_act)) + norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id], act=sys_act)) + all_sent_lens.extend([len(usr_utt), len(sys_utt)]) + # To stop dialog + norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size, act=[0.0]*self.act_size)) + # if self.config.to_learn == 'usr': + # norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + all_dlg_lens.append(len(raw_dlg['db'])) + processed_goal = self._process_goal(raw_dlg['goal']) + new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key)) + else: + dact_skip_count += 1 + + self.logger.info('{} sessions skipped due to missing dialogue act label'.format(dact_skip_count)) + self.logger.info('Max utt len = %d, mean utt len = %.2f' % ( + np.max(all_sent_lens), float(np.mean(all_sent_lens)))) + self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % ( + np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) + return new_dlgs + + def _extract_vocab(self): + all_words = [] + for dlg in self.train_corpus: + for turn in dlg.dlg: + all_words.extend(turn.utt) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + keep_vocab_size = min(self.config.max_vocab_size, raw_vocab_size) + oov_rate = np.sum([c for t, c in vocab_count[0:keep_vocab_size]]) / float(len(all_words)) + + self.logger.info('cut off at word {} with frequency={},\n'.format(vocab_count[keep_vocab_size - 1][0], + vocab_count[keep_vocab_size - 1][1]) + + 'OOV rate = {:.2f}%'.format(100.0 - oov_rate * 100)) + + vocab_count = vocab_count[0:keep_vocab_size] + self.vocab = SPECIAL_TOKENS + [t for t, cnt in vocab_count if t not in SPECIAL_TOKENS] + self.vocab_dict = {t: idx for idx, t in enumerate(self.vocab)} + self.unk_id = self.vocab_dict[UNK] + self.logger.info("Raw vocab size {} in train set and final vocab size {}".format(raw_vocab_size, len(self.vocab))) + + def _process_goal(self, raw_goal): + res = {} + for domain in self.domains: + all_words = [] + d_goal = raw_goal[domain] + if d_goal: + for info_type in self.info_types: + sv_info = d_goal.get(info_type, dict()) + if info_type == 'reqt' and isinstance(sv_info, list): + all_words.extend([info_type + '|' + item for item in sv_info]) + elif isinstance(sv_info, dict): + all_words.extend([info_type + '|' + k + '|' + str(v) for k, v in sv_info.items()]) + else: + print('Fatal Error!') + exit(-1) + res[domain] = all_words + return res + + def _process_multidomain_summary_acts(self, dact): + """ + process dialogue action dictionary into binary vector representation + each domain has its own vector, and final output is the flattened respresentation of each domain's action + """ + res = {} + # dact = {domain:{action:[slot]}, domain:{action:[slot]}} + for domain in self.domains: + res[domain] = np.zeros(len(self.act_types)) + if domain in dact.keys(): + for i in range(len(self.act_types)): + if self.act_types[i] in dact[domain].keys(): + res[domain][i] = 1 + + + # multiwoz dact = {domain-act:[[slot, value], [slot, value]]} + # for domain in self.domains: + # res[domain] = np.zeros(len(self.act_types)) + # for k in dact.keys(): + # d = k.split("-")[0].lower() + # a = k.split("-")[1].lower() + + # res[d][self.act2id[a]] = 1 + + flat_res = [act for domain in sorted(self.domains) for act in res[domain]] + return flat_res + + def _process_summary_acts(self, dact): + """ + process dialogue action dictionary into binary vector representation, ignoring domain information + """ + res = np.zeros(len(self.act_types)) + # damd dact = {domain:{action:[slot]}, domain:{action:[slot]}} + for domain in self.domains: + if domain in dact.keys(): + for i in range(len(self.act_types)): + if self.act_types[i] in dact[domain].keys(): + res[i] = 1 + + # multiwoz dact = {domain-act:[[slot, value], [slot, value]]} + # for k in dact.keys(): + # # d = k.split("-")[0].lower() + # a = k.split("-")[1].lower() + + # res[self.act2id[a]] = 1 + + return list(res) + + def _extract_goal_vocab(self): + self.goal_vocab, self.goal_vocab_dict, self.goal_unk_id = {}, {}, {} + for domain in self.domains: + all_words = [] + for dlg in self.train_corpus: + all_words.extend(dlg.goal[domain]) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + discard_wc = np.sum([c for t, c in vocab_count]) + + self.logger.info('================= domain = {}, \n'.format(domain) + + 'goal vocab size of train set = %d, \n' % (raw_vocab_size,) + + 'cut off at word %s with frequency = %d, \n' % (vocab_count[-1][0], vocab_count[-1][1]) + + 'OOV rate = %.2f' % (1 - float(discard_wc) / len(all_words),)) + + self.goal_vocab[domain] = [UNK] + [g for g, cnt in vocab_count] + self.goal_vocab_dict[domain] = {t: idx for idx, t in enumerate(self.goal_vocab[domain])} + self.goal_unk_id[domain] = self.goal_vocab_dict[domain][UNK] + + def get_corpus(self): + id_train = self._to_id_corpus('Train', self.train_corpus) + id_val = self._to_id_corpus('Valid', self.val_corpus) + id_test = self._to_id_corpus('Test', self.test_corpus) + return id_train, id_val, id_test + + def _to_id_corpus(self, name, data): + results = [] + for dlg in data: + if len(dlg.dlg) < 1: + continue + id_dlg = [] + for turn in dlg.dlg: + id_turn = Pack(utt=self._sent2id(turn.utt), + speaker=turn.speaker, + db=turn.db, bs=turn.bs, act=turn.act) + id_dlg.append(id_turn) + id_goal = self._goal2id(dlg.goal) + results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key)) + return results + + def _sent2id(self, sent): + return [self.vocab_dict.get(t, self.unk_id) for t in sent] + + def _goal2id(self, goal): + res = {} + for domain in self.domains: + d_bow = [0.0] * len(self.goal_vocab[domain]) + for word in goal[domain]: + word_id = self.goal_vocab_dict[domain].get(word, self.goal_unk_id[domain]) + d_bow[word_id] += 1.0 + res[domain] = d_bow + return res + + def id2sent(self, id_list): + return [self.vocab[i] for i in id_list] + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + diff --git a/convlab/policy/lava/multiwoz/latent_dialog/criterions.py b/convlab/policy/lava/multiwoz/latent_dialog/criterions.py new file mode 100644 index 0000000000000000000000000000000000000000..0bba8db136175c2ccbefea67c99c2a53771ba6b5 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/criterions.py @@ -0,0 +1,200 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +import numpy as np +from convlab.policy.lava.multiwoz.latent_dialog import domain +from convlab.policy.lava.multiwoz.latent_dialog.utils import LONG + + +class NLLEntropy(_Loss): + def __init__(self, padding_idx, avg_type): + super(NLLEntropy, self).__init__() + self.padding_idx = padding_idx + self.avg_type = avg_type + + def forward(self, net_output, labels): + batch_size = net_output.size(0) + pred = net_output.view(-1, net_output.size(-1)) + target = labels.view(-1) + + if self.avg_type is None: + loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx) + elif self.avg_type == 'seq': + loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx) + loss = loss / batch_size + elif self.avg_type == 'real_word': + loss = F.nll_loss(pred, target, ignore_index=self.padding_idx, reduce=False) + loss = loss.view(-1, net_output.size(1)) + loss = th.sum(loss, dim=1) + word_cnt = th.sum(th.sign(labels), dim=1).float() + loss = loss / word_cnt + loss = th.mean(loss) + elif self.avg_type == 'word': + loss = F.nll_loss(pred, target, reduction='mean', ignore_index=self.padding_idx) + else: + raise ValueError('Unknown average type') + + return loss + +class WeightedNLLEntropy(_Loss): + def __init__(self, padding_idx, avg_type, weight): + super(WeightedNLLEntropy, self).__init__() + self.padding_idx = padding_idx + self.avg_type = avg_type + self.weight = weight + + def forward(self, net_output, labels): + batch_size = net_output.size(0) + pred = net_output.view(-1, net_output.size(-1)) + target = labels.view(-1) + if self.avg_type == 'slot': + loss = F.nll_loss(pred, target, weight = self.weight, reduction='mean', ignore_index=self.padding_idx) + + return loss + +class NLLEntropy4CLF(_Loss): + def __init__(self, dictionary, bad_tokens=['<disconnect>', '<disagree>'], reduction='elementwise_mean'): + super(NLLEntropy4CLF, self).__init__() + w = th.Tensor(len(dictionary)).fill_(1) + for token in bad_tokens: + w[dictionary[token]] = 0.0 + self.crit = nn.CrossEntropyLoss(w, reduction=reduction) + + def forward(self, preds, labels): + # preds: (batch_size, outcome_len, outcome_vocab_size) + # labels: (batch_size, outcome_len) + preds = preds.view(-1, preds.size(-1)) + labels = labels.view(-1) + return self.crit(preds, labels) + + +class CombinedNLLEntropy4CLF(_Loss): + def __init__(self, dictionary, corpus, np2var, bad_tokens=['<disconnect>', '<disagree>']): + super(CombinedNLLEntropy4CLF, self).__init__() + self.dictionary = dictionary + self.domain = domain.get_domain('object_division') + self.corpus = corpus + self.np2var = np2var + self.bad_tokens = bad_tokens + + def forward(self, preds, goals_id, outcomes_id): + # preds: (batch_size, outcome_len, outcome_vocab_size) + # goals_id: list of list, id, batch_size*goal_len + # outcomes_id: list of list, id, batch_size*outcome_len + batch_size = len(goals_id) + losses = [] + for bth in range(batch_size): + pred = preds[bth] # (outcome_len, outcome_vocab_size) + goal = goals_id[bth] # list, id, len=goal_len + goal_str = self.corpus.id2goal(goal) # list, str, len=goal_len + outcome = outcomes_id[bth] # list, id, len=outcome_len + outcome_str = self.corpus.id2outcome(outcome) # list, str, len=outcome_len + + if outcome_str[0] in self.bad_tokens: + continue + + # get all the possible choices + choices = self.domain.generate_choices(goal_str) + sel_outs = [pred[i] for i in range(pred.size(0))] # outcome_len*(outcome_vocab_size, ) + + choices_logits = [] # outcome_len*(option_amount, 1) + for i in range(self.domain.selection_length()): + idxs = np.array([self.dictionary[c[i]] for c in choices]) + idxs_var = self.np2var(idxs, LONG) # (option_amount, ) + choices_logits.append(th.gather(sel_outs[i], 0, idxs_var).unsqueeze(1)) + + choice_logit = th.sum(th.cat(choices_logits, 1), 1, keepdim=False) # (option_amount, ) + choice_logit = choice_logit.sub(choice_logit.max().item()) # (option_amount, ) + prob = F.softmax(choice_logit, dim=0) # (option_amount, ) + + label = choices.index(outcome_str) + target_prob = prob[label] + losses.append(-th.log(target_prob)) + return sum(losses) / float(len(losses)) + + +class CatKLLoss(_Loss): + def __init__(self): + super(CatKLLoss, self).__init__() + + def forward(self, log_qy, log_py, batch_size=None, unit_average=False): + """ + qy * log(q(y)/p(y)) + """ + qy = th.exp(log_qy) + y_kl = th.sum(qy * (log_qy - log_py), dim=1) + if unit_average: + return th.mean(y_kl) + else: + return th.sum(y_kl)/batch_size + + +class Entropy(_Loss): + def __init__(self): + super(Entropy, self).__init__() + + def forward(self, log_qy, batch_size=None, unit_average=False): + """ + -qy log(qy) + """ + if log_qy.dim() > 2: + log_qy = log_qy.squeeze() + qy = th.exp(log_qy) + h_q = th.sum(-1 * log_qy * qy, dim=1) + if unit_average: + return th.mean(h_q) + else: + return th.sum(h_q) / batch_size + + +class GaussianEntropy(_Loss): + def __init__(self): + super(GaussianEntropy, self).__init__() + + def forward(self, mu, logvar): + """ + 0.5 (log(mu*var)) + 0.5 + """ + std = th.exp(0.5 * logvar) + var = th.square(std) + h_q = 0.5 * (th.log(2 * math.pi * var)) + 0.5 + + return th.mean(h_q) + + +class BinaryNLLEntropy(_Loss): + + def __init__(self, size_average=True): + super(BinaryNLLEntropy, self).__init__() + self.size_average = size_average + + def forward(self, net_output, label_output): + """ + :param net_output: batch_size x + :param labels: + :return: + """ + batch_size = net_output.size(0) + loss = F.binary_cross_entropy_with_logits(net_output, label_output, size_average=self.size_average) + if self.size_average is False: + loss /= batch_size + return loss + + +class NormKLLoss(_Loss): + def __init__(self, unit_average=False): + super(NormKLLoss, self).__init__() + self.unit_average = unit_average + + def forward(self, recog_mu, recog_logvar, prior_mu, prior_logvar): + # find the KL divergence between two Gaussian distribution + loss = 1.0 + (recog_logvar - prior_logvar) + loss -= th.div(th.pow(prior_mu - recog_mu, 2), th.exp(prior_logvar)) + loss -= th.div(th.exp(recog_logvar), th.exp(prior_logvar)) + if self.unit_average: + kl_loss = -0.5 * th.mean(loss, dim=1) + else: + kl_loss = -0.5 * th.sum(loss, dim=1) + avg_kl_loss = th.mean(kl_loss) + return avg_kl_loss diff --git a/convlab/policy/lava/multiwoz/latent_dialog/data_loaders.py b/convlab/policy/lava/multiwoz/latent_dialog/data_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc252422ee4a14e18266463e3e598f6c9100047 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/data_loaders.py @@ -0,0 +1,278 @@ +import numpy as np +import pdb +from convlab.policy.lava.multiwoz.latent_dialog.utils import Pack +from convlab.policy.lava.multiwoz.latent_dialog.base_data_loaders import BaseDataLoaders, LongDataLoader +from convlab.policy.lava.multiwoz.latent_dialog.corpora import USR, SYS +import json + + +class BeliefDbDataLoaders(BaseDataLoaders): + def __init__(self, name, data, config): + super(BeliefDbDataLoaders, self).__init__(name) + self.max_utt_len = config.max_utt_len + self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size) + self.data_size = len(self.data) + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + + def flatten_dialog(self, data, backward_size): + results = [] + indexes = [] + batch_indexes = [] + resp_set = set() + for dlg in data: + goal = dlg.goal + key = dlg.key + batch_index = [] + for i in range(1, len(dlg.dlg)): + if dlg.dlg[i].speaker == USR: + continue + e_idx = i + s_idx = max(0, e_idx - backward_size) + response = dlg.dlg[i].copy() + response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) + resp_set.add(json.dumps(response.utt)) + context = [] + for turn in dlg.dlg[s_idx: e_idx]: + turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) + context.append(turn) + results.append(Pack(context=context, response=response, goal=goal, key=key)) + indexes.append(len(indexes)) + batch_index.append(indexes[-1]) + if len(batch_index) > 0: + batch_indexes.append(batch_index) + print("Unique resp {}".format(len(resp_set))) + return results, indexes, batch_indexes + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + self.ptr = 0 + if fix_batch: + self.batch_size = None + self.num_batch = len(self.batch_indexes) + else: + self.batch_size = config.batch_size + self.num_batch = self.data_size // config.batch_size + self.batch_indexes = [] + for i in range(self.num_batch): + self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size]) + if verbose: + print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch)) + if shuffle: + if fix_batch: + self._shuffle_batch_indexes() + else: + self._shuffle_indexes() + + if verbose: + print('%s begins with %d batches' % (self.name, self.num_batch)) + + def _prepare_batch(self, selected_index): + rows = [self.data[idx] for idx in selected_index] + + ctx_utts, ctx_lens = [], [] + out_utts, out_lens = [], [] + + out_bs, out_db = [] , [] + goals, goal_lens = [], [[] for _ in range(len(self.domains))] + keys = [] + + for row in rows: + in_row, out_row, goal_row = row.context, row.response, row.goal + + # source context + keys.append(row.key) + batch_ctx = [] + for turn in in_row: + batch_ctx.append(self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) + ctx_utts.append(batch_ctx) + ctx_lens.append(len(batch_ctx)) + + # target response + out_utt = [t for idx, t in enumerate(out_row.utt)] + out_utts.append(out_utt) + out_lens.append(len(out_utt)) + + out_bs.append(out_row.bs) + out_db.append(out_row.db) + + # goal + goals.append(goal_row) + for i, d in enumerate(self.domains): + goal_lens[i].append(len(goal_row[d])) + + batch_size = len(ctx_lens) + vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns + max_ctx_len = np.max(vec_ctx_lens) + vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) + vec_out_bs = np.array(out_bs) # (batch_size, 94) + vec_out_db = np.array(out_db) # (batch_size, 30) + vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens + max_out_len = np.max(vec_out_lens) + vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32) + + max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens] + if max_goal_lens != min_goal_lens: + print('Fatal Error!') + exit(-1) + self.goal_lens = max_goal_lens + vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens] + + for b_id in range(batch_size): + vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] + vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] + for i, d in enumerate(self.domains): + vec_goals_list[i][b_id, :] = goals[b_id][d] + + return Pack(context_lens=vec_ctx_lens, # (batch_size, ) + contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len) + output_lens=vec_out_lens, # (batch_size, ) + outputs=vec_out_utts, # (batch_size, max_out_len) + bs=vec_out_bs, # (batch_size, 94) + db=vec_out_db, # (batch_size, 30) + goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain + keys=keys) + +class BeliefDbDataLoadersAE(BaseDataLoaders): + def __init__(self, name, data, config): + super(BeliefDbDataLoadersAE, self).__init__(name) + self.max_utt_len = config.max_utt_len + self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size) + self.data_size = len(self.data) + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome'] + if "ae_zero_pad" in config.keys(): + self.zero_pad = config.ae_zero_pad + else: + self.zero_pad = False + + def flatten_dialog(self, data, backward_size): + results = [] + indexes = [] + batch_indexes = [] + resp_set = set() + for dlg in data: + goal = dlg.goal + key = dlg.key + batch_index = [] + for i in range(1, len(dlg.dlg)): + if dlg.dlg[i].speaker == USR: + continue + e_idx = i + s_idx = max(0, e_idx - backward_size) + response = dlg.dlg[i].copy() + response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) + resp_set.add(json.dumps(response.utt)) + context = [] + for turn in dlg.dlg[s_idx: e_idx]: + turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) + context.append(turn) + results.append(Pack(context=context, response=response, goal=goal, key=key)) + indexes.append(len(indexes)) + batch_index.append(indexes[-1]) + if len(batch_index) > 0: + batch_indexes.append(batch_index) + print("Unique resp {}".format(len(resp_set))) + return results, indexes, batch_indexes + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + self.ptr = 0 + if fix_batch: + self.batch_size = None + self.num_batch = len(self.batch_indexes) + else: + self.batch_size = config.batch_size + self.num_batch = self.data_size // config.batch_size + self.batch_indexes = [] + for i in range(self.num_batch): + self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size]) + if verbose: + print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch)) + if shuffle: + if fix_batch: + self._shuffle_batch_indexes() + else: + self._shuffle_indexes() + + if verbose: + print('%s begins with %d batches' % (self.name, self.num_batch)) + + def _prepare_batch(self, selected_index): + rows = [self.data[idx] for idx in selected_index] + + ctx_utts, ctx_lens = [], [] + out_utts, out_lens = [], [] + out_act = [] + out_bs, out_db = [] , [] + goals, goal_lens = [], [[] for _ in range(len(self.domains))] + keys = [] + + for row in rows: + in_row, out_row, goal_row = row.context, row.response, row.goal + + # source context + keys.append(row.key) + + # batch_ctx = [] + # for turn in in_row: + # batch_ctx.append(self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) + # ctx_utts.append(batch_ctx) + # ctx_lens.append(len(batch_ctx)) + + # for AE, input = output + batch_ctx = [] + batch_ctx = self.pad_to(self.max_utt_len, out_row.utt, do_pad=True) + # batch_ctx = [t for idx, t in enumerate(out_row.utt)] + ctx_utts.append(batch_ctx) + ctx_lens.append(len(batch_ctx)) + + # target response + out_utt = [t for idx, t in enumerate(out_row.utt)] + out_utts.append(out_utt) + out_lens.append(len(out_utt)) + + if not self.zero_pad: + out_bs.append(out_row.bs) + out_db.append(out_row.db) + else: + out_bs.append([0] * 94) + out_db.append([0] * 30) + out_act.append(out_row.act) + + # goal + goals.append(goal_row) + for i, d in enumerate(self.domains): + goal_lens[i].append(len(goal_row[d])) + + batch_size = len(ctx_lens) + vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns + max_ctx_len = np.max(vec_ctx_lens) + vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) + vec_out_bs = np.array(out_bs) # (batch_size, 94) + vec_out_db = np.array(out_db) # (batch_size, 30) + vec_out_act = np.array(out_act) # (batch_size, 11) + vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens + max_out_len = np.max(vec_out_lens) + vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32) + + max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens] + if max_goal_lens != min_goal_lens: + print('Fatal Error!') + exit(-1) + self.goal_lens = max_goal_lens + vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens] + + for b_id in range(batch_size): + vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] + vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] + for i, d in enumerate(self.domains): + vec_goals_list[i][b_id, :] = goals[b_id][d] + + return Pack(context_lens=vec_ctx_lens, # (batch_size, ) + contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len) + output_lens=vec_out_lens, # (batch_size, ) + outputs=vec_out_utts, # (batch_size, max_out_len) + bs=vec_out_bs, # (batch_size, 94) + db=vec_out_db, # (batch_size, 30) + act=vec_out_act, #(batch_size, 11) + goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain + keys=keys) + diff --git a/convlab/policy/lava/multiwoz/latent_dialog/dialog_task.py b/convlab/policy/lava/multiwoz/latent_dialog/dialog_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b73e86060b2573638529e6dc16866e50445eb2 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/dialog_task.py @@ -0,0 +1,129 @@ +from convlab.policy.lava.multiwoz.latent_dialog.metric import MetricsContainer +from convlab.policy.lava.multiwoz.latent_dialog.corpora import EOD, EOS +from convlab.policy.lava.multiwoz.latent_dialog import evaluators + + +class Dialog(object): + """Dialogue runner.""" + def __init__(self, agents, args): + assert len(agents) == 2 + self.agents = agents + self.system, self.user = agents + self.args = args + self.metrics = MetricsContainer() + self.dlg_evaluator = evaluators.MultiWozEvaluator('SYS_WOZ') + self._register_metrics() + + def _register_metrics(self): + """Registers valuable metrics.""" + self.metrics.register_average('dialog_len') + self.metrics.register_average('sent_len') + self.metrics.register_average('reward') + self.metrics.register_time('time') + + def _is_eod(self, out): + return len(out) == 2 and out[0] == EOD and out[1] == EOS + + def _eval _dialog(self, conv, g_key, goal): + generated_dialog = dict() + generated_dialog[g_key] = {'goal': goal, 'log': list()} + for t_id, (name, utt) in enumerate(conv): + # assert utt[-1] == EOS, utt + if t_id % 2 == 0: + assert name == 'Baozi' + utt = ' '.join(utt[:-1]) + if utt == EOD: + continue + generated_dialog[g_key]['log'].append({'text': utt}) + report, success_r, match_r = self.dlg_evaluator.evaluateModel(generated_dialog, mode='rollout') + return success_r + match_r + + def show_metrics(self): + return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()]) + + def run(self, g_key, goal): + """Runs one instance of the dialogue.""" + # initialize agents by feeding in the goal + # initialize BOD utterance for each agent + for agent in self.agents: + agent.feed_goal(goal) + agent.bod_init() + + # role assignment + reader, writer = self.system, self.user + begin_name = writer.name + print('begin_name = {}'.format(begin_name)) + + conv = [] + # reset metrics + self.metrics.reset() + nturn = 0 + while True: + nturn += 1 + # produce an utterance + out_words = writer.write() # out: list of word, str, len = max_words + print('\t{} out_words = {}'.format(writer.name, ' '.join(out_words))) + + self.metrics.record('sent_len', len(out_words)) + # self.metrics.record('%s_unique' % writer.name, out_words) + + # append the utterance to the conversation + conv.append((writer.name, out_words)) + # make the other agent to read it + reader.read(out_words) + # check if the end of the conversation was generated + if self._is_eod(out_words): + break + + if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn: + # return conv, 0 + break + + writer, reader = reader, writer + + # evaluate dialog and produce success + reward = self._eval_dialog(conv, g_key, goal) + print('Reward = {}'.format(reward)) + # perform update + self.system.update(reward) + self.metrics.record('time') + self.metrics.record('dialog_len', len(conv)) + self.metrics.record('reward', int(reward)) + + print('='*50) + print(self.show_metrics()) + print('='*50) + return conv, reward + + +class DialogEval(Dialog): + def run(self, g_key, goal): + """Runs one instance of the dialogue.""" + # initialize agents by feeding in the goal + # initialize BOD utterance for each agent + for agent in self.agents: + agent.feed_goal(goal) + agent.bod_init() + + # role assignment + reader, writer = self.system, self.user + conv = [] + nturn = 0 + while True: + nturn += 1 + # produce an utterance + out_words = writer.write() # out: list of word, str, len = max_words + conv.append((writer.name, out_words)) + # make the other agent to read it + reader.read(out_words) + # check if the end of the conversation was generated + if self._is_eod(out_words): + break + + writer, reader = reader, writer + if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn: + return conv, 0 + + # evaluate dialog and produce success + reward = self._eval_dialog(conv, g_key, goal) + return conv, reward diff --git a/convlab/policy/lava/multiwoz/latent_dialog/domain.py b/convlab/policy/lava/multiwoz/latent_dialog/domain.py new file mode 100644 index 0000000000000000000000000000000000000000..43d3ff99bcf99042bc65ec3f3fc0da96734b391b --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/domain.py @@ -0,0 +1,124 @@ +import re +import random +import json + + +def get_domain(name): + if name == 'object_division': + return ObjectDivisionDomain() + raise() + + +class ObjectDivisionDomain(object): + def __init__(self): + self.item_pattern = re.compile('^item([0-9])=([0-9\-])+$') + + def input_length(self): + return 3 + + def selection_length(self): + return 6 + + def generate_choices(self, inpt): + cnts, _ = self.parse_context(inpt) + + def gen(cnts, idx=0, choice=[]): + if idx >= len(cnts): + left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)] + right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))] + return [left_choice + right_choice] + choices = [] + for c in range(cnts[idx] + 1): + choice.append(c) + choices += gen(cnts, idx + 1, choice) + choice.pop() + return choices + choices = gen(cnts) + choices.append(['<no_agreement>'] * self.selection_length()) + choices.append(['<disconnect>'] * self.selection_length()) + return choices + + def parse_context(self, ctx): + cnts = [int(n) for n in ctx[0::2]] + vals = [int(v) for v in ctx[1::2]] + return cnts, vals + + def _to_int(self, x): + try: + return int(x) + except: + return 0 + + def score_choices(self, choices, ctxs): + assert len(choices) == len(ctxs) + # print('choices = {}'.format(choices)) + # print('ctxs = {}'.format(ctxs)) + cnts = [int(x) for x in ctxs[0][0::2]] + agree, scores = True, [0 for _ in range(len(ctxs))] + for i, n in enumerate(cnts): + for agent_id, (choice, ctx) in enumerate(zip(choices, ctxs)): + # taken = self._to_int(choice[i+3][-1]) + taken = self._to_int(choice[i][-1]) + n -= taken + scores[agent_id] += int(ctx[2 * i + 1]) * taken + agree = agree and (n == 0) + return agree, scores + + +class ContextGenerator(object): + """Dialogue context generator. Generates contexes from the file.""" + def __init__(self, context_file): + self.ctxs = [] + with open(context_file, 'r') as f: + ctx_pair = [] + for line in f: + ctx = line.strip().split() + ctx_pair.append(ctx) + if len(ctx_pair) == 2: + self.ctxs.append(ctx_pair) + ctx_pair = [] + + def sample(self): + return random.choice(self.ctxs) + + def iter(self, nepoch=1): + for e in range(nepoch): + random.shuffle(self.ctxs) + for ctx in self.ctxs: + yield ctx + + def total_size(self, nepoch): + return nepoch*len(self.ctxs) + + +class ContextGeneratorEval(object): + """Dialogue context generator. Generates contexes from the file.""" + def __init__(self, context_file): + self.ctxs = [] + with open(context_file, 'r') as f: + ctx_pair = [] + for line in f: + ctx = line.strip().split() + ctx_pair.append(ctx) + if len(ctx_pair) == 2: + self.ctxs.append(ctx_pair) + ctx_pair = [] + + +class TaskGoalGenerator(object): + def __init__(self, goal_file): + self.goals = [] + data = json.load(open(goal_file)) + for key, raw_dlg in data.items(): + self.goals.append((key, raw_dlg['goal'])) + + def sample(self): + return random.choice(self.goals) + + def iter(self, nepoch=1): + for e in range(nepoch): + random.shuffle(self.goals) + for goal in self.goals: + yield goal + + diff --git a/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/__init__.py b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6caafab6a74e286872551db1a2edcc848eb00fd0 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# Author: Tiancheng Zhao +# Date: 9/15/18 diff --git a/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/base_modules.py b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/base_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7d98601596f147a25aa47332483f5502eaf9d3 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/base_modules.py @@ -0,0 +1,68 @@ +import torch as th +import torch.nn as nn +import numpy as np +from torch.nn.modules.module import _addindent + +def summary(model, show_weights=True, show_parameters=True): + """ + Summarizes torch model by showing trainable parameters and weights. + """ + tmpstr = model.__class__.__name__ + ' (\n' + total_params = 0 + for key, module in model._modules.items(): + # if it contains layers let call it recursively to get params + # and weights + if type(module) in [ + th.nn.modules.container.Container, + th.nn.modules.container.Sequential + ]: + modstr = summary(module) + else: + modstr = module.__repr__() + modstr = _addindent(modstr, 2) + + params = sum([np.prod(p.size()) for p in module.parameters()]) + weights = tuple([tuple(p.size()) for p in module.parameters()]) + total_params += params + + tmpstr += ' (' + key + '): ' + modstr + if show_weights: + tmpstr += ', weights={}'.format(weights) + if show_parameters: + tmpstr += ', parameters={}'.format(params) + tmpstr += '\n' + + tmpstr = tmpstr + ') Total Parameters={}'.format(total_params) + return tmpstr + + +class BaseRNN(nn.Module): + KEY_ATTN_SCORE = 'attention_score' + KEY_SEQUENCE = 'sequence' + + def __init__(self, input_dropout_p, rnn_cell, + input_size, hidden_size, num_layers, + output_dropout_p, bidirectional): + super(BaseRNN, self).__init__() + self.input_dropout = nn.Dropout(p=input_dropout_p) + if rnn_cell.lower() == 'lstm': + self.rnn_cell = nn.LSTM + elif rnn_cell.lower() == 'gru': + self.rnn_cell = nn.GRU + else: + raise ValueError('Unsupported RNN Cell Type: {0}'.format(rnn_cell)) + self.rnn = self.rnn_cell(input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=output_dropout_p, + bidirectional=bidirectional) + + # TODO Trick for initializing LSTM gate parameters + if rnn_cell.lower() == 'lstm': + for names in self.rnn._all_weights: + for name in filter(lambda n: 'bias' in n, names): + bias = getattr(self.rnn, name) + n = bias.size(0) + start, end = n // 4, n // 2 + bias.data[start:end].fill_(1.) diff --git a/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/classifier.py b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..96d39e2273af130da10903f0cd4588f69146c5da --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/classifier.py @@ -0,0 +1,103 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.base_modules import BaseRNN + + +class EncoderGRUATTN(BaseRNN): + def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths): + super(EncoderGRUATTN, self).__init__(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional) + self.variable_lengths = variable_lengths + self.nhid_attn = hidden_size + self.output_size = hidden_size*2 if bidirectional else hidden_size + + # attention to combine selection hidden states + self.attn = nn.Sequential( + nn.Linear(2 * hidden_size, hidden_size), + nn.Tanh(), + nn.Linear(hidden_size, 1) + ) + + def forward(self, residual_var, input_var, turn_feat, mask=None, init_state=None, input_lengths=None): + # residual_var: (batch_size, max_dlg_len, 2*utt_cell_size) + # input_var: (batch_size, max_dlg_len, dlg_cell_size) + + # TODO switch of mask + # mask = None + + require_embed = True + if require_embed: + # input_cat = th.cat([input_var, residual_var], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size) + input_cat = th.cat([input_var, residual_var, turn_feat], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size) + else: + # input_cat = th.cat([input_var], 2) + input_cat = th.cat([input_var, turn_feat], 2) + if mask is not None: + input_mask = mask.view(input_cat.size(0), input_cat.size(1), 1) # (batch_size, max_dlg_len*max_utt_len, 1) + input_cat = th.mul(input_cat, input_mask) + embedded = self.input_dropout(input_cat) + + require_rnn = True + if require_rnn: + if init_state is not None: + h, _ = self.rnn(embedded, init_state) + else: + h, _ = self.rnn(embedded) # (batch_size, max_dlg_len, 2*nhid_attn) + + logit = self.attn(h.contiguous().view(-1, 2*self.nhid_attn)).view(h.size(0), h.size(1)) # (batch_size, max_dlg_len) + # if mask is not None: + # logit_mask = mask.view(input_cat.size(0), input_cat.size(1)) + # logit_mask = -999.0 * logit_mask + # logit = logit_mask + logit + + prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(h) # (batch_size, max_dlg_len, 2*nhid_attn) + attn = th.sum(th.mul(h, prob), 1) # (batch_size, 2*nhid_attn) + + return attn + + else: + logit = self.attn(embedded.contiguous().view(input_cat.size(0)*input_cat.size(1), -1)).view(input_cat.size(0), input_cat.size(1)) + if mask is not None: + logit_mask = mask.view(input_cat.size(0), input_cat.size(1)) + logit_mask = -999.0 * logit_mask + logit = logit_mask + logit + + prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(embedded) # (batch_size, max_dlg_len, 2*nhid_attn) + attn = th.sum(th.mul(embedded, prob), 1) # (batch_size, 2*nhid_attn) + + return attn + + +class FeatureProjecter(nn.Module): + def __init__(self, input_dropout_p, input_size, output_size): + super(FeatureProjecter, self).__init__() + self.input_dropout = nn.Dropout(p=input_dropout_p) + self.sel_encoder = nn.Sequential( + nn.Linear(input_size, output_size), + nn.Tanh() + ) + + def forward(self, goals_h, attn_outs): + h = th.cat([attn_outs, goals_h], 1) # (batch_size, 2*nhid_attn+goal_nhid) + h = self.input_dropout(h) + h = self.sel_encoder.forward(h) # (batch_size, nhid_sel) + return h + + +class SelectionClassifier(nn.Module): + def __init__(self, selection_length, input_size, output_size): + super(SelectionClassifier, self).__init__() + self.sel_decoders = nn.ModuleList() + for _ in range(selection_length): + self.sel_decoders.append(nn.Linear(input_size, output_size)) + + def forward(self, proj_outs): + outs = [decoder.forward(proj_outs).unsqueeze(1) for decoder in self.sel_decoders] # outcome_len*(batch_size, 1, outcome_vocab_size) + outs = th.cat(outs, 1) # (batch_size, outcome_len, outcome_vocab_size) + return outs diff --git a/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/decoders.py b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb1dac5d64619e20c1df95d67fe4c2d96a12f98 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/decoders.py @@ -0,0 +1,574 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +import numpy as np +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.base_modules import BaseRNN +from convlab.policy.lava.multiwoz.latent_dialog.utils import cast_type, LONG, FLOAT +from convlab.policy.lava.multiwoz.latent_dialog.corpora import DECODING_MASKED_TOKENS, EOS + + +TEACH_FORCE = 'teacher_forcing' +TEACH_GEN = 'teacher_gen' +GEN = 'gen' +GEN_VALID = 'gen_valid' + + +class Attention(nn.Module): + def __init__(self, dec_cell_size, ctx_cell_size, attn_mode, project): + super(Attention, self).__init__() + self.dec_cell_size = dec_cell_size + self.ctx_cell_size = ctx_cell_size + self.attn_mode = attn_mode + if project: + self.linear_out = nn.Linear(dec_cell_size+ctx_cell_size, dec_cell_size) + else: + self.linear_out = None + + if attn_mode == 'general': + self.dec_w = nn.Linear(dec_cell_size, ctx_cell_size) + elif attn_mode == 'cat': + self.dec_w = nn.Linear(dec_cell_size, dec_cell_size) + self.attn_w = nn.Linear(ctx_cell_size, dec_cell_size) + self.query_w = nn.Linear(dec_cell_size, 1) + + def forward(self, output, context): + # output: (batch_size, output_seq_len, dec_cell_size) + # context: (batch_size, max_ctx_len, ctx_cell_size) + batch_size = output.size(0) + max_ctx_len = context.size(1) + + if self.attn_mode == 'dot': + attn = th.bmm(output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len) + elif self.attn_mode == 'general': + mapped_output = self.dec_w(output) # (batch_size, output_seq_len, ctx_cell_size) + attn = th.bmm(mapped_output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len) + elif self.attn_mode == 'cat': + mapped_output = self.dec_w(output) # (batch_size, output_seq_len, dec_cell_size) + mapped_attn = self.attn_w(context) # (batch_size, max_ctx_len, dec_cell_size) + tiled_output = mapped_output.unsqueeze(2).repeat(1, 1, max_ctx_len, 1) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size) + tiled_attn = mapped_attn.unsqueeze(1) # (batch_size, 1, max_ctx_len, dec_cell_size) + fc1 = th.tanh(tiled_output+tiled_attn) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size) + attn = self.query_w(fc1).squeeze(-1) # (batch_size, otuput_seq_len, max_ctx_len) + else: + raise ValueError('Unknown attention mode') + + # TODO mask + # if self.mask is not None: + + attn = F.softmax(attn.view(-1, max_ctx_len), dim=1).view(batch_size, -1, max_ctx_len) # (batch_size, output_seq_len, max_ctx_len) + mix = th.bmm(attn, context) # (batch_size, output_seq_len, ctx_cell_size) + combined = th.cat((mix, output), dim=2) # (batch_size, output_seq_len, dec_cell_size+ctx_cell_size) + if self.linear_out is None: + return combined, attn + else: + output = th.tanh( + self.linear_out(combined.view(-1, self.dec_cell_size+self.ctx_cell_size))).view( + batch_size, -1, self.dec_cell_size) # (batch_size, output_seq_len, dec_cell_size) + return output, attn + + +class DecoderRNN(BaseRNN): + def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, + bidirectional, vocab_size, use_attn, ctx_cell_size, attn_mode, sys_id, eos_id, use_gpu, + max_dec_len, embedding=None): + + super(DecoderRNN, self).__init__(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional) + + # TODO embedding is None or not + if embedding is None: + self.embedding = nn.Embedding(vocab_size, input_size) + else: + self.embedding = embedding + + # share parameters between encoder and decoder + # self.rnn = ctx_encoder.rnn + # self.FC = nn.Linear(input_size, utt_encoder.output_size) + + self.use_attn = use_attn + if self.use_attn: + self.attention = Attention(dec_cell_size=hidden_size, + ctx_cell_size=ctx_cell_size, + attn_mode=attn_mode, + project=True) + + self.dec_cell_size = hidden_size + self.output_size = vocab_size + self.project = nn.Linear(self.dec_cell_size, self.output_size) + self.log_softmax = F.log_softmax + + self.sys_id = sys_id + self.eos_id = eos_id + self.use_gpu = use_gpu + self.max_dec_len = max_dec_len + + def forward(self, batch_size, dec_inputs, dec_init_state, attn_context, mode, gen_type, beam_size, goal_hid=None): + # dec_inputs: (batch_size, response_size-1) + # attn_context: (batch_size, max_ctx_len, ctx_cell_size) + # goal_hid: (batch_size, goal_nhid) + + ret_dict = dict() + + if self.use_attn: + ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() + + if mode == GEN: + dec_inputs = None + + if gen_type != 'beam': + beam_size = 1 + + if dec_inputs is not None: + decoder_input = dec_inputs + else: + # prepare the BOS inputs + with th.no_grad(): + bos_var = Variable(th.LongTensor([self.sys_id])) + bos_var = cast_type(bos_var, LONG, self.use_gpu) + decoder_input = bos_var.expand(batch_size*beam_size, 1) # (batch_size, 1) + + if mode == GEN and gen_type == 'beam': + # TODO if beam search, repeat the initial states of the RNN + pass + else: + decoder_hidden_state = dec_init_state + + prob_outputs = [] # list of logprob | max_dec_len*(batch_size, 1, vocab_size) + symbol_outputs = [] # list of word ids | max_dec_len*(batch_size, 1) + # back_pointers = [] + # lengths = blabla... + + def decode(step, cum_sum, step_output, step_attn): + prob_outputs.append(step_output) + step_output_slice = step_output.squeeze(1) # (batch_size, vocab_size) + if self.use_attn: + ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) + + if gen_type == 'greedy': + _, symbols = step_output_slice.topk(1) # (batch_size, 1) + elif gen_type == 'sample': + # TODO FIXME + # symbols = self.gumbel_max(step_output_slice) + pass + elif gen_type == 'beam': + # TODO + pass + else: + raise ValueError('Unsupported decoding mode') + + symbol_outputs.append(symbols) + + return cum_sum, symbols + + if mode == TEACH_FORCE: + prob_outputs, decoder_hidden_state, attn = self.forward_step(input_var=decoder_input, hidden_state=decoder_hidden_state, encoder_outputs=attn_context, goal_hid=goal_hid) + else: + # do free running here + cum_sum = None + for step in range(self.max_dec_len): + # Input: + # decoder_input: (batch_size, 1) + # decoder_hidden_state: tuple: (h, c) + # attn_context: (batch_size, max_ctx_len, ctx_cell_size) + # goal_hid: (batch_size, goal_nhid) + # Output: + # decoder_output: (batch_size, 1, vocab_size) + # decoder_hidden_state: tuple: (h, c) + # step_attn: (batch_size, 1, max_ctx_len) + decoder_output, decoder_hidden_state, step_attn = self.forward_step(decoder_input, decoder_hidden_state, attn_context, goal_hid=goal_hid) + cum_sum, symbols = decode(step, cum_sum, decoder_output, step_attn) + decoder_input = symbols + + prob_outputs = th.cat(prob_outputs, dim=1) # (batch_size, max_dec_len, vocab_size) + + # back tracking to recover the 1-best in beam search + # if gen_type == 'beam': + + ret_dict[DecoderRNN.KEY_SEQUENCE] = symbol_outputs + + # prob_outputs: (batch_size, max_dec_len, vocab_size) + # decoder_hidden_state: tuple: (h, c) + # ret_dict[DecoderRNN.KEY_ATTN_SCORE]: max_dec_len*(batch_size, 1, max_ctx_len) + # ret_dict[DecoderRNN.KEY_SEQUENCE]: max_dec_len*(batch_size, 1) + return prob_outputs, decoder_hidden_state, ret_dict + + def forward_step(self, input_var, hidden_state, encoder_outputs, goal_hid): + # input_var: (batch_size, response_size-1 i.e. output_seq_len) + # hidden_state: tuple: (h, c) + # encoder_outputs: (batch_size, max_ctx_len, ctx_cell_size) + # goal_hid: (batch_size, goal_nhid) + batch_size, output_seq_len = input_var.size() + embedded = self.embedding(input_var) # (batch_size, output_seq_len, embedding_dim) + + # add goals + if goal_hid is not None: + goal_hid = goal_hid.view(goal_hid.size(0), 1, goal_hid.size(1)) # (batch_size, 1, goal_nhid) + goal_rep = goal_hid.repeat(1, output_seq_len, 1) # (batch_size, output_seq_len, goal_nhid) + embedded = th.cat([embedded, goal_rep], dim=2) # (batch_size, output_seq_len, embedding_dim+goal_nhid) + + embedded = self.input_dropout(embedded) + + # ############ + # embedded = self.FC(embedded.view(-1, embedded.size(-1))).view(batch_size, output_seq_len, -1) + + # output: (batch_size, output_seq_len, dec_cell_size) + # hidden: tuple: (h, c) + output, hidden_s = self.rnn(embedded, hidden_state) + + attn = None + if self.use_attn: + # output: (batch_size, output_seq_len, dec_cell_size) + # encoder_outputs: (batch_size, max_ctx_len, ctx_cell_size) + # attn: (batch_size, output_seq_len, max_ctx_len) + output, attn = self.attention(output, encoder_outputs) + + logits = self.project(output.contiguous().view(-1, self.dec_cell_size)) # (batch_size*output_seq_len, vocab_size) + prediction = self.log_softmax(logits, dim=logits.dim()-1).view(batch_size, output_seq_len, -1) # (batch_size, output_seq_len, vocab_size) + return prediction, hidden_s, attn + + # special for rl + def _step(self, input_var, hidden_state, encoder_outputs, goal_hid): + # input_var: (1, 1) + # hidden_state: tuple: (h, c) + # encoder_outputs: (1, max_dlg_len, dlg_cell_size) + # goal_hid: (1, goal_nhid) + batch_size, output_seq_len = input_var.size() + embedded = self.embedding(input_var) # (1, 1, embedding_dim) + + if goal_hid is not None: + goal_hid = goal_hid.view(goal_hid.size(0), 1, goal_hid.size(1)) # (1, 1, goal_nhid) + goal_rep = goal_hid.repeat(1, output_seq_len, 1) # (1, 1, goal_nhid) + embedded = th.cat([embedded, goal_rep], dim=2) # (1, 1, embedding_dim+goal_nhid) + + embedded = self.input_dropout(embedded) + + # ############ + # embedded = self.FC(embedded.view(-1, embedded.size(-1))).view(batch_size, output_seq_len, -1) + + # output: (1, 1, dec_cell_size) + # hidden: tuple: (h, c) + output, hidden_s = self.rnn(embedded, hidden_state) + + attn = None + if self.use_attn: + # output: (1, 1, dec_cell_size) + # encoder_outputs: (1, max_dlg_len, dlg_cell_size) + # attn: (1, 1, max_dlg_len) + output, attn = self.attention(output, encoder_outputs) + + logits = self.project(output.view(-1, self.dec_cell_size)) # (1*1, vocab_size) + prediction = logits.view(batch_size, output_seq_len, -1) # (1, 1, vocab_size) + # prediction = self.log_softmax(logits, dim=logits.dim()-1).view(batch_size, output_seq_len, -1) # (batch_size, output_seq_len, vocab_size) + return prediction, hidden_s + + # special for rl + def write(self, input_var, hidden_state, encoder_outputs, max_words, vocab, stop_tokens, goal_hid=None, mask=True, + decoding_masked_tokens=DECODING_MASKED_TOKENS): + # input_var: (1, 1) + # hidden_state: tuple: (h, c) + # encoder_outputs: max_dlg_len*(1, 1, dlg_cell_size) + # goal_hid: (1, goal_nhid) + logprob_outputs = [] # list of logprob | max_dec_len*(1, ) + symbol_outputs = [] # list of word ids | max_dec_len*(1, ) + decoder_input = input_var + decoder_hidden_state = hidden_state + if type(encoder_outputs) is list: + encoder_outputs = th.cat(encoder_outputs, 1) # (1, max_dlg_len, dlg_cell_size) + # print('encoder_outputs.size() = {}'.format(encoder_outputs.size())) + + if mask: + special_token_mask = Variable(th.FloatTensor([-999. if token in decoding_masked_tokens else 0. for token in vocab])) + special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, ) + + def _sample(dec_output, num_i): + # dec_output: (1, 1, vocab_size), need to softmax and log_softmax + dec_output = dec_output.view(-1) # (vocab_size, ) + # TODO temperature + prob = F.softmax(dec_output/0.6, dim=0) # (vocab_size, ) + logprob = F.log_softmax(dec_output, dim=0) # (vocab_size, ) + symbol = prob.multinomial(num_samples=1).detach() # (1, ) + # _, symbol = prob.topk(1) # (1, ) + _, tmp_symbol = prob.topk(1) # (1, ) + # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()])) + # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()])) + logprob = logprob.gather(0, symbol) # (1, ) + return logprob, symbol + + for i in range(max_words): + decoder_output, decoder_hidden_state = self._step(decoder_input, decoder_hidden_state, encoder_outputs, goal_hid) + # disable special tokens from being generated in a normal turn + if mask: + decoder_output += special_token_mask.expand(1, 1, -1) + logprob, symbol = _sample(decoder_output, i) + logprob_outputs.append(logprob) + symbol_outputs.append(symbol) + decoder_input = symbol.view(1, -1) + + if vocab[symbol.item()] in stop_tokens: + break + + assert len(logprob_outputs) == len(symbol_outputs) + # logprob_list = [t.item() for t in logprob_outputs] + logprob_list = logprob_outputs + symbol_list = [t.item() for t in symbol_outputs] + return logprob_list, symbol_list + + # For MultiWoz RL + def forward_rl(self, batch_size, dec_init_state, attn_context, vocab, max_words, goal_hid=None, mask=True, temp=0.1): + # prepare the BOS inputs + with th.no_grad(): + bos_var = Variable(th.LongTensor([self.sys_id])) + bos_var = cast_type(bos_var, LONG, self.use_gpu) + decoder_input = bos_var.expand(batch_size, 1) # (1, 1) + decoder_hidden_state = dec_init_state # tuple: (h, c) + encoder_outputs = attn_context # (1, ctx_len, ctx_cell_size) + + logprob_outputs = [] # list of logprob | max_dec_len*(1, ) + symbol_outputs = [] # list of word ids | max_dec_len*(1, ) + + if mask: + special_token_mask = Variable(th.FloatTensor([-999. if token in DECODING_MASKED_TOKENS else 0. for token in vocab])) + special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, ) + + def _sample(dec_output, num_i): + # dec_output: (1, 1, vocab_size), need to softmax and log_softmax + dec_output = dec_output.view(batch_size, -1) # (batch_size, vocab_size, ) + prob = F.softmax(dec_output/temp, dim=1) # (batch_size, vocab_size, ) + logprob = F.log_softmax(dec_output, dim=1) # (batch_size, vocab_size, ) + symbol = prob.multinomial(num_samples=1).detach() # (batch_size, 1) + # _, symbol = prob.topk(1) # (1, ) + _, tmp_symbol = prob.topk(1) # (1, ) + # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()])) + # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()])) + logprob = logprob.gather(1, symbol) # (1, ) + return logprob, symbol + + stopped_samples = set() + for i in range(max_words): + decoder_output, decoder_hidden_state = self._step(decoder_input, decoder_hidden_state, encoder_outputs, goal_hid) + # disable special tokens from being generated in a normal turn + if mask: + decoder_output += special_token_mask.expand(1, 1, -1) + logprob, symbol = _sample(decoder_output, i) + logprob_outputs.append(logprob) + symbol_outputs.append(symbol) + decoder_input = symbol.view(batch_size, -1) + for b_id in range(batch_size): + if vocab[symbol[b_id].item()] == EOS: + stopped_samples.add(b_id) + + if len(stopped_samples) == batch_size: + break + + assert len(logprob_outputs) == len(symbol_outputs) + symbol_outputs = th.cat(symbol_outputs, dim=1).cpu().data.numpy().tolist() + logprob_outputs = th.cat(logprob_outputs, dim=1) + logprob_list = [] + symbol_list = [] + for b_id in range(batch_size): + b_logprob = [] + b_symbol = [] + for t_id in range(logprob_outputs.shape[1]): + symbol = symbol_outputs[b_id][t_id] + if vocab[symbol] == EOS and t_id != 0: + break + + b_symbol.append(symbol_outputs[b_id][t_id]) + b_logprob.append(logprob_outputs[b_id][t_id]) + + logprob_list.append(b_logprob) + symbol_list.append(b_symbol) + + # TODO backward compatible, if batch_size == 1, we remove the nested structure + if batch_size == 1: + logprob_list = logprob_list[0] + symbol_list = symbol_list[0] + + return logprob_list, symbol_list + +class DecoderPointerGen(BaseRNN): + + def __init__(self, vocab_size, max_len, input_size, hidden_size, sos_id, + eos_id, n_layers=1, rnn_cell='lstm', input_dropout_p=0, + dropout_p=0, attn_mode='cat', attn_size=None, use_gpu=True, + embedding=None): + + super(DecoderPointerGen, self).__init__(vocab_size, input_size, + hidden_size, input_dropout_p, + dropout_p, n_layers, rnn_cell, False) + + self.output_size = vocab_size + self.max_length = max_len + self.eos_id = eos_id + self.sos_id = sos_id + self.use_gpu = use_gpu + self.attn_size = attn_size + + if embedding is None: + self.embedding = nn.Embedding(self.output_size, self.input_size) + else: + self.embedding = embedding + + self.attention = Attention(self.hidden_size, attn_size, attn_mode, + project=True) + + self.project = nn.Linear(self.hidden_size, self.output_size) + self.sentinel = nn.Parameter(torch.randn((1, 1, attn_size)), requires_grad=True) + self.register_parameter('sentinel', self.sentinel) + + def forward_step(self, input_var, hidden, attn_ctxs, attn_words, ctx_embed=None): + """ + attn_size: number of context to attend + :param input_var: + :param hidden: + :param attn_ctxs: batch_size x attn_size+1 x ctx_size. If None, then leave it empty + :param attn_words: batch_size x attn_size + :return: + """ + # we enable empty attention context + batch_size = input_var.size(0) + seq_len = input_var.size(1) + embedded = self.embedding(input_var) + if ctx_embed is not None: + embedded += ctx_embed + + embedded = self.input_dropout(embedded) + output, hidden = self.rnn(embedded, hidden) + + if attn_ctxs is None: + # pointer network here + logits = self.project(output.contiguous().view(-1, self.hidden_size)) + predicted_softmax = F.log_softmax(logits, dim=1) + return predicted_softmax, None, hidden, None, None + else: + attn_size = attn_words.size(1) + combined_output, attn = self.attention(output, attn_ctxs) + + # output: batch_size x seq_len x hidden_size + # attn: batch_size x seq_len x (attn_size+1) + + # pointer network here + rnn_softmax = F.softmax(self.project(output.view(-1, self.hidden_size)), dim=1) + g = attn[:, :, 0].contiguous() + ptr_attn = attn[:, :, 1:].contiguous() + ptr_softmax = Variable(torch.zeros((batch_size * seq_len * attn_size, self.vocab_size))) + ptr_softmax = cast_type(ptr_softmax, FLOAT, self.use_gpu) + + # convert words and ids into 1D + flat_attn_words = attn_words.unsqueeze(1).repeat(1, seq_len, 1).view(-1, 1) + flat_attn = ptr_attn.view(-1, 1) + + # fill in the attention into ptr_softmax + ptr_softmax = ptr_softmax.scatter_(1, flat_attn_words, flat_attn) + ptr_softmax = ptr_softmax.view(batch_size * seq_len, attn_size, self.vocab_size) + ptr_softmax = torch.sum(ptr_softmax, dim=1) + + # mix the softmax from rnn and pointer + mixture_softmax = rnn_softmax * g.view(-1, 1) + ptr_softmax + + # take the log to get logsoftmax + logits = torch.log(mixture_softmax.clamp(min=1e-8)) + predicted_softmax = logits.view(batch_size, seq_len, -1) + ptr_softmax = ptr_softmax.view(batch_size, seq_len, -1) + + return predicted_softmax, ptr_softmax, hidden, ptr_attn, g + + def forward(self, batch_size, attn_context, attn_words, + inputs=None, init_state=None, mode=TEACH_FORCE, + gen_type='greedy', ctx_embed=None): + + # sanity checks + ret_dict = dict() + + if mode == GEN: + inputs = None + + if inputs is not None: + decoder_input = inputs + else: + # prepare the BOS inputs + bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True) + bos_var = cast_type(bos_var, LONG, self.use_gpu) + decoder_input = bos_var.expand(batch_size, 1) + + # append sentinel to the attention + if attn_context is not None: + attn_context = torch.cat([self.sentinel.expand(batch_size, 1, self.attn_size), + attn_context], dim=1) + + decoder_hidden = init_state + decoder_outputs = [] # a list of logprob + sequence_symbols = [] # a list word ids + attentions = [] + pointer_gs = [] + pointer_outputs = [] + lengths = np.array([self.max_length] * batch_size) + + def decode(step, step_output): + decoder_outputs.append(step_output) + step_output_slice = step_output.squeeze(1) + + if gen_type == 'greedy': + symbols = step_output_slice.topk(1)[1] + elif gen_type == 'sample': + symbols = self.gumbel_max(step_output_slice) + else: + raise ValueError("Unsupported decoding mode") + + sequence_symbols.append(symbols) + + eos_batches = symbols.data.eq(self.eos_id) + if eos_batches.dim() > 0: + eos_batches = eos_batches.cpu().view(-1).numpy() + update_idx = ((lengths > di) & eos_batches) != 0 + lengths[update_idx] = len(sequence_symbols) + return symbols + + # Manual unrolling is used to support random teacher forcing. + # If teacher_forcing_ratio is True or False instead of a probability, + # the unrolling can be done in graph + if mode == TEACH_FORCE: + pred_softmax, ptr_softmax, decoder_hidden, attn, step_g = self.forward_step( + decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed) + + # in teach forcing mode, we don't need symbols. + attentions = attn + decoder_outputs = pred_softmax + pointer_gs = step_g + pointer_outputs = ptr_softmax + + else: + # do free running here + for di in range(self.max_length): + pred_softmax, ptr_softmax, decoder_hidden, step_attn, step_g = self.forward_step( + decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed) + + symbols = decode(di, pred_softmax) + + # append the results into ctx dictionary + attentions.append(step_attn) + pointer_gs.append(step_g) + pointer_outputs.append(ptr_softmax) + decoder_input = symbols + + # make list be a tensor + decoder_outputs = torch.cat(decoder_outputs, dim=1) + pointer_outputs = torch.cat(pointer_outputs, dim=1) + pointer_gs = torch.cat(pointer_gs, dim=1) + + # save the decoded sequence symbols and sequence length + ret_dict[self.KEY_ATTN_SCORE] = attentions + ret_dict[self.KEY_SEQUENCE] = sequence_symbols + ret_dict[self.KEY_LENGTH] = lengths + ret_dict[self.KEY_G] = pointer_gs + ret_dict[self.KEY_PTR_SOFTMAX] = pointer_outputs + ret_dict[self.KEY_PTR_CTX] = attn_words + + return decoder_outputs, decoder_hidden, ret_dict diff --git a/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/encoders.py b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..33c754e55c4e60b668357c0f4790b4595d3b4d30 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/enc2dec/encoders.py @@ -0,0 +1,215 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.base_modules import BaseRNN + + +class EncoderRNN(BaseRNN): + def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths): + super(EncoderRNN, self).__init__(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional) + self.variable_lengths = variable_lengths + self.output_size = hidden_size*2 if bidirectional else hidden_size + + def forward(self, input_var, init_state=None, input_lengths=None, goals=None): + # add goals + if goals is not None: + batch_size, max_ctx_len, ctx_nhid = input_var.size() + goals = goals.view(goals.size(0), 1, goals.size(1)) + goals_rep = goals.repeat(1, max_ctx_len, 1).view(batch_size, max_ctx_len, -1) # (batch_size, max_ctx_len, goal_nhid) + input_var = th.cat([input_var, goals_rep], dim=2) + + embedded = self.input_dropout(input_var) + + if self.variable_lengths: + embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, + batch_first=True) + if init_state is not None: + output, hidden = self.rnn(embedded, init_state) + else: + output, hidden = self.rnn(embedded) + if self.variable_lengths: + output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) + + return output, hidden + + +class RnnUttEncoder(nn.Module): + def __init__(self, vocab_size, embedding_dim, feat_size, goal_nhid, rnn_cell, + utt_cell_size, num_layers, input_dropout_p, output_dropout_p, + bidirectional, variable_lengths, use_attn, embedding=None): + super(RnnUttEncoder, self).__init__() + if embedding is None: + self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) + else: + self.embedding = embedding + + self.rnn = EncoderRNN(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=embedding_dim+feat_size+goal_nhid, + hidden_size=utt_cell_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional, + variable_lengths=variable_lengths) + + self.utt_cell_size = utt_cell_size + self.multiplier = 2 if bidirectional else 1 + self.output_size = self.multiplier * self.utt_cell_size + self.use_attn = use_attn + if self.use_attn: + self.key_w = nn.Linear(self.output_size, self.utt_cell_size) + self.query = nn.Linear(self.utt_cell_size, 1) + + def forward(self, utterances, feats=None, init_state=None, goals=None): + batch_size, max_ctx_len, max_utt_len = utterances.size() + # get word embeddings + flat_words = utterances.view(-1, max_utt_len) # (batch_size*max_ctx_len, max_utt_len) + word_embeddings = self.embedding(flat_words) # (batch_size*max_ctx_len, max_utt_len, embedding_dim) + flat_mask = th.sign(flat_words).float() + # add features + if feats is not None: + flat_feats = feats.view(-1, 1) # (batch_size*max_ctx_len, 1) + flat_feats = flat_feats.unsqueeze(1).repeat(1, max_utt_len, 1) # (batch_size*max_ctx_len, max_utt_len, 1) + word_embeddings = th.cat([word_embeddings, flat_feats], dim=2) # (batch_size*max_ctx_len, max_utt_len, embedding_dim+1) + + # add goals + if goals is not None: + goals = goals.view(goals.size(0), 1, 1, goals.size(1)) + goals_rep = goals.repeat(1, max_ctx_len, max_utt_len, 1).view(batch_size*max_ctx_len, max_utt_len, -1) # (batch_size*max_ctx_len, max_utt_len, goal_nhid) + word_embeddings = th.cat([word_embeddings, goals_rep], dim=2) + + # enc_outs: (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size) + # enc_last: (num_layers*num_directions, batch_size*max_ctx_len, utt_cell_size) + enc_outs, enc_last = self.rnn(word_embeddings, init_state=init_state) + + if self.use_attn: + fc1 = th.tanh(self.key_w(enc_outs)) # (batch_size*max_ctx_len, max_utt_len, utt_cell_size) + attn = self.query(fc1).squeeze(2) + # (batch_size*max_ctx_len, max_utt_len) + attn = F.softmax(attn, attn.dim()-1) # (batch_size*max_ctx_len, max_utt_len, 1) + attn = attn * flat_mask + attn = (attn / (th.sum(attn, dim=1, keepdim=True)+1e-10)).unsqueeze(2) + utt_embedded = attn * enc_outs # (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size) + utt_embedded = th.sum(utt_embedded, dim=1) # (batch_size*max_ctx_len, num_directions*utt_cell_size) + else: + # FIXME bug for multi-layer + attn = None + utt_embedded = enc_last.transpose(0, 1).contiguous() # (batch_size*max_ctx_lens, num_layers*num_directions, utt_cell_size) + utt_embedded = utt_embedded.view(-1, self.output_size) # (batch_size*max_ctx_len*num_layers, num_directions*utt_cell_size) + + utt_embedded = utt_embedded.view(batch_size, max_ctx_len, self.output_size) + return utt_embedded, word_embeddings.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1), \ + enc_outs.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1) + + +class MlpGoalEncoder(nn.Module): + def __init__(self, goal_vocab_size, k, nembed, nhid, init_range): + super(MlpGoalEncoder, self).__init__() + + # create separate embedding for counts and values + self.cnt_enc = nn.Embedding(goal_vocab_size, nembed) + self.val_enc = nn.Embedding(goal_vocab_size, nembed) + + self.encoder = nn.Sequential( + nn.Tanh(), + nn.Linear(k*nembed, nhid) + ) + + self.cnt_enc.weight.data.uniform_(-init_range, init_range) + self.val_enc.weight.data.uniform_(-init_range, init_range) + self._init_cont(self.encoder, init_range) + + def _init_cont(self, cont, init_range): + """initializes a container uniformly.""" + for m in cont: + if hasattr(m, 'weight'): + m.weight.data.uniform_(-init_range, init_range) + if hasattr(m, 'bias'): + m.bias.data.fill_(0) + + def forward(self, goal): + # goal: (batch_size, goal_len) + goal = goal.transpose(0, 1).contiguous() # (goal_len, batch_size) + idx = np.arange(goal.size(0) // 2) + + # extract counts and values + cnt_idx = Variable(th.from_numpy(2 * idx + 0)) + val_idx = Variable(th.from_numpy(2 * idx + 1)) + + if goal.is_cuda: + cnt_idx = cnt_idx.type(th.cuda.LongTensor) + val_idx = val_idx.type(th.cuda.LongTensor) + else: + cnt_idx = cnt_idx.type(th.LongTensor) + val_idx = val_idx.type(th.LongTensor) + + cnt = goal.index_select(0, cnt_idx) # (3, batch_size) + val = goal.index_select(0, val_idx) # (3, batch_size) + + # embed counts and values + cnt_emb = self.cnt_enc(cnt) # (3, batch_size, nembed) + val_emb = self.val_enc(val) # (3, batch_size, nembed) + + # element wise multiplication to get a hidden state + h = th.mul(cnt_emb, val_emb) # (3, batch_size, nembed) + # run the hidden state through the MLP + h = h.transpose(0, 1).contiguous().view(goal.size(1), -1) # (batch_size, 3*nembed) + goal_h = self.encoder(h) # (batch_size, nhid) + + return goal_h + + +class TaskMlpGoalEncoder(nn.Module): + def __init__(self, goal_vocab_sizes, nhid, init_range): + super(TaskMlpGoalEncoder, self).__init__() + + self.encoder = nn.ModuleList() + for v_size in goal_vocab_sizes: + domain_encoder = nn.Sequential( + nn.Linear(v_size, nhid), + nn.Tanh() + ) + self._init_cont(domain_encoder, init_range) + self.encoder.append(domain_encoder) + + def _init_cont(self, cont, init_range): + """initializes a container uniformly.""" + for m in cont: + if hasattr(m, 'weight'): + m.weight.data.uniform_(-init_range, init_range) + if hasattr(m, 'bias'): + m.bias.data.fill_(0) + + def forward(self, goals_list): + # goals_list: list of tensor, 7*(batch_size, goal_len), goal_len varies among differnet domains + outs = [encoder.forward(goal) for goal, encoder in zip(goals_list, self.encoder)] # 7*(batch_size, goal_nhid) + outs = th.sum(th.stack(outs), dim=0) # (batch_size, goal_nhid) + return outs + + +class SelfAttn(nn.Module): + def __init__(self, hidden_size): + super(SelfAttn, self).__init__() + self.query = nn.Linear(hidden_size, 1) + + def forward(self, keys, values, attn_mask=None): + """ + :param attn_inputs: batch_size x time_len x hidden_size + :param attn_mask: batch_size x time_len + :return: summary state + """ + alpha = F.softmax(self.query(keys), dim=1) + if attn_mask is not None: + alpha = alpha * attn_mask.unsqueeze(2) + alpha = alpha / th.sum(alpha, dim=1, keepdim=True) + + summary = th.sum(values * alpha, dim=1) + return summary diff --git a/convlab/policy/lava/multiwoz/latent_dialog/evaluators.py b/convlab/policy/lava/multiwoz/latent_dialog/evaluators.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b03076fa90648a42a5cd669ef6bfe073272353 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/evaluators.py @@ -0,0 +1,1325 @@ +from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction +import math +import numpy as np +import convlab.policy.lava.multiwoz.latent_dialog.normalizer.delexicalize as delex +from convlab.policy.lava.multiwoz.latent_dialog.utils import get_tokenize, get_detokenize +from collections import Counter, defaultdict +from nltk.util import ngrams +from convlab.policy.lava.multiwoz.latent_dialog.corpora import SYS, USR, BOS, EOS +from sklearn.feature_extraction.text import CountVectorizer +import json +from convlab.policy.lava.multiwoz.latent_dialog.normalizer.delexicalize import normalize +import sqlite3 +import os +import random +import logging +import pdb +from sklearn.multiclass import OneVsRestClassifier +from sklearn.linear_model import SGDClassifier +from sklearn import metrics +from nltk.translate import bleu_score +from nltk.translate.bleu_score import SmoothingFunction +from scipy.stats import gmean + + + +class BaseEvaluator(object): + def initialize(self): + raise NotImplementedError + + def add_example(self, ref, hyp): + raise NotImplementedError + + def get_report(self, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def _get_prec_recall(tp, fp, fn): + precision = tp / (tp + fp + 10e-20) + recall = tp / (tp + fn + 10e-20) + f1 = 2 * precision * recall / (precision + recall + 1e-20) + return precision, recall, f1 + + @staticmethod + def _get_tp_fp_fn(label_list, pred_list): + tp = len([t for t in pred_list if t in label_list]) + fp = max(0, len(pred_list) - tp) + fn = max(0, len(label_list) - tp) + return tp, fp, fn + + +class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def score(self, hypothesis, corpus, n=1): + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in zip(hypothesis, corpus): + # if type(hyps[0]) is list: + # hyps = [hyp.split() for hyp in hyps[0]] + # else: + # hyps = [hyp.split() for hyp in hyps] + + # refs = [ref.split() for ref in refs] + hyps = [hyps] + # Shawn's evaluation + # refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + # hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + + for idx, hyp in enumerate(hyps): + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + if n == 1: + break + # computing bleu score + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu + + +class BleuEvaluator(BaseEvaluator): + def __init__(self, data_name): + self.data_name = data_name + self.labels = list() + self.hyps = list() + + def initialize(self): + self.labels = list() + self.hyps = list() + + def add_example(self, ref, hyp): + self.labels.append(ref) + self.hyps.append(hyp) + + def get_report(self): + tokenize = get_tokenize() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + for label, hyp in zip(self.labels, self.hyps): + # label = label.replace(EOS, '') + # hyp = hyp.replace(EOS, '') + # ref_tokens = tokenize(label)[1:] + # hyp_tokens = tokenize(hyp)[1:] + ref_tokens = tokenize(label) + hyp_tokens = tokenize(hyp) + refs.append([ref_tokens]) + hyps.append(hyp_tokens) + bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + report = '\n===== BLEU = %f =====\n' % (bleu,) + return '\n===== REPORT FOR DATASET {} ====={}'.format(self.data_name, report) + + +class MultiWozDB(object): + # loading databases + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'] # , 'police'] + dbs = {} + CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '') + + for domain in domains: + db = os.path.join(CUR_DIR, 'data/norm-multi-woz/db/{}-dbase.db'.format(domain)) + conn = sqlite3.connect(db) + c = conn.cursor() + dbs[domain] = c + + def queryResultVenues(self, domain, turn, real_belief=False): + # query the db + sql_query = "select * from {}".format(domain) + + if real_belief == True: + items = turn.items() + else: + items = turn['metadata'][domain]['semi'].items() + + flag = True + for key, val in items: + if val == "" or val == "dontcare" or val == 'not mentioned' or val == "don't care" or val == "dont care" or val == "do n't care": + pass + else: + if flag: + sql_query += " where " + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" " + key + "=" + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" and " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" and " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" and " + key + "=" + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + return self.dbs[domain].execute(sql_query).fetchall() + except: + return [] # TODO test it + + +class MultiWozEvaluator(BaseEvaluator): + CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '') + logger = logging.getLogger() + def __init__(self, data_name): + self.data_name = data_name + self.slot_dict = delex.prepareSlotValuesIndependent() + self.delex_dialogues = json.load(open(os.path.join(self.CUR_DIR, 'data/norm-multi-woz/delex.json'))) + self.db = MultiWozDB() + self.labels = list() + self.hyps = list() + + def initialize(self): + self.labels = list() + self.hyps = list() + + def add_example(self, ref, hyp): + self.labels.append(ref) + self.hyps.append(hyp) + + def _parseGoal(self, goal, d, domain): + """Parses user goal into dictionary format.""" + goal[domain] = {} + goal[domain] = {'informable': [], 'requestable': [], 'booking': []} + if 'info' in d['goal'][domain]: + # if d['goal'][domain].has_key('info'): + if domain == 'train': + # we consider dialogues only where train had to be booked! + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]['requestable'].append('reference') + if 'reqt' in d['goal'][domain]: + # if d['goal'][domain].has_key('reqt'): + if 'trainID' in d['goal'][domain]['reqt']: + goal[domain]['requestable'].append('id') + else: + if 'reqt' in d['goal'][domain]: + # if d['goal'][domain].has_key('reqt'): + for s in d['goal'][domain]['reqt']: # addtional requests: + if s in ['phone', 'address', 'postcode', 'reference', 'id']: + # ones that can be easily delexicalized + goal[domain]['requestable'].append(s) + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]['requestable'].append("reference") + + goal[domain]["informable"] = d['goal'][domain]['info'] + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]["booking"] = d['goal'][domain]['book'] + + return goal + + def _evaluateGeneratedDialogue(self, dialog, goal, realDialogue, real_requestables, soft_acc=False): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + domains_in_goal = [] + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + domains_in_goal.append(domain) + + for t, sent_t in enumerate(dialog): + for domain in goal.keys(): + # for computing success + if '[' + domain + '_name]' in sent_t or '_id' in sent_t: # undo delexicalization if system generates [domain_name] or [domain_id] + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + # in this case, look for the actual offered venues based on true belief state + venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1]) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'train_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if '[' + domain + '_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + # assumption doesn't always hold, maybe it's better if name is provided by user that it is ignored? + if 'info' in realDialogue['goal'][domain]: + if 'name' in realDialogue['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + if domain == 'train': + if not venue_offered[domain]: + # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']: + if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']: + venue_offered[domain] = '[' + domain + '_name]' + + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match = 0 + success = 0 + # MATCH (between offered venue by generated dialogue and venue actually fitting to the criteria) + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + # if venue offered is not dict + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: # yields false positive, does not match what is offered with real dialogue? + match += 1 + match_stat = 1 + # if venue offered is dict + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: # actually checks the offered venue + match += 1 + match_stat = 1 + # other domains + else: + if domain + '_name]' in venue_offered[domain]: # yields false positive, in terms of occurence and correctness + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if soft_acc: + match = float(match)/len(goal.keys()) + else: + # only count success if all domain has matches + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + # SUCCESS (whether the requestable info in realDialogue is generated by the system) + # if no match, then success is assumed to be 0 + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: # if there is no requestable, assume to be succesful. incorrect, cause does not count false positives. + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success)/len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + # rint requests, 'DIFF', requests_real, 'SUCC', success + return success, match, stats + + def _evaluateGeneratedDialogue_new(self, dialog, goal, realDialogue, real_requestables, soft_acc=False): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + domains_in_goal = [] + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + domains_in_goal.append(domain) + + for t, sent_t in enumerate(dialog): + for domain in goal.keys(): + # for computing success + if '[' + domain + '_name]' in sent_t or '_id' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1]) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'train_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + # if realDialogue['goal'][domain].has_key('info'): + if 'info' in realDialogue['goal'][domain]: + # if realDialogue['goal'][domain]['info'].has_key('name'): + if 'name' in realDialogue['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + # the original method + # if domain == 'train': + # if not venue_offered[domain]: + # # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']: + # if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']: + # venue_offered[domain] = '[' + domain + '_name]' + + # Wrong one in HDSA + # if domain == 'train': + # if not venue_offered[domain]: + # if goal[domain]['requestable'] and 'id' not in goal[domain]['requestable']: + # venue_offered[domain] = '[' + domain + '_name]' + + # if id was not requested but train was found we dont want to override it to check if we booked the right train + if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']): + venue_offered[domain] = '[' + domain + '_name]' + + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match = 0 + success = 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: + match += 1 + match_stat = 1 + else: + if domain + '_name]' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if soft_acc: + match = float(match)/len(goal.keys()) + else: + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + # SUCCESS + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success)/len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + # rint requests, 'DIFF', requests_real, 'SUCC', success + return success, match, stats + + def _evaluateRealDialogue(self, dialog, filename): + """Evaluation of the real dialogue from corpus. + First we loads the user goal and then go through the dialogue history. + Similar to evaluateGeneratedDialogue above.""" + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # get the list of domains in the goal + domains_in_goal = [] + goal = {} + for domain in domains: + if dialog['goal'][domain]: + goal = self._parseGoal(goal, dialog, domain) + domains_in_goal.append(domain) + + # compute corpus success + real_requestables = {} + provided_requestables = {} + venue_offered = {} + for domain in goal.keys(): + provided_requestables[domain] = [] + venue_offered[domain] = [] + real_requestables[domain] = goal[domain]['requestable'] + + # iterate each turn + m_targetutt = [turn['text'] for idx, turn in enumerate(dialog['log']) if idx % 2 == 1] + for t in range(len(m_targetutt)): + for domain in domains_in_goal: + sent_t = m_targetutt[t] + # for computing match - where there are limited entities + if domain + '_name' in sent_t or '_id' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + venues = self.db.queryResultVenues(domain, dialog['log'][t * 2 + 1]) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + for requestable in requestables: + # check if reference could be issued + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if dialog['log'][t * 2]['db_pointer'][-5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if dialog['log'][t * 2]['db_pointer'][-3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + # return goal, 0, match, real_requestables + elif 'train_reference' in sent_t: + if dialog['log'][t * 2]['db_pointer'][-1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable in sent_t: + provided_requestables[domain].append(requestable) + + # offer was made? + for domain in domains_in_goal: + # if name was provided for the user, the match is being done automatically + # if dialog['goal'][domain].has_key('info'): + if 'info' in dialog['goal'][domain]: + # if dialog['goal'][domain]['info'].has_key('name'): + if 'name' in dialog['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + # if id was not requested but train was found we dont want to override it to check if we booked the right train + if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']): + venue_offered[domain] = '[' + domain + '_name]' + + # HARD (0-1) EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match, success = 0, 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, dialog['goal'][domain]['info'], real_belief=True) + # print(goal_venues) + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: + match += 1 + match_stat = 1 + + else: + if domain + '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if match == len(goal.keys()): + match = 1 + else: + match = 0 + + # SUCCESS + if match: + for domain in domains_in_goal: + domain_success = 0 + success_stat = 0 + if len(real_requestables[domain]) == 0: + # check that + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + return goal, success, match, real_requestables, stats + + def _evaluateRolloutDialogue(self, dialog): + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # get the list of domains in the goal + domains_in_goal = [] + goal = {} + for domain in domains: + if dialog['goal'][domain]: + goal = self._parseGoal(goal, dialog, domain) + domains_in_goal.append(domain) + + # compute corpus success + real_requestables = {} + provided_requestables = {} + venue_offered = {} + for domain in goal.keys(): + provided_requestables[domain] = [] + venue_offered[domain] = [] + real_requestables[domain] = goal[domain]['requestable'] + + # iterate each turn + m_targetutt = [turn['text'] for idx, turn in enumerate(dialog['log']) if idx % 2 == 1] + for t in range(len(m_targetutt)): + for domain in domains_in_goal: + sent_t = m_targetutt[t] + # for computing match - where there are limited entities + if domain + '_name' in sent_t or domain+'_id' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + venue_offered[domain] = '[' + domain + '_name]' + """ + venues = self.db.queryResultVenues(domain, dialog['log'][t * 2 + 1]) + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + """ + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + for requestable in requestables: + # check if reference could be issued + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if True or dialog['log'][t * 2]['db_pointer'][-5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if True or dialog['log'][t * 2]['db_pointer'][-3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + # return goal, 0, match, real_requestables + elif 'train_reference' in sent_t: + if True or dialog['log'][t * 2]['db_pointer'][-1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable in sent_t: + provided_requestables[domain].append(requestable) + + # offer was made? + for domain in domains_in_goal: + # if name was provided for the user, the match is being done automatically + # if dialog['goal'][domain].has_key('info'): + if 'info' in dialog['goal'][domain]: + # if dialog['goal'][domain]['info'].has_key('name'): + if 'name' in dialog['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + # if id was not requested but train was found we dont want to override it to check if we booked the right train + if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']): + venue_offered[domain] = '[' + domain + '_name]' + + # REWARD CALCULATION + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + match, success = 0.0, 0.0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, dialog['goal'][domain]['info'], real_belief=True) + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: + match += 1 + match_stat = 1 + else: + if domain + '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + match = min(1.0, float(match) / len(goal.keys())) + + # SUCCESS + if match: + for domain in domains_in_goal: + domain_success = 0 + success_stat = 0 + if len(real_requestables[domain]) == 0: + # check that + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + success = min(1.0, float(success) / len(real_requestables)) + + return success, match, stats + + def _parse_entities(self, tokens): + entities = [] + for t in tokens: + if '[' in t and ']' in t: + entities.append(t) + return entities + + def evaluateModel(self, dialogues, mode='valid', new_version=False): + """Gathers statistics for the whole sets.""" + delex_dialogues = self.delex_dialogues + # pdb.set_trace() + successes, matches = 0, 0 + corpus_successes, corpus_matches = 0, 0 + total = 0 + + gen_stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + sng_gen_stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + for filename, dial in dialogues.items(): + if mode == 'rollout': + success, match, stats = self._evaluateRolloutDialogue(dial) + else: + # data is ground truth, dial is generated + data = delex_dialogues[filename] + goal, success, match, requestables, _ = self._evaluateRealDialogue(data, filename) # only goal and requestables are kept + corpus_successes += success + corpus_matches += match + if new_version: + success, match, stats = self._evaluateGeneratedDialogue_new(dial, goal, data, requestables, + soft_acc=mode =='offline_rl') + else: + success, match, stats = self._evaluateGeneratedDialogue(dial, goal, data, requestables, + soft_acc=mode =='offline_rl') + + successes += success + matches += match + total += 1 + + for domain in gen_stats.keys(): + gen_stats[domain][0] += stats[domain][0] + gen_stats[domain][1] += stats[domain][1] + gen_stats[domain][2] += stats[domain][2] + + if 'SNG' in filename: + for domain in gen_stats.keys(): + sng_gen_stats[domain][0] += stats[domain][0] + sng_gen_stats[domain][1] += stats[domain][1] + sng_gen_stats[domain][2] += stats[domain][2] + + report = "" + report += '{} Corpus Matches : {:2.2f}%, Groundtruth {} Matches : {:2.2f}%'.format(mode, (matches / float(total) * 100), mode, (corpus_matches / float(total) * 100)) + "\n" + report += '{} Corpus Success : {:2.2f}%, Groundtruth {} Success : {:2.2f}%'.format(mode, (successes / float(total) * 100), mode, (corpus_successes / float(total) * 100)) + "\n" + report += 'Total number of dialogues: %s, new version=%s ' % (total, new_version) + + self.logger.info(report) + return report, successes/float(total), matches/float(total) + + def get_report(self): + tokenize = lambda x: x.split() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + tp, fp, fn = 0, 0, 0 + for label, hyp in zip(self.labels, self.hyps): + ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS] + hyp_tokens = [BOS] + tokenize(hyp.replace(SYS, '').replace(USR, '').strip()) + [EOS] + refs.append([ref_tokens]) + hyps.append(hyp_tokens) + + ref_entities = self._parse_entities(ref_tokens) + hyp_entities = self._parse_entities(hyp_tokens) + tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, hyp_entities) + tp += tpp + fp += fpp + fn += fnn + + # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + bleu = BLEUScorer().score(hyps, refs) + prec, rec, f1 = self._get_prec_recall(tp, fp, fn) + report = "\nBLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1) + return report, bleu, prec, rec, f1 + + def get_groundtruth_report(self): + tokenize = lambda x: x.split() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + tp, fp, fn = 0, 0, 0 + for label, hyp in zip(self.labels, self.hyps): + ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS] + refs.append([ref_tokens]) + + ref_entities = self._parse_entities(ref_tokens) + tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, ref_entities) + tp += tpp + fp += fpp + fn += fnn + + # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + # bleu = BLEUScorer().score(refs, refs) + prec, rec, f1 = self._get_prec_recall(tp, fp, fn) + # report = "\nGroundtruth BLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1) + report = "\nGroundtruth\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(prec, rec, f1) + return report, 0, prec, rec, f1 + +class SimDialEvaluator(BaseEvaluator): + CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '') + logger = logging.getLogger() + def __init__(self, data_name): + self.data_name = data_name + self.slot_dict = delex.prepareSlotValuesIndependent() + # self.delex_dialogues = json.load(open(os.path.join(self.CUR_DIR, 'data/norm-multi-woz/delex.json'))) + # self.db = MultiWozDB() + self.labels = list() + self.hyps = list() + + def initialize(self): + self.labels = list() + self.hyps = list() + + def add_example(self, ref, hyp): + self.labels.append(ref) + self.hyps.append(hyp) + + def _parse_entities(self, tokens): + entities = [] + for t in tokens: + if '[' in t and ']' in t: + entities.append(t) + return entities + + def get_report(self): + tokenize = lambda x: x.split() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + tp, fp, fn = 0, 0, 0 + for label, hyp in zip(self.labels, self.hyps): + ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS] + hyp_tokens = [BOS] + tokenize(hyp.replace(SYS, '').replace(USR, '').strip()) + [EOS] + refs.append([ref_tokens]) + hyps.append(hyp_tokens) + + ref_entities = self._parse_entities(ref_tokens) + hyp_entities = self._parse_entities(hyp_tokens) + tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, hyp_entities) + tp += tpp + fp += fpp + fn += fnn + + # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + bleu = BLEUScorer().score(hyps, refs) + prec, rec, f1 = self._get_prec_recall(tp, fp, fn) + report = "\nBLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1) + return report, bleu, prec, rec, f1 + + def get_groundtruth_report(self): + tokenize = lambda x: x.split() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + tp, fp, fn = 0, 0, 0 + for label, hyp in zip(self.labels, self.hyps): + ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS] + refs.append([ref_tokens]) + + ref_entities = self._parse_entities(ref_tokens) + tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, ref_entities) + tp += tpp + fp += fpp + fn += fnn + + # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + # bleu = BLEUScorer().score(refs, refs) + prec, rec, f1 = self._get_prec_recall(tp, fp, fn) + # report = "\nGroundtruth BLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1) + report = "\nGroundtruth\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(prec, rec, f1) + return report, 0, prec, rec, f1 + +class TurnEvaluator(BaseEvaluator): + """ + Use string matching to find the F-1 score of slots + Use logistic regression to find F-1 score of acts + Use string matching to find F-1 score of KB_SEARCH + """ + CLF = "clf" + REPRESENTATION = "rep" + ID2TAG = "id2tag" + TAG2ID = "tag2id" + logger = logging.getLogger() + + def __init__(self, data_name, turn_corpus, domain_meta): + self.data_name = data_name + # train a dialog act classifier + domain2ids = defaultdict(list) + for d_id, d in enumerate(turn_corpus): + domain2ids[d.domain].append(d_id) + selected_ids = [v[0:1000] for v in domain2ids.values()] + corpus = [turn_corpus[idx] for idxs in selected_ids for idx in idxs] + + self.model = self.get_intent_tagger(corpus) + + # get entity value vocabulary + self.domain_id2ent = self.get_entity_dict_from_meta(domain_meta) + + # Initialize containers + self.domain_labels = defaultdict(list) + self.domain_hyps = defaultdict(list) + + def get_entity_dict_from_meta(self, domain_meta): + # get entity value vocabulary + domain_id2ent = defaultdict(set) + for domain, meta in domain_meta.items(): + domain_id2ent[domain].add("QUERY") + domain_id2ent[domain].add("GOALS") + for slot, vocab in meta.sys_slots.items(): + domain_id2ent[domain].add(slot) + for v in vocab: + domain_id2ent[domain].add(v) + + for slot, vocab in meta.usr_slots.items(): + domain_id2ent[domain].add(slot) + for v in vocab: + domain_id2ent[domain].add(v) + + domain_id2ent = {k: list(v) for k, v in domain_id2ent.items()} + return domain_id2ent + + def get_entity_dict(self, turn_corpus): + utt2act = {} + for msg in turn_corpus: + utt2act[" ".join(msg.utt[1:-1])] = msg + + detokenize = get_detokenize() + utt2act = {detokenize(k.split()): v for k, v in utt2act.items()} + self.logger.info("Compress utt2act from {}->{}".format(len(turn_corpus), len(utt2act))) + + # get entity value vocabulary + domain_id2ent = defaultdict(set) + for utt, msg in utt2act.items(): + for act in msg.actions: + paras = act['parameters'] + intent = act['act'] + if intent == 'inform': + for v in paras[0].values(): + domain_id2ent[msg.domain].add(str(v)) + elif intent == 'query': + for v in paras[0].values(): + domain_id2ent[msg.domain].add(v) + else: + for k, v in paras: + if v: + domain_id2ent[msg.domain].add(v) + domain_id2ent = {k: list(v) for k, v in domain_id2ent.items()} + return domain_id2ent + + def get_intent_tagger(self, corpus): + """ + :return: train a dialog act tagger for system utterances + """ + self.logger.info("Train a new intent tagger") + all_tags, utts, tags = [], [], [] + de_tknize = get_detokenize() + for msg in corpus: + utts.append(de_tknize(msg.utt[1:-1])) + tags.append([a['act'] for a in msg.actions]) + all_tags.extend([a['act'] for a in msg.actions]) + + most_common = Counter(all_tags).most_common() + self.logger.info(most_common) + tag_set = [t for t, c, in most_common] + rev_tag_set = {t: i for i, t in enumerate(tag_set)} + + # create train and test set: + data_size = len(corpus) + train_size = int(data_size * 0.7) + train_utts = utts[0:train_size] + test_utts = utts[train_size:] + + # create y: + sparse_y = np.zeros([data_size, len(tag_set)]) + for idx, utt_tags in enumerate(tags): + for tag in utt_tags: + sparse_y[idx, rev_tag_set[tag]] = 1 + train_y = sparse_y[0:train_size, :] + test_y = sparse_y[train_size:, :] + + # train classifier + representation = CountVectorizer(ngram_range=[1, 2]).fit(train_utts) + train_x = representation.transform(train_utts) + test_x = representation.transform(test_utts) + + clf = OneVsRestClassifier(SGDClassifier(loss='hinge', max_iter=10)).fit(train_x, train_y) + pred_test_y = clf.predict(test_x) + + def print_report(score_name, scores, names): + for s, n in zip(scores, names): + self.logger.info("%s: %s -> %f" % (score_name, n, s)) + + print_report('F1', metrics.f1_score(test_y, pred_test_y, average=None), + tag_set) + + x = representation.transform(utts) + clf = OneVsRestClassifier(SGDClassifier(loss='hinge', max_iter=20)) \ + .fit(x, sparse_y) + + model_dump = {self.CLF: clf, self.REPRESENTATION: representation, + self.ID2TAG: tag_set, + self.TAG2ID: rev_tag_set} + # pkl.dump(model_dump, open("{}.pkl".format(self.data_name), "wb")) + return model_dump + + def pred_ents(self, sentence, tokenize, domain): + pred_ents = [] + padded_hyp = "/{}/".format("/".join(tokenize(sentence))) + for e in self.domain_id2ent[domain]: + count = padded_hyp.count("/{}/".format(e)) + if domain =='movie' and e == 'I': + continue + pred_ents.extend([e] * count) + return pred_ents + + def pred_acts(self, utts): + test_x = self.model[self.REPRESENTATION].transform(utts) + pred_test_y = self.model[self.CLF].predict(test_x) + pred_tags = [] + for ys in pred_test_y: + temp = [] + for i in range(len(ys)): + if ys[i] == 1: + temp.append(self.model[self.ID2TAG][i]) + pred_tags.append(temp) + return pred_tags + + """ + Public Functions + """ + def initialize(self): + self.domain_labels = defaultdict(list) + self.domain_hyps = defaultdict(list) + + def add_example(self, ref, hyp, domain='default'): + self.domain_labels[domain].append(ref) + self.domain_hyps[domain].append(hyp) + + def get_report(self, include_error=False): + reports = [] + + errors = [] + + for domain, labels in self.domain_labels.items(): + intent2refs = defaultdict(list) + intent2hyps = defaultdict(list) + + predictions = self.domain_hyps[domain] + self.logger.info("Generate report for {} for {} samples".format(domain, len(predictions))) + + # find entity precision, recall and f1 + tp, fp, fn = 0.0, 0.0, 0.0 + + # find intent precision recall f1 + itp, ifp, ifn = 0.0, 0.0, 0.0 + + # backend accuracy + btp, bfp, bfn = 0.0, 0.0, 0.0 + + # BLEU score + refs, hyps = [], [] + + pred_intents = self.pred_acts(predictions) + label_intents = self.pred_acts(labels) + + tokenize = get_tokenize() + bad_predictions = [] + + for label, hyp, label_ints, pred_ints in zip(labels, predictions, label_intents, pred_intents): + refs.append([label.split()]) + hyps.append(hyp.split()) + + # pdb.set_trace() + + label_ents = self.pred_ents(label, tokenize, domain) + pred_ents = self.pred_ents(hyp, tokenize, domain) + + for intent in label_ints: + intent2refs[intent].append([label.split()]) + intent2hyps[intent].append(hyp.split()) + + # update the intent + ttpp, ffpp, ffnn = self._get_tp_fp_fn(label_ints, pred_ints) + itp += ttpp + ifp += ffpp + ifn += ffnn + + # entity or KB search + ttpp, ffpp, ffnn = self._get_tp_fp_fn(label_ents, pred_ents) + if ffpp > 0 or ffnn > 0: + bad_predictions.append((label, hyp)) + + if "query" in label_ints: + btp += ttpp + bfp += ffpp + bfn += ffnn + else: + tp += ttpp + fp += ffpp + fn += ffnn + + # compute corpus level scores + bleu = bleu_score.corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + ent_precision, ent_recall, ent_f1 = self._get_prec_recall(tp, fp, fn) + int_precision, int_recall, int_f1 = self._get_prec_recall(itp, ifp, ifn) + back_precision, back_recall, back_f1 = self._get_prec_recall(btp, bfp, bfn) + + # compute BLEU w.r.t intents + intent_report = [] + for intent in intent2refs.keys(): + i_bleu = bleu_score.corpus_bleu(intent2refs[intent], intent2hyps[intent], + smoothing_function=SmoothingFunction().method1) + intent_report.append("{}: {}".format(intent, i_bleu)) + + intent_report = "\n".join(intent_report) + + # create bad cases + error = '' + if include_error: + error = '\nDomain {} errors\n'.format(domain) + error += "\n".join(['True: {} ||| Pred: {}'.format(r, h) + for r, h in bad_predictions]) + report = "\nDomain: %s\n" \ + "Entity precision %f recall %f and f1 %f\n" \ + "Intent precision %f recall %f and f1 %f\n" \ + "KB precision %f recall %f and f1 %f\n" \ + "BLEU %f BEAK %f\n\n%s\n" \ + % (domain, + ent_precision, ent_recall, ent_f1, + int_precision, int_recall, int_f1, + back_precision, back_recall, back_f1, + bleu, gmean([ent_f1, int_f1, back_f1, bleu]), + intent_report) + reports.append(report) + errors.append(error) + + if include_error: + return "\n==== REPORT===={error}\n========\n {report}".format(error="========".join(errors), + report="========".join(reports)) + else: + return "\n==== REPORT===={report}".format(report="========".join(reports)) diff --git a/convlab/policy/lava/multiwoz/latent_dialog/main.py b/convlab/policy/lava/multiwoz/latent_dialog/main.py new file mode 100644 index 0000000000000000000000000000000000000000..44c1b053c3728098dc161da6af96135b40b6b5ad --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/main.py @@ -0,0 +1,708 @@ +import os +import sys +import numpy as np +import torch as th +from torch import nn +from collections import defaultdict, Counter +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.base_modules import summary +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.decoders import TEACH_FORCE, GEN, DecoderRNN +from datetime import datetime +from convlab.policy.lava.multiwoz.latent_dialog.utils import get_detokenize, LONG, FLOAT +from convlab.policy.lava.multiwoz.latent_dialog.corpora import EOS, PAD +from convlab.policy.lava.multiwoz.latent_dialog.data_loaders import BeliefDbDataLoaders +from convlab.policy.lava.multiwoz.latent_dialog import evaluators +from convlab.policy.lava.multiwoz.latent_dialog.record import record, record_task, UniquenessSentMetric, UniquenessWordMetric +import logging +import pdb + +logger = logging.getLogger() + +class LossManager(object): + def __init__(self): + self.losses = defaultdict(list) + self.backward_losses = [] + + def add_loss(self, loss): + for key, val in loss.items(): + # print('key = %s\nval = %s' % (key, val)) + if val is not None and type(val) is not bool: + self.losses[key].append(val.item()) + + def pprint(self, name, window=None, prefix=None): + str_losses = [] + for key, loss in self.losses.items(): + if loss is None: + continue + aver_loss = np.average(loss) if window is None else np.average(loss[-window:]) + if 'nll' in key: + str_losses.append('{} PPL {:.3f}'.format(key, np.exp(aver_loss))) + else: + str_losses.append('{} {:.3f}'.format(key, aver_loss)) + + + if prefix: + return '{}: {} {}'.format(prefix, name, ' '.join(str_losses)) + else: + return '{} {}'.format(name, ' '.join(str_losses)) + + def clear(self): + self.losses = defaultdict(list) + self.backward_losses = [] + + def add_backward_loss(self, loss): + self.backward_losses.append(loss.item()) + + def avg_loss(self): + return np.mean(self.backward_losses) + +class OfflineTaskReinforce(object): + def __init__(self, agent, corpus, sv_config, sys_model, rl_config, generate_func): + self.agent = agent + self.corpus = corpus + self.sv_config = sv_config + self.sys_model = sys_model + self.rl_config = rl_config + # training func for supervised learning + self.train_func = task_train_single_batch + self.record_func = record_task + self.validate_func = validate + + # prepare data loader + train_dial, val_dial, test_dial = self.corpus.get_corpus() + self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config) + self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config) + + # create log files + if self.rl_config.record_freq > 0: + self.learning_exp_file = open(os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w') + self.ppl_val_file = open(os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w') + self.rl_val_file = open(os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w') + self.ppl_test_file = open(os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w') + self.rl_test_file = open(os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w') + # evaluation + self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ') + self.generate_func = generate_func + + def run(self): + n = 0 + best_valid_loss = np.inf + best_rewards = -1 * np.inf + + # BEFORE RUN, RECORD INITIAL PERFORMANCE + test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True) + t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config, + self.evaluator, None, verbose=False) + + self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(test_loss), t_bleu, t_f1)) + self.ppl_test_file.flush() + self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match)) + self.rl_test_file.flush() + + self.sys_model.train() + try: + for epoch_id in range(self.rl_config.nepoch): + self.train_data.epoch_init(self.sv_config, shuffle=True, verbose=epoch_id == 0, fix_batch=True) + while True: + if n % self.rl_config.episode_repeat == 0: + batch = self.train_data.next_batch() + + if batch is None: + break + + n += 1 + if n % 50 == 0: + print("Reinforcement Learning {}/{} episode".format(n, self.train_data.num_batch*self.rl_config.nepoch)) + self.learning_exp_file.write( + '{}\t{}\n'.format(n, np.mean(self.agent.all_rewards[-50:]))) + self.learning_exp_file.flush() + + # reinforcement learning + # make sure it's the same dialo + assert len(set(batch['keys'])) == 1 + task_report, success, match = self.agent.run(batch, self.evaluator, max_words=self.rl_config.max_words, temp=self.rl_config.temperature) + reward = float(success) # + float(match) + stats = {'Match': match, 'Success': success} + self.agent.update(reward, stats) + + # supervised learning + if self.rl_config.sv_train_freq > 0 and n % self.rl_config.sv_train_freq == 0: + self.train_func(self.sys_model, self.sl_train_data, self.sv_config) + + # record model performance in terms of several evaluation metrics + if self.rl_config.record_freq > 0 and n % self.rl_config.record_freq == 0: + self.agent.print_dialog(self.agent.dlg_history, reward, stats) + print('-'*15, 'Recording start', '-'*15) + # save train reward + self.learning_exp_file.write('{}\t{}\n'.format(n, np.mean(self.agent.all_rewards[-self.rl_config.record_freq:]))) + self.learning_exp_file.flush() + + # PPL & reward on validation + valid_loss = self.validate_func(self.sys_model, self.val_data, self.sv_config, use_py=True) + v_success, v_match, v_bleu, v_f1 = self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, None, verbose=False) + self.ppl_val_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(valid_loss), v_bleu, v_f1)) + self.ppl_val_file.flush() + self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(n, (v_success + v_match), v_success, v_match)) + self.rl_val_file.flush() + + test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True) + t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, None, verbose=False) + self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(test_loss), t_bleu, t_f1)) + self.ppl_test_file.flush() + self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match)) + self.rl_test_file.flush() + + # save model is needed + if v_success+v_match > best_rewards: + print("Model saved with success {} match {}".format(v_success, v_match)) + th.save(self.sys_model.state_dict(), self.rl_config.reward_best_model_path) + best_rewards = v_success+v_match + + + self.sys_model.train() + print('-'*15, 'Recording end', '-'*15) + except KeyboardInterrupt: + print("RL training stopped from keyboard") + + print("$$$ Load {}-model".format(self.rl_config.reward_best_model_path)) + self.sv_config.batch_size = 32 + self.sys_model.load_state_dict(th.load(self.rl_config.reward_best_model_path)) + + validate(self.sys_model, self.val_data, self.sv_config, use_py=True) + validate(self.sys_model, self.test_data, self.sv_config, use_py=True) + + with open(os.path.join(self.rl_config.record_path, 'valid_file.txt'), 'w') as f: + self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f) + + with open(os.path.join(self.rl_config.record_path, 'test_file.txt'), 'w') as f: + self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f) + +def validate_rl(dialog_eval, ctx_gen, num_episode=200): + print("Validate on training goals for {} episode".format(num_episode)) + reward_list = [] + agree_list = [] + sent_metric = UniquenessSentMetric() + word_metric = UniquenessWordMetric() + for _ in range(num_episode): + ctxs = ctx_gen.sample() + conv, agree, rewards = dialog_eval.run(ctxs) + true_reward = rewards[0] if agree else 0 + reward_list.append(true_reward) + agree_list.append(float(agree if agree is not None else 0.0)) + for turn in conv: + if turn[0] == 'Elder': + sent_metric.record(turn[1]) + word_metric.record(turn[1]) + results = {'sys_rew': np.average(reward_list), + 'avg_agree': np.average(agree_list), + 'sys_sent_unique': sent_metric.value(), + 'sys_unique': word_metric.value()} + return results + +def train_single_batch(model, train_data, config): + batch_cnt = 0 + optimizer = model.get_optimizer(config, verbose=False) + model.train() + + # decoding CE + train_data.epoch_init(config, shuffle=True, verbose=False) + for i in range(16): + batch = train_data.next_batch() + if batch is None: + train_data.epoch_init(config, shuffle=True, verbose=False) + batch = train_data.next_batch() + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + +def task_train_single_batch(model, train_data, config): + batch_cnt = 0 + optimizer = model.get_optimizer(config, verbose=False) + model.train() + + # decoding CE + train_data.epoch_init(config, shuffle=True, verbose=False) + for i in range(16): + batch = train_data.next_batch() + if batch is None: + train_data.epoch_init(config, shuffle=True, verbose=False) + batch = train_data.next_batch() + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + +def train(model, train_data, val_data, test_data, config, evaluator, gen=None): + patience = 10 + valid_loss_threshold = np.inf + best_valid_loss = np.inf + batch_cnt = 0 + optimizer = model.get_optimizer(config) + done_epoch = 0 + best_epoch = 0 + train_loss = LossManager() + model.train() + logger.info(summary(model, show_weights=False)) + saved_models = [] + last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5 + + logger.info('***** Training Begins at {} *****'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('***** Epoch 0/{} *****'.format(config.max_epoch)) + while True: + train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch) + while True: + batch = train_data.next_batch() + if batch is None: + break + + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + batch_cnt += 1 + train_loss.add_loss(loss) + + if batch_cnt % config.print_step == 0: + # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(train_loss.pprint('Train', + window=config.print_step, + prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w))) + sys.stdout.flush() + + if batch_cnt % config.ckpt_step == 0: + logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('==== Evaluating Model ====') + logger.info(train_loss.pprint('Train')) + done_epoch += 1 + logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch)) + + # generation + if gen is not None: + gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num) + + # validation + valid_loss = validate(model, val_data, config, batch_cnt) + _ = validate(model, test_data, config, batch_cnt) + + # update early stopping stats + if valid_loss < best_valid_loss: + if valid_loss <= valid_loss_threshold * config.improve_threshold: + patience = max(patience, done_epoch*config.patient_increase) + valid_loss_threshold = valid_loss + logger.info('Update patience to {}'.format(patience)) + + if config.save_model: + cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S") + logger.info('!!Model Saved with loss = {},at {}.'.format(valid_loss, cur_time)) + th.save(model.state_dict(), os.path.join(config.saved_path, '{}-model'.format(done_epoch))) + best_epoch = done_epoch + saved_models.append(done_epoch) + if len(saved_models) > last_n_model: + remove_model = saved_models[0] + saved_models = saved_models[-last_n_model:] + os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model))) + + best_valid_loss = valid_loss + + if done_epoch >= config.max_epoch \ + or config.early_stop and patience <= done_epoch: + if done_epoch < config.max_epoch: + logger.info('!!!!! Early stop due to run out of patience !!!!!') + print('Best validation loss = %f' % (best_valid_loss, )) + return best_epoch + + # exit eval model + model.train() + train_loss.clear() + logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.max_epoch)) + sys.stdout.flush() + +def mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=None): + patience = 10 + valid_loss_threshold = np.inf + best_valid_loss = np.inf + batch_cnt = 0 + optimizer = model.get_optimizer(config) + done_epoch = 0 + best_epoch = 0 + train_loss = LossManager() + model.train() + logger.info(summary(model, show_weights=False)) + saved_models = [] + last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5 + + logger.info('***** Training Begins at {} *****'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('***** Epoch 0/{} *****'.format(config.max_epoch)) + while True: + train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch) + while True: + batch = train_data.next_batch() + if batch is None: + break + + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + batch_cnt += 1 + train_loss.add_loss(loss) + + if batch_cnt % config.print_step == 0: + # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(train_loss.pprint('Train', + window=config.print_step, + prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w))) + sys.stdout.flush() + + if batch_cnt % config.ckpt_step == 0: + logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('==== Evaluating Model ====') + logger.info(train_loss.pprint('Train')) + done_epoch += 1 + logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch)) + + # generation + if gen is not None: + gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num) + + # validation + valid_loss = validate(model, val_data, config, batch_cnt) + _ = validate(model, test_data, config, batch_cnt) + + # update early stopping stats + if valid_loss < best_valid_loss: + if valid_loss <= valid_loss_threshold * config.improve_threshold: + patience = max(patience, done_epoch*config.patient_increase) + valid_loss_threshold = valid_loss + logger.info('Update patience to {}'.format(patience)) + + if config.save_model: + cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S") + logger.info('!!Model Saved with loss = {},at {}.'.format(valid_loss, cur_time)) + th.save(model.state_dict(), os.path.join(config.saved_path, '{}-model'.format(done_epoch))) + best_epoch = done_epoch + saved_models.append(done_epoch) + if len(saved_models) > last_n_model: + remove_model = saved_models[0] + saved_models = saved_models[-last_n_model:] + os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model))) + + best_valid_loss = valid_loss + + if done_epoch >= config.max_epoch \ + or config.early_stop and patience <= done_epoch: + if done_epoch < config.max_epoch: + logger.info('!!!!! Early stop due to run out of patience !!!!!') + print('Best validation loss = %f' % (best_valid_loss, )) + return best_epoch + + + if done_epoch % config.aux_train_freq == 0: + model.train() + train_aux(model, aux_train_data, aux_val_data, aux_test_data, config, evaluator) + + # exit eval model + model.train() + train_loss.clear() + + logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.max_epoch)) + sys.stdout.flush() + +def train_aux(model, train_data, val_data, test_data, config, evaluator, gen=None): + patience = 10 + valid_loss_threshold = np.inf + best_valid_loss = np.inf + batch_cnt = 0 + optimizer = model.get_optimizer(config) + done_epoch = 0 + best_epoch = 0 + train_loss = LossManager() + model.train() + logger.info(summary(model, show_weights=False)) + saved_models = [] + last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5 + + logger.info('+++++ Aux Training Begins at {} +++++'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('+++++ Epoch 0/{} +++++'.format(config.aux_max_epoch)) + while True: + train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch) + while True: + batch = train_data.next_batch() + if batch is None: + break + + optimizer.zero_grad() + loss = model.forward_aux(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + batch_cnt += 1 + train_loss.add_loss(loss) + + if batch_cnt % config.print_step == 0: + # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(train_loss.pprint('Train', + window=config.print_step, + prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w))) + sys.stdout.flush() + + if batch_cnt % config.ckpt_step == 0: + logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('++++ Evaluating Model ++++') + logger.info(train_loss.pprint('Aux train')) + done_epoch += 1 + logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch)) + + # generation + if gen is not None: + gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num) + + # validation + valid_loss = aux_validate(model, val_data, config, batch_cnt) + _ = aux_validate(model, test_data, config, batch_cnt) + + # update early stopping stats + if valid_loss < best_valid_loss: + if valid_loss <= valid_loss_threshold * config.improve_threshold: + patience = max(patience, done_epoch*config.patient_increase) + valid_loss_threshold = valid_loss + logger.info('Update patience to {}'.format(patience)) + + if config.save_model: + cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S") + logger.info('!!New best model with loss = {},at {}.'.format(valid_loss, cur_time)) + th.save(model.state_dict(), os.path.join(config.saved_path, 'aux-{}-model'.format(done_epoch))) + best_epoch = done_epoch + saved_models.append(done_epoch) + if len(saved_models) > last_n_model: + remove_model = saved_models[0] + saved_models = saved_models[-last_n_model:] + os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model))) + + best_valid_loss = valid_loss + + if done_epoch >= config.aux_max_epoch \ + or config.early_stop and patience <= done_epoch: + if done_epoch < config.aux_max_epoch: + logger.info('!!!!! Early stop due to run out of patience !!!!!') + print('Best validation loss = %f' % (best_valid_loss, )) + return best_epoch + + # exit eval model + model.train() + train_loss.clear() + logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.aux_max_epoch)) + sys.stdout.flush() + +def validate(model, val_data, config, batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + if use_py is not None: + # loss = model(batch, mode=TEACH_FORCE, use_py=use_py) + loss = model(batch, mode=TEACH_FORCE) + else: + loss = model(batch, mode=TEACH_FORCE) + + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def validate_actz(model, val_data, config, enc="utt", batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + if use_py is not None: + # loss = model(batch, mode=TEACH_FORCE, use_py=use_py) + if enc=="aux": + loss = model.forward_aez(batch, mode=TEACH_FORCE) + else: + loss = model(batch, mode=TEACH_FORCE) + else: + if enc=="aux": + loss = model.forward_aez(batch, mode=TEACH_FORCE) + else: + loss = model(batch, mode=TEACH_FORCE) + + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def validate_mt(model, val_data, config, batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + loss = model(batch, mode=TEACH_FORCE) + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def aux_validate(model, val_data, config, batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + loss = model.forward_aux(batch, mode=TEACH_FORCE) + + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total aux valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def generate(model, data, config, evaluator, num_batch, dest_f=None): + + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + model.eval() + de_tknize = get_detokenize() + data.epoch_init(config, shuffle=num_batch is not None, verbose=False) + evaluator.initialize() + logger.info('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + print_cnt = 0 + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get attention if possible + if config.dec_use_attn: + pred_attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]] + pred_attns = np.array(pred_attns, dtype=float).squeeze(2).swapaxes(0, 1) # (batch_size, max_dec_len, max_ctx_len) + else: + pred_attns = None + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + + for b_id in range(pred_labels.shape[0]): + # TODO attn + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + evaluator.add_example(true_str, pred_str) + + if num_batch is None or batch_cnt < 2: + print_cnt += 1 + write('prev_ctx = %s' % (prev_ctx, )) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('='*30) + if num_batch is not None and print_cnt > 10: + break + + write(evaluator.get_report()) + # write(evaluator.get_groundtruth_report()) + write('Generation Done') + +def get_sent(vocab, de_tknize, data, b_id, stop_eos=True, stop_pad=True): + ws = [] + for t_id in range(data.shape[1]): + w = vocab[data[b_id, t_id]] + # TODO EOT + if (stop_eos and w == EOS) or (stop_pad and w == PAD): + break + if w != PAD: + ws.append(w) + + return de_tknize(ws) + +def most_frequent(List): + occ_count = Counter(List) + return occ_count.most_common(1)[0][0] + +def generate_with_name(model, data, config): + model.eval() + de_tknize = get_detokenize() + data.epoch_init(config, shuffle=False, verbose=False) + logger.info('Generation With Name: {} batches.'.format(data.num_batch)) + + from collections import defaultdict + res = defaultdict(dict) + while True: + batch = data.next_batch() + if batch is None: + break + keys, outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.cpu().data.numpy() # (batch_size, output_seq_len) + + for b_id in range(pred_labels.shape[0]): + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + dlg_name, dlg_turn = keys[b_id] + res[dlg_name][dlg_turn] = {'pred': pred_str, 'true': true_str} + + return res diff --git a/convlab/policy/lava/multiwoz/latent_dialog/metric.py b/convlab/policy/lava/multiwoz/latent_dialog/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..433c19e369fd7c8274cd1629a053137278f7079b --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/metric.py @@ -0,0 +1,151 @@ +import time +from collections import OrderedDict + + +class NumericMetric(object): + """Base class for a numeric metric.""" + def __init__(self): + self.k = 0 + self.n = 0 + + def reset(self): + pass + + def record(self, k, n=1): + self.k += k + self.n += n + + def value(self): + self.n = max(1, self.n) + return 1.0 * self.k / self.n + + +class AverageMetric(NumericMetric): + """Average.""" + def show(self): + return '%.2f' % (1. * self.value()) + + +class PercentageMetric(NumericMetric): + """Percentage.""" + def show(self): + return '%2.2f%%' % (100. * self.value()) + + +class TimeMetric(object): + """Time based metric.""" + def __init__(self): + self.t = 0 + self.n = 0 + + def reset(self): + self.last_t = time.time() + + def record(self, n=1): + self.t += time.time() - self.last_t + self.n += 1 + + def value(self): + self.n = max(1, self.n) + return 1.0 * self.t / self.n + + def show(self): + return '%.3fs' % (1. * self.value()) + + +class UniquenessMetric(object): + """Metric that evaluates the number of unique sentences.""" + def __init__(self): + self.seen = set() + + def reset(self): + pass + + def record(self, sen): + self.seen.add(' '.join(sen)) + + def value(self): + return len(self.seen) + + def show(self): + return str(self.value()) + + +class TextMetric(object): + """Text based metric.""" + def __init__(self, text): + self.text = text + self.k = 0 + self.n = 0 + + def reset(self): + pass + + def value(self): + self.n = max(1, self.n) + return 1. * self.k / self.n + + def show(self): + return '%.2f' % (1. * self.value()) + + +class NGramMetric(TextMetric): + """Metric that evaluates n grams.""" + def __init__(self, text, ngram=-1): + super(NGramMetric, self).__init__(text) + self.ngram = ngram + + def record(self, sen): + n = len(sen) if self.ngram == -1 else self.ngram + for i in range(len(sen) - n + 1): + self.n += 1 + target = ' '.join(sen[i:i + n]) + if self.text.find(target) != -1: + self.k += 1 + + +class MetricsContainer(object): + """A container that stores and updates several metrics.""" + def __init__(self): + self.metrics = OrderedDict() + + def _register(self, name, ty, *args, **kwargs): + name = name.lower() + assert name not in self.metrics + self.metrics[name] = ty(*args, **kwargs) + + def register_average(self, name, *args, **kwargs): + self._register(name, AverageMetric, *args, **kwargs) + + def register_time(self, name, *args, **kwargs): + self._register(name, TimeMetric, *args, **kwargs) + + def register_percentage(self, name, *args, **kwargs): + self._register(name, PercentageMetric, *args, **kwargs) + + def register_ngram(self, name, *args, **kwargs): + self._register(name, NGramMetric, *args, **kwargs) + + def register_uniqueness(self, name, *args, **kwargs): + self._register(name, UniquenessMetric, *args, **kwargs) + + def record(self, name, *args, **kwargs): + name = name.lower() + assert name in self.metrics + self.metrics[name].record(*args, **kwargs) + + def reset(self): + for m in self.metrics.values(): + m.reset() + + def value(self, name): + return self.metrics[name].value() + + def show(self): + return ' '.join(['%s=%s' % (k, v.show()) for k, v in self.metrics.iteritems()]) + + def dict(self): + d = OrderedDict() + for k, v in self.metrics.items(): + d[k] = v.show() + return d diff --git a/convlab/policy/lava/multiwoz/latent_dialog/models_task.py b/convlab/policy/lava/multiwoz/latent_dialog/models_task.py new file mode 100644 index 0000000000000000000000000000000000000000..84a1eb986f91301d6c463ada7c180ba0b6227927 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/models_task.py @@ -0,0 +1,3958 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from convlab.policy.lava.multiwoz.latent_dialog.base_models import BaseModel +from convlab.policy.lava.multiwoz.latent_dialog.corpora import SYS, EOS, PAD, BOS +from convlab.policy.lava.multiwoz.latent_dialog.utils import INT, FLOAT, LONG, Pack, cast_type +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.encoders import RnnUttEncoder +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.decoders import DecoderRNN, GEN, TEACH_FORCE +from convlab.policy.lava.multiwoz.latent_dialog.criterions import NLLEntropy, CatKLLoss, Entropy, NormKLLoss, WeightedNLLEntropy +from convlab.policy.lava.multiwoz.latent_dialog import nn_lib +import numpy as np +import pdb + + +class SysPerfectBD2Word(BaseModel): + def __init__(self, corpus, config): + super(SysPerfectBD2Word, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.policy = nn.Sequential(nn.Linear(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.dec_cell_size), nn.Tanh(), nn.Dropout(config.dropout)) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=self.utt_encoder.output_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # pack attention context + if self.config.dec_use_attn: + attn_context = enc_outs + else: + attn_context = None + + # create decoder initial states + dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) + + # decode + if self.config.dec_rnn_cell == 'lstm': + # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0) + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + return ret_dict, labels + if return_latent: + return Pack(nll=self.nll(dec_outputs, labels), + latent_action=dec_init_state) + else: + return Pack(nll=self.nll(dec_outputs, labels)) + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # pack attention context + if self.config.dec_use_attn: + attn_context = enc_outs + else: + attn_context = None + + # create decoder initial states + dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + return logprobs, outs + +class SysPerfectBD2Cat(BaseModel): + def __init__(self, corpus, config): + super(SysPerfectBD2Cat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + if not self.simple_posterior: + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + +class SysAECat(BaseModel): + def __init__(self, corpus, config): + super(SysAECat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + # self.bs_size = corpus.bs_size + # self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = True # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = False # does not use context cause AE task + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + # if not self.simple_posterior: #q(z|x,c) + # if self.contextual_posterior: + # # x, c, BS, and DB + # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + # config.y_size, config.k_size, is_lstm=False) + # else: + # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + # bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = utt_summary.squeeze(1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + # else: + # logits_py, log_py = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or (use_py is not None and use_py is True): + # sample_y = self.gumbel_connector(logits_py, hard=False) + # else: + # sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + +class SysMTCat(BaseModel): + def __init__(self, corpus, config): + super(SysMTCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + if self.use_aux_kl: + try: + total_loss += loss.aux_pi_kl + except KeyError: + total_loss += 0 + + return total_loss + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + # how to use z, alone or in combination with bs and db + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + # else: + # logits_py, log_py = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or (use_py is not None and use_py is True): + # sample_y = self.gumbel_connector(logits_py, hard=False) + # else: + # sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def shared_forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + aux_pi_kl = self.cat_kl_loss(log_qy, aux_log_qy, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['aux_pi_kl'] = aux_pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + +class SysActZCat(BaseModel): + def __init__(self, corpus, config): + super(SysActZCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = aux_log_qy + else: + logits_py, log_py = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + +class SysE2ECat(BaseModel): + def __init__(self, corpus, config): + super(SysE2ECat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + if not self.simple_posterior: + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.squeeze(1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.squeeze(1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + +class SysE2EActZCat(BaseModel): + def __init__(self, corpus, config): + super(SysE2EActZCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + if "use_act_label" in config: + self.use_act_label = config.use_act_label + else: + self.use_act_label = False + + if "use_aux_c2z" in config: + self.use_aux_c2z = config.use_aux_c2z + else: + self.use_aux_c2z = False + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + if self.use_act_label: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.act_size, config.y_size, config.k_size, is_lstm=False) + else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + if self.use_aux_c2z: + self.aux_c2z = nn_lib.Hidden2Discrete(self.aux_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = utt_summary.squeeze(1) + aux_enc_last = aux_utt_summary.squeeze(1) + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + if self.use_aux_c2z: + aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1)) + else: + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = aux_log_qy + else: + logits_py, log_py = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward_rl(self, data_feed, max_words, temp=0.1, enc="utt"): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + batch_size = len(ctx_lens) + if enc == "utt": + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = utt_summary.squeeze(1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + elif enc == "aux": + short_target_utts = self.np2var(data_feed['outputs'], LONG) + # short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['outputs'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + if self.simple_posterior: + if self.use_aux_c2z: + aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1)) + else: + aux_enc_last = aux_utt_summary.squeeze(1) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + logits_py = aux_logits_qy + log_qy = aux_log_qy + + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + +class SysPerfectBD2Gauss(BaseModel): + def __init__(self, corpus, config): + super(SysPerfectBD2Gauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + if "contextual posterior" in config: + self.contextual_posterior = config.contextual_posterior + else: + self.contextual_posterior = True # default value is true, i.e. q(z|x,c) + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + p_mu, p_logvar = self.c2z(enc_last) + + sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() + logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_z + +class SysAEGauss(BaseModel): + def __init__(self, corpus, config): + super(SysAEGauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + # self.bs_size = corpus.bs_size + # self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.y_size = config.y_size + self.simple_posterior = True + self.contextual_posterior = False + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, + config.y_size, is_lstm=False) + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + # bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.squeeze(1) + + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + # else: + # p_mu, p_logvar = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or use_py: + # sample_z = self.gauss_connector(p_mu, p_logvar) + # else: + # sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + p_mu, p_logvar = self.c2z(enc_last) + + sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() + logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_z + +class SysMTGauss(BaseModel): + def __init__(self, corpus, config): + super(SysMTGauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.use_aux_kl: + try: + total_loss += loss.aux_pi_kl + except KeyError: + total_loss += 0 + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def shared_forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + aux_pi_kl = self.gauss_kl(q_mu, q_logvar, aux_q_mu, aux_q_logvar) + result['pi_kl'] = pi_kl + result['aux_pi_kl'] = aux_pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + + return logprobs, outs, joint_logpz, sample_z + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + # else: + # p_mu, p_logvar = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or use_py: + # sample_z = self.gauss_connector(p_mu, p_logvar) + # else: + # sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + +class SysActZGauss(BaseModel): + def __init__(self, corpus, config): + super(SysActZGauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if config.avg_type == "slot": + # give slot_weight:1 ratio between slot tokens and other words + self.loss_weight = th.tensor([(config.slot_weight - 1) * int(k[0] == '[' and k[-1] == ']' ) + 1 for k in self.vocab_dict.keys()]).type(th.FloatTensor) + if self.use_gpu: + self.loss_weight = self.loss_weight.cuda() + self.nll = WeightedNLLEntropy(self.pad_id, config.avg_type, self.loss_weight) + else: + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + p_mu, p_logvar = self.c2z(aux_enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + + +# Grounded Models + +class SysEncodedBD2Cat(BaseModel): + def __init__(self, corpus, config): + super(SysEncodedBD2Cat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if config.use_metadata_for_decoding: + self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=int(config.embed_size / 2), + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=int(config.utt_cell_size / 2), + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if "policy_dropout" in config and config.policy_dropout: + if "policy_dropout_rate" in config: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + else: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + + else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + if not self.simple_posterior: + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + if config.use_metadata_for_decoding: + dec_hidden_size = config.dec_cell_size + self.metadata_encoder.output_size + else: + dec_hidden_size = config.dec_cell_size + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=dec_hidden_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def extract_short_ctx(self, data_feed): + utts = [] + ctx_lens = data_feed['context_lens'] # (batch_size, ) + context = data_feed['contexts'] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(context)): + utt = [] + for t_id in range(ctx_lens[b_id]): + utt.extend(context[b_id][t_id]) + try: + utt.extend(bs[b_id] + db[b_id]) + except: + pdb.set_trace() + utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) + return np.array(utts) + + def extract_metadata(self, data_feed): + utts = [] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(bs)): + utt = [] + utt.extend(bs[b_id] + db[b_id]) + utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) + return np.array(utts) + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + # pdb.set_trace() + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + logits_py, log_qy = self.c2z(enc_last) + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + + def sample_z(self, data_feed, n_z=1, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + # metadata = self.np2var(self.extract_metadata(data_feed), LONG) + # out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1)) + + + # create decoder initial states + enc_last = utt_summary.unsqueeze(1) + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + zs = [] + logpzs = [] + for i in range(n_z): + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + zs.append(sample_y) + logpzs.append(joint_logpz) + + + return th.stack(zs), th.stack(logpzs) + + def decode_z(self, sample_y, batch_size, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + # pack attention context + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1)) + + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # has to be forward_rl because we don't have the golden target + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + return logprobs, outs + +class SysGroundedActZCat(BaseModel): + def __init__(self, corpus, config): + super(SysGroundedActZCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if config.use_metadata_for_decoding: + self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=int(config.embed_size / 2), + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=int(config.utt_cell_size / 2), + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + + if "policy_dropout" in config and config.policy_dropout: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + if config.use_metadata_for_decoding: + dec_hidden_size = config.dec_cell_size + self.metadata_encoder.output_size + else: + dec_hidden_size = config.dec_cell_size + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=dec_hidden_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def extract_short_ctx(self, data_feed): + utts = [] + ctx_lens = data_feed['context_lens'] # (batch_size, ) + context = data_feed['contexts'] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(context)): + utt = [] + for t_id in range(ctx_lens[b_id]): + utt.extend(context[b_id][t_id]) + try: + utt.extend(bs[b_id] + db[b_id]) + except: + pdb.set_trace() + utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) + return np.array(utts) + + def extract_metadata(self, data_feed): + utts = [] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(bs)): + utt = [] + if "metadata_db_only" in config and self.config.metadata_db_only: + utt.extend(db[b_id]) + else: + utt.extend(bs[b_id] + db[b_id]) + utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) + return np.array(utts) + + def extract_AE_ctx(self, data_feed): + utts = [] + ctx_lens = data_feed['context_lens'] # (batch_size, ) + context = data_feed['outputs'] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(context)): + utt = [] + utt.extend(context[b_id]) + try: + utt.extend(bs[b_id] + db[b_id]) + except: + pdb.set_trace() + utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) + return np.array(utts) + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + if self.config.use_metadata_for_aux_encoder: + ctx_outs = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(ctx_outs.unsqueeze(1)) + else: + short_target_utts = self.np2var(data_feed['outputs'], LONG) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = utt_summary.unsqueeze(1) + aux_enc_last = aux_utt_summary.unsqueeze(1) + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = aux_log_qy + else: + logits_py, log_py = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + # pdb.set_trace() + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + logits_py, log_qy = self.c2z(enc_last) + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + +class SysGroundedMTCat(BaseModel): + def __init__(self, corpus, config): + super(SysGroundedMTCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if config.use_metadata_for_decoding: + self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=int(config.embed_size / 2), + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=int(config.utt_cell_size / 2), + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + + if "policy_dropout" in config and config.policy_dropout: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + + + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + if self.use_aux_kl: + try: + total_loss += loss.aux_pi_kl + except KeyError: + total_loss += 0 + + return total_loss + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + + ctx_lens = data_feed['context_lens'] # (batch_size, ) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + if self.config.use_metadata_for_aux_encoder: + ctx_outs = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db + utt_summary, _, _ = self.aux_encoder(ctx_outs.unsqueeze(1)) + else: + short_target_utts = self.np2var(data_feed['outputs'], LONG) + utt_summary, _, _ = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = utt_summary.unsqueeze(1) + + # how to use z, alone or in combination with bs and db + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + # else: + # logits_py, log_py = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or (use_py is not None and use_py is True): + # sample_y = self.gumbel_connector(logits_py, hard=False) + # else: + # sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def shared_forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + if self.config.use_metadata_for_aux_encoder: + ctx_outs = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(ctx_outs.unsqueeze(1)) + else: + short_target_utts = self.np2var(data_feed['outputs'], LONG) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = utt_summary.unsqueeze(1) + aux_enc_last = aux_utt_summary.unsqueeze(1) + + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + aux_pi_kl = self.cat_kl_loss(log_qy, aux_log_qy, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['aux_pi_kl'] = aux_pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def extract_metadata(self, data_feed): + utts = [] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(bs)): + utt = [] + if "metadata_db_only" in self.config and self.config.metadata_db_only: + utt.extend(db[b_id]) + else: + utt.extend(bs[b_id] + db[b_id]) + utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) + return np.array(utts) + + def extract_AE_ctx(self, data_feed): + utts = [] + ctx_lens = data_feed['context_lens'] # (batch_size, ) + context = data_feed['outputs'] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(context)): + utt = [] + utt.extend(context[b_id]) + try: + utt.extend(bs[b_id] + db[b_id]) + except: + pdb.set_trace() + utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) + return np.array(utts) + + def extract_short_ctx(self, data_feed): + utts = [] + ctx_lens = data_feed['context_lens'] # (batch_size, ) + context = data_feed['contexts'] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(context)): + utt = [] + for t_id in range(ctx_lens[b_id]): + utt.extend(context[b_id][t_id]) + try: + utt.extend(bs[b_id] + db[b_id]) + except: + pdb.set_trace() + utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) + return np.array(utts) + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + # pdb.set_trace() + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + logits_py, log_qy = self.c2z(enc_last) + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y diff --git a/convlab/policy/lava/multiwoz/latent_dialog/nn_lib.py b/convlab/policy/lava/multiwoz/latent_dialog/nn_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..e843684c9b6ac3303425074378c5a368076b9b82 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/nn_lib.py @@ -0,0 +1,182 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +from convlab.policy.lava.multiwoz.latent_dialog.utils import cast_type, FLOAT + + +class IdentityConnector(nn.Module): + def __init(self): + super(IdentityConnector, self).__init__() + + def forward(self, hidden_state): + return hidden_state + + +class Bi2UniConnector(nn.Module): + def __init__(self, rnn_cell, num_layer, hidden_size, output_size): + super(Bi2UniConnector, self).__init__() + if rnn_cell == 'lstm': + self.fch = nn.Linear(hidden_size*2*num_layer, output_size) + self.fcc = nn.Linear(hidden_size*2*num_layer, output_size) + else: + self.fc = nn.Linear(hidden_size*2*num_layer, output_size) + + self.rnn_cell = rnn_cell + self.hidden_size = hidden_size + self.output_size = output_size + + def forward(self, hidden_state): + """ + :param hidden_state: [num_layer, batch_size, feat_size] + :param inputs: [batch_size, feat_size] + :return: + """ + if self.rnn_cell == 'lstm': + h, c = hidden_state + num_layer = h.size()[0] + flat_h = h.transpose(0, 1).contiguous() + flat_c = c.transpose(0, 1).contiguous() + new_h = self.fch(flat_h.view(-1, self.hidden_size*num_layer)) + new_c = self.fch(flat_c.view(-1, self.hidden_size*num_layer)) + return (new_h.view(1, -1, self.output_size), + new_c.view(1, -1, self.output_size)) + else: + # FIXME fatal error here! + num_layer = hidden_state.size()[0] + new_s = self.fc(hidden_state.view(-1, self.hidden_size*num_layer)) + new_s = new_s.view(1, -1, self.output_size) + return new_s + + +class Hidden2Gaussian(nn.Module): + def __init__(self, input_size, output_size, is_lstm=False, has_bias=True): + super(Hidden2Gaussian, self).__init__() + if is_lstm: + self.mu_h = nn.Linear(input_size, output_size, bias=has_bias) + self.logvar_h = nn.Linear(input_size, output_size, bias=has_bias) + + self.mu_c = nn.Linear(input_size, output_size, bias=has_bias) + self.logvar_c = nn.Linear(input_size, output_size, bias=has_bias) + else: + self.mu = nn.Linear(input_size, output_size, bias=has_bias) + self.logvar = nn.Linear(input_size, output_size, bias=has_bias) + + self.is_lstm = is_lstm + + def forward(self, inputs): + """ + :param inputs: batch_size x input_size + :return: + """ + if self.is_lstm: + h, c= inputs + if h.dim() == 3: + h = h.squeeze(0) + c = c.squeeze(0) + + mu_h, mu_c = self.mu_h(h), self.mu_c(c) + logvar_h, logvar_c = self.logvar_h(h), self.logvar_c(c) + return mu_h+mu_c, logvar_h+logvar_c + else: + # if inputs.dim() == 3: + # inputs = inputs.squeeze(0) + mu = self.mu(inputs) + logvar = self.logvar(inputs) + return mu, logvar + + +class Hidden2Discrete(nn.Module): + def __init__(self, input_size, y_size, k_size, is_lstm=False, has_bias=True): + super(Hidden2Discrete, self).__init__() + self.y_size = y_size + self.k_size = k_size + latent_size = self.k_size*self.y_size + if is_lstm: + self.p_h = nn.Linear(input_size, latent_size, bias=has_bias) + + self.p_c = nn.Linear(input_size, latent_size, bias=has_bias) + else: + self.p_h = nn.Linear(input_size, latent_size, bias=has_bias) + + self.is_lstm = is_lstm + + def forward(self, inputs): + """ + :param inputs: batch_size x input_size + :return: + """ + if self.is_lstm: + h, c= inputs + if h.dim() == 3: + h = h.squeeze(0) + c = c.squeeze(0) + logits = self.p_h(h) + self.p_c(c) + else: + logits = self.p_h(inputs) + logits = logits.view(-1, self.k_size) + log_qy = F.log_softmax(logits, dim=1) + return logits, log_qy + + +class GaussianConnector(nn.Module): + def __init__(self, use_gpu): + super(GaussianConnector, self).__init__() + self.use_gpu = use_gpu + + def forward(self, mu, logvar): + """ + Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the + reparametrization trick. + TODO: this should be better be a instance method in a Gaussian class. + :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching + :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None. + :return: + """ + epsilon = th.randn(logvar.size()) + epsilon = cast_type(Variable(epsilon), FLOAT, self.use_gpu) + std = th.exp(0.5 * logvar) + z = mu + std * epsilon + return z + + +class GumbelConnector(nn.Module): + def __init__(self, use_gpu): + super(GumbelConnector, self).__init__() + self.use_gpu = use_gpu + + def sample_gumbel(self, logits, use_gpu, eps=1e-20): + u = th.rand(logits.size()) + sample = Variable(-th.log(-th.log(u + eps) + eps)) + sample = cast_type(sample, FLOAT, use_gpu) + return sample + + def gumbel_softmax_sample(self, logits, temperature, use_gpu): + """ Draw a sample from the Gumbel-Softmax distribution""" + eps = self.sample_gumbel(logits, use_gpu) + y = logits + eps + return F.softmax(y / temperature, dim=y.dim()-1) + + def forward(self, logits, temperature=1.0, hard=False, + return_max_id=False): + """ + :param logits: [batch_size, n_class] unnormalized log-prob + :param temperature: non-negative scalar + :param hard: if True take argmax + :param return_max_id + :return: [batch_size, n_class] sample from gumbel softmax + """ + y = self.gumbel_softmax_sample(logits, temperature, self.use_gpu) + _, y_hard = th.max(y, dim=1, keepdim=True) + if hard: + y_onehot = cast_type(Variable(th.zeros(y.size())), FLOAT, self.use_gpu) + y_onehot.scatter_(1, y_hard, 1.0) + y = y_onehot + if return_max_id: + return y, y_hard + else: + return y + + + diff --git a/convlab/policy/lava/multiwoz/latent_dialog/normalizer/__init__.py b/convlab/policy/lava/multiwoz/latent_dialog/normalizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/convlab/policy/lava/multiwoz/latent_dialog/normalizer/delexicalize.py b/convlab/policy/lava/multiwoz/latent_dialog/normalizer/delexicalize.py new file mode 100644 index 0000000000000000000000000000000000000000..df4636b24d0083bffb9b6d1303dc05d447720970 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/normalizer/delexicalize.py @@ -0,0 +1,283 @@ +import re +import os +# import simplejson as json +import json + + +digitpat = re.compile('\d+') +timepat = re.compile("\d{1,2}[:]\d{1,2}") +pricepat = re.compile("\d{1,3}[.]\d{1,2}") + +CUR_PATH = os.path.join(os.path.dirname(__file__)) +fin = open(os.path.join(CUR_PATH, 'mapping.pair'), 'r') +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + +# FORMAT +# domain_value +# restaurant_postcode +# restaurant_address +# taxi_car8 +# taxi_number +# train_id etc.. + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + + +def normalize(text): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # replace time and and price + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [value_price] ', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\":\<>@\(\)]', '', text) + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + + +def prepareSlotValuesIndependent(): + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + dic = [] + dic_area = [] + dic_food = [] + dic_price = [] + + # read databases + for domain in domains: + try: + fin = file(os.path.join(CUR_PATH.replace('latent_dialog/normalizer', ''), 'data/norm-multi-woz/' + domain + '_db.json')) + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if val == '?' or val == 'free': + pass + elif key == 'address': + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + if "road" in val: + val = val.replace("road", "rd") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "rd" in val: + val = val.replace("rd", "road") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "st" in val: + val = val.replace("st", "street") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "street" in val: + val = val.replace("street", "st") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif key == 'name': + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + if "b & b" in val: + val = val.replace("b & b", "bed and breakfast") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "bed and breakfast" in val: + val = val.replace("bed and breakfast", "b & b") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "hotel" in val and 'gonville' not in val: + val = val.replace("hotel", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "restaurant" in val: + val = val.replace("restaurant", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif key == 'postcode': + dic.append((normalize(val), '[' + domain + '_' + 'postcode' + ']')) + elif key == 'phone': + dic.append((val, '[' + domain + '_' + 'phone' + ']')) + elif key == 'trainID': + dic.append((normalize(val), '[' + domain + '_' + 'id' + ']')) + elif key == 'department': + dic.append((normalize(val), '[' + domain + '_' + 'department' + ']')) + + # NORMAL DELEX + elif key == 'area': + dic_area.append((normalize(val), '[' + 'value' + '_' + 'area' + ']')) + elif key == 'food': + dic_food.append((normalize(val), '[' + 'value' + '_' + 'food' + ']')) + elif key == 'pricerange': + dic_price.append((normalize(val), '[' + 'value' + '_' + 'pricerange' + ']')) + else: + pass + # TODO car type? + except: + pass + + if domain == 'hospital': + dic.append((normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('Hills Road'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Addenbrookes Hospital'), '[' + domain + '_' + 'name' + ']')) + + elif domain == 'police': + dic.append((normalize('Parkside'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Parkside Police Station'), '[' + domain + '_' + 'name' + ']')) + + # add at the end places from trains + fin = open(os.path.join(CUR_PATH.replace('latent_dialog/normalizer', ''), 'data/norm-multi-woz/' + 'train' + '_db.json')) + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if key == 'departure' or key == 'destination': + dic.append((normalize(val), '[' + 'value' + '_' + 'place' + ']')) + + # add specific values: + for key in ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']: + dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']')) + + # more general values add at the end + dic.extend(dic_area) + dic.extend(dic_food) + dic.extend(dic_price) + + return dic + + +def delexicalise(utt, dictionary): + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + return utt + + +def delexicaliseReferenceNumber(sent, metadata): + """Based on the belief state, we can find reference number that + during data gathering was created randomly.""" + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'] # , 'police'] + if metadata: + for domain in domains: + if metadata[domain]['book']['booked']: + for slot in metadata[domain]['book']['booked'][0]: + if slot == 'reference': + val = '[' + domain + '_' + slot + ']' + else: + val = '[' + domain + '_' + slot + ']' + key = normalize(metadata[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with hashtag + key = normalize("#" + metadata[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with ref# + key = normalize("ref#" + metadata[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + return sent + + +def delexicalse_num(sent): + # changes to numbers only here + digitpat = re.compile('\d+') + sent = re.sub(digitpat, '[value_count]', sent) + return sent + + +def e2e_delecalise(utt, dictionary, metadata): + utt = normalize(utt) + utt = delexicalise(utt, dictionary) + utt = delexicaliseReferenceNumber(utt, metadata) + return delexicalse_num(utt) + + +if __name__ == '__main__': + prepareSlotValuesIndependent() diff --git a/convlab/policy/lava/multiwoz/latent_dialog/normalizer/mapping.pair b/convlab/policy/lava/multiwoz/latent_dialog/normalizer/mapping.pair new file mode 100644 index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/normalizer/mapping.pair @@ -0,0 +1,83 @@ +it's it is +don't do not +doesn't does not +didn't did not +you'd you would +you're you are +you'll you will +i'm i am +they're they are +that's that is +what's what is +couldn't could not +i've i have +we've we have +can't cannot +i'd i would +i'd i would +aren't are not +isn't is not +wasn't was not +weren't were not +won't will not +there's there is +there're there are +. . . +restaurants restaurant -s +hotels hotel -s +laptops laptop -s +cheaper cheap -er +dinners dinner -s +lunches lunch -s +breakfasts breakfast -s +expensively expensive -ly +moderately moderate -ly +cheaply cheap -ly +prices price -s +places place -s +venues venue -s +ranges range -s +meals meal -s +locations location -s +areas area -s +policies policy -s +children child -s +kids kid -s +kidfriendly kid friendly +cards card -s +upmarket expensive +inpricey cheap +inches inch -s +uses use -s +dimensions dimension -s +driverange drive range +includes include -s +computers computer -s +machines machine -s +families family -s +ratings rating -s +constraints constraint -s +pricerange price range +batteryrating battery rating +requirements requirement -s +drives drive -s +specifications specification -s +weightrange weight range +harddrive hard drive +batterylife battery life +businesses business -s +hours hour -s +one 1 +two 2 +three 3 +four 4 +five 5 +six 6 +seven 7 +eight 8 +nine 9 +ten 10 +eleven 11 +twelve 12 +anywhere any where +good bye goodbye diff --git a/convlab/policy/lava/multiwoz/latent_dialog/offlinerl_utils.py b/convlab/policy/lava/multiwoz/latent_dialog/offlinerl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85f76eb767b518063751db6d9d1d8ac8e2bbb367 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/offlinerl_utils.py @@ -0,0 +1,845 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# +# Copyright © 2021 lubis <lubis@hilbert50> +# +# Distributed under terms of the MIT license. + +""" + +""" + +import torch as th +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +import pdb +import random +import copy +from collections import namedtuple, deque +from torch.autograd import Variable +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.encoders import RnnUttEncoder +from convlab.policy.lava.multiwoz.latent_dialog.utils import get_detokenize, cast_type, extract_short_ctx, np2var, LONG, FLOAT +from convlab.policy.lava.multiwoz.latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS +import dill + +class Actor(nn.Module): + def __init__(self, model, corpus, config): + super(Actor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.is_stochastic = config.is_stochastic + self.y_size = config.y_size + if 'k_size' in config: + self.k_size = config.k_size + self.is_gauss = False + else: + self.max_action = config.max_action if "max_action" in config else None + self.is_gauss = True + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + self.c2z = copy.deepcopy(model.c2z) + if not self.is_gauss: + self.gumbel_connector = copy.deepcopy(model.gumbel_connector) + else: + self.gauss_connector = copy.deepcopy(model.gauss_connector) + self.gaussian_logprob = model.gaussian_logprob + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + # self.l1 = nn.Linear(self.utt_encoder.output_size, 400) + # self.l2 = nn.Linear(400, 300) + # self.l3 = nn.Linear(300, config.y_size * config.k_size) + + def forward(self, data_feed, hard=False): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + if self.is_gauss: + q_mu, q_logvar = self.c2z(enc_last) + # sample_z = q_mu + if self.is_stochastic: + sample_z = self.gauss_connector(q_mu, q_logvar) + else: + sample_z = q_mu + logprob_sample_z = self.gaussian_logprob(q_mu, q_logvar, sample_z) + # joint_logpz = th.sum(logprob_sample_z, dim=1) + # return self.max_action * torch.tanh(z) + else: + logits_qy, log_qy = self.c2z(enc_last) + qy = F.softmax(logits_qy / 1.0, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_qy, dim=1) # (batch_size, vocab_size, ) + + if self.is_stochastic: + idx = th.multinomial(qy, 1).detach() + soft_z = self.gumbel_connector(logits_qy, hard=False) + else: + idx = th.argmax(th.exp(log_qy), dim=1, keepdim=True) + soft_z = th.exp(log_qy) + sample_z = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_z.scatter_(1, idx, 1.0) + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + + joint_logpz = th.sum(logprob_sample_z, dim=1) + # for i in range(logprob_sample_z.shape[0]): + # print(logprob_sample_z[i]) + # print(joint_logpz[i]) + return joint_logpz, sample_z + +class DeterministicGaussianActor(nn.Module): + def __init__(self, model, corpus, config): + super(DeterministicGaussianActor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.y_size = config.y_size + self.max_action = config.max_action if "max_action" in config else None + self.is_gauss = True + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + + self.policy = copy.deepcopy(model.c2z) + # self.gauss_connector = copy.deepcopy(model.gauss_connector) + + def forward(self, data_feed): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + mu, logvar = self.policy(enc_last) + z = mu + if self.max_action is not None: + z = self.max_action * th.tanh(z) + + return z, mu, logvar + +class StochasticGaussianActor(nn.Module): + def __init__(self, model, corpus, config): + super(StochasticGaussianActor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.y_size = config.y_size + self.max_action = config.max_action if "max_action" in config else None + self.is_gauss = True + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + self.policy = copy.deepcopy(model.c2z) + self.gauss_connector = copy.deepcopy(model.gauss_connector) + + def forward(self, data_feed, n_z=1): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + q_mu, q_logvar = self.policy(enc_last) + if n_z > 1: + z = [self.gauss_connector(q_mu, q_logvar) for _ in range(n_z)] + else: + z = self.gauss_connector(q_mu, q_logvar) + + return z, q_mu, q_logvar + +class RecurrentCritic(nn.Module): + def __init__(self,cvae, corpus, config, args): + super(RecurrentLatentCritic, self).__init__() + + # self.vocab = corpus.vocab + # self.vocab_dict = corpus.vocab_dict + # self.vocab_size = len(self.vocab) + # self.bos_id = self.vocab_dict[BOS] + # self.eos_id = self.vocab_dict[EOS] + # self.pad_id = self.vocab_dict[PAD] + + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + self.action_dim = cvae.aux_encoder.output_size + else: + self.action_dim = config.y_size #TODO adjust for categorical + # if "k_size" in config: + # if args.embed_z_for_critic: + # self.action_dim = config.dec_cell_size # for categorical, the action can be embedded + # else: + # self.action_dim *= config.k_size + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + # self.input_dim = self.state_dim + 50 + self.action_dim + self.goal_to_critic = args.goal_to_critic + if self.goal_to_critic: + raise NotImplementedError + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + else: + self.action_encoder = None + + # self.q11 = nn.Linear(self.state_dim + self.action_dim + 50, 500) + self.q11 = nn.Linear(self.input_dim, 500) + self.q12 = nn.Linear(500, 300) + self.q13 = nn.Linear(300, 100) + self.q14 = nn.Linear(100, 20) + self.q15 = nn.Linear(20, 1) + + self.q21 = nn.Linear(self.input_dim, 500) + self.q22 = nn.Linear(500, 300) + self.q23 = nn.Linear(300, 100) + self.q24 = nn.Linear(100, 20) + self.q25 = nn.Linear(20, 1) + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + q1 = self.q11(sa) + #- + # q1 = F.relu(self.q12(th.cat([q1, metadata_summary], dim=1))) + q1 = F.relu(self.q12(q1)) + # q1 = self.q12(q1) + #- + # q1 = th.sigmoid(self.q13(q1)) + q1 = F.relu(self.q13(q1)) + # q1 = F.softmax(self.q13(q1)) + #- + # q1 = th.sigmoid(self.q14(q1)) + q1 = F.relu(self.q14(q1)) + # q1 = self.q14(q1) + #- + q1 = self.q15(q1) + + + q2 = self.q21(sa) + #- + # q2 = F.relu(self.lq22(th.cat([q2, metadata_summary], dim=1))) + q2 = F.relu(self.q22(q2)) + # q2 = self.q22(q2) + #- + # q2 = th.sigmoid(self.q23(q2)) + q2 = F.relu(self.q23(q2)) + # q2 = F.softmax(self.q23(q2)) + #- + # q2 = th.sigmoid(self.q24(q2)) + q2 = F.relu(self.q24(q2)) + # q2 = self.q24(q2) + #- + q2 = self.q25(q2) + + return q1, q2 + + def q1(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(0)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + q1 = self.q11(sa) + #- + # q1 = F.relu(self.q12(th.cat([q1, metadata_summary], dim=1))) + q1 = F.relu(self.q12(q1)) + # q1 = self.q12(q1) + #- + # q1 = th.sigmoid(self.q13(q1)) + q1 = F.relu(self.q13(q1)) + # q1 = F.softmax(self.q13(q1)) + #- + # q1 = th.sigmoid(self.q14(q1)) + q1 = F.relu(self.q14(q1)) + # q1 = self.q14(q1) + #- + q1 = self.q15(q1) + + return q1 + +class SingleRecurrentCritic(nn.Module): + def __init__(self, cvae, corpus, config, args): + super(SingleRecurrentCritic, self).__init__() + + # self.vocab = corpus.vocab + # self.vocab_dict = corpus.vocab_dict + # self.vocab_size = len(self.vocab) + # self.bos_id = self.vocab_dict[BOS] + # self.eos_id = self.vocab_dict[EOS] + # self.pad_id = self.vocab_dict[PAD] + + if "gauss" in args.sv_config_path: + self.is_gauss = True + else: + self.is_gauss = False + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + self.action_dim = cvae.aux_encoder.output_size + else: + if self.is_gauss: + self.action_dim = config.y_size + else: + if args.embed_z_for_critic: + self.action_dim = config.dec_cell_size + else: + self.action_dim = config.y_size * config.k_size + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + + self.goal_to_critic = args.goal_to_critic + if self.goal_to_critic: + self.goal_size = corpus.goal_size + self.input_dim += self.goal_size + + # self.input_dim = self.state_dim + 50 + self.action_dim + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + else: + self.action_encoder = None + + # if self.goal_to_critic: + # self.q11 = nn.Linear(self.input_dim, 500) + # self.q12 = nn.Linear(500, 1) + # else: + self.q11 = nn.Linear(self.input_dim, 1) + self.activation_function = args.critic_actf if "critic_actf" in args else "none" + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False) + else: + self.d = th.nn.Dropout(p=0.0, inplace=False) + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(len(ctx_summary)): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + sa = th.cat([sa, goals], dim = 1) + + # metadata_summary = self.metadata_encoder(th.cat([bs_label, db_label], dim=1)) + # sa = th.cat([ctx_summary.squeeze(1), metadata_summary, act], dim=1) + # if self.is_gauss: + # q1 = F.relu(self.q11(self.d(sa))) + # else: + # q1 = F.sigmoid(self.q11(self.d(sa))) + q1 = self.q11(self.d(sa)) + # if self.goal_to_critic: + # q1 = self.q12(q1) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = F.sigmoid(q1) + + return q1 + +class SingleHierarchicalRecurrentCritic(nn.Module): + def __init__(self, cvae, corpus, config, args): + super(SingleHierarchicalRecurrentCritic, self).__init__() + + # self.vocab = corpus.vocab + # self.vocab_dict = corpus.vocab_dict + # self.vocab_size = len(self.vocab) + # self.bos_id = self.vocab_dict[BOS] + # self.eos_id = self.vocab_dict[EOS] + # self.pad_id = self.vocab_dict[PAD] + + self.hidden_size = 500 + + if "gauss" in args.sv_config_path: + self.is_gauss = True + else: + self.is_gauss = False + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + self.action_dim = cvae.aux_encoder.output_size + else: + if self.is_gauss: + self.action_dim = config.y_size + else: + if args.embed_z_for_critic: + self.action_dim = config.dec_cell_size + else: + self.action_dim = config.y_size * config.k_size + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + # self.input_dim = self.state_dim + 50 + self.action_dim + + self.goal_to_critic = args.goal_to_critic + self.add_goal = args.add_goal + if self.goal_to_critic: + self.goal_size = corpus.goal_size + if self.add_goal == "early": + self.input_dim += self.goal_size + + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + else: + self.action_encoder = None + + self.dialogue_encoder = nn.LSTM( + input_size = self.input_dim, + hidden_size = self.hidden_size, + dropout=0.1 + ) + + if self.add_goal=="late": + self.q11 = nn.Linear(self.hidden_size + self.goal_size, 1) + else: + self.q11 = nn.Linear(self.hidden_size, 1) + self.activation_function = args.critic_actf if "critic_actf" in args else "none" + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False) + else: + self.d = th.nn.Dropout(p=0.0, inplace=False) + + if args.critic_actf == "tanh" or args.critic_actf == "sigmoid": + self.maxq = args.critic_maxq + else: + self.maxq = None + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(len(ctx_summary)): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + + #OPTION 1 add goal to encoder for each time step + if self.goal_to_critic and self.add_goal=="early": + sa = th.cat([sa, goals], dim = 1) + + output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1))) + + #OPTION 2 add goal combined with hidden state to predict final score + if self.goal_to_critic and self.add_goal=="late": + output = th.cat([output, goals.unsqueeze(1)], dim = 2) + + q1 = self.q11(output.squeeze(1)) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = th.sigmoid(q1) + elif self.activation_function == "tanh": + q1 = F.tanh(q1) + + if self.maxq is not None: + q1 *= self.maxq + + return q1 + + def forward_target(self, data_feed, act, corpus_act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + q1s =[] + for i in range(bs_label.shape[0]): + if self.word_plas: + corpus_resp_summary, _, _ = self.action_encoder(corpus_act[:-i].unsqueeze(1)) + actor_resp_summary, _, _ = self.action_encoder(act[i].unsqueeze(1)) + sa = th.cat([ctx_summary[:i+1].squeeze(1), bs_label[:i+1], db_label[:i+1], th.cat([corpus_resp_summary[:i], actor_resp_summary[i]], dim=0).squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary[:i+1].squeeze(1), bs_label[:i+1], db_label[:i+1], th.cat([corpus_act[:i], act[i].unsqueeze(0)], dim=0)], dim=1) + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'][:i+1], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(i+1): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + + #OPTION 1 add goal to encoder for each time step + if self.goal_to_critic and self.add_goal=="early": + sa = th.cat([sa, goals], dim = 1) + + output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1))) + + #OPTION 2 add goal combined with hidden state to predict final score + if self.goal_to_critic and self.add_goal=="late": + output = th.cat([output, goals.unsqueeze(1)], dim = 2) + + q1 = self.q11(output.squeeze(1)) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = F.sigmoid(q1) + elif self.activation_function == "tanh": + q1 = F.tanh(q1) * self.maxq + + q1s.append(q1[-1]) + + return th.cat(q1s, dim=0).unsqueeze(1) + +class SingleTransformersCritic(nn.Module): + def __init__(self, cvae, corpus, config, args): + super(SingleTransformersCritic, self).__init__() + + # self.vocab = corpus.vocab + # self.vocab_dict = corpus.vocab_dict + # self.vocab_size = len(self.vocab) + # self.bos_id = self.vocab_dict[BOS] + # self.eos_id = self.vocab_dict[EOS] + # self.pad_id = self.vocab_dict[PAD] + + self.hidden_size = 128 + + if "gauss" in args.sv_config_path: + self.is_gauss = True + else: + self.is_gauss = False + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + self.action_dim = cvae.aux_encoder.output_size + else: + if self.is_gauss: + self.action_dim = config.y_size + else: + if args.embed_z_for_critic: + self.action_dim = config.dec_cell_size + else: + self.action_dim = config.y_size * config.k_size + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + self.db_embedding = nn.Linear(self.db_size, config.embed_size) + self.bs_embedding = nn.Linear(self.bs_size, config.embed_size) + + self.goal_to_critic = args.goal_to_critic + if self.goal_to_critic: + raise NotImplementedError + + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + else: + self.action_encoder = None + + self.trans_encoder_layer = nn.TransformerEncoderLayer(nhead=8, d_model=config.embed_size) + self.trans_encoder = nn.TransformerEncoder(self.trans_encoder_layer, num_layers=4) + + self.dialogue_encoder = nn.LSTM( + input_size = config.embed_size, + hidden_size = self.hidden_size, + dropout=0.1 + ) + if not self.word_plas: + self.act_embedding = nn.Linear(self.action_dim, config.embed_size) + self.bs_encoder = nn.Linear(self.db_size, config.embed_size) + self.db_encoder = nn.Linear(self.db_size, config.embed_size) + + self.q11 = nn.Linear(self.hidden_size, 1) + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False) + else: + self.d = th.nn.Dropout(p=0.0, inplace=False) + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, word_emb, enc_outs = self.state_encoder(ctx.unsqueeze(1)) + # word_emb : (batch_size, max_len, 256) + # enc_outs : (batch_size, max_len, 600) + metadata_embedding = th.cat([self.bs_embedding(bs_label).unsqueeze(1), self.db_embedding(db_label).unsqueeze(1)], dim=1) + + if self.word_plas: + resp_summary, resp_word_emb, resp_enc_outs = self.action_encoder(act.unsqueeze(1)) + act_embedding = resp_word_emb + else: + act_embedding = self.act_embedding(act).unsqueeze(1) + + sa = th.cat([word_emb, metadata_embedding, act_embedding], dim=1) + sa = self.trans_encoder(self.d(sa)) + output, (hn, cn) = self.dialogue_encoder(self.d(sa)) + q1 = F.sigmoid(self.q11(output[:, -1].squeeze(1))) + # q1 = self.q11(q1[:, 0]) + + + return q1 + + +class CatActor(nn.Module): + def __init__(self, model, corpus, config): + super(CatActor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.y_size = config.y_size + self.k_size = config.k_size + # self.max_action = config.max_action + self.is_gauss = False + self.is_stochastic = config.is_stochastic + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + + self.policy = copy.deepcopy(model.c2z) + if self.is_stochastic: + self.gumbel_connector = copy.deepcopy(model.gumbel_connector) + + def forward(self, data_feed): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + + logits_qy, log_qy = self.policy(enc_last) + if self.is_stochastic: + z = self.gumbel_connector(logits_qy, hard=True) + soft_z = self.gumbel_connector(logits_qy, hard=False) + else: + z_idx = th.argmax(th.exp(log_qy), dim=1, keepdim=True) + z = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + z.scatter_(1, z_idx, 1.0) + soft_z = th.exp(log_qy) + + return z, soft_z, log_qy + + + +class ReplayBuffer(object): + """ + Buffer to store experiences, to be used in off-policy learning + """ + def __init__(self, config): + # true_responses = id2sent(next_state) + # pred_responses = model.z2x(action) + + self.batch_size = config.batch_size + self.fix_episode = config.fix_episode + + self.experiences = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "next_action", "done", "Return"]) + self.memory = deque() + self.seed = random.seed(config.random_seed) + # self.reinforce_data = config.reinforce_data + + def add(self, states, actions, rewards, next_states, next_actions, dones, Returns): + if self.fix_episode: + self._add_episode(states, actions, rewards, next_states, next_actions, dones, Returns) + else: + for i in range(len(states)): + self._add(states[i], actions[i], rewards[i], next_states[i], next_actions[i], dones[i], Returns[i]) + + def _add(self, state, action, reward, next_state, next_action, done, Return): + e = self.experiences(state, action, reward, next_state, next_action, done, Return) + self.memory.append(e) + + def _add_episode(self, states, actions, rewards, next_states, next_actions, dones, Returns): + ep = [] + for s, a, r, s_, a_, d, R in zip(states, actions, rewards, next_states, next_actions, dones, Returns): + ep.append(self.experiences(s, a, r, s_, a_, d, R)) + self.memory.append(ep) + + + def sample(self): + if self.fix_episode: + return self._sample_episode() + else: + return self._sample() + + def _sample(self): + experiences = random.sample(self.memory, k = self.batch_size) + + states = {} + states['contexts'] = np.asarray([e.state['contexts'] for e in experiences]) + states['bs'] = np.asarray([e.state['bs'] for e in experiences]) + states['db'] = np.asarray([e.state['db'] for e in experiences]) + states['context_lens'] = np.asarray([e.state['context_lens'] for e in experiences]) + states['goals'] = np.asarray([e.state['goals'] for e in experiences]) + + actions = np.asarray([e.action for e in experiences if e is not None]) + rewards = np.asarray([e.reward for e in experiences if e is not None]) + + next_states = {} + next_states['contexts'] = np.asarray([e.next_state['contexts'] for e in experiences]) + next_states['bs'] = np.asarray([e.next_state['bs'] for e in experiences]) + next_states['db'] = np.asarray([e.next_state['db'] for e in experiences]) + next_states['context_lens'] = np.asarray([e.next_state['context_lens'] for e in experiences]) + next_states['goals'] = np.asarray([e.next_state['goals'] for e in experiences]) + + next_actions = np.asarray([e.next_action for e in experiences if e is not None]) + + dones = np.asarray([e.done for e in experiences if e is not None]) + returns = np.asarray([e.Return for e in experiences if e is not None]) + # if self.reinforce_data: + # rewards = dones * 10 + 1 # give positive rewards to all actions taken in the data + + return (states, actions, rewards, next_states, next_actions, dones, returns) + # return experiences + + def _sample_episode(self): + # episodes = random.sample(self.memory, k = self.batch_size) + episodes = random.sample(self.memory, k = 1) + + for experiences in episodes: + states = {} + states['contexts'] = np.asarray([e.state['contexts'] for e in experiences]) + states['bs'] = np.asarray([e.state['bs'] for e in experiences]) + states['db'] = np.asarray([e.state['db'] for e in experiences]) + states['keys'] = [e.state['keys'] for e in experiences] + states['context_lens'] = np.asarray([e.state['context_lens'] for e in experiences]) + states['goals'] = np.asarray([e.state['goals'] for e in experiences]) + + actions = np.asarray([e.action for e in experiences if e is not None]) + rewards = np.asarray([e.reward for e in experiences if e is not None]) + + next_states = {} + next_states['contexts'] = np.asarray([e.next_state['contexts'] for e in experiences]) + next_states['bs'] = np.asarray([e.next_state['bs'] for e in experiences]) + next_states['db'] = np.asarray([e.next_state['db'] for e in experiences]) + next_states['keys'] = [e.next_state['keys'] for e in experiences] + next_states['context_lens'] = np.asarray([e.next_state['context_lens'] for e in experiences]) + next_states['goals'] = np.asarray([e.next_state['goals'] for e in experiences]) + + next_actions = np.asarray([e.next_action for e in experiences if e is not None]) + + dones = np.asarray([e.done for e in experiences if e is not None]) + returns = np.asarray([e.Return for e in experiences if e is not None]) + # if self.reinforce_data: + # rewards = dones * 10 + 1 # give positive rewards to all actions taken in the data + + return (states, actions, rewards, next_states, next_actions, dones, returns) + # return experiences + + def __len__(self): + return len(self.memory) + + def save(self, path): + with open(path, 'wb') as f: + dill.dump(self.memory, f) + + def load(self, path): + with open(path, 'rb') as f: + self.memory = dill.load(f) diff --git a/convlab/policy/lava/multiwoz/latent_dialog/record.py b/convlab/policy/lava/multiwoz/latent_dialog/record.py new file mode 100644 index 0000000000000000000000000000000000000000..88a072b78670f92787b360ecf63e7fa9fe49c62a --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/record.py @@ -0,0 +1,169 @@ +import numpy as np +from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.decoders import TEACH_FORCE, GEN, DecoderRNN, GEN_VALID +from collections import Counter + + +class UniquenessSentMetric(object): + """Metric that evaluates the number of unique sentences.""" + def __init__(self): + self.seen = set() + self.all_sents = [] + + def record(self, sen): + self.seen.add(' '.join(sen)) + self.all_sents.append(' '.join(sen)) + + def value(self): + return len(self.seen) + + def top_n(self, n): + return Counter(self.all_sents).most_common(n) + + +class UniquenessWordMetric(object): + """Metric that evaluates the number of unique sentences.""" + def __init__(self): + self.seen = set() + + def record(self, word_list): + self.seen.update(word_list) + + def value(self): + return len(self.seen) + + +def record_task(n_epsd, model, val_data, config, ppl_f, dialog, ctx_gen_eval, rl_f): + record_ppl(n_epsd, model, val_data, config, ppl_f) + record_rl_task(n_epsd, dialog, ctx_gen_eval, rl_f) + + +def record(n_epsd, model, val_data, sv_config, lm_model, ppl_f, dialog, ctx_gen_eval, rl_f): + record_ppl_with_lm(n_epsd, model, val_data, sv_config, lm_model, ppl_f) + record_rl(n_epsd, dialog, ctx_gen_eval, rl_f) + + +def record_ppl_with_lm(n_epsd, model, data, config, lm_model, ppl_f): + model.eval() + loss_list = [] + data.epoch_init(config, shuffle=False, verbose=True) + while True: + batch = data.next_batch() + if batch is None: + break + for i in range(1): + loss = model(batch, mode=TEACH_FORCE, use_py=True) + loss_list.append(loss.nll.item()) + + # USE LM to test generation performance + data.epoch_init(config, shuffle=False, verbose=False) + gen_loss_list = [] + # first generate + while True: + batch = data.next_batch() + if batch is None: + break + + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + # clean up the pred labels + clean_pred_labels = np.zeros((pred_labels.shape[0], pred_labels.shape[1]+1)) + clean_pred_labels[:, 0] = model.sys_id + for b_id in range(pred_labels.shape[0]): + for t_id in range(pred_labels.shape[1]): + token = pred_labels[b_id, t_id] + clean_pred_labels[b_id, t_id + 1] = token + if token in [model.eos_id] or t_id == pred_labels.shape[1]-1: + break + + pred_out_lens = np.sum(np.sign(clean_pred_labels), axis=1) + max_pred_lens = np.max(pred_out_lens) + clean_pred_labels = clean_pred_labels[:, 0:int(max_pred_lens)] + batch['outputs'] = clean_pred_labels + batch['output_lens'] = pred_out_lens + loss = lm_model(batch, mode=TEACH_FORCE) + gen_loss_list.append(loss.nll.item()) + + avg_loss = np.average(loss_list) + avg_ppl = np.exp(avg_loss) + gen_avg_loss = np.average(gen_loss_list) + gen_avg_ppl = np.exp(gen_avg_loss) + + ppl_f.write('{}\t{}\t{}\n'.format(n_epsd, avg_ppl, gen_avg_ppl)) + ppl_f.flush() + model.train() + + +def record_ppl(n_epsd, model, val_data, config, ppl_f): + model.eval() + loss_list = [] + val_data.epoch_init(config, shuffle=False, verbose=True) + while True: + batch = val_data.next_batch() + if batch is None: + break + loss = model(batch, mode=TEACH_FORCE, use_py=True) + loss_list.append(loss.nll.item()) + aver_loss = np.average(loss_list) + aver_ppl = np.exp(aver_loss) + ppl_f.write('{}\t{}\n'.format(n_epsd, aver_ppl)) + ppl_f.flush() + model.train() + + +def record_rl(n_epsd, dialog, ctx_gen, rl_f): + conv_list = [] + reward_list = [] + agree_list = [] + sent_metric = UniquenessSentMetric() + word_metric = UniquenessWordMetric() + + for ctxs in ctx_gen.ctxs: + conv, agree, rewards = dialog.run(ctxs) + true_reward = rewards[0] if agree else 0 + reward_list.append(true_reward) + conv_list.append(conv) + agree_list.append(float(agree) if agree is not None else 0.0) + for turn in conv: + if turn[0] == 'Elder': + sent_metric.record(turn[1]) + word_metric.record(turn[1]) + + # json.dump(conv_list, text_f, indent=4) + aver_reward = np.average(reward_list) + aver_agree = np.average(agree_list) + unique_sent_num = sent_metric.value() + unique_word_num = word_metric.value() + print(sent_metric.top_n(10)) + + rl_f.write('{}\t{}\t{}\t{}\t{}\n'.format(n_epsd, aver_reward, aver_agree, unique_sent_num, unique_word_num)) + rl_f.flush() + + +def record_rl_task(n_epsd, dialog, goal_gen, rl_f): + conv_list = [] + reward_list = [] + sent_metric = UniquenessSentMetric() + word_metric = UniquenessWordMetric() + print("Begin RL testing") + cnt = 0 + for g_key, goal in goal_gen.iter(1): + cnt += 1 + conv, success = dialog.run(g_key, goal) + true_reward = success + reward_list.append(true_reward) + conv_list.append(conv) + for turn in conv: + if turn[0] == 'Elder': + sent_metric.record(turn[1]) + word_metric.record(turn[1]) + + # json.dump(conv_list, text_f, indent=4) + aver_reward = np.average(reward_list) + unique_sent_num = sent_metric.value() + unique_word_num = word_metric.value() + rl_f.write('{}\t{}\t{}\t{}\n'.format(n_epsd, aver_reward, unique_sent_num, unique_word_num)) + rl_f.flush() + print("End RL testing") diff --git a/convlab/policy/lava/multiwoz/latent_dialog/utils.py b/convlab/policy/lava/multiwoz/latent_dialog/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..764f95a1a5ce436781579b89f9fc8d6ce4bcf292 --- /dev/null +++ b/convlab/policy/lava/multiwoz/latent_dialog/utils.py @@ -0,0 +1,132 @@ +import os +import numpy as np +import random +import torch as th +from nltk import RegexpTokenizer +from torch.autograd import Variable +from nltk.tokenize.treebank import TreebankWordDetokenizer +import logging +import sys +from collections import defaultdict + +INT = 0 +LONG = 1 +FLOAT = 2 + + +class Pack(dict): + def __getattr__(self, name): + return self[name] + + def add(self, **kwargs): + for k, v in kwargs.items(): + self[k] = v + + def copy(self): + pack = Pack() + for k, v in self.items(): + if type(v) is list: + pack[k] = list(v) + else: + pack[k] = v + return pack + + @staticmethod + def msg_from_dict(dictionary, tokenize, speaker2id, bos_id, eos_id, include_domain=False): + pack = Pack() + for k, v in dictionary.items(): + pack[k] = v + pack['speaker'] = speaker2id[pack.speaker] + pack['conf'] = dictionary.get('conf', 1.0) + utt = pack['utt'] + if 'QUERY' in utt or "RET" in utt: + utt = str(utt) + # utt = utt.translate(None, ''.join([':', '"', "{", "}", "]", "["])) + utt = utt.translate(str.maketrans('', '', ''.join([':', '"', "{", "}", "]", "["]))) + utt = str(utt) + if include_domain: + pack['utt'] = [bos_id, pack['speaker'], pack['domain']] + tokenize(utt) + [eos_id] + else: + pack['utt'] = [bos_id, pack['speaker']] + tokenize(utt) + [eos_id] + return pack + +def get_tokenize(): + return RegexpTokenizer(r'\w+|#\w+|<\w+>|%\w+|[^\w\s]+').tokenize + +def get_detokenize(): + return lambda x: TreebankWordDetokenizer().detokenize(x) + +def cast_type(var, dtype, use_gpu): + if use_gpu: + if dtype == INT: + var = var.type(th.cuda.IntTensor) + elif dtype == LONG: + var = var.type(th.cuda.LongTensor) + elif dtype == FLOAT: + var = var.type(th.cuda.FloatTensor) + else: + raise ValueError('Unknown dtype') + else: + if dtype == INT: + var = var.type(th.IntTensor) + elif dtype == LONG: + var = var.type(th.LongTensor) + elif dtype == FLOAT: + var = var.type(th.FloatTensor) + else: + raise ValueError('Unknown dtype') + return var + +def read_lines(file_name): + """Reads all the lines from the file.""" + assert os.path.exists(file_name), 'file does not exists %s' % file_name + lines = [] + with open(file_name, 'r') as f: + for line in f: + lines.append(line.strip()) + return lines + +def set_seed(seed): + """Sets random seed everywhere.""" + th.manual_seed(seed) + if th.cuda.is_available(): + th.cuda.manual_seed(seed) + np.random.seed(seed) + +def prepare_dirs_loggers(config, script=""): + logFormatter = logging.Formatter("%(message)s") + rootLogger = logging.getLogger() + rootLogger.setLevel(logging.DEBUG) + + consoleHandler = logging.StreamHandler(sys.stdout) + consoleHandler.setLevel(logging.DEBUG) + consoleHandler.setFormatter(logFormatter) + #rootLogger.addHandler(consoleHandler) + + if hasattr(config, 'forward_only') and config.forward_only: + return + + fileHandler = logging.FileHandler(os.path.join(config.saved_path,'session.log')) + fileHandler.setLevel(logging.DEBUG) + fileHandler.setFormatter(logFormatter) + rootLogger.addHandler(fileHandler) + +def get_chat_tokenize(): + return nltk.RegexpTokenizer(r'\w+|<sil>|[^\w\s]+').tokenize + +class missingdict(defaultdict): + def __missing__(self, key): + return self.default_factory() + +def extract_short_ctx(context, context_lens, backward_size=1): + utts = [] + for b_id in range(context.shape[0]): + utts.append(context[b_id, context_lens[b_id]-1]) + return np.array(utts) + +def np2var(inputs, dtype, use_gpu): + if inputs is None: + return None + return cast_type(Variable(th.from_numpy(inputs)), + dtype, + use_gpu) diff --git a/convlab/policy/lava/multiwoz/lava.py b/convlab/policy/lava/multiwoz/lava.py index dad3f6b58a262f7ff6d6aa28c6b8842a8c14fd34..76d177396c4d9a9d4c7e20f18ce652f6f8852ad5 100755 --- a/convlab/policy/lava/multiwoz/lava.py +++ b/convlab/policy/lava/multiwoz/lava.py @@ -10,7 +10,8 @@ from convlab.policy.lava.multiwoz.latent_dialog.models_task import * from convlab.policy import Policy from convlab.util.file_util import cached_path from convlab.util.multiwoz.state import default_state -from convlab.util.multiwoz.dbquery import Database +# from convlab.util.multiwoz.dbquery import Database +from data.unified_datasets.multiwoz21.database import Database from copy import deepcopy import json import os @@ -156,7 +157,7 @@ def get_relevant_domains(state): for domain in state.keys(): # print("--", domain, "--") - for slot, value in state[domain]['semi'].items(): + for slot, value in state[domain].items(): if len(value) > 0: # print(slot, value) domains.append(domain) @@ -174,7 +175,8 @@ def addDBPointer(state, db): num_entities = {} for domain in domains: # entities = dbPointer.queryResultVenues(domain, {'metadata': state}) - entities = db.query(domain, state[domain]['semi'].items()) + constraints = [[slot, value] for slot, value in state[domain].items() if value] if domain in state else [] + entities = db.query(domain, constraints, topk=10) num_entities[domain] = len(entities) if len(entities) > 0: # fields = dbPointer.table_schema(domain) @@ -233,36 +235,37 @@ def delexicaliseReferenceNumber(sent, state): during data gathering was created randomly.""" domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'] # , 'police'] - for domain in domains: - if state[domain]['book']['booked']: - for slot in state[domain]['book']['booked'][0]: - if slot == 'reference': - val = '[' + domain + '_' + slot + ']' - else: + + if state['history'][-1][0]=="sys": + # print(state["booked"]) + for domain in domains: + if state['booked'][domain]: + for slot in state['booked'][domain][0]: val = '[' + domain + '_' + slot + ']' - key = normalize(state[domain]['book']['booked'][0][slot]) - sent = (' ' + sent + ' ').replace(' ' + - key + ' ', ' ' + val + ' ') - - # try reference with hashtag - key = normalize("#" + state[domain]['book']['booked'][0][slot]) - sent = (' ' + sent + ' ').replace(' ' + - key + ' ', ' ' + val + ' ') - - # try reference with ref# - key = normalize( - "ref#" + state[domain]['book']['booked'][0][slot]) - sent = (' ' + sent + ' ').replace(' ' + - key + ' ', ' ' + val + ' ') + key = normalize(state['booked'][domain][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + + key + ' ', ' ' + val + ' ') + + # try reference with hashtag + key = normalize("#" + state['booked'][domain][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + + key + ' ', ' ' + val + ' ') + + # try reference with ref# + key = normalize( + "ref#" + state['booked'][domain][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + + key + ' ', ' ' + val + ' ') + return sent def domain_mark_not_mentioned(state, active_domain): - if active_domain not in ['police', 'hospital', 'taxi', 'train', 'attraction', 'restaurant', 'hotel'] or active_domain is None: + if active_domain not in ['hospital', 'taxi', 'train', 'attraction', 'restaurant', 'hotel'] or active_domain is None: return - for s in state[active_domain]['semi']: - if state[active_domain]['semi'][s] == '': - state[active_domain]['semi'][s] = 'not mentioned' + for s in state[active_domain]: + if state[active_domain][s] == '': + state[active_domain][s] = 'not mentioned' def mark_not_mentioned(state): for domain in state: @@ -274,9 +277,9 @@ def mark_not_mentioned(state): # for s in state[domain]['semi']: # if s != 'book' and state[domain]['semi'][s] == '': # state[domain]['semi'][s] = 'not mentioned' - for s in state[domain]['semi']: - if state[domain]['semi'][s] == '': - state[domain]['semi'][s] = 'not mentioned' + for s in state[domain]: + if state[domain][s] == '': + state[domain][s] = 'not mentioned' except Exception as e: # print(str(e)) # pprint(state[domain]) @@ -331,22 +334,96 @@ def get_summary_bstate(bstate): assert len(summary_bstate) == 94 return summary_bstate +def get_summary_bstate_unifiedformat(state): + """Based on the mturk annotations we form multi-domain belief state""" + domains = [u'taxi', u'restaurant', u'hospital', + u'hotel', u'attraction', u'train']#, u'police'] + bstate = state['belief_state'] + # booked = state['booked'] + # how to make empty book this format instead of an empty dictionary? + #TODO fix booked info update in state! + booked = { + "taxi": [], + "hotel": [], + "restaurant": [], + "train": [], + "attraction": [], + "hospital": [] + } + + summary_bstate = [] + + for domain in domains: + domain_active = False + + booking = [] + if len(booked[domain]) > 0: + booking.append(1) + else: + booking.append(0) + if domain == 'train': + if not bstate[domain]['book people']: + booking.append(0) + else: + booking.append(1) + if booked[domain] and 'ticket' in booked[domain][0].keys(): + booking.append(1) + else: + booking.append(0) + summary_bstate += booking + + if domain == "restaurant": + book_slots = ['book day', 'book people', 'book time'] + elif domain == "hotel": + book_slots = ['book day', 'book people', 'book stay'] + else: + book_slots = [] + for slot in book_slots: + if bstate[domain][slot] == '': + summary_bstate.append(0) + else: + summary_bstate.append(1) + + for slot in [s for s in bstate[domain] if "book" not in s]: + slot_enc = [0, 0, 0] + if bstate[domain][slot] == 'not mentioned': + slot_enc[0] = 1 + elif bstate[domain][slot] == 'dont care' or bstate[domain][slot] == 'dontcare' or bstate[domain][slot] == "don't care": + slot_enc[1] = 1 + elif bstate[domain][slot]: + slot_enc[2] = 1 + if slot_enc != [0, 0, 0]: + domain_active = True + summary_bstate += slot_enc + + # quasi domain-tracker + if domain_active: # 7 domains + summary_bstate += [1] + else: + summary_bstate += [0] + + + # add manually from action as police is not tracked anymore in unified format + if "Police" in [act[1] for act in state['user_action']]: + summary_bstate += [0, 1] + else: + summary_bstate += [0, 0] + + assert len(summary_bstate) == 94 + return summary_bstate + DEFAULT_CUDA_DEVICE = -1 class LAVA(Policy): def __init__(self, - model_file="/gpfs/project/lubis/public_code/LAVA/experiments_woz/sys_config_log_model/2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48/reward_best.model", is_train=False): + model_file="", is_train=False): if not model_file: raise Exception("No model for LAVA is specified!") temp_path = os.path.dirname(os.path.abspath(__file__)) - # print(temp_path) - #zip_ref = zipfile.ZipFile(archive_file, 'r') - # zip_ref.extractall(temp_path) - # zip_ref.close() self.prev_state = default_state() self.prev_active_domain = None @@ -354,24 +431,7 @@ class LAVA(Policy): domain_name = 'object_division' domain_info = domain.get_domain(domain_name) self.db=Database() - # data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/data_2.1/') - # train_data_path = os.path.join(data_path,'train_dials.json') - # if not os.path.exists(train_data_path): - # zipped_file = os.path.join(data_path, 'norm-multi-woz.zip') - # archive = zipfile.ZipFile(zipped_file, 'r') - # archive.extractall(data_path) - - # norm_multiwoz_path = data_path - # with open(os.path.join(norm_multiwoz_path, 'input_lang.index2word.json')) as f: - # self.input_lang_index2word = json.load(f) - # with open(os.path.join(norm_multiwoz_path, 'input_lang.word2index.json')) as f: - # self.input_lang_word2index = json.load(f) - # with open(os.path.join(norm_multiwoz_path, 'output_lang.index2word.json')) as f: - # self.output_lang_index2word = json.load(f) - # with open(os.path.join(norm_multiwoz_path, 'output_lang.word2index.json')) as f: - # self.output_lang_word2index = json.load(f) - - + path, _ = os.path.split(model_file) if "rl-" in model_file: rl_config_path = os.path.join(path, "rl_config.json") @@ -386,13 +446,12 @@ class LAVA(Policy): try: self.corpus = corpora_inference.NormMultiWozCorpus(config) except (FileNotFoundError, PermissionError): - train_data_path = "/gpfs/project/lubis/LAVA_code/LAVA_dev/data/norm-multi-woz/train_dials.json" + train_data_path = "/gpfs/project/lubis/NeuralDialog-LaRL/data/norm-multi-woz/train_dials.json" config['train_path'] = train_data_path config['valid_path'] = train_data_path.replace("train", "val") config['test_path'] = train_data_path.replace("train", "test") self.corpus = corpora_inference.NormMultiWozCorpus(config) - if "rl" in model_file: if "gauss" in model_file: self.model = SysPerfectBD2Gauss(self.corpus, config) @@ -429,7 +488,6 @@ class LAVA(Policy): self.rl_config["US_best_reward_model_path"] = model_file.replace( ".model", "_US.model") if "lr_rl" not in config: - #config["lr_rl"] = config["init_lr"] self.config["lr_rl"] = 0.01 self.config["gamma"] = 0.99 @@ -442,7 +500,6 @@ class LAVA(Policy): lr=self.config.lr_rl, momentum=self.config.momentum, nesterov=False) - # nesterov=(self.config.nesterov and self.config.momentum > 0)) if config.use_gpu: self.model.load_state_dict(torch.load(model_file)) @@ -504,39 +561,6 @@ class LAVA(Policy): utts.append(context[b_id, context_lens[b_id]-1]) return np.array(utts) - def get_active_domain_test(self, prev_active_domain, prev_action, action): - domains = ['hotel', 'restaurant', 'attraction', - 'train', 'taxi', 'hospital', 'police'] - active_domain = None - cur_action_keys = action.keys() - state = [] - for act in cur_action_keys: - slots = act.split('-') - action = slots[0].lower() - state.append(action) - - # print('get_active_domain') - # for domain in domains: - """for domain in range(len(domains)): - domain = domains[i] - if domain not in prev_state and domain not in state: - continue - if domain in prev_state and domain not in state: - return domain - elif domain not in prev_state and domain in state: - return domain - elif prev_state[domain] != state[domain]: - active_domain = domain - if active_domain is None: - active_domain = prev_active_domain""" - if len(state) != 0: - active_domain = state[0] - if active_domain is None: - active_domain = prev_active_domain - elif active_domain == "general": - active_domain = prev_active_domain - return active_domain - def is_masked_action(self, bs_label, db_label, response): """ check if the generated response should be masked based on belief state and db result @@ -563,15 +587,20 @@ class LAVA(Policy): # print("MASK: inform no offer mentioning criteria") # return True # always only inform "no match for your criteria" w/o mentioning them explicitly elif any([p in response for p in REQ_TOKENS[domain]]) or "i have [value_count]" in response or "there are [value_count]" in response: # if requestable token is present - # TODO also check for i have [value_count] match, not only the requestable tokens if db_idx >= 0 and int(db_label[db_idx]) == 1: # and domain has a DB to be queried and there are no matches # print("MASK: inform match when there are no DB match on domain {}".format(domain)) return True return False + def is_active(self, domain, state): - def get_active_domain(self, prev_active_domain, prev_state, state): + if domain in [act[1] for act in state['user_action']]: + return True + else: + return False + + def get_active_domain_unified(self, prev_active_domain, prev_state, state): domains = ['hotel', 'restaurant', 'attraction', 'train', 'taxi', 'hospital', 'police'] active_domain = None @@ -580,30 +609,17 @@ class LAVA(Policy): # print("NEW_STATE",state) # print() for domain in domains: - if domain not in prev_state and domain not in state: - continue - if domain in prev_state and domain not in state: + if not self.is_active(domain, prev_state) and self.is_active(domain, state): #print("case 1:",domain) return domain - elif domain not in prev_state and domain in state: - #print("case 2:",domain) + elif self.is_active(domain, prev_state) and self.is_active(domain, state): return domain - elif prev_state[domain] != state[domain]: + # elif self.is_active(domain, prev_state) and not self.is_active(domain, state): + #print("case 2:",domain) + # return domain + # elif prev_state['belief_state'][domain] != state['belief_state'][domain]: #print("case 3:",domain) - active_domain = domain - if active_domain is None: - active_domain = prev_active_domain - return active_domain - - def get_active_domain_new(self, prev_active_domain, prev_state, state): - domains = ['hotel', 'restaurant', 'attraction', - 'train', 'taxi', 'hospital', 'police'] - active_domain = None - i = 0 - for domain in domains: - if prev_state[domain] != state[domain]: - active_domain = domain - i += 1 + # active_domain = domain if active_domain is None: active_domain = prev_active_domain return active_domain @@ -642,7 +658,7 @@ class LAVA(Policy): else: return False - def predict_response(self, state): + def predict_response (self, state): # input state is in convlab format history = [] for i in range(len(state['history'])): @@ -675,8 +691,9 @@ class LAVA(Policy): # mark_not_mentioned(prev_state) #active_domain = self.get_active_domain_convlab(self.prev_active_domain, prev_bstate, bstate) - active_domain = self.get_active_domain(self.prev_active_domain, prev_bstate, bstate) - #print(active_domain) + active_domain = self.get_active_domain_unified(self.prev_active_domain, self.prev_state, state) + # print("---------") + # print("active domain: ", active_domain) # if active_domain is not None: # print(f"DST on {active_domain}: {bstate[active_domain]}") @@ -685,16 +702,23 @@ class LAVA(Policy): top_results, num_results = None, None for t_id in range(len(context)): usr = context[t_id] + # print(usr) if t_id == 0: #system turns if usr == "null": usr = "<d>" + # booked = {"taxi": [], + # "restaurant": [], + # "hospital": [], + # "hotel": [], + # "attraction": [], + # "train": []} words = usr.split() usr = delexicalize.delexicalise(' '.join(words).lower(), self.dic) # parsing reference number GIVEN belief state - usr = delexicaliseReferenceNumber(usr, bstate) + usr = delexicaliseReferenceNumber(usr, state) # changes to numbers only here digitpat = re.compile('(^| )\d+( |$)') @@ -702,11 +726,13 @@ class LAVA(Policy): # add database pointer pointer_vector, top_results, num_results = addDBPointer(bstate,self.db) - #print(top_results) + if state['history'][-1][0] == "sys": + booked = state['booked'] + #print(top_results) # add booking pointer pointer_vector = addBookingPointer(bstate, pointer_vector) - belief_summary = get_summary_bstate(bstate) + belief_summary = get_summary_bstate_unifiedformat(state) usr_utt = [BOS] + usr.split() + [EOS] packed_val = {} @@ -725,15 +751,13 @@ class LAVA(Policy): # data_feed is in LaRL format data_feed = prepare_batch_gen(results, self.config) + # print(belief_summary) - for i in range(10): + for i in range(1): outputs = self.model_predict(data_feed) self.prev_output = outputs mul = False - if self.is_masked_action(data_feed['bs'][0], data_feed['db'][0], outputs) and i < 9: # if it's the last try, accept masked action - continue - # default lexicalization if active_domain is not None and active_domain in num_results: num_results = num_results[active_domain] @@ -748,26 +772,23 @@ class LAVA(Policy): top_results = {active_domain: top_results[active_domain]} else: if active_domain == 'train': #special case, where we want the last match instead of the first - if bstate['train']['semi']['arriveBy'] != "not mentioned" and len(bstate['train']['semi']['arriveBy']) > 0: + if bstate['train']['arrive by'] != "not mentioned" and len(bstate['train']['arrive by']) > 0: top_results = {active_domain: top_results[active_domain][-1]} # closest to arrive by else: top_results = {active_domain: top_results[active_domain][0]} else: - top_results = {active_domain: top_results[active_domain][0]} + top_results = {active_domain: top_results[active_domain][0]} # if active domain is wrong, this becomes the wrong entity else: top_results = {} state_with_history = deepcopy(bstate) state_with_history['history'] = deepcopy(state_history) - if active_domain in ["hotel", "attraction", "train", "restaurant"] and len(top_results.keys()) == 0: # no db match for active domain + if active_domain in ["hotel", "attraction", "train", "restaurant"] and active_domain not in top_results.keys(): # no db match for active domain if any([p in outputs for p in REQ_TOKENS[active_domain]]): - # self.fail_info_penalty += 1 - # response = "I am sorry there are no matches." - # print(outputs) response = "I am sorry, can you say that again?" database_results = {} else: - response = self.populate_template( + response = self.populate_template_unified( outputs, top_results, num_results, state_with_history, active_domain) # print(response) @@ -776,60 +797,22 @@ class LAVA(Policy): response = self.populate_template_options(outputs, top_results, num_results, state_with_history) else: try: - response = self.populate_template( + response = self.populate_template_unified( outputs, top_results, num_results, state_with_history, active_domain) except: - print(outputs) - + print("can not lexicalize: ", outputs) + response = "I am sorry, can you say that again?" - - for domain in DOMAIN_REQ_TOKEN: - """ - mask out of domain action - """ - if domain != active_domain and any([p in outputs for p in REQ_TOKENS[domain]]): - # print(f"MASK: illegal action for {active_domain}: {outputs}") - response = "Can I help you with anything else?" - self.wrong_domain_penalty += 1 - - - # if active_domain is not None: - # print ("===========================") - - # print(active_domain) - # print ("BS: ") - # for k, v in bstate[active_domain].items(): - # print(k, ": ", v) - # print ("DB: ") - # if len(database_results.keys()) > 0: - # for k, v in database_results[active_domain][1][0 % database_results[active_domain][0]].items(): - # print(k, ": ", v) - # print ("===========================") - # print ("input: ") - # for turn in data_feed['contexts'][0]: - # print(" ".join(self.corpus.id2sent(turn))) - # print ("system delex: ", outputs) - # print ("system: ",response) - # print ("===========================\n") - self.num_generated_response += 1 - break - response = response.replace("free pounds", "free") response = response.replace("pounds pounds", "pounds") if any([p in response for p in ["not mentioned", "dontcare", "[", "]"]]): - # response = "I am sorry there are no matches." - # print(usr) - # print(outputs) - # print(response) - # print(active_domain, len(top_results[active_domain]), delex_bstate[active_domain]) - # pdb.set_trace() - # response = "I am sorry can you repeat that?" response = "I am sorry, can you say that again?" return response, active_domain - def populate_template(self, template, top_results, num_results, state, active_domain): + + def populate_template_unified(self, template, top_results, num_results, state, active_domain): # print("template:",template) # print("top_results:",top_results) # active_domain = None if len( @@ -848,7 +831,7 @@ class LAVA(Policy): if domain == 'train' and slot == 'id': slot = 'trainID' elif active_domain != 'train' and slot == 'price': - slot = 'pricerange' + slot = 'price range' elif slot == 'reference': slot = 'Ref' if domain in top_results and len(top_results[domain]) > 0 and slot in top_results[domain]: @@ -864,27 +847,24 @@ class LAVA(Policy): elif active_domain == "restaurant": if "people" in tokens[index:index+1] or "table" in tokens[index-2:index]: response.append( - state[active_domain]["book"]["people"]) + state[active_domain]["book people"]) elif active_domain == "train": if "ticket" in " ".join(tokens[index-2:index+1]) or "people" in tokens[index:]: response.append( - state[active_domain]["book"]["people"]) + state[active_domain]["book people"]) elif index+1 < len(tokens) and "minute" in tokens[index+1]: response.append( top_results['train']['duration'].split()[0]) elif active_domain == "hotel": if index+1 < len(tokens): if "star" in tokens[index+1]: - try: - response.append(top_results['hotel']['stars']) - except: - response.append(state['hotel']['semi']['stars']) + response.append(top_results['hotel']['stars']) elif "nights" in tokens[index+1]: response.append( - state[active_domain]["book"]["stay"]) + state[active_domain]["book stay"]) elif "people" in tokens[index+1]: response.append( - state[active_domain]["book"]["people"]) + state[active_domain]["book people"]) elif active_domain == "attraction": if index + 1 < len(tokens): if "pounds" in tokens[index+1] and "entrance fee" in " ".join(tokens[index-3:index]): @@ -912,14 +892,14 @@ class LAVA(Policy): top_results[active_domain]["destination"]) elif active_domain == "taxi": response.append( - state[active_domain]['semi']["destination"]) + state[active_domain]["destination"]) elif 'leav' in " ".join(tokens[index-2:index]) or "from" in tokens[index-2:index] or "depart" in " ".join(tokens[index-2:index]): if active_domain == "train": response.append( top_results[active_domain]["departure"]) elif active_domain == "taxi": response.append( - state[active_domain]['semi']["departure"]) + state[active_domain]["departure"]) elif "hospital" in template: response.append("Cambridge") else: @@ -928,9 +908,9 @@ class LAVA(Policy): if d == 'history': continue for s in ['destination', 'departure']: - if s in state[d]['semi']: + if s in state[d]: response.append( - state[d]['semi'][s]) + state[d][s]) raise except: pass @@ -938,7 +918,7 @@ class LAVA(Policy): response.append(token) elif slot == 'time': if 'arrive' in ' '.join(response[-5:]) or 'arrival' in ' '.join(response[-5:]) or 'arriving' in ' '.join(response[-3:]): - if active_domain is "train" and 'arriveBy' in top_results[active_domain]: + if active_domain == "train" and 'arriveBy' in top_results[active_domain]: # print('{} -> {}'.format(token, top_results[active_domain]['arriveBy'])) response.append( top_results[active_domain]['arriveBy']) @@ -946,12 +926,12 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if 'arriveBy' in state[d]['semi']: + if 'arrive by' in state[d]: response.append( - state[d]['semi']['arriveBy']) + state[d]['arrive by']) break elif 'leave' in ' '.join(response[-5:]) or 'leaving' in ' '.join(response[-5:]) or 'departure' in ' '.join(response[-3:]): - if active_domain is "train" and 'leaveAt' in top_results[active_domain]: + if active_domain == "train" and 'leaveAt' in top_results[active_domain]: # print('{} -> {}'.format(token, top_results[active_domain]['leaveAt'])) response.append( top_results[active_domain]['leaveAt']) @@ -959,23 +939,23 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if 'leaveAt' in state[d]['semi']: + if 'leave at' in state[d]: response.append( - state[d]['semi']['leaveAt']) + state[d]['leave at']) break elif 'book' in response or "booked" in response: - if state['restaurant']['book']['time'] != "": + if state['restaurant']['book time'] != "": response.append( - state['restaurant']['book']['time']) + state['restaurant']['book time']) else: try: for d in state: if d == 'history': continue - for s in ['arriveBy', 'leaveAt']: - if s in state[d]['semi']: + for s in ['arrive by', 'leave at']: + if s in state[d]: response.append( - state[d]['semi'][s]) + state[d][s]) raise except: pass @@ -999,9 +979,9 @@ class LAVA(Policy): response.append( top_results[active_domain][slot].split()[0]) elif slot == "day" and active_domain in ["restaurant", "hotel"]: - if state[active_domain]['book']['day'] != "": + if state[active_domain]['book day'] != "": response.append( - state[active_domain]['book']['day']) + state[active_domain]['book day']) else: # slot-filling based on query results @@ -1014,8 +994,8 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if slot in state[d]['semi']: - response.append(state[d]['semi'][slot]) + if slot in state[d]: + response.append(state[d][slot]) break else: response.append(token) @@ -1028,7 +1008,7 @@ class LAVA(Policy): elif slot == 'address': response.append("56 Lincoln street") elif slot == "postcode": - response.append('533421') + response.append('cb1p3') elif domain == 'police': if slot == 'phone': response.append('01223358966') @@ -1037,7 +1017,7 @@ class LAVA(Policy): elif slot == 'address': response.append('Parkside, Cambridge') elif slot == 'postcode': - response.append('533420') + response.append('cb3l3') elif domain == 'taxi': if slot == 'phone': response.append('01223358966') @@ -1068,6 +1048,7 @@ class LAVA(Policy): # if "not mentioned" in response: # pdb.set_trace() + # print("lexicalized: ", response) return response @@ -1158,9 +1139,9 @@ class LAVA(Policy): if d == 'history': continue for s in ['destination', 'departure']: - if s in state[d]['semi']: + if s in state[d]: response.append( - state[d]['semi'][s]) + state[d][s]) raise except: pass @@ -1176,9 +1157,9 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if 'arriveBy' in state[d]['semi']: + if 'arriveBy' in state[d]: response.append( - state[d]['semi']['arriveBy']) + state[d]['arrive by']) break elif 'leav' in ' '.join(response[-7:]) or 'depart' in ' '.join(response[-7:]): if active_domain is not None and 'leaveAt' in top_results[active_domain][result_idx]: @@ -1189,23 +1170,23 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if 'leaveAt' in state[d]['semi']: + if 'leave at' in state[d]: response.append( - state[d]['semi']['leaveAt']) + state[d]['leave at']) break elif 'book' in response or "booked" in response: - if state['restaurant']['book']['time'] != "": + if state['restaurant']['book time'] != "": response.append( - state['restaurant']['book']['time']) + state['restaurant']['book time']) else: try: for d in state: if d == 'history': continue - for s in ['arriveBy', 'leaveAt']: - if s in state[d]['semi']: + for s in ['arrive by', 'leave at']: + if s in state[d]: response.append( - state[d]['semi'][s]) + state[d][s]) raise except: pass @@ -1227,9 +1208,9 @@ class LAVA(Policy): response.append( top_results[active_domain][result_idx][slot].split()[0]) elif slot == "day" and active_domain in ["restaurant", "hotel"]: - if state[active_domain]['book']['day'] != "": + if state[active_domain]['book day'] != "": response.append( - state[active_domain]['book']['day']) + state[active_domain]['book day']) else: # slot-filling based on query results @@ -1243,8 +1224,8 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if slot in state[d]['semi']: - response.append(state[d]['semi'][slot]) + if slot in state[d]: + response.append(state[d][slot]) break else: response.append(token) @@ -1294,26 +1275,16 @@ class LAVA(Policy): return response def model_predict(self, data_feed): - # TODO use model's forward function, add null vector for the target response self.logprobs = [] logprobs, pred_labels, joint_logpz, sample_y = self.model.forward_rl( data_feed, self.model.config.max_dec_len) - # if len(data_feed['bs']) == 1: - # logprobs = [logprobs] - # for log_prob in logprobs: - # self.logprobs.extend(log_prob) self.logprobs.extend(joint_logpz) pred_labels = np.array( - [pred_labels], dtype=int) # .squeeze(-1).swapaxes(0, 1) + [pred_labels], dtype=int) de_tknize = get_detokenize() - # if pred_labels.shape[1] == self.model.config.max_utt_len: - # pdb.set_trace() pred_str = get_sent(self.model.vocab, de_tknize, pred_labels, 0) - #for b_id in range(pred_labels.shape[0]): - # only one val for pred_str now - # pred_str = get_sent(self.model.vocab, de_tknize, pred_labels, b_id) return pred_str @@ -1336,11 +1307,8 @@ class LAVA(Policy): loss = 0 # estimate the loss using one MonteCarlo rollout - # TODO better loss estimation? - # TODO better update, instead of reinforce? for lp, re in zip(logprobs, rewards): loss -= lp * re - #tmp = self.model.state_dict()['c2z.p_h.weight'].clone() self.opt.zero_grad() if "fp16" in self.config and self.config.fp16: with amp.scale_loss(loss, self.opt) as scaled_loss: @@ -1350,10 +1318,7 @@ class LAVA(Policy): loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) - #self._print_grad() self.opt.step() - #tmp2 = self.model.state_dict()['c2z.p_h.weight'].clone() - #print(tmp==tmp2) def _print_grad(self): for name, p in self.model.named_parameters(): @@ -1529,7 +1494,8 @@ if __name__ == '__main__': 'history': [['sys', ''], ['user', 'Could you book a 4 stars hotel east of town for one night, 1 person?']]} - cur_model = LAVA() + model_file="path/to/model" # points to model from lava repo + cur_model = LAVA(model_file) response = cur_model.predict(state) # print(response) diff --git a/convlab/policy/mle/loader.py b/convlab/policy/mle/loader.py index bb898ab4a30ad957db632b4eb74fca581f05f05b..9c10d2a7a3841efb55ba418e5cf56708e4e94b7d 100755 --- a/convlab/policy/mle/loader.py +++ b/convlab/policy/mle/loader.py @@ -56,13 +56,50 @@ class PolicyDataVectorizer: state['belief_state'] = data_point['context'][-1]['state'] state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts']) - else: + elif "setsumbt" in str(self.dst): last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else '' self.dst.state['history'].append(['sys', last_system_utt]) usr_utt = data_point['context'][-1]['utterance'] state = deepcopy(self.dst.update(usr_utt)) self.dst.state['history'].append(['usr', usr_utt]) + elif "trippy" in str(self.dst): + # Get last system acts and text. + # System acts are used to fill the inform memory. + last_system_acts = [] + last_system_utt = '' + if len(data_point['context']) > 1: + last_system_acts = [] + for act_type in data_point['context'][-2]['dialogue_acts']: + for act in data_point['context'][-2]['dialogue_acts'][act_type]: + value = '' + if 'value' not in act: + if act['intent'] == 'request': + value = '?' + elif act['intent'] == 'inform': + value = 'yes' + else: + value = act['value'] + last_system_acts.append([act['intent'], act['domain'], act['slot'], value]) + last_system_utt = data_point['context'][-2]['utterance'] + + # Get current user acts and text. + # User acts are used for internal evaluation. + usr_acts = [] + for act_type in data_point['context'][-1]['dialogue_acts']: + for act in data_point['context'][-1]['dialogue_acts'][act_type]: + usr_acts.append([act['intent'], act['domain'], act['slot'], act['value'] if 'value' in act else '']) + usr_utt = data_point['context'][-1]['utterance'] + + # Update the state for DST, then update the state via DST. + self.dst.state['system_action'] = last_system_acts + self.dst.state['user_action'] = usr_acts + self.dst.state['history'].append(['sys', last_system_utt]) + self.dst.state['history'].append(['usr', usr_utt]) + state = deepcopy(self.dst.update(usr_utt)) + else: + raise NameError(f"Tracker: {self.dst} not implemented.") + last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {} state['system_action'] = flatten_acts(last_system_act) state['terminated'] = data_point['terminated'] diff --git a/convlab/policy/mle/train.py b/convlab/policy/mle/train.py index c2477760c8a189cbacd978151755fa790996a421..5253f95d9709fe19071ab7e36ed5a7339597bbe9 100755 --- a/convlab/policy/mle/train.py +++ b/convlab/policy/mle/train.py @@ -12,6 +12,7 @@ from convlab.util.custom_util import set_seed, init_logging, save_config from convlab.util.train_util import to_device from convlab.policy.rlmodule import MultiDiscretePolicy from convlab.policy.vector.vector_binary import VectorBinary +from convlab.policy.vector.vector_binary_fuzzy import VectorBinaryFuzzy root_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) @@ -195,8 +196,15 @@ if __name__ == '__main__': use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info) else: vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking) + elif args.dst == "trippy": + dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ') + if '=' in arg] if args.dst_args is not None else [] + dst_args = {key: eval(value) for key, value in dst_args} + from convlab.dst.trippy import TRIPPY + dst = TRIPPY(**dst_args) + vector = VectorBinaryFuzzy(dataset_name=args.dataset_name, use_masking=args.use_masking) else: - raise NameError(f"Tracker: {args.tracker} not implemented.") + raise NameError(f"Tracker: {args.dst} not implemented.") manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst) agent = MLE_Trainer(manager, vector, cfg) diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/trippy_config.json new file mode 100644 index 0000000000000000000000000000000000000000..41b1c3623aca944312c6389e55e34d72422fb6e0 --- /dev/null +++ b/convlab/policy/ppo/trippy_config.json @@ -0,0 +1,73 @@ +{ + "model": { + "load_path": "/path/to/model/checkpoint", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 1000, + "seed": 0, + "epoch": 50, + "eval_frequency": 5, + "process_num": 2, + "num_eval_dialogues": 500, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "fuzzy_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary_fuzzy.VectorBinaryFuzzy", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": {}, + "dst_sys": { + "TripPy": { + "class_path": "convlab.dst.trippy.TRIPPY", + "ini_params": { + "model_type": "roberta", + "model_name": "roberta-base", + "model_path": "/path/to/model/checkpoint", + "dataset_name": "multiwoz21" + } + } + }, + "sys_nlg": { + "TemplateNLG": { + "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", + "ini_params": { + "is_user": false + } + } + }, + "nlu_usr": { + "BERTNLU": { + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", + "ini_params": { + "mode": "sys", + "config_file": "multiwoz21_sys_context3.json", + "model_file": "/path/to/model/checkpoint.zip" + } + } + }, + "dst_usr": {}, + "policy_usr": { + "RulePolicy": { + "class_path": "convlab.policy.rule.multiwoz.RulePolicy", + "ini_params": { + "character": "usr" + } + } + }, + "usr_nlg": { + "TemplateNLG": { + "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", + "ini_params": { + "is_user": true, + "label_noise": 0.0, + "text_noise": 0.0 + } + } + } +} diff --git a/convlab/policy/vector/vector_binary_fuzzy.py b/convlab/policy/vector/vector_binary_fuzzy.py new file mode 100755 index 0000000000000000000000000000000000000000..5314cf1303f99f0ac2d0fbc5a2bca34c4e3d74c4 --- /dev/null +++ b/convlab/policy/vector/vector_binary_fuzzy.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +import sys +import numpy as np +from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da +from .vector_binary import VectorBinary + + +class VectorBinaryFuzzy(VectorBinary): + + def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True, + seed=0): + + super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) + + def dbquery_domain(self, domain): + """ + query entities of specified domain + Args: + domain string: + domain to query + Returns: + entities list: + list of entities of the specified domain + """ + # Get all user constraints + constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \ + if domain in self.state else [] + xx = self.db.query(domain=domain, state=[], soft_contraints=constraints, fuzzy_match_ratio=100, topk=10) + yy = self.db.query(domain=domain, state=constraints, topk=10) + #print("STRICT:", yy) + #print("FUZZY :", xx) + #if len(yy) == 1 and len(xx) > 1: + # import pdb + # pdb.set_trace() + return xx + #return self.db.query(domain=domain, state=[], soft_contraints=constraints, fuzzy_match_ratio=100, topk=10) diff --git a/convlab/util/multiwoz/state.py b/convlab/util/multiwoz/state.py index 5b65ba066b33cf9cd1e2fb49515406d962f03588..116403ba98ea9604f011aa5c9eb1bf03cda2d708 100755 --- a/convlab/util/multiwoz/state.py +++ b/convlab/util/multiwoz/state.py @@ -1,7 +1,14 @@ def default_state(): state = dict(user_action=[], system_action=[], - belief_state={}, + belief_state={ + 'attraction': {'type': '', 'name': '', 'area': ''}, + 'hotel': {'name': '', 'area': '', 'parking': '', 'price range': '', 'stars': '4', 'internet': 'yes', 'type': 'hotel', 'book stay': '', 'book day': '', 'book people': ''}, + 'restaurant': {'food': '', 'price range': '', 'name': '', 'area': '', 'book time': '', 'book day': '', 'book people': ''}, + 'taxi': {'leave at': '', 'destination': '', 'departure': '', 'arrive by': ''}, + 'train': {'leave at': '', 'destination': '', 'day': '', 'arrive by': '', 'departure': '', 'book people': ''}, + 'hospital': {'department': ''} + }, booked={}, request_state={}, terminated=False, diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index a23908ee13b460614213ecaad3f747fc791deff3..f599f95acfb9e89efc11572f3fb779f25900604e 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -57,18 +57,23 @@ class Database(BaseDatabase): for key, val in state: if key == 'department': department = val + if not department: + for key, val in soft_contraints: + if key == 'department': + department = val if not department: return deepcopy(self.dbs['hospital']) else: return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()] state = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state)) + soft_contraints = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), soft_contraints)) found = [] for i, record in enumerate(self.dbs[domain]): constraints_iterator = zip(state, [False] * len(state)) soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints)) for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator): - if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]: + if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care", "do not care"]: pass else: try: diff --git a/examples/agent_examples/test_LAVA.py b/examples/agent_examples/test_LAVA.py index 38379072d5c3611e349925a64cefa2c8b9560be6..d5b592c2e09f8c7f73ebf925abdf52b1a4be24b6 100755 --- a/examples/agent_examples/test_LAVA.py +++ b/examples/agent_examples/test_LAVA.py @@ -8,13 +8,14 @@ import json import random from pprint import pprint from argparse import ArgumentParser -from convlab.nlu.jointBERT.multiwoz import BERTNLU +from convlab.nlu.jointBERT.unified_datasets import BERTNLU +# from convlab.nlu.jointBERT.multiwoz import BERTNLU as BERTNLU_woz # from convlab.nlu.milu.multiwoz import MILU # available DST models from convlab.dst.rule.multiwoz import RuleDST # from convlab.dst.mdbt.multiwoz import MDBT # from convlab.dst.sumbt.multiwoz import SUMBT -from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker +# from convlab.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker # from convlab.dst.trippy.multiwoz import TRIPPY # from convlab.dst.trade.multiwoz import TRADE # from convlab.dst.comer.multiwoz import COMER @@ -31,7 +32,7 @@ from convlab.policy.rule.multiwoz import RulePolicy from convlab.policy.lava.multiwoz import LAVA # available NLG models from convlab.nlg.template.multiwoz import TemplateNLG -from convlab.nlg.sclstm.multiwoz import SCLSTM +# from convlab.nlg.sclstm.multiwoz import SCLSTM # available E2E models # from convlab.e2e.sequicity.multiwoz import Sequicity # from convlab.e2e.damd.multiwoz import Damd @@ -60,7 +61,7 @@ def set_seed(r_seed): def test_end2end(args, model_dir): # BERT nlu if args.dst_type=="bertnlu_rule": - sys_nlu = BERTNLU() + sys_nlu = BERTNLU("user", config_file="multiwoz21_user_context3.json", model_file="bertnlu_unified_multiwoz21_user_context3") elif args.dst_type in ["setsumbt", "trippy"]: sys_nlu = None @@ -79,10 +80,10 @@ def test_end2end(args, model_dir): # where the models are saved from training - lava_dir = "/gpfs/project/lubis/ConvLab-3/convlab/policy/lava/multiwoz/experiments_woz/sys_config_log_model/" + # lava_dir = "/gpfs/project/lubis/ConvLab-3/convlab/policy/lava/multiwoz/experiments_woz/sys_config_log_model/" + lava_dir = "/gpfs/project/lubis/LAVA_code/LAVA_published/experiments_woz/sys_config_log_model/" if "rl" in model_dir: - # lava_path = "{}/{}/reward_best.model".format(lava_dir, model_path[args.lava_model_type]) lava_path = "{}/{}/reward_best.model".format(lava_dir, model_dir) else: # default saved model format @@ -96,14 +97,9 @@ def test_end2end(args, model_dir): # template NLG sys_nlg = None - # assemble - sys_agent = PipelineAgent( - sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') - sys_agent.add_booking_info = False # BERT nlu trained on sys utterance - user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', - model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') + user_nlu = BERTNLU("sys", config_file="multiwoz21_system_context3_new.json", model_file="bertnlu_unified_multiwoz21_system_context3") if args.US_type == "ABUS": # not use dst user_dst = None @@ -122,9 +118,14 @@ def test_end2end(args, model_dir): user_policy = UserPolicy(user_config) # template NLG user_nlg = TemplateNLG(is_user=True) - # assemble + # assemble agents user_agent = PipelineAgent( user_nlu, user_dst, user_policy, user_nlg, name='user') + sys_agent = PipelineAgent( + sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') + + sys_agent.add_booking_info = False + analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') @@ -132,15 +133,15 @@ def test_end2end(args, model_dir): set_seed(args.seed) model_name = '{}_{}_lava_{}'.format(args.US_type, args.dst_type, model_dir) - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=1000) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=500) if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument("--lava_dir", type=str, default="2020-05-12-14-51-49-actz_cat") + parser.add_argument("--lava_dir", type=str, default="2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48") parser.add_argument("--US_trained", type=bool, default=False, help="whether to use model trained on US or not") parser.add_argument("--seed", type=int, default=20200202, help="seed for random processes") parser.add_argument("--US_type", type=str, default="ABUS", help="which user simulator to us, ABUS or TUS") - parser.add_argument("--dst_type", type=str, default="setsumbt", help="which DST to use, bertnlu_rule, setsumbt, or trippy") + parser.add_argument("--dst_type", type=str, default="bertnlu_rule", help="which DST to use, bertnlu_rule, setsumbt, or trippy") args = parser.parse_args() test_end2end(args, args.lava_dir)