diff --git a/convlab/policy/lava/README.md b/convlab/policy/lava/README.md index bd06a73b9d4c2954338d448a7f7291652228fc59..6fd115cd01c6499fbf1b4815d4125685059d8297 100755 --- a/convlab/policy/lava/README.md +++ b/convlab/policy/lava/README.md @@ -1,74 +1,16 @@ ## LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization -Codebase for [LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization](https://), published as a long paper in COLING 2020. The code is developed based on the implementations of the [LaRL](https://arxiv.org/abs/1902.08858) paper. +ConvLab3 interface for [LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization](https://aclanthology.org/2020.coling-main.41/), published as a long paper in COLING 2020. -### Requirements - python 3 - pytorch - numpy - -### Data -The pre-processed MultiWoz 2.0 data is included in data.zip. Unzip the compressed file and access the data under **data/norm-multi-woz**. - -### Over structure: -The implementation of the models, as well as training and evaluation scripts are under **latent_dialog**. -The scripts for running the experiments are under **experiment_woz**. The trained models and evaluation results are under **experiment_woz/sys_config_log_model**. +To train a LAVA model, clone and follow instructions from the [original LAVA repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/lava-public). -There are 3 types of training to achieve the final model. +With a (pre-)trained LAVA model, it is possible to evaluate or perform online RL with ConvLab3 US by loading the lava module with -### Step 1: Unsupervised training (variational auto-encoding (VAE) task) -Given a dialogue response, the model is tasked to reproduce it via latent variables. With this task we aim to unsupervisedly capture generative factors of dialogue responses. +- from convlab.policy.lava.multiwoz import LAVA - - sl_cat_ae.py: train a VAE model using categorical latent variable - - sl_gauss_ae.py: train a VAE model using continuous (Gaussian) latent variable +and using it as the policy module in the ConvLab pipeline (NLG should be set to None). -### Step 2: Supervised training (response generation task) -The supervised training step of the variational encoder-decoder model could be done 4 different ways. -1. from scratch: +Code example can be found at +- ConvLab-3/examples/agent_examples/test_LAVA.py - - sl_word: train a standard encoder decoder model using supervised learning (SL) - - sl_cat: train a latent action model with categorical latent variables using SL, - - sl_gauss: train a latent action model with continuous latent varaibles using SL, - -2. using the VAE models as pre-trained model (equivalent to LAVA_pt): - - - - finetune_cat_ae: use the VAE with categorical latent variables as weight initialization, and then fine-tune the model on response generation task - - finetune_gauss_ae: as above but with continuous latent variables - - Note: Fine-tuning can be set to be selective (only fine-tune encoder) or not (fine-tune the entire network) using the "selective_finetune" argument in config - -3. using the distribution of the VAE models to obtain informed prior (equivalent to LAVA_kl): - - - - actz_cat: initialized new encoder is combined with pre-trained VAE decoder and fine-tuned on response generation task. VAE encoder is used to obtain an informed prior of the target response and not trained further. - - actz_gauss: as above but with continuous latent variables - -4. or simultaneously from scrath with VAE task in a multi-task fashion (equivalent to LAVA_mt): - - - - mt_cat: train a model to optimize both auto-encoding and response generation in a multi-task fashion, using categorical latent variables - - mt_gauss: as above but with continuous latent variables - -No.1 and 4 can be directly trained without Step 1. No. 2 and 3 requires a pre-trained VAE model, given via a dictionary - - pretrained = {"2020-02-26-18-11-37-sl_cat_ae":100} - -### Step 3: Reinforcement Learning -The model can be further optimized with RL to maximize the dialogue success. - -Each script is used for: - - - reinforce_word: fine tune a pretrained model with word-level policy gradient (PG) - - reinforce_cat: fine tune a pretrained categorical latent action model with latent-level PG. - - reinforce_gauss: fine tune a pretrained gaussian latent action model with latent-level PG. - -The script takes a file containing list of test results from the SL step. - - f_in = "sys_config_log_model/test_files.lst" - - -### Checking the result -The evaluation result can be found at the bottom of the test_file.txt. We provide the best model in this repo. - -NOTE: when re-running the experiments some variance is to be expected in the numbers due to factors such as random seed and hardware specificiations. Some methods are more sensitive to this than others. diff --git a/convlab/policy/lava/multiwoz/latent_dialog/models_task_dev.py b/convlab/policy/lava/multiwoz/latent_dialog/models_task_dev.py deleted file mode 100644 index 78198d74777a6260832e8da0aad197993912f769..0000000000000000000000000000000000000000 --- a/convlab/policy/lava/multiwoz/latent_dialog/models_task_dev.py +++ /dev/null @@ -1,5254 +0,0 @@ -import torch as th -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -from convlab2.policy.lava.multiwoz.latent_dialog.base_models import BaseModel, frange_cycle_linear -from convlab2.policy.lava.multiwoz.latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS -from convlab2.policy.lava.multiwoz.latent_dialog.utils import INT, FLOAT, LONG, Pack, cast_type -from convlab2.policy.lava.multiwoz.latent_dialog.enc2dec.encoders import RnnUttEncoder -from convlab2.policy.lava.multiwoz.latent_dialog.enc2dec.decoders import DecoderRNN, GEN, TEACH_FORCE -from convlab2.policy.lava.multiwoz.latent_dialog.criterions import NLLEntropy, CatKLLoss, Entropy, NormKLLoss, GaussianEntropy -from convlab2.policy.lava.multiwoz.latent_dialog import nn_lib -import numpy as np -import pdb -import json - - -class SysPerfectBD2Word(BaseModel): - def __init__(self, corpus, config): - super(SysPerfectBD2Word, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.policy = nn.Sequential(nn.Linear(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.dec_cell_size), nn.Tanh(), nn.Dropout(config.dropout)) - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=self.utt_encoder.output_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # pack attention context - if self.config.dec_use_attn: - attn_context = enc_outs - else: - attn_context = None - - # create decoder initial states - dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) - - # decode - if self.config.dec_rnn_cell == 'lstm': - # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0) - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - return ret_dict, labels - if return_latent: - return Pack(nll=self.nll(dec_outputs, labels), - latent_action=dec_init_state) - else: - return Pack(nll=self.nll(dec_outputs, labels)) - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # pack attention context - if self.config.dec_use_attn: - attn_context = enc_outs - else: - attn_context = None - - # create decoder initial states - dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=temp) - return logprobs, outs - -class SysPerfectBD2Cat(BaseModel): - def __init__(self, corpus, config): - super(SysPerfectBD2Cat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior - self.contextual_posterior = config.contextual_posterior - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if "policy_dropout" in config and config.policy_dropout: - if "policy_dropout_rate" in config: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - else: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False) - - else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - if not self.simple_posterior: - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - if "state_for_decoding" not in self.config: - self.state_for_decoding = False - else: - self.state_for_decoding = self.config.state_for_decoding - - if self.state_for_decoding: - dec_hidden_size = config.dec_cell_size + self.utt_encoder.output_size + self.db_size + self.bs_size - else: - dec_hidden_size = config.dec_cell_size - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=dec_hidden_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - if "kl_annealing" in self.config and config.kl_annealing=="cyclical": - self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10) - else: - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if isinstance(self.beta, float): - beta = self.beta - else: - if batch_cnt == None: - beta = self.beta[-1] - else: - beta = self.beta[int(batch_cnt)] - - - if self.simple_posterior or "kl_annealing" in self.config: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - else: - logits_py, log_py = self.c2z(enc_last) # p(z|c) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=True) - else: - sample_y = self.gumbel_connector(logits_qy, hard=False) - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - - result['diversity'] = th.mean(p) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - - # if np.random.rand() < epsilon: # greedy exploration - # print("randomly sampling latent") - # idx = th.multinomial(th.cuda.FloatTensor(qy.shape).uniform_(), 1) - # else: # normal latent sampling - idx = th.multinomial(qy, 1).detach() - - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - - def sample_z(self, data_feed, n_z=1, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - - zs = [] - logpzs = [] - for i in range(n_z): - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - zs.append(sample_y) - logpzs.append(joint_logpz) - - - return th.stack(zs), th.stack(logpzs) - - def sample_z_with_exploration(self, data_feed, n_z=1, temp=0.1, epsilon=0.05): - #TODO consider deleting this function - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - - zs = [] - logpzs = [] - for i in range(n_z): - if np.random.rand() < epsilon: # greedy exploration - idx = th.multinomial(th.cuda.FloatTensor(qy.shape).uniform_(), 1) - else: # normal latent sampling - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - zs.append(sample_y) - logpzs.append(joint_logpz) - - - return th.stack(zs), th.stack(logpzs) - - def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): - """ - generate response from latent var - """ - - if data_feed: - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - # pack attention context - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.state_for_decoding: - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - - dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - - #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0) - #attn_context = self.np2var(attn_context, FLOAT) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # has to be forward_rl because we don't have the golden target - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=temp) - return logprobs, outs - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - -class SysEncodedBD2Cat(BaseModel): - def __init__(self, corpus, config): - super(SysEncodedBD2Cat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.k_size = config.k_size - self.y_size = config.y_size - self.config = config - self.simple_posterior = config.simple_posterior - self.contextual_posterior = config.contextual_posterior - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if config.use_metadata_for_decoding: - self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=int(config.embed_size / 2), - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=int(config.dec_cell_size / 2), - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if "policy_dropout" in config and config.policy_dropout: - if "policy_dropout_rate" in config: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - else: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - - else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - if not self.simple_posterior: - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - if "state_for_decoding" not in self.config: - self.state_for_decoding = False - else: - self.state_for_decoding = self.config.state_for_decoding - - dec_hidden_size = config.dec_cell_size - if config.use_metadata_for_decoding: - if "metadata_to_decoder" not in config or config.metadata_to_decoder == "concat": - dec_hidden_size += self.metadata_encoder.output_size - if self.state_for_decoding: - dec_hidden_size += self.utt_encoder.output_size - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=dec_hidden_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - - if "kl_annealing" in self.config and config.kl_annealing=="cyclical": - self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10) - else: - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - - - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if isinstance(self.beta, float): - beta = self.beta - else: - if batch_cnt == None: - beta = self.beta[-1] - else: - beta = self.beta[int(batch_cnt % self.config.n_iter)] - - if self.simple_posterior or "kl_annealing" in self.config: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def extract_short_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['contexts'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - for t_id in range(ctx_lens[b_id]): - utt.extend(context[b_id][t_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def extract_metadata(self, data_feed): - utts = [] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(bs)): - utt = [] - if "metadata_db_only" in self.config and self.config.metadata_db_only: - utt.extend(db[b_id]) - else: - utt.extend(bs[b_id] + db[b_id]) - utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) - return np.array(utts) - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - enc_last = utt_summary.unsqueeze(1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - else: - logits_py, log_py = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last.squeeze(), x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=True) - else: - sample_y = self.gumbel_connector(logits_qy, hard=False) - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2) - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - - result['diversity'] = th.mean(p) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - enc_last = utt_summary.unsqueeze(1) - # create decoder initial states - logits_py, log_qy = self.c2z(enc_last) - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - idx = th.multinomial(qy, 1).detach() - - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - # decode - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - - def sample_z(self, data_feed, n_z=1, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - # metadata = self.np2var(self.extract_metadata(data_feed), LONG) - # out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1)) - - - # create decoder initial states - enc_last = utt_summary.unsqueeze(1) - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - - zs = [] - logpzs = [] - for i in range(n_z): - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - zs.append(sample_y) - logpzs.append(joint_logpz) - - - return th.stack(zs), th.stack(logpzs) - - def decode_z(self, sample_y, batch_size, max_words=None, temp=1.0, gen_type='greedy'): - """ - generate response from latent var - """ - # pack attention context - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1)) - - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - if self.config.use_metadata_for_decoding: - raise NotImplementedError - - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # has to be forward_rl because we don't have the golden target - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=temp) - return logprobs, outs - -class SysAECat(BaseModel): - def __init__(self, corpus, config): - super(SysAECat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = True # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = False # does not use context cause AE task - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if "ae_zero_padding" in self.config and self.config.ae_zero_padding: - # self.use_metadata = self.config.use_metadata - self.ae_zero_padding = self.config.ae_zero_padding - c2z_input_size = self.utt_encoder.output_size + self.db_size + self.bs_size - else: - # self.use_metadata = False - self.ae_zero_padding = False - c2z_input_size = self.utt_encoder.output_size - - - if "policy_dropout" in config and config.policy_dropout: - self.c2z = nn_lib.Hidden2DiscretewDropout(c2z_input_size, - config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - else: - self.c2z = nn_lib.Hidden2Discrete(c2z_input_size, - config.y_size, config.k_size, is_lstm=False) - - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - # if not self.simple_posterior: #q(z|x,c) - # if self.contextual_posterior: - # # x, c, BS, and DB - # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - # config.y_size, config.k_size, is_lstm=False) - # else: - # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - if self.ae_zero_padding: - enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) - else: - enc_last = utt_summary.squeeze(1) - - - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - # else: - # logits_py, log_py = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or (use_py is not None and use_py is True): - # sample_y = self.gumbel_connector(logits_py, hard=False) - # else: - # sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - -class SysGroundedAECat(BaseModel): - def __init__(self, corpus, config): - super(SysGroundedAECat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = True # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = False # does not use context cause AE task - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - if config.use_metadata_for_decoding: - self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=int(config.embed_size / 2), - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=int(config.dec_cell_size / 2), - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if "policy_dropout" in config and config.policy_dropout: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - # if not self.simple_posterior: #q(z|x,c) # use bs and db grounding as c - # if self.contextual_posterior: - # x, c, BS, and DB - # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - # config.y_size, config.k_size, is_lstm=False) - # else: - # self.xc2z = nn_lib.Hidden2Discrete(self.metadata_encoder.output_size, config.y_size, config.k_size, is_lstm=False) # prior network conditioned on BS+DB - - if config.use_metadata_for_decoding: - if "metadata_to_decoder" not in config or config.metadata_to_decoder == "concat": - dec_hidden_size = config.dec_cell_size + self.metadata_encoder.output_size - else: - dec_hidden_size = config.dec_cell_size - else: - dec_hidden_size = config.dec_cell_size - - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=dec_hidden_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off {} to {}".format(len(tokens), max_len)) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - - def extract_AE_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['outputs'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - utt.extend(context[b_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def extract_metadata(self, data_feed): - utts = [] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(bs)): - utt = [] - if "metadata_db_only" in self.config and self.config.metadata_db_only: - utt.extend(db[b_id]) - else: - utt.extend(bs[b_id] + db[b_id]) - utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) - return np.array(utts) - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - batch_size = len(ctx_lens) - - if self.config.use_metadata_for_encoder: - ctx_utts = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db - utt_summary, _, enc_outs = self.utt_encoder(ctx_utts.unsqueeze(1)) - else: - in_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - utt_summary, _, enc_outs = self.utt_encoder(in_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.squeeze(1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - # else: - # logits_py, log_py = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or (use_py is not None and use_py is True): - # sample_y = self.gumbel_connector(logits_py, hard=False) - # else: - # sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - -class SysGroundedAEGauss(BaseModel): - def __init__(self, corpus, config): - super(SysGroundedAEGauss, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.y_size = config.y_size - self.simple_posterior = True # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = False # does not use context cause AE task - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=int(config.embed_size / 2), - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=int(config.utt_cell_size / 2), - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if "policy_dropout" in config and config.policy_dropout: - raise NotImplementedError - else: - self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, - config.y_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size, bias=False) - self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) - if not self.simple_posterior: - # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - # config.y_size, is_lstm=False) - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size + self.metadata_encoder.output_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - self.nll = NLLEntropy(self.pad_id, config.avg_type) - - self.gauss_kl = NormKLLoss(unit_average=True) - self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.config.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - return total_loss - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - - def extract_AE_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['outputs'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - utt.extend(context[b_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def extract_metadata(self, data_feed): - utts = [] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(bs)): - utt = [] - utt.extend(bs[b_id] + db[b_id]) - utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) - return np.array(utts) - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - batch_size = len(ctx_lens) - - if self.config.use_metadata_for_encoder: - ctx_utts = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db - utt_summary, _, enc_outs = self.utt_encoder(ctx_utts.unsqueeze(1)) - else: - in_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - utt_summary, _, enc_outs = self.utt_encoder(in_utts.unsqueeze(1)) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.squeeze(1) - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - p_mu, p_logvar = self.zero, self.zero - # else: - # p_mu, p_logvar = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or use_py: - # sample_z = self.gauss_connector(p_mu, p_logvar) - # else: - # sample_z = self.gauss_connector(q_mu, q_logvar) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - # decode - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - result['pi_kl'] = pi_kl - result['nll'] = self.nll(dec_outputs, labels) - return result - - def gaussian_logprob(self, mu, logvar, sample_z): - var = th.exp(logvar) - constant = float(-0.5 * np.log(2*np.pi)) - logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) - return logprob - -class SysMTCat(BaseModel): - def __init__(self, corpus, config): - super(SysMTCat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = config.contextual_posterior # does not use context cause AE task - self.shared_train = config.shared_train - - if "use_aux_kl" in config: - self.use_aux_kl = config.use_aux_kl - else: - self.use_aux_kl = False - - self.embedding = None - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - - if not self.simple_posterior: #q(z|x,c) - if self.contextual_posterior: - # x, c, BS, and DB - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if "aux_pi_beta" in self.config: - self.aux_pi_beta = self.config.aux_pi_beta - else: - self.aux_pi_beta = 1.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.shared_train: - if "selective_fine_tune" in self.config and self.config.selective_fine_tune: - total_loss = loss.nll + self.config.beta * loss.aux_pi_kl - else: - total_loss = loss.nll + loss.ae_nll + self.config.aux_pi_beta * loss.aux_pi_kl + self.config.beta * loss.aux_kl - else: - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.config.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - - return total_loss - - def encode_state(self, data_feed): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - return enc_last - - def encode_action(self, data_feed): - batch_size = data_feed.shape[0] - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) - - # create decoder initial states - aux_enc_last = aux_utt_summary.squeeze(1) - - return aux_enc_last - - def get_z_via_vae(self, data_feed, hard=False): - batch_size = data_feed.shape[0] - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) - - # create decoder initial states - aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) - - logits_qy, log_qy = self.c2z(aux_enc_last) - aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) - - return aux_sample_z - - def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): - """ - generate response from latent var - """ - - # if data_feed: - # ctx_lens = data_feed['context_lens'] # (batch_size, ) - # short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - # bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - # pack attention context - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - # if self.state_for_decoding: - # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - - # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - - #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0) - #attn_context = self.np2var(attn_context, FLOAT) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # has to be forward_rl because we don't have the golden target - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=temp) - return logprobs, outs - - def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - - # how to use z, alone or in combination with bs and db - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=False) - sample_y_discrete = self.gumbel_connector(logits_qy, hard=True) - log_py = self.log_uniform_y - # else: - # logits_py, log_py = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or (use_py is not None and use_py is True): - # sample_y = self.gumbel_connector(logits_py, hard=False) - # else: - # sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - return result - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - short_target_utts = self.np2var(data_feed['outputs'], LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - aux_enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), aux_utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - if self.shared_train: - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - aux_sample_y = self.gumbel_connector(aux_logits_qy, hard=mode==GEN) - - log_py = self.log_uniform_y - else: - logits_py, log_py = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=False) - else: - sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - - if self.shared_train: - if self.config.dec_use_attn: - aux_z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - aux_attn_context = [] - aux_temp_sample_y = aux_sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - aux_attn_context.append(th.mm(aux_temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - aux_attn_context = th.cat(aux_attn_context, dim=1) - aux_dec_init_state = th.sum(aux_attn_context, dim=1).unsqueeze(0) - else: - aux_dec_init_state = self.z_embedding(aux_sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - aux_attn_context = None - if self.config.dec_rnn_cell == 'lstm': - aux_dec_init_state = tuple([aux_dec_init_state, aux_dec_init_state]) - - - - # decode - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - if self.shared_train: - ae_dec_outputs, ae_dec_hidden_state, ae_ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=aux_dec_init_state, # tuple: (h, c) - attn_context=aux_attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - result['ae_nll'] = self.nll(ae_dec_outputs, labels) - aux_pi_kl = self.cat_kl_loss(log_qy, aux_log_qy, batch_size, unit_average=True) - aux_kl = self.cat_kl_loss(aux_log_qy, log_py, batch_size, unit_average=True) - result['aux_pi_kl'] = aux_pi_kl - result['aux_kl'] = aux_kl - - - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - return result - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - -class SysGroundedMTCat(BaseModel): - def __init__(self, corpus, config): - super(SysGroundedMTCat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = config.contextual_posterior # does not use context cause AE task - - if "use_aux_kl" in config: - self.use_aux_kl = config.use_aux_kl - else: - self.use_aux_kl = False - - self.embedding = None - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if config.use_metadata_for_decoding: - self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=int(config.embed_size / 2), - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=int(config.dec_cell_size / 2), - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - - - if "policy_dropout" in config and config.policy_dropout: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - - - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - - if not self.simple_posterior: #q(z|x,c) - if self.contextual_posterior: - # x, c, BS, and DB - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - if config.use_metadata_for_decoding: - if "metadata_to_decoder" not in config or config.metadata_to_decoder == "concat": - dec_hidden_size = config.dec_cell_size + self.metadata_encoder.output_size - else: - dec_hidden_size = config.dec_cell_size - else: - dec_hidden_size = config.dec_cell_size - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=dec_hidden_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - if self.use_aux_kl: - try: - total_loss += loss.aux_pi_kl - except KeyError: - total_loss += 0 - - return total_loss - - def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - - ctx_lens = data_feed['context_lens'] # (batch_size, ) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - batch_size = len(ctx_lens) - - if self.config.use_metadata_for_aux_encoder: - ctx_outs = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db - utt_summary, _, _ = self.aux_encoder(ctx_outs.unsqueeze(1)) - else: - short_target_utts = self.np2var(data_feed['outputs'], LONG) - utt_summary, _, _ = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.unsqueeze(1) - - # how to use z, alone or in combination with bs and db - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - # else: - # logits_py, log_py = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or (use_py is not None and use_py is True): - # sample_y = self.gumbel_connector(logits_py, hard=False) - # else: - # sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - return result - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.unsqueeze(1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - else: - logits_py, log_py = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=False) - else: - sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def shared_forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - if self.config.use_metadata_for_aux_encoder: - ctx_outs = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(ctx_outs.unsqueeze(1)) - else: - short_target_utts = self.np2var(data_feed['outputs'], LONG) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.unsqueeze(1) - aux_enc_last = aux_utt_summary.unsqueeze(1) - - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - else: - logits_py, log_py = self.c2z(enc_last) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=False) - else: - sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - aux_pi_kl = self.cat_kl_loss(log_qy, aux_log_qy, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['aux_pi_kl'] = aux_pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - return result - - def extract_metadata(self, data_feed): - utts = [] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(bs)): - utt = [] - if "metadata_db_only" in self.config and self.config.metadata_db_only: - utt.extend(db[b_id]) - else: - utt.extend(bs[b_id] + db[b_id]) - utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) - return np.array(utts) - - def extract_AE_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['outputs'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - utt.extend(context[b_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def extract_short_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['contexts'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - for t_id in range(ctx_lens[b_id]): - utt.extend(context[b_id][t_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - -class SysActZCat(BaseModel): - def __init__(self, corpus, config): - super(SysActZCat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = config.contextual_posterior # does not use context cause AE task - - if "use_aux_kl" in config: - self.use_aux_kl = config.use_aux_kl - else: - self.use_aux_kl = False - - if "use_aux_c2z" in config: - self.use_aux_c2z = config.use_aux_c2z - else: - self.use_aux_c2z = False - - - - self.embedding = None - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - # if "policy_dropout" in config and config.policy_dropout: - # self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size, - # config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - # else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False) - if self.use_aux_c2z: - self.aux_c2z = nn_lib.Hidden2Discrete(self.aux_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - - if not self.simple_posterior: #q(z|x,c) - if self.contextual_posterior: - # x, c, BS, and DB - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - short_target_utts = self.np2var(data_feed['outputs'], LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - if self.use_aux_c2z: - aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1)) - else: - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = aux_log_qy - else: - logits_py, log_py = self.c2z(enc_last) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=True) - else: - sample_y = self.gumbel_connector(logits_qy, hard=False) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) - - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - # else: - # p_mu, p_logvar = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or use_py: - # sample_z = self.gauss_connector(p_mu, p_logvar) - # else: - # sample_z = self.gauss_connector(q_mu, q_logvar) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def get_z_via_vae(self, data_feed, hard=False): - batch_size = data_feed.shape[0] - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) - - # create decoder initial states - aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) - - logits_qy, log_qy = self.c2z(aux_enc_last) - aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) - - return aux_sample_z, logits_qy, log_qy - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - - def sample_z(self, data_feed, n_z=1, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - - zs = [] - logpzs = [] - for i in range(n_z): - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - zs.append(sample_y) - logpzs.append(joint_logpz) - - - return th.stack(zs), th.stack(logpzs) - - def sample_z_with_exploration(self, data_feed, n_z=1, temp=0.1, epsilon=0.05): - #TODO consider deleting this function - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - - zs = [] - logpzs = [] - for i in range(n_z): - if np.random.rand() < epsilon: # greedy exploration - idx = th.multinomial(th.cuda.FloatTensor(qy.shape).uniform_(), 1) - else: # normal latent sampling - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - zs.append(sample_y) - logpzs.append(joint_logpz) - - - return th.stack(zs), th.stack(logpzs) - - def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): - """ - generate response from latent var - """ - # pack attention context - - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - - # decode - - #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0) - #attn_context = self.np2var(attn_context, FLOAT) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # has to be forward_rl because we don't have the golden target - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=temp) - - return logprobs, outs - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - -class SysGroundedActZCat(BaseModel): - def __init__(self, corpus, config): - super(SysGroundedActZCat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = config.contextual_posterior # does not use context cause AE task - - if "use_aux_kl" in config: - self.use_aux_kl = config.use_aux_kl - else: - self.use_aux_kl = False - - self.embedding = None - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - if config.use_metadata_for_decoding: - self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=int(config.embed_size / 2), - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=int(config.dec_cell_size / 2), - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - - - if "policy_dropout" in config and config.policy_dropout: - self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) - else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - - if not self.simple_posterior: #q(z|x,c) - if self.contextual_posterior: - # x, c, BS, and DB - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - if config.use_metadata_for_decoding: - if "metadata_to_decoder" not in config or config.metadata_to_decoder == "concat": - dec_hidden_size = config.dec_cell_size + self.metadata_encoder.output_size - else: - dec_hidden_size = config.dec_cell_size - else: - dec_hidden_size = config.dec_cell_size - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=dec_hidden_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def extract_short_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['contexts'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - for t_id in range(ctx_lens[b_id]): - utt.extend(context[b_id][t_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def extract_metadata(self, data_feed): - utts = [] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(bs)): - utt = [] - if "metadata_db_only" in self.config and self.config.metadata_db_only: - utt.extend(db[b_id]) - else: - utt.extend(bs[b_id] + db[b_id]) - utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) - return np.array(utts) - - def extract_AE_ctx(self, data_feed): - utts = [] - ctx_lens = data_feed['context_lens'] # (batch_size, ) - context = data_feed['outputs'] - bs = data_feed['bs'] - db = data_feed['db'] - if not isinstance(bs, list): - bs = data_feed['bs'].tolist() - db = data_feed['db'].tolist() - - for b_id in range(len(context)): - utt = [] - utt.extend(context[b_id]) - try: - utt.extend(bs[b_id] + db[b_id]) - except: - pdb.set_trace() - utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) - return np.array(utts) - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - if self.config.use_metadata_for_aux_encoder: - ctx_outs = self.np2var(self.extract_AE_ctx(data_feed), LONG) # contains bs and db - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(ctx_outs.unsqueeze(1)) - else: - short_target_utts = self.np2var(data_feed['outputs'], LONG) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.unsqueeze(1) - aux_enc_last = aux_utt_summary.unsqueeze(1) - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = aux_log_qy - else: - logits_py, log_py = self.c2z(enc_last) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=False) - else: - sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - # pdb.set_trace() - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - enc_last = utt_summary.unsqueeze(1) - # create decoder initial states - logits_py, log_qy = self.c2z(enc_last) - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - idx = th.multinomial(qy, 1).detach() - - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - if self.config.use_metadata_for_decoding: - metadata = self.np2var(self.extract_metadata(data_feed), LONG) - metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) - if "metadata_to_decoder" in self.config: - if self.config.metadata_to_decoder == "add": - dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) - elif self.config.metadata_to_decoder == "avg": - dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - else: - dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) - - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - -class SysE2ECat(BaseModel): - def __init__(self, corpus, config): - super(SysE2ECat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior - self.contextual_posterior = config.contextual_posterior - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - if not self.simple_posterior: - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - enc_last = utt_summary.squeeze(1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = self.log_uniform_y - else: - logits_py, log_py = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=False) - else: - sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) - return result - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - enc_last = utt_summary.squeeze(1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - -class SysE2EActZCat(BaseModel): - def __init__(self, corpus, config): - super(SysE2EActZCat, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.k_size = config.k_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context - self.contextual_posterior = config.contextual_posterior # does not use context cause AE task - if "use_act_label" in config: - self.use_act_label = config.use_act_label - else: - self.use_act_label = False - - if "use_aux_c2z" in config: - self.use_aux_c2z = config.use_aux_c2z - else: - self.use_aux_c2z = False - - self.embedding = None - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - - if self.use_act_label: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.act_size, config.y_size, config.k_size, is_lstm=False) - else: - self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - if self.use_aux_c2z: - self.aux_c2z = nn_lib.Hidden2Discrete(self.aux_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) - self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) - - if not self.simple_posterior: #q(z|x,c) - if self.contextual_posterior: - # x, c, BS, and DB - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, - config.y_size, config.k_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - self.cat_kl_loss = CatKLLoss() - self.entropy_loss = Entropy() - self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) - self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - if self.use_gpu: - self.log_uniform_y = self.log_uniform_y.cuda() - self.eye = self.eye.cuda() - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - short_target_utts = self.np2var(data_feed['outputs'], LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = utt_summary.squeeze(1) - aux_enc_last = aux_utt_summary.squeeze(1) - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) - # create decoder initial states - if self.simple_posterior: - logits_qy, log_qy = self.c2z(enc_last) - if self.use_aux_c2z: - aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1)) - else: - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) - log_py = aux_log_qy - else: - logits_py, log_py = self.c2z(enc_last) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or (use_py is not None and use_py is True): - sample_y = self.gumbel_connector(logits_py, hard=False) - else: - sample_y = self.gumbel_connector(logits_qy, hard=True) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - # (batch_size, response_size-1) - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - # (batch_size, max_ctx_len, ctx_cell_size) - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_y - ret_dict['log_qy'] = log_qy - return ret_dict, labels - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - # regularization qy to be uniform - avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) - avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) - b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) - mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) - pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) - q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b - p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) - - result['pi_kl'] = pi_kl - result['diversity'] = th.mean(p) - result['nll'] = self.nll(dec_outputs, labels) - result['b_pr'] = b_pr - result['mi'] = mi - return result - - def forward_rl(self, data_feed, max_words, temp=0.1, enc="utt"): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - batch_size = len(ctx_lens) - if enc == "utt": - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # create decoder initial states - enc_last = utt_summary.squeeze(1) - # create decoder initial states - if self.simple_posterior: - logits_py, log_qy = self.c2z(enc_last) - else: - logits_py, log_qy = self.c2z(enc_last) - - elif enc == "aux": - short_target_utts = self.np2var(data_feed['outputs'], LONG) - # short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['outputs'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - if self.simple_posterior: - if self.use_aux_c2z: - aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1)) - else: - aux_enc_last = aux_utt_summary.squeeze(1) - aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) - logits_py = aux_logits_qy - log_qy = aux_log_qy - - - qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) - log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) - idx = th.multinomial(qy, 1).detach() - logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) - joint_logpz = th.sum(logprob_sample_z, dim=1) - sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) - sample_y.scatter_(1, idx, 1.0) - - # pack attention context - if self.config.dec_use_attn: - z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) - attn_context = [] - temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) - for z_id in range(self.y_size): - attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) - attn_context = th.cat(attn_context, dim=1) - dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) - else: - dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_y - -class SysPerfectBD2Gauss(BaseModel): - def __init__(self, corpus, config): - super(SysPerfectBD2Gauss, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior - if "contextual posterior" in config: - self.contextual_posterior = config.contextual_posterior - else: - self.contextual_posterior = True # default value is true, i.e. q(z|x,c) - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) - self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) - if not self.simple_posterior: - # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - # config.y_size, is_lstm=False) - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) - - if "state_for_decoding" not in self.config: - self.state_for_decoding = False - else: - self.state_for_decoding = self.config.state_for_decoding - - if self.state_for_decoding: - dec_hidden_size = config.dec_cell_size + self.utt_encoder.output_size + self.db_size + self.bs_size - else: - dec_hidden_size = config.dec_cell_size - - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=dec_hidden_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - - self.gauss_kl = NormKLLoss(unit_average=True) - self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.config.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - - # create decoder initial states - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - p_mu, p_logvar = self.zero, self.zero - else: - p_mu, p_logvar = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) - - # use prior at inference time, otherwise use posterior - if mode == GEN or use_py: - sample_z = self.gauss_connector(p_mu, p_logvar) - else: - sample_z = self.gauss_connector(q_mu, q_logvar) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - result['pi_kl'] = pi_kl - result['nll'] = self.nll(dec_outputs, labels) - return result - - def gaussian_logprob(self, mu, logvar, sample_z): - var = th.exp(logvar) - constant = float(-0.5 * np.log(2*np.pi)) - logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) - return logprob - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - p_mu, p_logvar = self.c2z(enc_last) - - sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() - logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) - joint_logpz = th.sum(logprob_sample_z, dim=1) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.state_for_decoding: - dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_z - - def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): - """ - generate response from latent var - """ - - if data_feed: - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - # pack attention context - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - dec_init_state = self.z_embedding(sample_y.unsqueeze(0)) - if (dec_init_state != dec_init_state).any(): - pdb.set_trace() - attn_context = None - - # decode - if self.state_for_decoding: - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - - return logprobs, outs - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - -class SysAEGauss(BaseModel): - def __init__(self, corpus, config): - super(SysAEGauss, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - # self.act_size = corpus.act_size - self.y_size = config.y_size - self.simple_posterior = True - self.contextual_posterior = False - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - # if "use_metadata" in self.config and self.config.use_metadata: - if "ae_zero_padding" in self.config and self.config.ae_zero_padding: - # self.use_metadata = self.config.use_metadata - self.ae_zero_padding = self.config.ae_zero_padding - c2z_input_size = self.utt_encoder.output_size + self.db_size + self.bs_size - else: - # self.use_metadata = False - self.ae_zero_padding = False - c2z_input_size = self.utt_encoder.output_size - - self.c2z = nn_lib.Hidden2Gaussian(c2z_input_size, - config.y_size, is_lstm=False) - - self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) - - self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) - if not self.simple_posterior: - # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - # config.y_size, is_lstm=False) - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - - self.gauss_kl = NormKLLoss(unit_average=True) - self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) - - - if "kl_annealing" in self.config and config.kl_annealing=="cyclical": - if "n_iter" not in self.config: - config['n_iter'] = config.ckpt_step * config.max_epoch - self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10) - else: - self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 - - def valid_loss(self, loss, batch_cnt=None): - if isinstance(self.beta, float): - beta = self.beta - else: - if batch_cnt == None: - beta = self.beta[-1] - else: - beta = self.beta[int(batch_cnt)] - - - if self.simple_posterior or "kl_annealing" in self.config: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - # print(short_ctx_utts[0]) - # print(out_utts[0]) - - - # create decoder initial states - # if self.use_metadata: - if self.ae_zero_padding: - enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) - else: - enc_last = utt_summary.squeeze(1) - - # create decoder initial states - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - p_mu, p_logvar = self.zero, self.zero - # else: - # p_mu, p_logvar = self.c2z(enc_last) - # # encode response and use posterior to find q(z|x, c) - # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - # if self.contextual_posterior: - # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - # else: - # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) - - # # use prior at inference time, otherwise use posterior - # if mode == GEN or use_py: - # sample_z = self.gauss_connector(p_mu, p_logvar) - # else: - # sample_z = self.gauss_connector(q_mu, q_logvar) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - # print(labels[0]) - # print("========") - # pdb.set_trace() - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - result['pi_kl'] = pi_kl - result['nll'] = self.nll(dec_outputs, labels) - return result - - def gaussian_logprob(self, mu, logvar, sample_z): - var = th.exp(logvar) - constant = float(-0.5 * np.log(2*np.pi)) - logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) - return logprob - -class SysMTGauss(BaseModel): - def __init__(self, corpus, config): - super(SysMTGauss, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior - self.contextual_posterior = config.contextual_posterior - if "shared_train" in config: - self.shared_train = config.shared_train - else: - self.shared_train = False - - if "use_aux_kl" in config: - self.use_aux_kl = config.use_aux_kl - else: - self.use_aux_kl = False - - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - # if self.shared_train: - # self.aux_c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, - # config.y_size, is_lstm=False) - - self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) - self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) - if not self.simple_posterior: - # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - # config.y_size, is_lstm=False) - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - if "priornet_config_path" in config and self.config.priornet_config_path is not None: - self._init_priornet(corpus) - else: - self.priornet = None - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - self.entropy_loss = GaussianEntropy() - - self.gauss_kl = NormKLLoss(unit_average=True) - self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) - - # self.aux_pi_beta = self.config.aux_pi_beta if hasattr(self.config, 'aux_pi_beta') else 1.0 - # if hasattr(self.config, 'aux_pi_beta'): - if "aux_pi_beta" in self.config: - self.aux_pi_beta = self.config.aux_pi_beta - else: - self.aux_pi_beta = 1.0 - - - def _init_priornet(self, corpus): - priornet_config = Pack(json.load(open(self.config.priornet_config_path))) - - if "actz" in self.config.priornet_config_path: - self.priornet = SysActZGauss(corpus, priornet_config) - else: - self.priornet = SysMTGauss(corpus, priornet_config) - - priornet_model_dict = th.load(self.config.priornet_model_path, map_location=lambda storage, location: storage) - self.priornet.load_state_dict(priornet_model_dict) - - for p in self.priornet.parameters(): - p.requires_grad=False - - - def valid_loss(self, loss, batch_cnt=None): - if self.shared_train: - if "selective_fine_tune" in self.config and self.config.selective_fine_tune: - total_loss = loss.nll + self.config.beta * loss.aux_pi_kl - else: - total_loss = loss.nll + loss.ae_nll + self.aux_pi_beta * loss.aux_pi_kl + self.config.beta * loss.aux_kl - else: - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.config.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - short_target_utts = self.np2var(data_feed['outputs'], LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # aux_enc_last = aux_utt_summary.squeeze(1) - aux_enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), aux_utt_summary.squeeze(1)], dim=1) - - # create decoder initial states - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - if self.shared_train: - # aux_q_mu, aux_q_logvar = self.aux_c2z(aux_enc_last) - aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) - aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) - if self.priornet is not None: - _, p_mu, p_logvar = self.priornet.get_z_via_rg(data_feed) - else: - p_mu, p_logvar = self.zero, self.zero - else: - p_mu, p_logvar = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) - - aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) - - # use prior at inference time, otherwise use posterior - if mode == GEN or use_py: - sample_z = self.gauss_connector(p_mu, p_logvar) - else: - sample_z = self.gauss_connector(q_mu, q_logvar) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - if self.shared_train: - aux_dec_init_state = self.z_embedding(aux_sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - if self.shared_train: - aux_dec_init_state = tuple([aux_dec_init_state, aux_dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - return ret_dict, labels - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - if self.shared_train: - ae_dec_outputs, ae_hidden_state, ae_ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=aux_dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - result['ae_nll'] = self.nll(ae_dec_outputs, labels) - aux_pi_kl = self.gauss_kl(q_mu, q_logvar, aux_q_mu, aux_q_logvar) - aux_kl = self.gauss_kl(aux_q_mu, aux_q_logvar, p_mu, p_logvar) - result['aux_pi_kl'] = aux_pi_kl - result['aux_kl'] = aux_kl - # result['aux_entropy'] = self.entropy_loss(aux_q_mu, aux_q_logvar) - - - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - result['pi_kl'] = pi_kl - # result['pi_entropy'] = self.entropy_loss(q_mu, q_logvar) - result['nll'] = self.nll(dec_outputs, labels) - return result - - def encode_state(self, data_feed): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - return enc_last - - def encode_action(self, data_feed): - batch_size = data_feed.shape[0] - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) - - # create decoder initial states - aux_enc_last = aux_utt_summary.squeeze(1) - - return aux_enc_last - - def get_z_via_vae(self, responses): - batch_size = responses.shape[0] - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(responses.unsqueeze(1)) - - # create decoder initial states - aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) - - aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) - aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) - - return aux_sample_z, aux_q_mu, aux_q_logvar - - def get_z_via_rg(self, data_feed): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - q_mu, q_logvar = self.c2z(enc_last) - - sample_z = self.gauss_connector(q_mu, q_logvar) - - return sample_z, q_mu, q_logvar - - - def gaussian_logprob(self, mu, logvar, sample_z): - var = th.exp(logvar) - constant = float(-0.5 * np.log(2*np.pi)) - logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) - return logprob - - return logprobs, outs, joint_logpz, sample_z - - def forward_rl(self, data_feed, max_words, temp=0.1): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # create decoder initial states - p_mu, p_logvar = self.c2z(enc_last) - - # sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() - sample_z = self.gauss_connector(p_mu, p_logvar) - logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) - joint_logpz = th.sum(logprob_sample_z, dim=1) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - # if self.state_for_decoding: - # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - return logprobs, outs, joint_logpz, sample_z - - def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): - """ - generate response from latent var - """ - - if data_feed: - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - # pack attention context - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - dec_init_state = self.z_embedding(sample_y.unsqueeze(0)) - if (dec_init_state != dec_init_state).any(): - pdb.set_trace() - attn_context = None - - # decode - # if self.state_for_decoding: - # if not data_feed: - # raise ValueError - # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - - return logprobs, outs - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens - - def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, sample_z=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) - - # create decoder initial states - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - p_mu, p_logvar = self.zero, self.zero - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - result['pi_kl'] = pi_kl - return result - -class SysActZGauss(BaseModel): - def __init__(self, corpus, config): - super(SysActZGauss, self).__init__(config) - self.vocab = corpus.vocab - self.vocab_dict = corpus.vocab_dict - self.vocab_size = len(self.vocab) - self.bos_id = self.vocab_dict[BOS] - self.eos_id = self.vocab_dict[EOS] - self.pad_id = self.vocab_dict[PAD] - self.bs_size = corpus.bs_size - self.db_size = corpus.db_size - self.y_size = config.y_size - self.simple_posterior = config.simple_posterior - self.contextual_posterior = config.contextual_posterior - - self.embedding = None - self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, - embedding_dim=config.embed_size, - feat_size=0, - goal_nhid=0, - rnn_cell=config.utt_rnn_cell, - utt_cell_size=config.utt_cell_size, - num_layers=config.num_layers, - input_dropout_p=config.dropout, - output_dropout_p=config.dropout, - bidirectional=config.bi_utt_cell, - variable_lengths=False, - use_attn=config.enc_use_attn, - embedding=self.embedding) - - - self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - # self.aux_c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, - # config.y_size, is_lstm=False) - - self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) - self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) - if not self.simple_posterior: - # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - # config.y_size, is_lstm=False) - if self.contextual_posterior: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, - config.y_size, is_lstm=False) - else: - self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) - - - self.decoder = DecoderRNN(input_dropout_p=config.dropout, - rnn_cell=config.dec_rnn_cell, - input_size=config.embed_size, - hidden_size=config.dec_cell_size, - num_layers=config.num_layers, - output_dropout_p=config.dropout, - bidirectional=False, - vocab_size=self.vocab_size, - use_attn=config.dec_use_attn, - ctx_cell_size=config.dec_cell_size, - attn_mode=config.dec_attn_mode, - sys_id=self.bos_id, - eos_id=self.eos_id, - use_gpu=config.use_gpu, - max_dec_len=config.max_dec_len, - embedding=self.embedding) - - self.nll = NLLEntropy(self.pad_id, config.avg_type) - if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": - req_tokens = [] - for d in REQ_TOKENS.keys(): - req_tokens.extend(REQ_TOKENS[d]) - nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) - print("req tokens assigned with special weights") - if config.use_gpu: - nll_weight = nll_weight.cuda() - self.nll.set_weight(nll_weight) - - - - self.gauss_kl = NormKLLoss(unit_average=True) - self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) - - def valid_loss(self, loss, batch_cnt=None): - if self.simple_posterior: - total_loss = loss.nll - if self.config.use_pr > 0.0: - total_loss += self.config.beta * loss.pi_kl - else: - total_loss = loss.nll + loss.pi_kl - - if self.config.use_mi: - total_loss += (loss.b_pr * self.beta) - - if self.config.use_diversity: - total_loss += loss.diversity - - if "match_z" in self.config and self.config.match_z: - total_loss += loss.z_mse - - return total_loss - - def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - short_target_utts = self.np2var(data_feed['outputs'], LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) - # aux_enc_last = aux_utt_summary.squeeze(1) - - # create decoder initial states - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - # p_mu, p_logvar = self.aux_c2z(aux_enc_last) - p_mu, p_logvar = self.c2z(aux_enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - aux_sample_z = self.gauss_connector(p_mu, p_logvar) - else: - p_mu, p_logvar = self.c2z(enc_last) - # encode response and use posterior to find q(z|x, c) - x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) - if self.contextual_posterior: - q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) - else: - q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) - - aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) - - # use prior at inference time, otherwise use posterior - if mode == GEN or use_py: - sample_z = self.gauss_connector(p_mu, p_logvar) - else: - sample_z = self.gauss_connector(q_mu, q_logvar) - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - return ret_dict, labels - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - z_mse = F.mse_loss(aux_sample_z, sample_z) - result['pi_kl'] = pi_kl - result['z_mse'] = z_mse - # result['nll'] = self.nll(dec_outputs, labels) - return result - - def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - batch_size = len(ctx_lens) - - utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) - - # get decoder inputs - dec_inputs = out_utts[:, :-1] - labels = out_utts[:, 1:].contiguous() - - # create decoder initial states - enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) - - # create decoder initial states - if self.simple_posterior: - q_mu, q_logvar = self.c2z(enc_last) - sample_z = self.gauss_connector(q_mu, q_logvar) - p_mu, p_logvar = self.zero, self.zero - - # pack attention context - dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) - attn_context = None - - # decode - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, - dec_inputs=dec_inputs, - dec_init_state=dec_init_state, # tuple: (h, c) - attn_context=attn_context, - mode=mode, - gen_type=gen_type, - beam_size=self.config.beam_size) # (batch_size, goal_nhid) - if mode == GEN: - ret_dict['sample_z'] = sample_z - ret_dict['q_mu'] = q_mu - ret_dict['q_logvar'] = q_logvar - return ret_dict, labels - - else: - result = Pack(nll=self.nll(dec_outputs, labels)) - pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) - result['pi_kl'] = pi_kl - result['nll'] = self.nll(dec_outputs, labels) - return result - - def get_z_via_vae(self, responses): - batch_size = responses.shape[0] - aux_utt_summary, _, aux_enc_outs = self.aux_encoder(responses.unsqueeze(1)) - - # create decoder initial states - aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) - - aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) - aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) - - return aux_sample_z, aux_q_mu, aux_q_logvar - - def get_z_via_rg(self, data_feed): - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - q_mu, q_logvar = self.c2z(enc_last) - - sample_z = self.gauss_connector(q_mu, q_logvar) - - return sample_z, q_mu, q_logvar - - - def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): - """ - generate response from latent var - """ - - if data_feed: - ctx_lens = data_feed['context_lens'] # (batch_size, ) - short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) - bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) - - # pack attention context - if isinstance(sample_y, np.ndarray): - sample_y = self.np2var(sample_y, FLOAT) - - dec_init_state = self.z_embedding(sample_y.unsqueeze(0)) - if (dec_init_state != dec_init_state).any(): - pdb.set_trace() - attn_context = None - - # decode - # if self.state_for_decoding: - # if not data_feed: - # raise ValueError - # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) - # # create decoder initial states - # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) - # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) - - if self.config.dec_rnn_cell == 'lstm': - dec_init_state = tuple([dec_init_state, dec_init_state]) - - # decode - logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, - dec_init_state=dec_init_state, - attn_context=attn_context, - vocab=self.vocab, - max_words=max_words, - temp=0.1) - - return logprobs, outs - - def gaussian_logprob(self, mu, logvar, sample_z): - var = th.exp(logvar) - constant = float(-0.5 * np.log(2*np.pi)) - logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) - return logprob - - return logprobs, outs, joint_logpz, sample_z - - def pad_to(self, max_len, tokens, do_pad): - if len(tokens) >= max_len: - # print("cutting off, ", tokens) - return tokens[: max_len-1] + [tokens[-1]] - elif do_pad: - return tokens + [0] * (max_len - len(tokens)) - else: - return tokens diff --git a/convlab/policy/lava/multiwoz/lava.py b/convlab/policy/lava/multiwoz/lava.py index 76ea072ddb87f64a0144de9ae421707ea057c7e4..76d177396c4d9a9d4c7e20f18ce652f6f8852ad5 100755 --- a/convlab/policy/lava/multiwoz/lava.py +++ b/convlab/policy/lava/multiwoz/lava.py @@ -418,16 +418,12 @@ DEFAULT_CUDA_DEVICE = -1 class LAVA(Policy): def __init__(self, - model_file="/gpfs/project/lubis/public_code/LAVA/experiments_woz/sys_config_log_model/2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48/reward_best.model", is_train=False): + model_file="", is_train=False): if not model_file: raise Exception("No model for LAVA is specified!") temp_path = os.path.dirname(os.path.abspath(__file__)) - # print(temp_path) - #zip_ref = zipfile.ZipFile(archive_file, 'r') - # zip_ref.extractall(temp_path) - # zip_ref.close() self.prev_state = default_state() self.prev_active_domain = None @@ -435,24 +431,7 @@ class LAVA(Policy): domain_name = 'object_division' domain_info = domain.get_domain(domain_name) self.db=Database() - # data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/data_2.1/') - # train_data_path = os.path.join(data_path,'train_dials.json') - # if not os.path.exists(train_data_path): - # zipped_file = os.path.join(data_path, 'norm-multi-woz.zip') - # archive = zipfile.ZipFile(zipped_file, 'r') - # archive.extractall(data_path) - - # norm_multiwoz_path = data_path - # with open(os.path.join(norm_multiwoz_path, 'input_lang.index2word.json')) as f: - # self.input_lang_index2word = json.load(f) - # with open(os.path.join(norm_multiwoz_path, 'input_lang.word2index.json')) as f: - # self.input_lang_word2index = json.load(f) - # with open(os.path.join(norm_multiwoz_path, 'output_lang.index2word.json')) as f: - # self.output_lang_index2word = json.load(f) - # with open(os.path.join(norm_multiwoz_path, 'output_lang.word2index.json')) as f: - # self.output_lang_word2index = json.load(f) - - + path, _ = os.path.split(model_file) if "rl-" in model_file: rl_config_path = os.path.join(path, "rl_config.json") @@ -467,13 +446,12 @@ class LAVA(Policy): try: self.corpus = corpora_inference.NormMultiWozCorpus(config) except (FileNotFoundError, PermissionError): - train_data_path = "/gpfs/project/lubis/LAVA_code/LAVA_dev/data/norm-multi-woz/train_dials.json" + train_data_path = "/gpfs/project/lubis/NeuralDialog-LaRL/data/norm-multi-woz/train_dials.json" config['train_path'] = train_data_path config['valid_path'] = train_data_path.replace("train", "val") config['test_path'] = train_data_path.replace("train", "test") self.corpus = corpora_inference.NormMultiWozCorpus(config) - if "rl" in model_file: if "gauss" in model_file: self.model = SysPerfectBD2Gauss(self.corpus, config) @@ -510,7 +488,6 @@ class LAVA(Policy): self.rl_config["US_best_reward_model_path"] = model_file.replace( ".model", "_US.model") if "lr_rl" not in config: - #config["lr_rl"] = config["init_lr"] self.config["lr_rl"] = 0.01 self.config["gamma"] = 0.99 @@ -523,7 +500,6 @@ class LAVA(Policy): lr=self.config.lr_rl, momentum=self.config.momentum, nesterov=False) - # nesterov=(self.config.nesterov and self.config.momentum > 0)) if config.use_gpu: self.model.load_state_dict(torch.load(model_file)) @@ -585,39 +561,6 @@ class LAVA(Policy): utts.append(context[b_id, context_lens[b_id]-1]) return np.array(utts) - def get_active_domain_test(self, prev_active_domain, prev_action, action): - domains = ['hotel', 'restaurant', 'attraction', - 'train', 'taxi', 'hospital', 'police'] - active_domain = None - cur_action_keys = action.keys() - state = [] - for act in cur_action_keys: - slots = act.split('-') - action = slots[0].lower() - state.append(action) - - # print('get_active_domain') - # for domain in domains: - """for domain in range(len(domains)): - domain = domains[i] - if domain not in prev_state and domain not in state: - continue - if domain in prev_state and domain not in state: - return domain - elif domain not in prev_state and domain in state: - return domain - elif prev_state[domain] != state[domain]: - active_domain = domain - if active_domain is None: - active_domain = prev_active_domain""" - if len(state) != 0: - active_domain = state[0] - if active_domain is None: - active_domain = prev_active_domain - elif active_domain == "general": - active_domain = prev_active_domain - return active_domain - def is_masked_action(self, bs_label, db_label, response): """ check if the generated response should be masked based on belief state and db result @@ -657,8 +600,6 @@ class LAVA(Policy): else: return False - - def get_active_domain_unified(self, prev_active_domain, prev_state, state): domains = ['hotel', 'restaurant', 'attraction', 'train', 'taxi', 'hospital', 'police'] @@ -817,9 +758,6 @@ class LAVA(Policy): self.prev_output = outputs mul = False - # if self.is_masked_action(data_feed['bs'][0], data_feed['db'][0], outputs) and i < 9: # if it's the last try, accept masked action - # continue - # default lexicalization if active_domain is not None and active_domain in num_results: num_results = num_results[active_domain] @@ -847,9 +785,6 @@ class LAVA(Policy): if active_domain in ["hotel", "attraction", "train", "restaurant"] and active_domain not in top_results.keys(): # no db match for active domain if any([p in outputs for p in REQ_TOKENS[active_domain]]): - # self.fail_info_penalty += 1 - # response = "I am sorry there are no matches." - # print(outputs) response = "I am sorry, can you say that again?" database_results = {} else: @@ -868,316 +803,15 @@ class LAVA(Policy): print("can not lexicalize: ", outputs) response = "I am sorry, can you say that again?" - - # for domain in DOMAIN_REQ_TOKEN: - # """ - # mask out of domain action - # """ - # if domain != active_domain and any([p in outputs for p in REQ_TOKENS[domain]]): - # # print(f"MASK: illegal action for {active_domain}: {outputs}") - # response = "Can I help you with anything else?" - # self.wrong_domain_penalty += 1 - - - # if active_domain is not None: - # print ("===========================") - - # print(active_domain) - # print ("BS: ") - # for k, v in bstate[active_domain].items(): - # print(k, ": ", v) - # print ("DB: ") - # if len(database_results.keys()) > 0: - # for k, v in database_results[active_domain][1][0 % database_results[active_domain][0]].items(): - # print(k, ": ", v) - # print ("===========================") - # print ("input: ") - # for turn in data_feed['contexts'][0]: - # print(" ".join(self.corpus.id2sent(turn))) - # print ("system delex: ", outputs) - # print ("system: ",response) - # print ("===========================\n") - # self.num_generated_response += 1 - # break - response = response.replace("free pounds", "free") response = response.replace("pounds pounds", "pounds") - # print("final lexicalization: ", response) if any([p in response for p in ["not mentioned", "dontcare", "[", "]"]]): - # response = "I am sorry there are no matches." - # print(usr) - # print(outputs) - # print(response) - # print(active_domain, len(top_results[active_domain]), delex_bstate[active_domain]) - # pdb.set_trace() - # response = "I am sorry can you repeat that?" response = "I am sorry, can you say that again?" - # print("final response: ", response) - return response, active_domain - def populate_template_unified(self, template, top_results, num_results, state, active_domain): - # print("template:",template) - # print("top_results:",top_results) - # active_domain = None if len( - # top_results.keys()) == 0 else list(top_results.keys())[0] - - template = template.replace( - 'book [value_count] of', 'book one of') - tokens = template.split() - response = [] - for index, token in enumerate(tokens): - if token.startswith('[') and (token.endswith(']') or token.endswith('].') or token.endswith('],')): - domain = token[1:-1].split('_')[0] - slot = token[1:-1].split('_')[1] - if slot.endswith(']'): - slot = slot[:-1] - if domain == 'train' and slot == 'id': - slot = 'trainID' - elif active_domain != 'train' and slot == 'price': - slot = 'price range' - elif slot == 'reference': - slot = 'Ref' - if domain in top_results and len(top_results[domain]) > 0 and slot in top_results[domain]: - # print('{} -> {}'.format(token, top_results[domain][slot])) - response.append(top_results[domain][slot]) - elif domain == 'value': - if slot == 'count': - if "there are" in " ".join(tokens[index-2:index]) or "i have" in " ".join(tokens[index-2:index]): - response.append(str(num_results)) - # the first [value_count], the last [value_count] - elif "the" in tokens[index-2]: - response.append("one") - elif active_domain == "restaurant": - if "people" in tokens[index:index+1] or "table" in tokens[index-2:index]: - response.append( - state[active_domain]["book people"]) - elif active_domain == "train": - if "ticket" in " ".join(tokens[index-2:index+1]) or "people" in tokens[index:]: - response.append( - state[active_domain]["book people"]) - elif index+1 < len(tokens) and "minute" in tokens[index+1]: - response.append( - top_results['train']['duration'].split()[0]) - elif active_domain == "hotel": - if index+1 < len(tokens): - if "star" in tokens[index+1]: - try: - response.append(top_results['hotel']['stars']) - except: - response.append(state['hotel']['stars']) - elif "nights" in tokens[index+1]: - response.append( - state[active_domain]["book stay"]) - elif "people" in tokens[index+1]: - response.append( - state[active_domain]["book people"]) - elif active_domain == "attraction": - if index + 1 < len(tokens): - if "pounds" in tokens[index+1] and "entrance fee" in " ".join(tokens[index-3:index]): - value = top_results[active_domain]['entrance fee'] - if "?" in value: - value = "unknown" - # if "?" not in value: - # try: - # value = str(int(value)) - # except: - # value = 'free' - # else: - # value = "unknown" - response.append(value) - # if "there are" in " ".join(tokens[index-2:index]): - # response.append(str(num_results)) - # elif "the" in tokens[index-2]: # the first [value_count], the last [value_count] - # response.append("1") - else: - response.append(str(num_results)) - elif slot == 'place': - if 'arriv' in " ".join(tokens[index-2:index]) or "to" in " ".join(tokens[index-2:index]): - if active_domain == "train": - try: - response.append( - top_results[active_domain]["destination"]) - except: - response.append( - state[active_domain]['semi']["destination"]) - elif active_domain == "taxi": - response.append( - state[active_domain]["destination"]) - elif 'leav' in " ".join(tokens[index-2:index]) or "from" in tokens[index-2:index] or "depart" in " ".join(tokens[index-2:index]): - if active_domain == "train": - try: - response.append( - top_results[active_domain]["departure"]) - except: - response.append( - state[active_domain]['semi']["departure"]) - elif active_domain == "taxi": - response.append( - state[active_domain]['semi']["departure"]) - elif "hospital" in template: - response.append("Cambridge") - else: - try: - for d in state: - if d == 'history': - continue - for s in ['destination', 'departure']: - if s in state[d]: - response.append( - state[d][s]) - raise - except: - pass - else: - response.append(token) - elif slot == 'time': - if 'arrive' in ' '.join(response[-5:]) or 'arrival' in ' '.join(response[-5:]) or 'arriving' in ' '.join(response[-3:]): - if active_domain == "train" and 'arriveBy' in top_results[active_domain]: - # print('{} -> {}'.format(token, top_results[active_domain]['arriveBy'])) - response.append( - top_results[active_domain]['arriveBy']) - continue - for d in state: - if d == 'history': - continue - if 'arrive by' in state[d]: - response.append( - state[d]['arrive by']) - break - elif 'leave' in ' '.join(response[-5:]) or 'leaving' in ' '.join(response[-5:]) or 'departure' in ' '.join(response[-3:]): - if active_domain == "train" and 'leaveAt' in top_results[active_domain]: - # print('{} -> {}'.format(token, top_results[active_domain]['leaveAt'])) - response.append( - top_results[active_domain]['leaveAt']) - continue - for d in state: - if d == 'history': - continue - if 'leave at' in state[d]: - response.append( - state[d]['leave at']) - break - elif 'book' in response or "booked" in response: - if state['restaurant']['book time'] != "": - response.append( - state['restaurant']['book time']) - else: - try: - for d in state: - if d == 'history': - continue - for s in ['arrive by', 'leave at']: - if s in state[d]: - response.append( - state[d][s]) - raise - except: - pass - else: - response.append(token) - elif slot == 'price': - if active_domain == 'attraction': - # .split()[0] - value = top_results['attraction']['entrance fee'] - if "?" in value: - value = "unknown" - # if "?" not in value: - # try: - # value = str(int(value)) - # except: - # value = 'free' - # else: - # value = "unknown" - response.append(value) - elif active_domain == "train": - value = top_results[active_domain][slot].split()[0] - if state[active_domain]['book people'] not in ["", "dontcare"]: - try: - value = str(float(value) * int(state[active_domain]['book people'])) - except ValueError: - int_map = {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5, "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10} - value = str(float(value) * int_map[state[active_domain]['book people']]) - response.append(value) - elif slot == "day" and active_domain in ["restaurant", "hotel"]: - if state[active_domain]['book day'] != "": - response.append( - state[active_domain]['book day']) - - else: - # slot-filling based on query results - for d in top_results: - if slot in top_results[d]: - response.append(top_results[d][slot]) - break - else: - # slot-filling based on belief state - for d in state: - if d == 'history': - continue - if slot in state[d]: - response.append(state[d][slot]) - break - else: - response.append(token) - else: - if domain == 'hospital': - if slot == 'phone': - response.append('01223216297') - elif slot == 'department': - if state['hospital']['department'] != "": - response.append(state['hospital']['department']) - else: - response.append('neurosciences critical care unit') - elif slot == 'address': - response.append("56 Lincoln street") - elif slot == "postcode": - response.append('533421') - elif domain == 'police': - if slot == 'phone': - response.append('01223358966') - elif slot == 'name': - response.append('Parkside Police Station') - elif slot == 'address': - response.append('Parkside, Cambridge') - elif slot == 'postcode': - response.append('533420') - elif domain == 'taxi': - if slot == 'phone': - response.append('01223358966') - elif slot == 'color': - # response.append(random.choice(["black","white","red","yellow","blue",'grey'])) - response.append("black") - elif slot == 'type': - # response.append(random.choice(["toyota","skoda","bmw",'honda','ford','audi','lexus','volvo','volkswagen','tesla'])) - response.append("toyota") - else: - # print(token) - response.append(token) - else: - if token == "pounds" and len(response) > 0 and ("pounds" in response[-1] or "unknown" in response[-1] or "free" in response[-1]): - pass - else: - response.append(token) - - try: - response = ' '.join(response) - except Exception as e: - # pprint(response) - raise - response = response.replace(' -s', 's') - response = response.replace(' -ly', 'ly') - response = response.replace(' .', '.') - response = response.replace(' ?', '?') - - # if "not mentioned" in response: - # pdb.set_trace() - - return response - def populate_template_unified(self, template, top_results, num_results, state, active_domain): # print("template:",template) # print("top_results:",top_results) @@ -1505,9 +1139,9 @@ class LAVA(Policy): if d == 'history': continue for s in ['destination', 'departure']: - if s in state[d]['semi']: + if s in state[d]: response.append( - state[d]['semi'][s]) + state[d][s]) raise except: pass @@ -1523,9 +1157,9 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if 'arriveBy' in state[d]['semi']: + if 'arriveBy' in state[d]: response.append( - state[d]['semi']['arriveBy']) + state[d]['arrive by']) break elif 'leav' in ' '.join(response[-7:]) or 'depart' in ' '.join(response[-7:]): if active_domain is not None and 'leaveAt' in top_results[active_domain][result_idx]: @@ -1536,23 +1170,23 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if 'leaveAt' in state[d]['semi']: + if 'leave at' in state[d]: response.append( - state[d]['semi']['leaveAt']) + state[d]['leave at']) break elif 'book' in response or "booked" in response: - if state['restaurant']['book']['time'] != "": + if state['restaurant']['book time'] != "": response.append( - state['restaurant']['book']['time']) + state['restaurant']['book time']) else: try: for d in state: if d == 'history': continue - for s in ['arriveBy', 'leaveAt']: - if s in state[d]['semi']: + for s in ['arrive by', 'leave at']: + if s in state[d]: response.append( - state[d]['semi'][s]) + state[d][s]) raise except: pass @@ -1574,9 +1208,9 @@ class LAVA(Policy): response.append( top_results[active_domain][result_idx][slot].split()[0]) elif slot == "day" and active_domain in ["restaurant", "hotel"]: - if state[active_domain]['book']['day'] != "": + if state[active_domain]['book day'] != "": response.append( - state[active_domain]['book']['day']) + state[active_domain]['book day']) else: # slot-filling based on query results @@ -1590,8 +1224,8 @@ class LAVA(Policy): for d in state: if d == 'history': continue - if slot in state[d]['semi']: - response.append(state[d]['semi'][slot]) + if slot in state[d]: + response.append(state[d][slot]) break else: response.append(token) @@ -1641,26 +1275,16 @@ class LAVA(Policy): return response def model_predict(self, data_feed): - # TODO use model's forward function, add null vector for the target response self.logprobs = [] logprobs, pred_labels, joint_logpz, sample_y = self.model.forward_rl( data_feed, self.model.config.max_dec_len) - # if len(data_feed['bs']) == 1: - # logprobs = [logprobs] - # for log_prob in logprobs: - # self.logprobs.extend(log_prob) self.logprobs.extend(joint_logpz) pred_labels = np.array( - [pred_labels], dtype=int) # .squeeze(-1).swapaxes(0, 1) + [pred_labels], dtype=int) de_tknize = get_detokenize() - # if pred_labels.shape[1] == self.model.config.max_utt_len: - # pdb.set_trace() pred_str = get_sent(self.model.vocab, de_tknize, pred_labels, 0) - #for b_id in range(pred_labels.shape[0]): - # only one val for pred_str now - # pred_str = get_sent(self.model.vocab, de_tknize, pred_labels, b_id) return pred_str @@ -1683,11 +1307,8 @@ class LAVA(Policy): loss = 0 # estimate the loss using one MonteCarlo rollout - # TODO better loss estimation? - # TODO better update, instead of reinforce? for lp, re in zip(logprobs, rewards): loss -= lp * re - #tmp = self.model.state_dict()['c2z.p_h.weight'].clone() self.opt.zero_grad() if "fp16" in self.config and self.config.fp16: with amp.scale_loss(loss, self.opt) as scaled_loss: @@ -1697,10 +1318,7 @@ class LAVA(Policy): loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) - #self._print_grad() self.opt.step() - #tmp2 = self.model.state_dict()['c2z.p_h.weight'].clone() - #print(tmp==tmp2) def _print_grad(self): for name, p in self.model.named_parameters(): @@ -1876,7 +1494,8 @@ if __name__ == '__main__': 'history': [['sys', ''], ['user', 'Could you book a 4 stars hotel east of town for one night, 1 person?']]} - cur_model = LAVA() + model_file="path/to/model" # points to model from lava repo + cur_model = LAVA(model_file) response = cur_model.predict(state) # print(response) diff --git a/examples/agent_examples/test_LAVA.py b/examples/agent_examples/test_LAVA.py index 63bf075c5a25dcdfc8b89175d42dbf6272b23f0e..d5b592c2e09f8c7f73ebf925abdf52b1a4be24b6 100755 --- a/examples/agent_examples/test_LAVA.py +++ b/examples/agent_examples/test_LAVA.py @@ -132,8 +132,8 @@ def test_end2end(args, model_dir): #seed = 2020 set_seed(args.seed) - model_name = '{}_{}_lava_{}_tmp'.format(args.US_type, args.dst_type, model_dir) - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=100) + model_name = '{}_{}_lava_{}'.format(args.US_type, args.dst_type, model_dir) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=500) if __name__ == '__main__': parser = ArgumentParser()