Skip to content
Snippets Groups Projects
Commit 25b2ee25 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Setsumbt

parent f6519348
No related branches found
No related tags found
No related merge requests found
Showing
with 462 additions and 149 deletions
......@@ -66,7 +66,8 @@ convlab/nlu/jointBERT_new/**/output/
convlab/nlu/milu/09*
convlab/nlu/jointBERT/multiwoz/configs/multiwoz_new_usr_context.json
convlab/nlu/milu/multiwoz/configs/system_without_context.jsonnet
convlab/nlu/milu/multiwoz/configs/user_without_context.jsonnet
convlab/nlu/milu/multiwoz/configs/user_without_context.jsonnet\
*.pkl
# test script
*_test.py
......@@ -87,7 +88,6 @@ dist
convlab.egg-info
# configs
*experiment*
*pretrained_models*
.ipynb_checkpoints
......
import json
import os
from pprint import pprint
import numpy as np
def load_results(predict_results):
files = [file.strip() for file in predict_results.split(',')]
files = [file for file in files if os.path.isfile(file)]
predictions = []
for file in files:
reader = open(file, 'r')
predictions += json.load(reader)
reader.close()
return predictions
def evaluate(predict_result):
predict_result = json.load(open(predict_result))
predict_result = load_results(predict_result)
metrics = {'TP': 0, 'FP': 0, 'FN': 0}
jga = []
......
......@@ -383,3 +383,25 @@ def change_batch_size(loader: DataLoader, batch_size: int) -> DataLoader:
loader = DataLoader(loader.dataset, sampler=sampler, batch_size=batch_size)
return loader
def dataloader_sample_dialogues(loader: DataLoader, sample_size: int) -> DataLoader:
"""
Sample a subset of the dialogues in a dataloader
Args:
loader (DataLoader): Dataloader to train and evaluate the setsumbt model
sample_size (int): Number of dialogues to sample
Returns:
loader (DataLoader): Dataloader to train and evaluate the setsumbt model
"""
loader.dataset = loader.dataset.resample(sample_size)
if 'SequentialSampler' in str(loader.sampler):
sampler = SequentialSampler(loader.dataset)
else:
sampler = RandomSampler(loader.dataset)
loader = DataLoader(loader.dataset, sampler=sampler, batch_size=loader.batch_size)
return loader
......@@ -23,42 +23,13 @@ import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset
from convlab.dst.setsumbt.dataset.unified_format import UnifiedFormatDataset, change_batch_size
from convlab.dst.setsumbt.modeling import EnsembleSetSUMBT
from convlab.dst.setsumbt.modeling import training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_path', type=str)
parser.add_argument('--model_type', type=str)
parser.add_argument('--set_type', type=str)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--reduction', type=str, default='mean')
parser.add_argument('--get_ensemble_distributions', action='store_true')
parser.add_argument('--build_dataloaders', action='store_true')
args = parser.parse_args()
if args.get_ensemble_distributions:
get_ensemble_distributions(args)
elif args.build_dataloaders:
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
data = torch.load(path)
reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
ontology = json.load(reader)
reader.close()
loader = get_loader(data, ontology, args.set_type, args.batch_size)
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
torch.save(loader, path)
else:
raise NameError("NotImplemented")
def get_loader(data: dict, ontology: dict, set_type: str = 'train', batch_size: int = 3) -> DataLoader:
"""
Build dataloader from ensemble prediction data
......@@ -145,6 +116,9 @@ def get_ensemble_distributions(args):
dataloader = torch.load(dataloader)
database = torch.load(database)
if dataloader.batch_size != args.batch_size:
dataloader = change_batch_size(dataloader, args.batch_size)
training.set_ontology_embeddings(model, database)
print('Environment set up.')
......@@ -185,9 +159,9 @@ def get_ensemble_distributions(args):
if model.config.predict_actions:
for slot in request_labels:
request_labels[slot].append(batch['request_labels-' + slot])
for domain in domain_labels:
domain_labels[domain].append(batch['active_domain_labels-' + domain])
greeting_labels.append(batch['general_act_labels'])
for domain in active_domain_labels:
active_domain_labels[domain].append(batch['active_domain_labels-' + domain])
general_act_labels.append(batch['general_act_labels'])
with torch.no_grad():
p, p_req, p_dom, p_gen, _ = model(ids, mask, tt_ids, reduction=args.reduction)
......@@ -239,5 +213,65 @@ def get_ensemble_distributions(args):
torch.save(data, file)
def ensemble_distribution_data_to_predictions_format(model_path: str, set_type: str):
"""
Convert ensemble predictions to predictions file format.
Args:
model_path: Path to ensemble location.
set_type: Evaluation dataset (train/dev/test).
"""
data = torch.load(os.path.join(model_path, 'dataloaders', f"{set_type}.data"))
# Get oracle labels
if 'request_probs' in data:
data_new = {'state_labels': data['state_labels'],
'request_labels': data['request_labels'],
'active_domain_labels': data['active_domain_labels'],
'general_act_labels': data['general_act_labels']}
else:
data_new = {'state_labels': data['state_labels']}
# Marginalising across ensemble distributions
data_new['belief_states'] = {slot: distribution.mean(-2) for slot, distribution in data['belief_state'].items()}
if 'request_probs' in data:
data_new['request_probs'] = {slot: distribution.mean(-1)
for slot, distribution in data['request_probs'].items()}
data_new['active_domain_probs'] = {domain: distribution.mean(-1)
for domain, distribution in data['active_domain_probs'].items()}
data_new['general_act_probs'] = data['general_act_probs'].mean(-2)
# Save predictions file
predictions_dir = os.path.join(model_path, 'predictions')
if not os.path.exists(predictions_dir):
os.mkdir(predictions_dir)
torch.save(data_new, os.path.join(predictions_dir, f"{set_type}.predictions"))
if __name__ == "__main__":
main()
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_path', type=str)
parser.add_argument('--set_type', type=str)
parser.add_argument('--batch_size', type=int, default=3)
parser.add_argument('--reduction', type=str, default='none')
parser.add_argument('--get_ensemble_distributions', action='store_true')
parser.add_argument('--convert_distributions_to_predictions', action='store_true')
parser.add_argument('--build_dataloaders', action='store_true')
args = parser.parse_args()
if args.get_ensemble_distributions:
get_ensemble_distributions(args)
if args.convert_distributions_to_predictions:
ensemble_distribution_data_to_predictions_format(args.model_path, args.set_type)
if args.build_dataloaders:
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
data = torch.load(path)
reader = open(os.path.join(args.model_path, 'database', f'{args.set_type}.json'), 'r')
ontology = json.load(reader)
reader.close()
loader = get_loader(data, ontology, args.set_type, args.batch_size)
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
torch.save(loader, path)
......@@ -19,8 +19,10 @@ import logging
import os
from shutil import copy2 as copy
import json
from copy import deepcopy
import torch
import transformers
from transformers import (BertModel, BertConfig, BertTokenizer,
RobertaModel, RobertaConfig, RobertaTokenizer)
from tensorboardX import SummaryWriter
......@@ -30,6 +32,7 @@ from convlab.dst.setsumbt.dataset import unified_format
from convlab.dst.setsumbt.modeling import training
from convlab.dst.setsumbt.dataset import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, update_args
from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
# Available model
......@@ -97,6 +100,7 @@ def main(args=None, config=None):
args.fp16 = False
# Initialise Model
transformers.utils.logging.set_verbosity_info()
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = model.to(device)
......@@ -109,29 +113,9 @@ def main(args=None, config=None):
training.set_seed(args)
embeddings.set_seed(args)
transformers.utils.logging.set_verbosity_error()
if args.ensemble_size > 1:
logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
args.data_sampling_size))
dataloaders = [unified_format.get_dataloader(args.dataset,
'train',
args.train_batch_size,
tokenizer,
args.max_dialogue_len,
args.max_turn_len,
resampled_size=args.data_sampling_size,
train_ratio=args.dataset_train_ratio,
seed=args.seed)
for _ in range(args.ensemble_size)]
logger.info('Dataloaders built.')
for i, loader in enumerate(dataloaders):
path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, 'train.dataloader')
torch.save(loader, path)
logger.info('Dataloaders saved.')
# Build all dataloaders
train_dataloader = unified_format.get_dataloader(args.dataset,
'train',
args.train_batch_size,
......@@ -144,7 +128,7 @@ def main(args=None, config=None):
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
dev_dataloader = unified_format.get_dataloader(args.dataset,
'validation',
args.train_batch_size,
args.dev_batch_size,
tokenizer,
args.max_dialogue_len,
args.max_turn_len,
......@@ -154,7 +138,7 @@ def main(args=None, config=None):
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
test_dataloader = unified_format.get_dataloader(args.dataset,
'test',
args.train_batch_size,
args.test_batch_size,
tokenizer,
args.max_dialogue_len,
args.max_turn_len,
......@@ -167,6 +151,21 @@ def main(args=None, config=None):
embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology, 'dev', args, tokenizer, encoder)
embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology, 'test', args, tokenizer, encoder)
setup_ensemble(OUTPUT_DIR, args.ensemble_size)
logger.info(f'Building {args.ensemble_size} resampled dataloaders each of size {args.data_sampling_size}.')
dataloaders = [unified_format.dataloader_sample_dialogues(deepcopy(train_dataloader), args.data_sampling_size)
for _ in range(args.ensemble_size)]
logger.info('Dataloaders built.')
for i, loader in enumerate(dataloaders):
path = os.path.join(OUTPUT_DIR, 'ens-%i' % i)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, 'dataloaders', 'train.dataloader')
torch.save(loader, path)
logger.info('Dataloaders saved.')
# Do not perform standard training after ensemble setup is created
return 0
......
......@@ -201,7 +201,7 @@ class RKLDirichletMediatorLoss(Module):
ensemble_stats = compute_ensemble_stats(targets)
alphas, precision = get_dirichlet_parameters(logits, self.parametrization, self.model_offset)
alphas, precision = get_dirichlet_parameters(logits, self.parameterization, self.model_offset)
normalized_probs = alphas / precision.unsqueeze(1)
......@@ -270,7 +270,7 @@ class BinaryRKLDirichletMediatorLoss(RKLDirichletMediatorLoss):
# Convert single target probability p to distribution [1-p, p]
targets = targets.reshape(-1, targets.size(-1), 1)
targets = torch.cat([1 - targets, targets], -1)
targets[targets[:, 1] == self.ignore_index] = self.ignore_index
targets[targets[:, 0, 1] == self.ignore_index] = self.ignore_index
# Convert input logits into predictive distribution [1-z, z]
logits = torch.sigmoid(logits).unsqueeze(1)
......
......@@ -16,6 +16,7 @@
"""Ensemble SetSUMBT"""
import os
from shutil import copy2 as copy
import torch
from torch.nn import Module
......@@ -51,7 +52,7 @@ class EnsembleSetSUMBT(Module):
"""
for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
idx = attr.split('_', 1)[-1]
state_dict = torch.load(os.path.join(path, f'pytorch_model_{idx}.bin'))
state_dict = torch.load(os.path.join(path, f'ens-{idx}/pytorch_model.bin'))
getattr(self, attr).load_state_dict(state_dict)
def add_slot_candidates(self, slot_candidates: tuple):
......@@ -67,7 +68,7 @@ class EnsembleSetSUMBT(Module):
getattr(self, attr).add_slot_candidates(slot_candidates)
self.requestable_slot_ids = self.model_0.setsumbt.requestable_slot_ids
self.informable_slot_ids = self.model_0.setsumbt.informable_slot_ids
self.domain_ids = self.setsumbt.model_0.domain_ids
self.domain_ids = self.model_0.setsumbt.domain_ids
def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
"""
......@@ -96,9 +97,9 @@ class EnsembleSetSUMBT(Module):
Returns:
"""
belief_state_probs = {slot: [] for slot in self.model_0.informable_slot_ids}
request_probs = {slot: [] for slot in self.model_0.requestable_slot_ids}
active_domain_probs = {dom: [] for dom in self.model_0.domain_ids}
belief_state_probs = {slot: [] for slot in self.informable_slot_ids}
request_probs = {slot: [] for slot in self.requestable_slot_ids}
active_domain_probs = {dom: [] for dom in self.domain_ids}
general_act_probs = []
for attr in [f'model_{i}' for i in range(self.config.ensemble_size)]:
# Prediction from each ensemble member
......@@ -107,7 +108,7 @@ class EnsembleSetSUMBT(Module):
attention_mask=attention_mask)
for slot in belief_state_probs:
belief_state_probs[slot].append(b[slot].unsqueeze(-2))
if self.config.predict_intents:
if self.config.predict_actions:
for slot in request_probs:
request_probs[slot].append(r[slot].unsqueeze(-1))
for dom in active_domain_probs:
......@@ -115,7 +116,7 @@ class EnsembleSetSUMBT(Module):
general_act_probs.append(g.unsqueeze(-2))
belief_state_probs = {slot: torch.cat(l, -2) for slot, l in belief_state_probs.items()}
if self.config.predict_intents:
if self.config.predict_actions:
request_probs = {slot: torch.cat(l, -1) for slot, l in request_probs.items()}
active_domain_probs = {dom: torch.cat(l, -1) for dom, l in active_domain_probs.items()}
general_act_probs = torch.cat(general_act_probs, -2)
......@@ -138,17 +139,42 @@ class EnsembleSetSUMBT(Module):
@classmethod
def from_pretrained(cls, path):
if not os.path.exists(os.path.join(path, 'config.json')):
config_path = os.path.join(path, 'ens-0', 'config.json')
if not os.path.exists(config_path):
raise(NameError('Could not find config.json in model path.'))
if not os.path.exists(os.path.join(path, 'pytorch_model_0.bin')):
raise(NameError('Could not find a model binary in the model path.'))
try:
config = RobertaConfig.from_pretrained(path)
config = RobertaConfig.from_pretrained(config_path)
except:
config = BertConfig.from_pretrained(path)
config = BertConfig.from_pretrained(config_path)
config.ensemble_size = len([dir for dir in os.listdir(path) if 'ens-' in dir])
model = cls(config)
model.load(path)
model._load(path)
return model
def setup_ensemble(model_path: str, ensemble_size: int):
"""
Setup ensemble model directory structure.
Args:
model_path: Path to ensemble model directory
ensemble_size: Number of ensemble members
"""
for i in range(ensemble_size):
path = os.path.join(model_path, f'ens-{i}')
if not os.path.exists(path):
os.mkdir(path)
os.mkdir(os.path.join(path, 'dataloaders'))
os.mkdir(os.path.join(path, 'database'))
# Add development set dataloader to each ensemble member directory
for set_type in ['dev']:
copy(os.path.join(model_path, 'dataloaders', f'{set_type}.dataloader'),
os.path.join(path, 'dataloaders', f'{set_type}.dataloader'))
# Add training and development set ontologies to each ensemble member directory
for set_type in ['train', 'dev']:
copy(os.path.join(model_path, 'database', f'{set_type}.db'),
os.path.join(path, 'database', f'{set_type}.db'))
......@@ -61,7 +61,11 @@ def set_ontology_embeddings(model, slots, load_slots=True):
if load_slots:
slots = {slot: embs for slot, embs in slots.items()}
model.add_slot_candidates(slots)
for slot in model.setsumbt.informable_slot_ids:
try:
informable_slot_ids = model.setsumbt.informable_slot_ids
except:
informable_slot_ids = model.informable_slot_ids
for slot in informable_slot_ids:
model.add_value_candidates(slot, values[slot], replace=True)
......@@ -79,14 +83,16 @@ def log_info(global_step, loss, jg_acc=None, sl_acc=None, req_f1=None, dom_f1=No
gen_f1: General action prediction F1 score
stats: Uncertainty measure statistics of model
"""
info = f"{global_step} steps complete, " if type(global_step) == int else ""
if global_step == 'training_complete':
info += f"Training Complete"
if type(global_step) == int:
info = f"{global_step} steps complete, "
info += f"Loss since last update: {loss}. Validation set stats: "
if global_step == 'dev':
info += f"Validation set stats: Loss: {loss}, "
if global_step == 'test':
info += f"Test set stats: Loss: {loss}, "
elif global_step == 'training_complete':
info = f"Training Complete"
info += f"Validation set stats: "
elif global_step == 'dev':
info = f"Validation set stats: Loss: {loss}, "
elif global_step == 'test':
info = f"Test set stats: Loss: {loss}, "
info += f"Joint Goal Acc: {jg_acc}, Slot Acc: {sl_acc}, "
if req_f1 is not None:
info += f"Request F1 Score: {req_f1}, Active Domain F1 Score: {dom_f1}, "
......@@ -134,8 +140,9 @@ def get_input_dict(batch: dict,
input_dict['token_type_ids'] = batch['token_type_ids'].to(device) if 'token_type_ids' in batch else None
input_dict['attention_mask'] = batch['attention_mask'].to(device) if 'attention_mask' in batch else None
if 'belief_state' in batch:
input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device) for slot in model_informable_slot_ids
if any('belief_state' in key for key in batch):
input_dict['state_labels'] = {slot: batch['belief_state-' + slot].to(device)
for slot in model_informable_slot_ids
if ('belief_state-' + slot) in batch}
if predict_actions:
input_dict['request_labels'] = {slot: batch['request_probs-' + slot].to(device)
......@@ -644,9 +651,12 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
sl_acc = sum(jg_acc / len(model.setsumbt.informable_slot_ids)).float()
jg_acc = sum((jg_acc == len(model.setsumbt.informable_slot_ids)).int()).float()
req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float() if req_tp is not None else torch.tensor(0.0)
req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float() if req_fp is not None else torch.tensor(0.0)
req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float() if req_fn is not None else torch.tensor(0.0)
if req_tp is not None and model.setsumbt.requestable_slot_ids:
req_tp = sum(req_tp / len(model.setsumbt.requestable_slot_ids)).float()
req_fp = sum(req_fp / len(model.setsumbt.requestable_slot_ids)).float()
req_fn = sum(req_fn / len(model.setsumbt.requestable_slot_ids)).float()
else:
req_tp, req_fp, req_fn = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
dom_tp = sum(dom_tp / len(model.setsumbt.domain_ids)).float() if dom_tp is not None else torch.tensor(0.0)
dom_fp = sum(dom_fp / len(model.setsumbt.domain_ids)).float() if dom_fp is not None else torch.tensor(0.0)
dom_fn = sum(dom_fn / len(model.setsumbt.domain_ids)).float() if dom_fn is not None else torch.tensor(0.0)
......
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Predict dataset user action using SetSUMBT Model"""
from copy import deepcopy
import os
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from convlab.util.custom_util import flatten_acts as flatten
from convlab.util import load_dataset, load_policy_data
from convlab.dst.setsumbt import SetSUMBTTracker
def flatten_acts(acts: dict) -> list:
"""
Flatten dictionary actions.
Args:
acts: Dictionary acts
Returns:
flat_acts: Flattened actions
"""
acts = flatten(acts)
flat_acts = []
for intent, domain, slot, value in acts:
flat_acts.append([intent,
domain,
slot if slot != 'none' else '',
value.lower() if value != 'none' else ''])
return flat_acts
def get_user_actions(context: list, system_acts: list) -> list:
"""
Extract user actions from the data.
Args:
context: Previous dialogue turns.
system_acts: List of flattened system actions.
Returns:
user_acts: List of flattened user actions.
"""
user_acts = context[-1]['dialogue_acts']
user_acts = flatten_acts(user_acts)
if len(context) == 3:
prev_state = context[-3]['state']
cur_state = context[-1]['state']
for domain, substate in cur_state.items():
for slot, value in substate.items():
if prev_state[domain][slot] != value:
act = ['inform', domain, slot, value]
if act not in user_acts and act not in system_acts:
user_acts.append(act)
return user_acts
def extract_dataset(dataset: str = 'multiwoz21') -> list:
"""
Extract acts and utterances from the dataset.
Args:
dataset: Dataset name
Returns:
data: Extracted data
"""
data = load_dataset(dataset_name=dataset)
raw_data = load_policy_data(data, data_split='test', context_window_size=3)['test']
dialogue = list()
data = list()
for turn in raw_data:
state = dict()
state['system_utterance'] = turn['context'][-2]['utterance'] if len(turn['context']) > 1 else ''
state['utterance'] = turn['context'][-1]['utterance']
state['system_actions'] = turn['context'][-2]['dialogue_acts'] if len(turn['context']) > 1 else {}
state['system_actions'] = flatten_acts(state['system_actions'])
state['user_actions'] = get_user_actions(turn['context'], state['system_actions'])
dialogue.append(state)
if turn['terminated']:
data.append(dialogue)
dialogue = list()
return data
def unflatten_acts(acts: list) -> dict:
"""
Convert acts from flat list format to dict format.
Args:
acts: List of flat actions.
Returns:
unflat_acts: Dictionary of acts.
"""
binary_acts = []
cat_acts = []
for intent, domain, slot, value in acts:
include = True if (domain == 'general') or (slot != 'none') else False
if include and (value == '' or value == 'none' or intent == 'request'):
binary_acts.append({'intent': intent,
'domain': domain,
'slot': slot if slot != 'none' else ''})
elif include:
cat_acts.append({'intent': intent,
'domain': domain,
'slot': slot if slot != 'none' else '',
'value': value})
unflat_acts = {'categorical': cat_acts, 'binary': binary_acts, 'non-categorical': list()}
return unflat_acts
def predict_user_acts(data: list, tracker: SetSUMBTTracker) -> list:
"""
Predict the user actions using the SetSUMBT Tracker.
Args:
data: List of dialogues.
tracker: SetSUMBT Tracker
Returns:
predict_result: List of turns containing predictions and true user actions.
"""
tracker.init_session()
predict_result = []
for dial_idx, dialogue in enumerate(data):
for turn_idx, state in enumerate(dialogue):
sample = {'dial_idx': dial_idx, 'turn_idx': turn_idx}
tracker.state['history'].append(['sys', state['system_utterance']])
predicted_state = deepcopy(tracker.update(state['utterance']))
tracker.state['history'].append(['usr', state['utterance']])
tracker.state['system_action'] = state['system_actions']
sample['predictions'] = {'dialogue_acts': unflatten_acts(predicted_state['user_action'])}
sample['dialogue_acts'] = unflatten_acts(state['user_actions'])
predict_result.append(sample)
tracker.init_session()
return predict_result
if __name__ =="__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
parser.add_argument('--model_path', type=str, help='Path to model dir')
args = parser.parse_args()
dataset = extract_dataset(args.dataset_name)
tracker = SetSUMBTTracker(args.model_path)
predict_results = predict_user_acts(dataset, tracker)
with open(os.path.join(args.model_path, 'predictions', 'test_nlu.json'), 'w') as writer:
json.dump(predict_results, writer, indent=2)
writer.close()
......@@ -13,6 +13,7 @@ from convlab.dst.dst import DST
from convlab.util.custom_util import model_downloader
USE_CUDA = torch.cuda.is_available()
transformers.logging.set_verbosity_error()
class SetSUMBTTracker(DST):
......@@ -139,14 +140,17 @@ class SetSUMBTTracker(DST):
def init_session(self):
self.state = dict()
self.state['belief_state'] = dict()
self.state['booked'] = dict()
for domain, substate in self.ontology.items():
self.state['belief_state'][domain] = dict()
for slot, slot_info in substate.items():
if slot_info['possible_values'] and slot_info['possible_values'] != ['?']:
self.state['belief_state'][domain][slot] = ''
self.state['booked'][domain] = list()
self.state['history'] = []
self.state['system_action'] = []
self.state['user_action'] = []
self.state['terminated'] = False
self.active_domains = {}
self.hidden_states = None
self.info_dict = {}
......
......@@ -2,6 +2,9 @@ import os
import pickle
import torch
import torch.utils.data as data
from copy import deepcopy
from tqdm import tqdm
from convlab.policy.vector.vector_binary import VectorBinary
from convlab.util import load_policy_data, load_dataset
......@@ -12,18 +15,20 @@ from convlab.policy.vector.dataset import ActDataset
class PolicyDataVectorizer:
def __init__(self, dataset_name='multiwoz21', vector=None):
def __init__(self, dataset_name='multiwoz21', vector=None, dst=None):
self.dataset_name = dataset_name
if vector is None:
self.vector = VectorBinary(dataset_name)
else:
self.vector = vector
self.dst = dst
self.process_data()
def process_data(self):
processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'processed_data/{self.dataset_name}_{type(self.vector).__name__}')
name = f"{self.dataset_name}_"
name += f"{type(self.dst).__name__}_" if self.dst is not None else ""
name += f"{type(self.vector).__name__}"
processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), name)
if os.path.exists(processed_dir):
print('Load processed data file')
self._load_data(processed_dir)
......@@ -42,15 +47,27 @@ class PolicyDataVectorizer:
self.data[split] = []
raw_data = data_split[split]
for data_point in raw_data:
if self.dst is not None:
self.dst.init_session()
for data_point in tqdm(raw_data):
if self.dst is None:
state = default_state()
state['belief_state'] = data_point['context'][-1]['state']
state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
last_system_act = data_point['context'][-2]['dialogue_acts'] \
if len(data_point['context']) > 1 else {}
else:
last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else ''
self.dst.state['history'].append(['sys', last_system_utt])
usr_utt = data_point['context'][-1]['utterance']
state = deepcopy(self.dst.update(usr_utt))
self.dst.state['history'].append(['usr', usr_utt])
last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {}
state['system_action'] = flatten_acts(last_system_act)
state['terminated'] = data_point['terminated']
if self.dst is not None and state['terminated']:
self.dst.init_session()
state['booked'] = data_point['booked']
dialogue_act = flatten_acts(data_point['dialogue_acts'])
......
......@@ -137,15 +137,6 @@ class MLE_Trainer(MLE_Trainer_Abstract):
def __init__(self, manager, vector, cfg):
self._init_data(manager, cfg)
try:
self.use_entropy = manager.use_entropy
self.use_mutual_info = manager.use_mutual_info
self.use_confidence_scores = manager.use_confidence_scores
except:
self.use_entropy = False
self.use_mutual_info = False
self.use_confidence_scores = False
# override the loss defined in the MLE_Trainer_Abstract to support pos_weight
pos_weight = cfg['pos_weight'] * torch.ones(vector.da_dim).to(device=DEVICE)
self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
......@@ -161,6 +152,10 @@ def arg_parser():
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eval_freq", type=int, default=1)
parser.add_argument("--dataset_name", type=str, default="multiwoz21")
parser.add_argument("--use_masking", action='store_true')
parser.add_argument("--dst", type=str, default=None)
parser.add_argument("--dst_args", type=str, default=None)
args = parser.parse_args()
return args
......@@ -181,8 +176,28 @@ if __name__ == '__main__':
set_seed(args.seed)
logging.info(f"Seed used: {args.seed}")
vector = VectorBinary(dataset_name=args.dataset_name, use_masking=False)
manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector)
if args.dst is None:
vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
dst = None
elif args.dst == "setsumbt":
dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ')
if '=' in arg] if args.dst_args is not None else []
dst_args = {key: eval(value) for key, value in dst_args}
from convlab.dst.setsumbt import SetSUMBTTracker
dst = SetSUMBTTracker(**dst_args)
if dst.return_confidence_scores:
from convlab.policy.vector.vector_uncertainty import VectorUncertainty
vector = VectorUncertainty(dataset_name=args.dataset_name, use_masking=args.use_masking,
manually_add_entity_names=False,
use_confidence_scores=dst.return_confidence_scores,
confidence_thresholds=dst.confidence_thresholds,
use_state_total_uncertainty=dst.return_belief_state_entropy,
use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info)
else:
vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
else:
raise NameError(f"Tracker: {args.tracker} not implemented.")
manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst)
agent = MLE_Trainer(manager, vector, cfg)
logging.info('Start training')
......
{
"model": {
"load_path": "supervised",
"load_path": "/gpfs/project/niekerk/src/ConvLab3/convlab/policy/mle/experiments/experiment_2022-08-11-23-55-55/save/supervised",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
"process_num": 4,
"process_num": 2,
"num_eval_dialogues": 500,
"sys_semantic_to_usr": false
"sys_semantic_to_usr": true
},
"vectorizer_sys": {
"uncertainty_vector_mul": {
"class_path": "convlab.policy.vector.vector_multiwoz_uncertainty.MultiWozVector",
"class_path": "convlab.policy.vector.vector_uncertainty.VectorUncertainty",
"ini_params": {
"use_masking": false,
"manually_add_entity_names": false,
"seed": 0
"manually_add_entity_names": true,
"seed": 0,
"use_confidence_scores": true,
"use_state_total_uncertainty": true
}
}
},
"nlu_sys": {},
"dst_sys": {
"setsumbt-mul": {
"class_path": "convlab.dst.setsumbt.multiwoz.Tracker.SetSUMBTTracker",
"class_path": "convlab.dst.setsumbt.SetSUMBTTracker",
"ini_params": {
"model_path": "https://zenodo.org/record/5497808/files/setsumbt_end.zip",
"get_confidence_scores": true,
"return_mutual_info": false,
"return_entropy": true
"model_path": "/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42",
"return_confidence_scores": true,
"return_belief_state_entropy": true
}
}
},
......@@ -41,16 +42,7 @@
}
}
},
"nlu_usr": {
"BERTNLU": {
"class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU",
"ini_params": {
"mode": "sys",
"config_file": "multiwoz_sys_context.json",
"model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip"
}
}
},
"nlu_usr": {},
"dst_usr": {},
"policy_usr": {
"RulePolicy": {
......
......@@ -228,14 +228,6 @@ if __name__ == '__main__':
env, sess = env_config(conf, policy_sys)
# Setup uncertainty thresholding
if env.sys_dst:
try:
if env.sys_dst.use_confidence_scores:
policy_sys.vector.setup_uncertain_query(env.sys_dst.thresholds)
except:
logging.info('Uncertainty threshold not set.')
policy_sys.current_time = current_time
policy_sys.log_dir = config_save_path.replace('configs', 'logs')
policy_sys.save_dir = save_path
......
......@@ -205,6 +205,14 @@ def env_config(conf, policy_sys, check_book_constraints=True):
policy_usr = conf['policy_usr_activated']
usr_nlg = conf['usr_nlg_activated']
# Setup uncertainty thresholding
if dst_sys:
try:
if dst_sys.return_confidence_scores:
policy_sys.vector.setup_uncertain_query(dst_sys.confidence_thresholds)
except:
logging.info('Uncertainty threshold not set.')
simulator = PipelineAgent(nlu_usr, dst_usr, policy_usr, usr_nlg, 'user')
system_pipeline = PipelineAgent(nlu_sys, dst_sys, policy_sys, sys_nlg,
'sys', return_semantic_acts=conf['model']['sys_semantic_to_usr'])
......@@ -531,18 +539,21 @@ def get_config(filepath, args) -> dict:
vec_name = [model for model in conf['vectorizer_sys']]
vec_name = vec_name[0] if vec_name else None
if dst_name and 'setsumbt' in dst_name.lower():
if 'get_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']:
conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = conf['dst_sys'][dst_name]['ini_params']['get_confidence_scores']
if 'return_confidence_scores' in conf['dst_sys'][dst_name]['ini_params']:
param = conf['dst_sys'][dst_name]['ini_params']['return_confidence_scores']
conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = param
else:
conf['vectorizer_sys'][vec_name]['ini_params']['use_confidence_scores'] = False
if 'return_mutual_info' in conf['dst_sys'][dst_name]['ini_params']:
conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = conf['dst_sys'][dst_name]['ini_params']['return_mutual_info']
if 'return_belief_state_mutual_info' in conf['dst_sys'][dst_name]['ini_params']:
param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_mutual_info']
conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = param
else:
conf['vectorizer_sys'][vec_name]['ini_params']['use_mutual_info'] = False
if 'return_entropy' in conf['dst_sys'][dst_name]['ini_params']:
conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = conf['dst_sys'][dst_name]['ini_params']['return_entropy']
conf['vectorizer_sys'][vec_name]['ini_params']['use_state_knowledge_uncertainty'] = False
if 'return_belief_state_entropy' in conf['dst_sys'][dst_name]['ini_params']:
param = conf['dst_sys'][dst_name]['ini_params']['return_belief_state_entropy']
conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = param
else:
conf['vectorizer_sys'][vec_name]['ini_params']['use_entropy'] = False
conf['vectorizer_sys'][vec_name]['ini_params']['use_state_total_uncertainty'] = False
from convlab.nlu import NLU
from convlab.dst import DST
......@@ -571,8 +582,7 @@ def get_config(filepath, args) -> dict:
cls_path = infos.get('class_path', '')
cls = map_class(cls_path)
conf[unit + '_class'] = cls
conf[unit + '_activated'] = conf[unit +
'_class'](**conf[unit][model]['ini_params'])
conf[unit + '_activated'] = conf[unit + '_class'](**conf[unit][model]['ini_params'])
print("Loaded " + model + " for " + unit)
return conf
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment