Skip to content
Snippets Groups Projects
Commit 1e2f4dc0 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Merge remote-tracking branch 'origin/lava_unifiedformat2' into github_master

# Conflicts:
#	convlab/dst/setsumbt/dataset/unified_format.py
#	convlab/dst/setsumbt/do/nbt.py
#	convlab/dst/setsumbt/modeling/training.py
#	convlab/policy/ppo/setsumbt_config.json
#	convlab/policy/ppo/train.py
parents 1452c5a7 04fa5948
No related branches found
No related tags found
No related merge requests found
Showing
with 4430 additions and 72 deletions
......@@ -66,7 +66,8 @@ convlab/nlu/jointBERT_new/**/output/
convlab/nlu/milu/09*
convlab/nlu/jointBERT/multiwoz/configs/multiwoz_new_usr_context.json
convlab/nlu/milu/multiwoz/configs/system_without_context.jsonnet
convlab/nlu/milu/multiwoz/configs/user_without_context.jsonnet
convlab/nlu/milu/multiwoz/configs/user_without_context.jsonnet\
*.pkl
# test script
*_test.py
......@@ -87,7 +88,6 @@ dist
convlab.egg-info
# configs
*experiment*
*pretrained_models*
.ipynb_checkpoints
......
......@@ -7,6 +7,7 @@ from convlab.policy import Policy
from convlab.nlg import NLG
from copy import deepcopy
import time
import pdb
from pprint import pprint
......@@ -63,7 +64,7 @@ class PipelineAgent(Agent):
===== ===== ====== === == ===
"""
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str, return_semantic_acts=False):
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str):
"""The constructor of PipelineAgent class.
Here are some special combination cases:
......@@ -94,7 +95,7 @@ class PipelineAgent(Agent):
self.dst = dst
self.policy = policy
self.nlg = nlg
self.return_semantic_acts = return_semantic_acts
self.init_session()
self.agent_saves = []
self.history = []
......@@ -151,6 +152,7 @@ class PipelineAgent(Agent):
self.input_action = self.nlu.predict(
observation, context=[x[1] for x in self.history[:-1]])
# print("system semantic action: ", self.input_action)
else:
self.input_action = observation
self.input_action_eval = observation
......@@ -186,7 +188,7 @@ class PipelineAgent(Agent):
if type(self.output_action) == list:
for intent, domain, slot, value in self.output_action:
if intent == "book":
if intent.lower() == "book":
self.dst.state['booked'][domain] = [{slot: value}]
else:
self.dst.state['user_action'] = self.output_action
......@@ -196,8 +198,6 @@ class PipelineAgent(Agent):
self.history.append([self.name, model_response])
self.turn += 1
if self.return_semantic_acts:
return self.output_action
self.agent_saves.append(self.save_info())
return model_response
......
......@@ -3,6 +3,7 @@
import logging
import re
import numpy as np
import pdb
from copy import deepcopy
from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map
......@@ -158,6 +159,15 @@ class MultiWozEvaluator(Evaluator):
list[intent, domain, slot, value]
"""
new_acts = list()
for intent, domain, slot, value in da_turn:
if intent.lower() == 'book':
ref = [_value for _intent, _domain, _slot, _value in da_turn if _domain == domain and _intent.lower() == 'inform' and _slot.lower() == 'ref']
ref = ref[0] if ref else ''
value = ref
new_acts.append([intent, domain, slot, value])
da_turn = new_acts
da_turn = self._convert_action(da_turn)
for intent, domain, slot, value in da_turn:
......
## LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization
Codebase for [LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization](https://), published as a long paper in COLING 2020. The code is developed based on the implementations of the [LaRL](https://arxiv.org/abs/1902.08858) paper.
ConvLab3 interface for [LAVA: Latent Action Spaces via Variational Auto-encoding for Dialogue Policy Optimization](https://aclanthology.org/2020.coling-main.41/), published as a long paper in COLING 2020.
### Requirements
python 3
pytorch
numpy
To train a LAVA model, clone and follow instructions from the [original LAVA repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/lava-public).
### Data
The pre-processed MultiWoz 2.0 data is included in data.zip. Unzip the compressed file and access the data under **data/norm-multi-woz**.
With a (pre-)trained LAVA model, it is possible to evaluate or perform online RL with ConvLab3 US by loading the lava module with
### Over structure:
The implementation of the models, as well as training and evaluation scripts are under **latent_dialog**.
The scripts for running the experiments are under **experiment_woz**. The trained models and evaluation results are under **experiment_woz/sys_config_log_model**.
- from convlab.policy.lava.multiwoz import LAVA
There are 3 types of training to achieve the final model.
and using it as the policy module in the ConvLab pipeline (NLG should be set to None).
### Step 1: Unsupervised training (variational auto-encoding (VAE) task)
Given a dialogue response, the model is tasked to reproduce it via latent variables. With this task we aim to unsupervisedly capture generative factors of dialogue responses.
- sl_cat_ae.py: train a VAE model using categorical latent variable
- sl_gauss_ae.py: train a VAE model using continuous (Gaussian) latent variable
Code example can be found at
- ConvLab-3/examples/agent_examples/test_LAVA.py
### Step 2: Supervised training (response generation task)
The supervised training step of the variational encoder-decoder model could be done 4 different ways.
1. from scratch:
- sl_word: train a standard encoder decoder model using supervised learning (SL)
- sl_cat: train a latent action model with categorical latent variables using SL,
- sl_gauss: train a latent action model with continuous latent varaibles using SL,
2. using the VAE models as pre-trained model (equivalent to LAVA_pt):
- finetune_cat_ae: use the VAE with categorical latent variables as weight initialization, and then fine-tune the model on response generation task
- finetune_gauss_ae: as above but with continuous latent variables
- Note: Fine-tuning can be set to be selective (only fine-tune encoder) or not (fine-tune the entire network) using the "selective_finetune" argument in config
3. using the distribution of the VAE models to obtain informed prior (equivalent to LAVA_kl):
- actz_cat: initialized new encoder is combined with pre-trained VAE decoder and fine-tuned on response generation task. VAE encoder is used to obtain an informed prior of the target response and not trained further.
- actz_gauss: as above but with continuous latent variables
4. or simultaneously from scrath with VAE task in a multi-task fashion (equivalent to LAVA_mt):
- mt_cat: train a model to optimize both auto-encoding and response generation in a multi-task fashion, using categorical latent variables
- mt_gauss: as above but with continuous latent variables
No.1 and 4 can be directly trained without Step 1. No. 2 and 3 requires a pre-trained VAE model, given via a dictionary
pretrained = {"2020-02-26-18-11-37-sl_cat_ae":100}
### Step 3: Reinforcement Learning
The model can be further optimized with RL to maximize the dialogue success.
Each script is used for:
- reinforce_word: fine tune a pretrained model with word-level policy gradient (PG)
- reinforce_cat: fine tune a pretrained categorical latent action model with latent-level PG.
- reinforce_gauss: fine tune a pretrained gaussian latent action model with latent-level PG.
The script takes a file containing list of test results from the SL step.
f_in = "sys_config_log_model/test_files.lst"
### Checking the result
The evaluation result can be found at the bottom of the test_file.txt. We provide the best model in this repo.
NOTE: when re-running the experiments some variance is to be expected in the numbers due to factors such as random seed and hardware specificiations. Some methods are more sensitive to this than others.
# @Time : 10/18/18 1:55 PM
# @Author : Tiancheng Zhao
\ No newline at end of file
import torch as th
import torch.nn as nn
import torch.optim as optim
import numpy as np
from convlab.policy.lava.multiwoz.latent_dialog.utils import LONG, FLOAT, Pack, get_detokenize
from convlab.policy.lava.multiwoz.latent_dialog.main import get_sent
from convlab.policy.lava.multiwoz.latent_dialog.data_loaders import BeliefDbDataLoaders
from collections import deque, namedtuple, defaultdict
import random
import pdb
import dill
class OfflineRlAgent(object):
def __init__(self, model, corpus, args, name, tune_pi_only):
self.model = model
self.corpus = corpus
self.args = args
self.name = name
self.raw_goal = None
self.vec_goals_list = None
self.logprobs = None
print("Do we only tune the policy: {}".format(tune_pi_only))
self.opt = optim.SGD(
[p for n, p in self.model.named_parameters() if 'c2z' in n or not tune_pi_only],
lr=self.args.rl_lr,
momentum=self.args.momentum,
nesterov=(self.args.nesterov and self.args.momentum > 0))
# self.opt = optim.Adam(self.model.parameters(), lr=0.01)
# self.opt = optim.RMSprop(self.model.parameters(), lr=0.0005)
self.all_rewards = []
self.all_grads = []
self.model.train()
def print_dialog(self, dialog, reward, stats):
for t_id, turn in enumerate(dialog):
if t_id % 2 == 0:
print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>'])))
else:
print("Sys: {}".format(' '.join(turn)))
report = ['{}: {}'.format(k, v) for k, v in stats.items()]
print("Reward {}. {}".format(reward, report))
def run(self, batch, evaluator, max_words=None, temp=0.1):
self.logprobs = []
self.dlg_history =[]
batch_size = len(batch['keys'])
logprobs, outs = self.model.forward_rl(batch, max_words, temp)
if batch_size == 1:
logprobs = [logprobs]
outs = [outs]
key = batch['keys'][0]
sys_turns = []
# construct the dialog history for printing
for turn_id, turn in enumerate(batch['contexts']):
user_input = self.corpus.id2sent(turn[-1])
self.dlg_history.append(user_input)
sys_output = self.corpus.id2sent(outs[turn_id])
self.dlg_history.append(sys_output)
sys_turns.append(' '.join(sys_output))
for log_prob in logprobs:
self.logprobs.extend(log_prob)
# compute reward here
generated_dialog = {key: sys_turns}
return evaluator.evaluateModel(generated_dialog, mode="offline_rl")
def update(self, reward, stats):
self.all_rewards.append(reward)
# standardize the reward
r = (reward - np.mean(self.all_rewards)) / max(1e-4, np.std(self.all_rewards))
# compute accumulated discounted reward
g = self.model.np2var(np.array([r]), FLOAT).view(1, 1)
rewards = []
for _ in self.logprobs:
rewards.insert(0, g)
g = g * self.args.gamma
loss = 0
# estimate the loss using one MonteCarlo rollout
for lp, r in zip(self.logprobs, rewards):
loss -= lp * r
self.opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.rl_clip)
# for name, p in self.model.named_parameters():
# print(name)
# print(p.grad)
self.opt.step()
class OfflineLatentRlAgent(OfflineRlAgent):
def run(self, batch, evaluator, max_words=None, temp=0.1):
self.logprobs = []
self.dlg_history =[]
batch_size = len(batch['keys'])
logprobs, outs, logprob_z, sample_z = self.model.forward_rl(batch, max_words, temp)
if batch_size == 1:
outs = [outs]
key = batch['keys'][0]
sys_turns = []
# construct the dialog history for printing
for turn_id, turn in enumerate(batch['contexts']):
user_input = self.corpus.id2sent(turn[-1])
self.dlg_history.append(user_input)
sys_output = self.corpus.id2sent(outs[turn_id])
self.dlg_history.append(sys_output)
sys_turns.append(' '.join(sys_output))
for b_id in range(batch_size):
self.logprobs.append(logprob_z[b_id])
# compute reward here
generated_dialog = {key: sys_turns}
return evaluator.evaluateModel(generated_dialog, mode="offline_rl")
This diff is collapsed.
import numpy as np
import logging
class BaseDataLoaders(object):
def __init__(self, name):
self.data_size = None
self.indexes = None
self.name = name
def _shuffle_indexes(self):
np.random.shuffle(self.indexes)
def _shuffle_batch_indexes(self):
np.random.shuffle(self.batch_indexes)
def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
self.ptr = 0
self.batch_size = config.batch_size
self.num_batch = self.data_size // config.batch_size
if verbose:
print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch))
if shuffle and not fix_batch:
self._shuffle_indexes()
self.batch_indexes = []
for i in range(self.num_batch):
self.batch_indexes.append(self.indexes[i*self.batch_size: (i+1)*self.batch_size])
if shuffle and fix_batch:
self._shuffle_batch_indexes()
if verbose:
print('%s begins with %d batches' % (self.name, self.num_batch))
def next_batch(self):
if self.ptr < self.num_batch:
selected_ids = self.batch_indexes[self.ptr]
self.ptr += 1
return self._prepare_batch(selected_index=selected_ids)
else:
return None
def _prepare_batch(self, *args, **kwargs):
raise NotImplementedError('Have to override _prepare_batch()')
def pad_to(self, max_len, tokens, do_pad):
if len(tokens) >= max_len:
return tokens[: max_len-1] + [tokens[-1]]
elif do_pad:
return tokens + [0] * (max_len - len(tokens))
else:
return tokens
class LongDataLoader(object):
"""A special efficient data loader for TBPTT. Assume the data contains
N long sequences, each sequence has length k_i
:ivar batch_size: the size of a minibatch
:ivar backward_size: how many steps in time to do BP
:ivar step_size: how fast we move the window
:ivar ptr: the current idx of batch
:ivar num_batch: the total number of batch
:ivar batch_indexes: a list of list. Each item is the IDs in this batch
:ivar grid_indexes: a list of (b_id, s_id, e_id). b_id is the index of
batch, s_id is the starting time id in that batch and e_id is the ending
time id.
:ivar indexes: a list, the ordered of sequences ID it should go through
:ivar data_size: the number of sequences, N.
:ivar data_lens: a list containing k_i
:ivar prev_alive_size:
:ivar name: the name of the this data loader
"""
logger = logging.getLogger()
def __init__(self, name):
self.batch_size = 0
self.backward_size = 0
self.step_size = 0
self.ptr = 0
self.num_batch = None
self.batch_indexes = None # one batch is a dialog
self.grid_indexes = None # grid is the tokenized versiion
self.indexes = None
self.data_lens = None
self.data_size = None
self.name = name
def _shuffle_batch_indexes(self):
np.random.shuffle(self.batch_indexes)
def _shuffle_grid_indexes(self):
np.random.shuffle(self.grid_indexes)
def _prepare_batch(self, cur_grid, prev_grid):
raise NotImplementedError("Have to override prepare batch")
def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
assert len(self.indexes) == self.data_size and \
len(self.data_lens) == self.data_size
# make sure backward_size can be divided by step size
assert config.backward_size % config.step_size == 0
self.ptr = 0
self.batch_size = config.batch_size
self.backward_size = config.backward_size
self.step_size = config.step_size
# create batch indexes
temp_num_batch = self.data_size // config.batch_size
self.batch_indexes = []
for i in range(temp_num_batch):
self.batch_indexes.append(
self.indexes[i * self.batch_size:(i + 1) * self.batch_size])
left_over = self.data_size - temp_num_batch * config.batch_size
if shuffle:
self._shuffle_batch_indexes()
# create grid indexes
self.grid_indexes = []
for idx, b_ids in enumerate(self.batch_indexes):
# assume the b_ids are sorted
all_lens = [self.data_lens[i] for i in b_ids]
max_len = self.data_lens[b_ids[0]]
min_len = self.data_lens[b_ids[-1]]
assert np.max(all_lens) == max_len
assert np.min(all_lens) == min_len
num_seg = (max_len - self.backward_size - self.step_size) // self.step_size
cut_start, cut_end = [], []
if num_seg > 1:
cut_start = list(range(config.step_size, num_seg * config.step_size, config.step_size))
cut_end = list(range(config.backward_size + config.step_size,
num_seg * config.step_size + config.backward_size,
config.step_size))
assert cut_end[-1] < max_len
actual_size = min(max_len, config.backward_size)
temp_end = list(range(2, actual_size, config.step_size))
temp_start = [0] * len(temp_end)
cut_start = temp_start + cut_start
cut_end = temp_end + cut_end
assert len(cut_end) == len(cut_start)
new_grids = [(idx, s_id, e_id) for s_id, e_id in
zip(cut_start, cut_end) if s_id < min_len - 1]
self.grid_indexes.extend(new_grids)
# shuffle batch indexes
if shuffle:
self._shuffle_grid_indexes()
self.num_batch = len(self.grid_indexes)
if verbose:
self.logger.info("%s init with %d batches with %d left over samples" %
(self.name, self.num_batch, left_over))
def next_batch(self):
if self.ptr < self.num_batch:
current_grid = self.grid_indexes[self.ptr]
if self.ptr > 0:
prev_grid = self.grid_indexes[self.ptr - 1]
else:
prev_grid = None
self.ptr += 1
return self._prepare_batch(cur_grid=current_grid,
prev_grid=prev_grid)
else:
return None
def pad_to(self, max_len, tokens, do_pad=True):
if len(tokens) >= max_len:
return tokens[0:max_len - 1] + [tokens[-1]]
elif do_pad:
return tokens + [0] * (max_len - len(tokens))
else:
return tokens
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from convlab.policy.lava.multiwoz.latent_dialog.utils import INT, FLOAT, LONG, cast_type
import pdb
class BaseModel(nn.Module):
def __init__(self, config):
super(BaseModel, self).__init__()
self.use_gpu = config.use_gpu
self.config = config
self.kl_w = 0.0
def np2var(self, inputs, dtype):
if inputs is None:
return None
return cast_type(Variable(th.from_numpy(inputs)),
dtype,
self.use_gpu)
def forward(self, *inputs):
raise NotImplementedError
def backward(self, loss, batch_cnt):
total_loss = self.valid_loss(loss, batch_cnt)
total_loss.backward()
def valid_loss(self, loss, batch_cnt=None):
total_loss = 0.0
for k, l in loss.items():
if l is not None:
total_loss += l
return total_loss
def get_optimizer(self, config, verbose=True):
if config.op == 'adam':
if verbose:
print('Use Adam')
return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=config.init_lr, weight_decay=config.l2_norm)
elif config.op == 'sgd':
print('Use SGD')
return optim.SGD(self.parameters(), lr=config.init_lr, momentum=config.momentum)
elif config.op == 'rmsprop':
print('Use RMSProp')
return optim.RMSprop(self.parameters(), lr=config.init_lr, momentum=config.momentum)
def get_clf_optimizer(self, config):
params = []
params.extend(self.gru_attn_encoder.parameters())
params.extend(self.feat_projecter.parameters())
params.extend(self.sel_classifier.parameters())
if config.fine_tune_op == 'adam':
print('Use Adam')
return optim.Adam(params, lr=config.fine_tune_lr)
elif config.fine_tune_op == 'sgd':
print('Use SGD')
return optim.SGD(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum)
elif config.fine_tune_op == 'rmsprop':
print('Use RMSProp')
return optim.RMSprop(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum)
def model_sel_loss(self, loss, batch_cnt):
return self.valid_loss(loss, batch_cnt)
def extract_short_ctx(self, context, context_lens, backward_size=1):
utts = []
for b_id in range(context.shape[0]):
utts.append(context[b_id, context_lens[b_id]-1])
return np.array(utts)
def flatten_context(self, context, context_lens, align_right=False):
utts = []
temp_lens = []
for b_id in range(context.shape[0]):
temp = []
for t_id in range(context_lens[b_id]):
for token in context[b_id, t_id]:
if token != 0:
temp.append(token)
temp_lens.append(len(temp))
utts.append(temp)
max_temp_len = np.max(temp_lens)
results = np.zeros((context.shape[0], max_temp_len))
for b_id in range(context.shape[0]):
if align_right:
results[b_id, -temp_lens[b_id]:] = utts[b_id]
else:
results[b_id, 0:temp_lens[b_id]] = utts[b_id]
return results
def frange_cycle_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio=0.5):
L = np.ones(n_iter) * stop
period = n_iter/n_cycle
step = (stop-start)/(period*ratio) # linear schedule
for c in range(n_cycle):
v, i = start, 0
while v <= stop and (int(i+c*period) < n_iter):
L[int(i+c*period)] = v
v += step
i += 1
return L
This diff is collapsed.
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import numpy as np
from convlab.policy.lava.multiwoz.latent_dialog import domain
from convlab.policy.lava.multiwoz.latent_dialog.utils import LONG
class NLLEntropy(_Loss):
def __init__(self, padding_idx, avg_type):
super(NLLEntropy, self).__init__()
self.padding_idx = padding_idx
self.avg_type = avg_type
def forward(self, net_output, labels):
batch_size = net_output.size(0)
pred = net_output.view(-1, net_output.size(-1))
target = labels.view(-1)
if self.avg_type is None:
loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx)
elif self.avg_type == 'seq':
loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx)
loss = loss / batch_size
elif self.avg_type == 'real_word':
loss = F.nll_loss(pred, target, ignore_index=self.padding_idx, reduce=False)
loss = loss.view(-1, net_output.size(1))
loss = th.sum(loss, dim=1)
word_cnt = th.sum(th.sign(labels), dim=1).float()
loss = loss / word_cnt
loss = th.mean(loss)
elif self.avg_type == 'word':
loss = F.nll_loss(pred, target, reduction='mean', ignore_index=self.padding_idx)
else:
raise ValueError('Unknown average type')
return loss
class WeightedNLLEntropy(_Loss):
def __init__(self, padding_idx, avg_type, weight):
super(WeightedNLLEntropy, self).__init__()
self.padding_idx = padding_idx
self.avg_type = avg_type
self.weight = weight
def forward(self, net_output, labels):
batch_size = net_output.size(0)
pred = net_output.view(-1, net_output.size(-1))
target = labels.view(-1)
if self.avg_type == 'slot':
loss = F.nll_loss(pred, target, weight = self.weight, reduction='mean', ignore_index=self.padding_idx)
return loss
class NLLEntropy4CLF(_Loss):
def __init__(self, dictionary, bad_tokens=['<disconnect>', '<disagree>'], reduction='elementwise_mean'):
super(NLLEntropy4CLF, self).__init__()
w = th.Tensor(len(dictionary)).fill_(1)
for token in bad_tokens:
w[dictionary[token]] = 0.0
self.crit = nn.CrossEntropyLoss(w, reduction=reduction)
def forward(self, preds, labels):
# preds: (batch_size, outcome_len, outcome_vocab_size)
# labels: (batch_size, outcome_len)
preds = preds.view(-1, preds.size(-1))
labels = labels.view(-1)
return self.crit(preds, labels)
class CombinedNLLEntropy4CLF(_Loss):
def __init__(self, dictionary, corpus, np2var, bad_tokens=['<disconnect>', '<disagree>']):
super(CombinedNLLEntropy4CLF, self).__init__()
self.dictionary = dictionary
self.domain = domain.get_domain('object_division')
self.corpus = corpus
self.np2var = np2var
self.bad_tokens = bad_tokens
def forward(self, preds, goals_id, outcomes_id):
# preds: (batch_size, outcome_len, outcome_vocab_size)
# goals_id: list of list, id, batch_size*goal_len
# outcomes_id: list of list, id, batch_size*outcome_len
batch_size = len(goals_id)
losses = []
for bth in range(batch_size):
pred = preds[bth] # (outcome_len, outcome_vocab_size)
goal = goals_id[bth] # list, id, len=goal_len
goal_str = self.corpus.id2goal(goal) # list, str, len=goal_len
outcome = outcomes_id[bth] # list, id, len=outcome_len
outcome_str = self.corpus.id2outcome(outcome) # list, str, len=outcome_len
if outcome_str[0] in self.bad_tokens:
continue
# get all the possible choices
choices = self.domain.generate_choices(goal_str)
sel_outs = [pred[i] for i in range(pred.size(0))] # outcome_len*(outcome_vocab_size, )
choices_logits = [] # outcome_len*(option_amount, 1)
for i in range(self.domain.selection_length()):
idxs = np.array([self.dictionary[c[i]] for c in choices])
idxs_var = self.np2var(idxs, LONG) # (option_amount, )
choices_logits.append(th.gather(sel_outs[i], 0, idxs_var).unsqueeze(1))
choice_logit = th.sum(th.cat(choices_logits, 1), 1, keepdim=False) # (option_amount, )
choice_logit = choice_logit.sub(choice_logit.max().item()) # (option_amount, )
prob = F.softmax(choice_logit, dim=0) # (option_amount, )
label = choices.index(outcome_str)
target_prob = prob[label]
losses.append(-th.log(target_prob))
return sum(losses) / float(len(losses))
class CatKLLoss(_Loss):
def __init__(self):
super(CatKLLoss, self).__init__()
def forward(self, log_qy, log_py, batch_size=None, unit_average=False):
"""
qy * log(q(y)/p(y))
"""
qy = th.exp(log_qy)
y_kl = th.sum(qy * (log_qy - log_py), dim=1)
if unit_average:
return th.mean(y_kl)
else:
return th.sum(y_kl)/batch_size
class Entropy(_Loss):
def __init__(self):
super(Entropy, self).__init__()
def forward(self, log_qy, batch_size=None, unit_average=False):
"""
-qy log(qy)
"""
if log_qy.dim() > 2:
log_qy = log_qy.squeeze()
qy = th.exp(log_qy)
h_q = th.sum(-1 * log_qy * qy, dim=1)
if unit_average:
return th.mean(h_q)
else:
return th.sum(h_q) / batch_size
class GaussianEntropy(_Loss):
def __init__(self):
super(GaussianEntropy, self).__init__()
def forward(self, mu, logvar):
"""
0.5 (log(mu*var)) + 0.5
"""
std = th.exp(0.5 * logvar)
var = th.square(std)
h_q = 0.5 * (th.log(2 * math.pi * var)) + 0.5
return th.mean(h_q)
class BinaryNLLEntropy(_Loss):
def __init__(self, size_average=True):
super(BinaryNLLEntropy, self).__init__()
self.size_average = size_average
def forward(self, net_output, label_output):
"""
:param net_output: batch_size x
:param labels:
:return:
"""
batch_size = net_output.size(0)
loss = F.binary_cross_entropy_with_logits(net_output, label_output, size_average=self.size_average)
if self.size_average is False:
loss /= batch_size
return loss
class NormKLLoss(_Loss):
def __init__(self, unit_average=False):
super(NormKLLoss, self).__init__()
self.unit_average = unit_average
def forward(self, recog_mu, recog_logvar, prior_mu, prior_logvar):
# find the KL divergence between two Gaussian distribution
loss = 1.0 + (recog_logvar - prior_logvar)
loss -= th.div(th.pow(prior_mu - recog_mu, 2), th.exp(prior_logvar))
loss -= th.div(th.exp(recog_logvar), th.exp(prior_logvar))
if self.unit_average:
kl_loss = -0.5 * th.mean(loss, dim=1)
else:
kl_loss = -0.5 * th.sum(loss, dim=1)
avg_kl_loss = th.mean(kl_loss)
return avg_kl_loss
import numpy as np
import pdb
from convlab.policy.lava.multiwoz.latent_dialog.utils import Pack
from convlab.policy.lava.multiwoz.latent_dialog.base_data_loaders import BaseDataLoaders, LongDataLoader
from convlab.policy.lava.multiwoz.latent_dialog.corpora import USR, SYS
import json
class BeliefDbDataLoaders(BaseDataLoaders):
def __init__(self, name, data, config):
super(BeliefDbDataLoaders, self).__init__(name)
self.max_utt_len = config.max_utt_len
self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size)
self.data_size = len(self.data)
self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi']
def flatten_dialog(self, data, backward_size):
results = []
indexes = []
batch_indexes = []
resp_set = set()
for dlg in data:
goal = dlg.goal
key = dlg.key
batch_index = []
for i in range(1, len(dlg.dlg)):
if dlg.dlg[i].speaker == USR:
continue
e_idx = i
s_idx = max(0, e_idx - backward_size)
response = dlg.dlg[i].copy()
response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False)
resp_set.add(json.dumps(response.utt))
context = []
for turn in dlg.dlg[s_idx: e_idx]:
turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False)
context.append(turn)
results.append(Pack(context=context, response=response, goal=goal, key=key))
indexes.append(len(indexes))
batch_index.append(indexes[-1])
if len(batch_index) > 0:
batch_indexes.append(batch_index)
print("Unique resp {}".format(len(resp_set)))
return results, indexes, batch_indexes
def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
self.ptr = 0
if fix_batch:
self.batch_size = None
self.num_batch = len(self.batch_indexes)
else:
self.batch_size = config.batch_size
self.num_batch = self.data_size // config.batch_size
self.batch_indexes = []
for i in range(self.num_batch):
self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size])
if verbose:
print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch))
if shuffle:
if fix_batch:
self._shuffle_batch_indexes()
else:
self._shuffle_indexes()
if verbose:
print('%s begins with %d batches' % (self.name, self.num_batch))
def _prepare_batch(self, selected_index):
rows = [self.data[idx] for idx in selected_index]
ctx_utts, ctx_lens = [], []
out_utts, out_lens = [], []
out_bs, out_db = [] , []
goals, goal_lens = [], [[] for _ in range(len(self.domains))]
keys = []
for row in rows:
in_row, out_row, goal_row = row.context, row.response, row.goal
# source context
keys.append(row.key)
batch_ctx = []
for turn in in_row:
batch_ctx.append(self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
ctx_utts.append(batch_ctx)
ctx_lens.append(len(batch_ctx))
# target response
out_utt = [t for idx, t in enumerate(out_row.utt)]
out_utts.append(out_utt)
out_lens.append(len(out_utt))
out_bs.append(out_row.bs)
out_db.append(out_row.db)
# goal
goals.append(goal_row)
for i, d in enumerate(self.domains):
goal_lens[i].append(len(goal_row[d]))
batch_size = len(ctx_lens)
vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns
max_ctx_len = np.max(vec_ctx_lens)
vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32)
vec_out_bs = np.array(out_bs) # (batch_size, 94)
vec_out_db = np.array(out_db) # (batch_size, 30)
vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens
max_out_len = np.max(vec_out_lens)
vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32)
max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens]
if max_goal_lens != min_goal_lens:
print('Fatal Error!')
exit(-1)
self.goal_lens = max_goal_lens
vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens]
for b_id in range(batch_size):
vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
for i, d in enumerate(self.domains):
vec_goals_list[i][b_id, :] = goals[b_id][d]
return Pack(context_lens=vec_ctx_lens, # (batch_size, )
contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len)
output_lens=vec_out_lens, # (batch_size, )
outputs=vec_out_utts, # (batch_size, max_out_len)
bs=vec_out_bs, # (batch_size, 94)
db=vec_out_db, # (batch_size, 30)
goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain
keys=keys)
class BeliefDbDataLoadersAE(BaseDataLoaders):
def __init__(self, name, data, config):
super(BeliefDbDataLoadersAE, self).__init__(name)
self.max_utt_len = config.max_utt_len
self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size)
self.data_size = len(self.data)
self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi']
self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome']
if "ae_zero_pad" in config.keys():
self.zero_pad = config.ae_zero_pad
else:
self.zero_pad = False
def flatten_dialog(self, data, backward_size):
results = []
indexes = []
batch_indexes = []
resp_set = set()
for dlg in data:
goal = dlg.goal
key = dlg.key
batch_index = []
for i in range(1, len(dlg.dlg)):
if dlg.dlg[i].speaker == USR:
continue
e_idx = i
s_idx = max(0, e_idx - backward_size)
response = dlg.dlg[i].copy()
response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False)
resp_set.add(json.dumps(response.utt))
context = []
for turn in dlg.dlg[s_idx: e_idx]:
turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False)
context.append(turn)
results.append(Pack(context=context, response=response, goal=goal, key=key))
indexes.append(len(indexes))
batch_index.append(indexes[-1])
if len(batch_index) > 0:
batch_indexes.append(batch_index)
print("Unique resp {}".format(len(resp_set)))
return results, indexes, batch_indexes
def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
self.ptr = 0
if fix_batch:
self.batch_size = None
self.num_batch = len(self.batch_indexes)
else:
self.batch_size = config.batch_size
self.num_batch = self.data_size // config.batch_size
self.batch_indexes = []
for i in range(self.num_batch):
self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size])
if verbose:
print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch))
if shuffle:
if fix_batch:
self._shuffle_batch_indexes()
else:
self._shuffle_indexes()
if verbose:
print('%s begins with %d batches' % (self.name, self.num_batch))
def _prepare_batch(self, selected_index):
rows = [self.data[idx] for idx in selected_index]
ctx_utts, ctx_lens = [], []
out_utts, out_lens = [], []
out_act = []
out_bs, out_db = [] , []
goals, goal_lens = [], [[] for _ in range(len(self.domains))]
keys = []
for row in rows:
in_row, out_row, goal_row = row.context, row.response, row.goal
# source context
keys.append(row.key)
# batch_ctx = []
# for turn in in_row:
# batch_ctx.append(self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
# ctx_utts.append(batch_ctx)
# ctx_lens.append(len(batch_ctx))
# for AE, input = output
batch_ctx = []
batch_ctx = self.pad_to(self.max_utt_len, out_row.utt, do_pad=True)
# batch_ctx = [t for idx, t in enumerate(out_row.utt)]
ctx_utts.append(batch_ctx)
ctx_lens.append(len(batch_ctx))
# target response
out_utt = [t for idx, t in enumerate(out_row.utt)]
out_utts.append(out_utt)
out_lens.append(len(out_utt))
if not self.zero_pad:
out_bs.append(out_row.bs)
out_db.append(out_row.db)
else:
out_bs.append([0] * 94)
out_db.append([0] * 30)
out_act.append(out_row.act)
# goal
goals.append(goal_row)
for i, d in enumerate(self.domains):
goal_lens[i].append(len(goal_row[d]))
batch_size = len(ctx_lens)
vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns
max_ctx_len = np.max(vec_ctx_lens)
vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32)
vec_out_bs = np.array(out_bs) # (batch_size, 94)
vec_out_db = np.array(out_db) # (batch_size, 30)
vec_out_act = np.array(out_act) # (batch_size, 11)
vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens
max_out_len = np.max(vec_out_lens)
vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32)
max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens]
if max_goal_lens != min_goal_lens:
print('Fatal Error!')
exit(-1)
self.goal_lens = max_goal_lens
vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens]
for b_id in range(batch_size):
vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
for i, d in enumerate(self.domains):
vec_goals_list[i][b_id, :] = goals[b_id][d]
return Pack(context_lens=vec_ctx_lens, # (batch_size, )
contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len)
output_lens=vec_out_lens, # (batch_size, )
outputs=vec_out_utts, # (batch_size, max_out_len)
bs=vec_out_bs, # (batch_size, 94)
db=vec_out_db, # (batch_size, 30)
act=vec_out_act, #(batch_size, 11)
goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain
keys=keys)
from convlab.policy.lava.multiwoz.latent_dialog.metric import MetricsContainer
from convlab.policy.lava.multiwoz.latent_dialog.corpora import EOD, EOS
from convlab.policy.lava.multiwoz.latent_dialog import evaluators
class Dialog(object):
"""Dialogue runner."""
def __init__(self, agents, args):
assert len(agents) == 2
self.agents = agents
self.system, self.user = agents
self.args = args
self.metrics = MetricsContainer()
self.dlg_evaluator = evaluators.MultiWozEvaluator('SYS_WOZ')
self._register_metrics()
def _register_metrics(self):
"""Registers valuable metrics."""
self.metrics.register_average('dialog_len')
self.metrics.register_average('sent_len')
self.metrics.register_average('reward')
self.metrics.register_time('time')
def _is_eod(self, out):
return len(out) == 2 and out[0] == EOD and out[1] == EOS
def _eval _dialog(self, conv, g_key, goal):
generated_dialog = dict()
generated_dialog[g_key] = {'goal': goal, 'log': list()}
for t_id, (name, utt) in enumerate(conv):
# assert utt[-1] == EOS, utt
if t_id % 2 == 0:
assert name == 'Baozi'
utt = ' '.join(utt[:-1])
if utt == EOD:
continue
generated_dialog[g_key]['log'].append({'text': utt})
report, success_r, match_r = self.dlg_evaluator.evaluateModel(generated_dialog, mode='rollout')
return success_r + match_r
def show_metrics(self):
return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])
def run(self, g_key, goal):
"""Runs one instance of the dialogue."""
# initialize agents by feeding in the goal
# initialize BOD utterance for each agent
for agent in self.agents:
agent.feed_goal(goal)
agent.bod_init()
# role assignment
reader, writer = self.system, self.user
begin_name = writer.name
print('begin_name = {}'.format(begin_name))
conv = []
# reset metrics
self.metrics.reset()
nturn = 0
while True:
nturn += 1
# produce an utterance
out_words = writer.write() # out: list of word, str, len = max_words
print('\t{} out_words = {}'.format(writer.name, ' '.join(out_words)))
self.metrics.record('sent_len', len(out_words))
# self.metrics.record('%s_unique' % writer.name, out_words)
# append the utterance to the conversation
conv.append((writer.name, out_words))
# make the other agent to read it
reader.read(out_words)
# check if the end of the conversation was generated
if self._is_eod(out_words):
break
if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn:
# return conv, 0
break
writer, reader = reader, writer
# evaluate dialog and produce success
reward = self._eval_dialog(conv, g_key, goal)
print('Reward = {}'.format(reward))
# perform update
self.system.update(reward)
self.metrics.record('time')
self.metrics.record('dialog_len', len(conv))
self.metrics.record('reward', int(reward))
print('='*50)
print(self.show_metrics())
print('='*50)
return conv, reward
class DialogEval(Dialog):
def run(self, g_key, goal):
"""Runs one instance of the dialogue."""
# initialize agents by feeding in the goal
# initialize BOD utterance for each agent
for agent in self.agents:
agent.feed_goal(goal)
agent.bod_init()
# role assignment
reader, writer = self.system, self.user
conv = []
nturn = 0
while True:
nturn += 1
# produce an utterance
out_words = writer.write() # out: list of word, str, len = max_words
conv.append((writer.name, out_words))
# make the other agent to read it
reader.read(out_words)
# check if the end of the conversation was generated
if self._is_eod(out_words):
break
writer, reader = reader, writer
if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn:
return conv, 0
# evaluate dialog and produce success
reward = self._eval_dialog(conv, g_key, goal)
return conv, reward
import re
import random
import json
def get_domain(name):
if name == 'object_division':
return ObjectDivisionDomain()
raise()
class ObjectDivisionDomain(object):
def __init__(self):
self.item_pattern = re.compile('^item([0-9])=([0-9\-])+$')
def input_length(self):
return 3
def selection_length(self):
return 6
def generate_choices(self, inpt):
cnts, _ = self.parse_context(inpt)
def gen(cnts, idx=0, choice=[]):
if idx >= len(cnts):
left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)]
right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))]
return [left_choice + right_choice]
choices = []
for c in range(cnts[idx] + 1):
choice.append(c)
choices += gen(cnts, idx + 1, choice)
choice.pop()
return choices
choices = gen(cnts)
choices.append(['<no_agreement>'] * self.selection_length())
choices.append(['<disconnect>'] * self.selection_length())
return choices
def parse_context(self, ctx):
cnts = [int(n) for n in ctx[0::2]]
vals = [int(v) for v in ctx[1::2]]
return cnts, vals
def _to_int(self, x):
try:
return int(x)
except:
return 0
def score_choices(self, choices, ctxs):
assert len(choices) == len(ctxs)
# print('choices = {}'.format(choices))
# print('ctxs = {}'.format(ctxs))
cnts = [int(x) for x in ctxs[0][0::2]]
agree, scores = True, [0 for _ in range(len(ctxs))]
for i, n in enumerate(cnts):
for agent_id, (choice, ctx) in enumerate(zip(choices, ctxs)):
# taken = self._to_int(choice[i+3][-1])
taken = self._to_int(choice[i][-1])
n -= taken
scores[agent_id] += int(ctx[2 * i + 1]) * taken
agree = agree and (n == 0)
return agree, scores
class ContextGenerator(object):
"""Dialogue context generator. Generates contexes from the file."""
def __init__(self, context_file):
self.ctxs = []
with open(context_file, 'r') as f:
ctx_pair = []
for line in f:
ctx = line.strip().split()
ctx_pair.append(ctx)
if len(ctx_pair) == 2:
self.ctxs.append(ctx_pair)
ctx_pair = []
def sample(self):
return random.choice(self.ctxs)
def iter(self, nepoch=1):
for e in range(nepoch):
random.shuffle(self.ctxs)
for ctx in self.ctxs:
yield ctx
def total_size(self, nepoch):
return nepoch*len(self.ctxs)
class ContextGeneratorEval(object):
"""Dialogue context generator. Generates contexes from the file."""
def __init__(self, context_file):
self.ctxs = []
with open(context_file, 'r') as f:
ctx_pair = []
for line in f:
ctx = line.strip().split()
ctx_pair.append(ctx)
if len(ctx_pair) == 2:
self.ctxs.append(ctx_pair)
ctx_pair = []
class TaskGoalGenerator(object):
def __init__(self, goal_file):
self.goals = []
data = json.load(open(goal_file))
for key, raw_dlg in data.items():
self.goals.append((key, raw_dlg['goal']))
def sample(self):
return random.choice(self.goals)
def iter(self, nepoch=1):
for e in range(nepoch):
random.shuffle(self.goals)
for goal in self.goals:
yield goal
# -*- coding: utf-8 -*-
# Author: Tiancheng Zhao
# Date: 9/15/18
import torch as th
import torch.nn as nn
import numpy as np
from torch.nn.modules.module import _addindent
def summary(model, show_weights=True, show_parameters=True):
"""
Summarizes torch model by showing trainable parameters and weights.
"""
tmpstr = model.__class__.__name__ + ' (\n'
total_params = 0
for key, module in model._modules.items():
# if it contains layers let call it recursively to get params
# and weights
if type(module) in [
th.nn.modules.container.Container,
th.nn.modules.container.Sequential
]:
modstr = summary(module)
else:
modstr = module.__repr__()
modstr = _addindent(modstr, 2)
params = sum([np.prod(p.size()) for p in module.parameters()])
weights = tuple([tuple(p.size()) for p in module.parameters()])
total_params += params
tmpstr += ' (' + key + '): ' + modstr
if show_weights:
tmpstr += ', weights={}'.format(weights)
if show_parameters:
tmpstr += ', parameters={}'.format(params)
tmpstr += '\n'
tmpstr = tmpstr + ') Total Parameters={}'.format(total_params)
return tmpstr
class BaseRNN(nn.Module):
KEY_ATTN_SCORE = 'attention_score'
KEY_SEQUENCE = 'sequence'
def __init__(self, input_dropout_p, rnn_cell,
input_size, hidden_size, num_layers,
output_dropout_p, bidirectional):
super(BaseRNN, self).__init__()
self.input_dropout = nn.Dropout(p=input_dropout_p)
if rnn_cell.lower() == 'lstm':
self.rnn_cell = nn.LSTM
elif rnn_cell.lower() == 'gru':
self.rnn_cell = nn.GRU
else:
raise ValueError('Unsupported RNN Cell Type: {0}'.format(rnn_cell))
self.rnn = self.rnn_cell(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=output_dropout_p,
bidirectional=bidirectional)
# TODO Trick for initializing LSTM gate parameters
if rnn_cell.lower() == 'lstm':
for names in self.rnn._all_weights:
for name in filter(lambda n: 'bias' in n, names):
bias = getattr(self.rnn, name)
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.)
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.base_modules import BaseRNN
class EncoderGRUATTN(BaseRNN):
def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths):
super(EncoderGRUATTN, self).__init__(input_dropout_p=input_dropout_p,
rnn_cell=rnn_cell,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
output_dropout_p=output_dropout_p,
bidirectional=bidirectional)
self.variable_lengths = variable_lengths
self.nhid_attn = hidden_size
self.output_size = hidden_size*2 if bidirectional else hidden_size
# attention to combine selection hidden states
self.attn = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1)
)
def forward(self, residual_var, input_var, turn_feat, mask=None, init_state=None, input_lengths=None):
# residual_var: (batch_size, max_dlg_len, 2*utt_cell_size)
# input_var: (batch_size, max_dlg_len, dlg_cell_size)
# TODO switch of mask
# mask = None
require_embed = True
if require_embed:
# input_cat = th.cat([input_var, residual_var], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size)
input_cat = th.cat([input_var, residual_var, turn_feat], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size)
else:
# input_cat = th.cat([input_var], 2)
input_cat = th.cat([input_var, turn_feat], 2)
if mask is not None:
input_mask = mask.view(input_cat.size(0), input_cat.size(1), 1) # (batch_size, max_dlg_len*max_utt_len, 1)
input_cat = th.mul(input_cat, input_mask)
embedded = self.input_dropout(input_cat)
require_rnn = True
if require_rnn:
if init_state is not None:
h, _ = self.rnn(embedded, init_state)
else:
h, _ = self.rnn(embedded) # (batch_size, max_dlg_len, 2*nhid_attn)
logit = self.attn(h.contiguous().view(-1, 2*self.nhid_attn)).view(h.size(0), h.size(1)) # (batch_size, max_dlg_len)
# if mask is not None:
# logit_mask = mask.view(input_cat.size(0), input_cat.size(1))
# logit_mask = -999.0 * logit_mask
# logit = logit_mask + logit
prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(h) # (batch_size, max_dlg_len, 2*nhid_attn)
attn = th.sum(th.mul(h, prob), 1) # (batch_size, 2*nhid_attn)
return attn
else:
logit = self.attn(embedded.contiguous().view(input_cat.size(0)*input_cat.size(1), -1)).view(input_cat.size(0), input_cat.size(1))
if mask is not None:
logit_mask = mask.view(input_cat.size(0), input_cat.size(1))
logit_mask = -999.0 * logit_mask
logit = logit_mask + logit
prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(embedded) # (batch_size, max_dlg_len, 2*nhid_attn)
attn = th.sum(th.mul(embedded, prob), 1) # (batch_size, 2*nhid_attn)
return attn
class FeatureProjecter(nn.Module):
def __init__(self, input_dropout_p, input_size, output_size):
super(FeatureProjecter, self).__init__()
self.input_dropout = nn.Dropout(p=input_dropout_p)
self.sel_encoder = nn.Sequential(
nn.Linear(input_size, output_size),
nn.Tanh()
)
def forward(self, goals_h, attn_outs):
h = th.cat([attn_outs, goals_h], 1) # (batch_size, 2*nhid_attn+goal_nhid)
h = self.input_dropout(h)
h = self.sel_encoder.forward(h) # (batch_size, nhid_sel)
return h
class SelectionClassifier(nn.Module):
def __init__(self, selection_length, input_size, output_size):
super(SelectionClassifier, self).__init__()
self.sel_decoders = nn.ModuleList()
for _ in range(selection_length):
self.sel_decoders.append(nn.Linear(input_size, output_size))
def forward(self, proj_outs):
outs = [decoder.forward(proj_outs).unsqueeze(1) for decoder in self.sel_decoders] # outcome_len*(batch_size, 1, outcome_vocab_size)
outs = th.cat(outs, 1) # (batch_size, outcome_len, outcome_vocab_size)
return outs
This diff is collapsed.
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from convlab.policy.lava.multiwoz.latent_dialog.enc2dec.base_modules import BaseRNN
class EncoderRNN(BaseRNN):
def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths):
super(EncoderRNN, self).__init__(input_dropout_p=input_dropout_p,
rnn_cell=rnn_cell,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
output_dropout_p=output_dropout_p,
bidirectional=bidirectional)
self.variable_lengths = variable_lengths
self.output_size = hidden_size*2 if bidirectional else hidden_size
def forward(self, input_var, init_state=None, input_lengths=None, goals=None):
# add goals
if goals is not None:
batch_size, max_ctx_len, ctx_nhid = input_var.size()
goals = goals.view(goals.size(0), 1, goals.size(1))
goals_rep = goals.repeat(1, max_ctx_len, 1).view(batch_size, max_ctx_len, -1) # (batch_size, max_ctx_len, goal_nhid)
input_var = th.cat([input_var, goals_rep], dim=2)
embedded = self.input_dropout(input_var)
if self.variable_lengths:
embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths,
batch_first=True)
if init_state is not None:
output, hidden = self.rnn(embedded, init_state)
else:
output, hidden = self.rnn(embedded)
if self.variable_lengths:
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
return output, hidden
class RnnUttEncoder(nn.Module):
def __init__(self, vocab_size, embedding_dim, feat_size, goal_nhid, rnn_cell,
utt_cell_size, num_layers, input_dropout_p, output_dropout_p,
bidirectional, variable_lengths, use_attn, embedding=None):
super(RnnUttEncoder, self).__init__()
if embedding is None:
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
else:
self.embedding = embedding
self.rnn = EncoderRNN(input_dropout_p=input_dropout_p,
rnn_cell=rnn_cell,
input_size=embedding_dim+feat_size+goal_nhid,
hidden_size=utt_cell_size,
num_layers=num_layers,
output_dropout_p=output_dropout_p,
bidirectional=bidirectional,
variable_lengths=variable_lengths)
self.utt_cell_size = utt_cell_size
self.multiplier = 2 if bidirectional else 1
self.output_size = self.multiplier * self.utt_cell_size
self.use_attn = use_attn
if self.use_attn:
self.key_w = nn.Linear(self.output_size, self.utt_cell_size)
self.query = nn.Linear(self.utt_cell_size, 1)
def forward(self, utterances, feats=None, init_state=None, goals=None):
batch_size, max_ctx_len, max_utt_len = utterances.size()
# get word embeddings
flat_words = utterances.view(-1, max_utt_len) # (batch_size*max_ctx_len, max_utt_len)
word_embeddings = self.embedding(flat_words) # (batch_size*max_ctx_len, max_utt_len, embedding_dim)
flat_mask = th.sign(flat_words).float()
# add features
if feats is not None:
flat_feats = feats.view(-1, 1) # (batch_size*max_ctx_len, 1)
flat_feats = flat_feats.unsqueeze(1).repeat(1, max_utt_len, 1) # (batch_size*max_ctx_len, max_utt_len, 1)
word_embeddings = th.cat([word_embeddings, flat_feats], dim=2) # (batch_size*max_ctx_len, max_utt_len, embedding_dim+1)
# add goals
if goals is not None:
goals = goals.view(goals.size(0), 1, 1, goals.size(1))
goals_rep = goals.repeat(1, max_ctx_len, max_utt_len, 1).view(batch_size*max_ctx_len, max_utt_len, -1) # (batch_size*max_ctx_len, max_utt_len, goal_nhid)
word_embeddings = th.cat([word_embeddings, goals_rep], dim=2)
# enc_outs: (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size)
# enc_last: (num_layers*num_directions, batch_size*max_ctx_len, utt_cell_size)
enc_outs, enc_last = self.rnn(word_embeddings, init_state=init_state)
if self.use_attn:
fc1 = th.tanh(self.key_w(enc_outs)) # (batch_size*max_ctx_len, max_utt_len, utt_cell_size)
attn = self.query(fc1).squeeze(2)
# (batch_size*max_ctx_len, max_utt_len)
attn = F.softmax(attn, attn.dim()-1) # (batch_size*max_ctx_len, max_utt_len, 1)
attn = attn * flat_mask
attn = (attn / (th.sum(attn, dim=1, keepdim=True)+1e-10)).unsqueeze(2)
utt_embedded = attn * enc_outs # (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size)
utt_embedded = th.sum(utt_embedded, dim=1) # (batch_size*max_ctx_len, num_directions*utt_cell_size)
else:
# FIXME bug for multi-layer
attn = None
utt_embedded = enc_last.transpose(0, 1).contiguous() # (batch_size*max_ctx_lens, num_layers*num_directions, utt_cell_size)
utt_embedded = utt_embedded.view(-1, self.output_size) # (batch_size*max_ctx_len*num_layers, num_directions*utt_cell_size)
utt_embedded = utt_embedded.view(batch_size, max_ctx_len, self.output_size)
return utt_embedded, word_embeddings.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1), \
enc_outs.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1)
class MlpGoalEncoder(nn.Module):
def __init__(self, goal_vocab_size, k, nembed, nhid, init_range):
super(MlpGoalEncoder, self).__init__()
# create separate embedding for counts and values
self.cnt_enc = nn.Embedding(goal_vocab_size, nembed)
self.val_enc = nn.Embedding(goal_vocab_size, nembed)
self.encoder = nn.Sequential(
nn.Tanh(),
nn.Linear(k*nembed, nhid)
)
self.cnt_enc.weight.data.uniform_(-init_range, init_range)
self.val_enc.weight.data.uniform_(-init_range, init_range)
self._init_cont(self.encoder, init_range)
def _init_cont(self, cont, init_range):
"""initializes a container uniformly."""
for m in cont:
if hasattr(m, 'weight'):
m.weight.data.uniform_(-init_range, init_range)
if hasattr(m, 'bias'):
m.bias.data.fill_(0)
def forward(self, goal):
# goal: (batch_size, goal_len)
goal = goal.transpose(0, 1).contiguous() # (goal_len, batch_size)
idx = np.arange(goal.size(0) // 2)
# extract counts and values
cnt_idx = Variable(th.from_numpy(2 * idx + 0))
val_idx = Variable(th.from_numpy(2 * idx + 1))
if goal.is_cuda:
cnt_idx = cnt_idx.type(th.cuda.LongTensor)
val_idx = val_idx.type(th.cuda.LongTensor)
else:
cnt_idx = cnt_idx.type(th.LongTensor)
val_idx = val_idx.type(th.LongTensor)
cnt = goal.index_select(0, cnt_idx) # (3, batch_size)
val = goal.index_select(0, val_idx) # (3, batch_size)
# embed counts and values
cnt_emb = self.cnt_enc(cnt) # (3, batch_size, nembed)
val_emb = self.val_enc(val) # (3, batch_size, nembed)
# element wise multiplication to get a hidden state
h = th.mul(cnt_emb, val_emb) # (3, batch_size, nembed)
# run the hidden state through the MLP
h = h.transpose(0, 1).contiguous().view(goal.size(1), -1) # (batch_size, 3*nembed)
goal_h = self.encoder(h) # (batch_size, nhid)
return goal_h
class TaskMlpGoalEncoder(nn.Module):
def __init__(self, goal_vocab_sizes, nhid, init_range):
super(TaskMlpGoalEncoder, self).__init__()
self.encoder = nn.ModuleList()
for v_size in goal_vocab_sizes:
domain_encoder = nn.Sequential(
nn.Linear(v_size, nhid),
nn.Tanh()
)
self._init_cont(domain_encoder, init_range)
self.encoder.append(domain_encoder)
def _init_cont(self, cont, init_range):
"""initializes a container uniformly."""
for m in cont:
if hasattr(m, 'weight'):
m.weight.data.uniform_(-init_range, init_range)
if hasattr(m, 'bias'):
m.bias.data.fill_(0)
def forward(self, goals_list):
# goals_list: list of tensor, 7*(batch_size, goal_len), goal_len varies among differnet domains
outs = [encoder.forward(goal) for goal, encoder in zip(goals_list, self.encoder)] # 7*(batch_size, goal_nhid)
outs = th.sum(th.stack(outs), dim=0) # (batch_size, goal_nhid)
return outs
class SelfAttn(nn.Module):
def __init__(self, hidden_size):
super(SelfAttn, self).__init__()
self.query = nn.Linear(hidden_size, 1)
def forward(self, keys, values, attn_mask=None):
"""
:param attn_inputs: batch_size x time_len x hidden_size
:param attn_mask: batch_size x time_len
:return: summary state
"""
alpha = F.softmax(self.query(keys), dim=1)
if attn_mask is not None:
alpha = alpha * attn_mask.unsqueeze(2)
alpha = alpha / th.sum(alpha, dim=1, keepdim=True)
summary = th.sum(values * alpha, dim=1)
return summary
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment