Skip to content
Snippets Groups Projects
Commit e6fa98e2 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Commit unified training for setsumbt

parent 2c41a4eb
Branches
No related tags found
No related merge requests found
Showing
with 96 additions and 8052 deletions
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calibration Plot plotting script"""
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import torch
from matplotlib import pyplot as plt
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_dir', help='Location of the belief states', required=True)
parser.add_argument('--output', help='Output image path', default='calibration_plot.png')
parser.add_argument('--n_bins', help='Number of bins', default=10, type=int)
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
path = args.data_dir
models = os.listdir(path)
models = [os.path.join(path, model, 'test.belief') for model in models]
fig = plt.figure(figsize=(14,8))
font=20
plt.tick_params(labelsize=font-2)
linestyle = ['-', ':', (0, (3, 5, 1, 5)), '-.', (0, (5, 10))]
for i, model in enumerate(models):
conf, acc = get_calibration(model, device, n_bins=args.n_bins)
name = model.split('/')[-2].strip()
print(name, conf, acc)
plt.plot(conf, acc, label=name, linestyle=linestyle[i], linewidth=3)
plt.plot(torch.tensor([0,1]), torch.tensor([0,1]), linestyle='--', color='black', linewidth=3)
plt.xlabel('Confidence', fontsize=font)
plt.ylabel('Joint Goal Accuracy', fontsize=font)
plt.legend(fontsize=font)
plt.savefig(args.output)
def get_calibration(path, device, n_bins=10, temperature=1.00):
logits = torch.load(path, map_location=device)
y_true = logits['labels']
logits = logits['belief_states']
y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
goal_acc = sum([goal_acc[slot] for slot in goal_acc])
goal_acc = (goal_acc == len(y_true)).int()
scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
scores = torch.cat(scores, 0).min(0)[0]
step = 1.0 / float(n_bins)
bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
bins = []
for b in range(n_bins):
lower, upper = bin_ranges[b], bin_ranges[b + 1]
if b == 0:
ids = torch.where((scores >= lower) * (scores <= upper))[0]
else:
ids = torch.where((scores > lower) * (scores <= upper))[0]
bins.append(ids)
conf = [0.0]
for b in bins:
if b.size(0) > 0:
l = scores[b]
conf.append(l.mean())
else:
conf.append(-1)
conf = torch.tensor(conf)
slot = [s for s in y_true][0]
acc = [0.0]
for b in bins:
if b.size(0) > 0:
acc_ = goal_acc[b]
acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
if acc_.size(0) >= 0:
acc.append(acc_.float().mean())
else:
acc.append(-1)
else:
acc.append(-1)
acc = torch.tensor(acc)
conf = conf[acc != -1]
acc = acc[acc != -1]
return conf, acc
if __name__ == '__main__':
main()
#!/bin/bash
ENSEMBLE_SIZE=10
DATA_SIZE=7500
SEED=$1
OUT=$2
python run.py --run_nbt \
--output_dir $OUT \
--use_descriptions --set_similarity \
--ensemble_size $ENSEMBLE_SIZE \
--data_sampling_size $DATA_SIZE \
--seed $SEED
ENSEMBLE_SIZE=$(($ENSEMBLE_SIZE-1))
for e in $(seq 0 $ENSEMBLE_SIZE);do
mkdir -p "$OUT/ensemble-$e/dataloaders"
mv "$OUT/ensemble-$e/train.dataloader" "$OUT/ensemble-$e/dataloaders/"
cp "$OUT/dataloaders/dev.dataloader" "$OUT/ensemble-$e/dataloaders/"
cp "$OUT/dataloaders/test.dataloader" "$OUT/ensemble-$e/dataloaders/"
cp -r $OUT/database "$OUT/ensemble-$e/"
done
#!/bin/bash
IN=$1
IN_DATA=$2
OUT=$3
mkdir -p $OUT
cp "$IN/database/test.db" "$OUT/ontology.db"
cp "$IN_DATA/ontology_test.db" "$OUT/ontology.json"
cp "$IN/pytorch_model.bin" "$OUT/pytorch_model.bin"
cp "$IN/config.json" "$OUT/config.json"
#!/bin/bash
ENSEMBLE_SIZE=10
SEED=$1
OUT=$2
ENSEMBLE_SIZE=$(($ENSEMBLE_SIZE-1))
for e in $(seq 0 $ENSEMBLE_SIZE);do
cp "$OUT/ensemble-$e/pytorch_model.bin" "$OUT/pytorch_model_$e.bin"
done
cp "$OUT/ensemble-0/config.json" "$OUT/config.json"
for SET in "train" "dev" "test";do
python distillation_setup.py --get_ensemble_distributions \
--model_path $OUT \
--model_type roberta \
--set_type $SET \
--ensemble_size $ENSEMBLE_SIZE \
--reduction mean
done
python distillation_setup.py --build_dataloaders \
--model_path $OUT \
--set_type train \
--batch_size 3
for SET in "dev" "test";do
python distillation_setup.py --build_dataloaders \
--model_path $OUT \
--set_type $SET \
--batch_size 16
done
python run.py --run_nbt \
--output_dir $OUT \
--loss_function distillation \
--use_descriptions --set_similarity \
--do_train --do_eval \
--seed $SEED
#!/bin/bash
ENSEMBLE_SIZE=10
SEED=$1
OUT=$2
ENSEMBLE_SIZE=$(($ENSEMBLE_SIZE-1))
for e in $(seq 0 $ENSEMBLE_SIZE);do
cp "$OUT/ensemble-$e/pytorch_model.bin" "$OUT/pytorch_model_$e.bin"
done
cp "$OUT/ensemble-0/config.json" "$OUT/config.json"
for SET in "train" "dev" "test";do
python distillation_setup.py --get_ensemble_distributions \
--model_path $OUT \
--model_type roberta \
--set_type $SET \
--ensemble_size $ENSEMBLE_SIZE \
--reduction none
done
python distillation_setup.py --build_dataloaders \
--model_path $OUT \
--set_type train \
--batch_size 3
for SET in "dev" "test";do
python distillation_setup.py --build_dataloaders \
--model_path $OUT \
--set_type $SET \
--batch_size 16
done
python run.py --run_nbt \
--output_dir $OUT \
--loss_function "distribution_distillation" \
--use_descriptions --set_similarity \
--do_train --do_eval \
--seed $SEED
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os
import torch
import transformers
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import RobertaConfig, BertConfig
from tqdm import tqdm
import convlab2
from convlab2.dst.setsumbt.multiwoz.dataset.multiwoz21 import EnsembleMultiWoz21
from convlab2.dst.setsumbt.modeling import EnsembleSetSUMBT
DEVICE = 'cuda'
def args_parser():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_path', type=str)
parser.add_argument('--model_type', type=str)
parser.add_argument('--set_type', type=str)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--ensemble_size', type=int)
parser.add_argument('--reduction', type=str, default='mean')
parser.add_argument('--get_ensemble_distributions', action='store_true')
parser.add_argument('--build_dataloaders', action='store_true')
return parser.parse_args()
def main():
args = args_parser()
if args.get_ensemble_distributions:
get_ensemble_distributions(args)
elif args.build_dataloaders:
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
data = torch.load(path)
loader = get_loader(data, args.set_type, args.batch_size)
path = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
torch.save(loader, path)
else:
raise NameError("NotImplemented")
def get_loader(data, set_type='train', batch_size=3):
data = flatten_data(data)
data = do_label_padding(data)
data = EnsembleMultiWoz21(data)
if set_type == 'train':
sampler = RandomSampler(data)
else:
sampler = SequentialSampler(data)
loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return loader
def do_label_padding(data):
if 'attention_mask' in data:
dialogs, turns = torch.where(data['attention_mask'].sum(-1) == 0.0)
else:
dialogs, turns = torch.where(data['input_ids'].sum(-1) == 0.0)
for key in data:
if key not in ['input_ids', 'attention_mask', 'token_type_ids']:
data[key][dialogs, turns] = -1
return data
map_dict = {'belief_state': 'belief', 'greeting_act_belief': 'goodbye_belief',
'state_labels': 'labels', 'request_labels': 'request',
'domain_labels': 'active', 'greeting_labels': 'goodbye'}
def flatten_data(data):
data_new = dict()
for label, feats in data.items():
label = map_dict.get(label, label)
if type(feats) == dict:
for label_, feats_ in feats.items():
data_new[label + '-' + label_] = feats_
else:
data_new[label] = feats
return data_new
def get_ensemble_distributions(args):
if args.model_type == 'roberta':
config = RobertaConfig
elif args.model_type == 'bert':
config = BertConfig
config = config.from_pretrained(args.model_path)
config.ensemble_size = args.ensemble_size
device = DEVICE
model = EnsembleSetSUMBT.from_pretrained(args.model_path)
model = model.to(device)
print('Model Loaded!')
dataloader = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.dataloader')
database = os.path.join(args.model_path, 'database', f'{args.set_type}.db')
dataloader = torch.load(dataloader)
database = torch.load(database)
# Get slot and value embeddings
slots = {slot: val for slot, val in database.items()}
values = {slot: val[1] for slot, val in database.items()}
del database
# Load model ontology
model.add_slot_candidates(slots)
for slot in model.informable_slot_ids:
model.add_value_candidates(slot, values[slot], replace=True)
del slots, values
print('Environment set up.')
input_ids = []
token_type_ids = []
attention_mask = []
state_labels = {slot: [] for slot in model.informable_slot_ids}
request_labels = {slot: [] for slot in model.requestable_slot_ids}
domain_labels = {domain: [] for domain in model.domain_ids}
greeting_labels = []
belief_state = {slot: [] for slot in model.informable_slot_ids}
request_belief = {slot: [] for slot in model.requestable_slot_ids}
domain_belief = {domain: [] for domain in model.domain_ids}
greeting_act_belief = []
model.eval()
for batch in tqdm(dataloader, desc='Batch:'):
ids = batch['input_ids']
tt_ids = batch['token_type_ids'] if 'token_type_ids' in batch else None
mask = batch['attention_mask'] if 'attention_mask' in batch else None
input_ids.append(ids)
token_type_ids.append(tt_ids)
attention_mask.append(mask)
ids = ids.to(device)
tt_ids = tt_ids.to(device) if tt_ids is not None else None
mask = mask.to(device) if mask is not None else None
for slot in state_labels:
state_labels[slot].append(batch['labels-' + slot])
if model.config.predict_intents:
for slot in request_labels:
request_labels[slot].append(batch['request-' + slot])
for domain in domain_labels:
domain_labels[domain].append(batch['active-' + domain])
greeting_labels.append(batch['goodbye'])
with torch.no_grad():
p, p_req, p_dom, p_bye, _ = model(ids, mask, tt_ids,
reduction=args.reduction)
for slot in belief_state:
belief_state[slot].append(p[slot].cpu())
if model.config.predict_intents:
for slot in request_belief:
request_belief[slot].append(p_req[slot].cpu())
for domain in domain_belief:
domain_belief[domain].append(p_dom[domain].cpu())
greeting_act_belief.append(p_bye.cpu())
input_ids = torch.cat(input_ids, 0) if input_ids[0] is not None else None
token_type_ids = torch.cat(token_type_ids, 0) if token_type_ids[0] is not None else None
attention_mask = torch.cat(attention_mask, 0) if attention_mask[0] is not None else None
state_labels = {slot: torch.cat(l, 0) for slot, l in state_labels.items()}
if model.config.predict_intents:
request_labels = {slot: torch.cat(l, 0) for slot, l in request_labels.items()}
domain_labels = {domain: torch.cat(l, 0) for domain, l in domain_labels.items()}
greeting_labels = torch.cat(greeting_labels, 0)
belief_state = {slot: torch.cat(p, 0) for slot, p in belief_state.items()}
if model.config.predict_intents:
request_belief = {slot: torch.cat(p, 0) for slot, p in request_belief.items()}
domain_belief = {domain: torch.cat(p, 0) for domain, p in domain_belief.items()}
greeting_act_belief = torch.cat(greeting_act_belief, 0)
data = {'input_ids': input_ids}
if token_type_ids is not None:
data['token_type_ids'] = token_type_ids
if attention_mask is not None:
data['attention_mask'] = attention_mask
data['state_labels'] = state_labels
data['belief_state'] = belief_state
if model.config.predict_intents:
data['request_labels'] = request_labels
data['domain_labels'] = domain_labels
data['greeting_labels'] = greeting_labels
data['request_belief'] = request_belief
data['domain_belief'] = domain_belief
data['greeting_act_belief'] = greeting_act_belief
file = os.path.join(args.model_path, 'dataloaders', f'{args.set_type}.data')
torch.save(data, file)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -21,28 +21,19 @@ import os
from shutil import copy2 as copy
import torch
from torch.nn import DataParallel
from transformers import (BertModel, BertConfig, BertTokenizer,
RobertaModel, RobertaConfig, RobertaTokenizer,
AdamW, get_linear_schedule_with_warmup)
from tqdm import tqdm, trange
import numpy as np
RobertaModel, RobertaConfig, RobertaTokenizer)
from tensorboardX import SummaryWriter
from convlab2.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
from convlab2.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
from convlab2.dst.setsumbt.multiwoz import multiwoz21
from convlab2.dst.setsumbt.unified_format_data import unified_format
from convlab2.dst.setsumbt.modeling import training
from convlab2.dst.setsumbt.multiwoz import ontology as embeddings
from convlab2.dst.setsumbt.unified_format_data import ontology as embeddings
from convlab2.dst.setsumbt.utils import get_args, update_args
from convlab2.dst.setsumbt.modeling import ensemble_utils
# Datasets
DATASETS = {
'multiwoz21': multiwoz21
}
# from convlab2.dst.setsumbt.modeling import ensemble_utils
# Available model
MODELS = {
'bert': (BertSetSUMBT, BertModel, BertConfig, BertTokenizer),
'roberta': (RobertaSetSUMBT, RobertaModel, RobertaConfig, RobertaTokenizer)
......@@ -54,12 +45,6 @@ def main(args=None, config=None):
if args is None:
args, config = get_args(MODELS)
# Select Dataset object
if args.dataset in DATASETS:
Dataset = DATASETS[args.dataset]
else:
raise NameError('NotImplemented')
if args.model_type in MODELS:
SetSumbtModel, CandidateEncoderModel, ConfigClass, Tokenizer = MODELS[args.model_type]
else:
......@@ -107,20 +92,6 @@ def main(args=None, config=None):
args = update_args(args, config)
# Set up data directory
DATA_DIR = args.data_dir
Dataset.set_datadir(DATA_DIR)
embeddings.set_datadir(DATA_DIR)
# If use shrinked domains, remove bus and hospital domains from the training data and model ontology
if args.shrink_active_domains and args.dataset == 'multiwoz21':
Dataset.set_active_domains(
['attraction', 'hotel', 'restaurant', 'taxi', 'train'])
# Download and preprocess
Dataset.create_examples(
args.max_turn_len, args.predict_actions, args.force_processing)
# Create TensorboardX writer
tb_writer = SummaryWriter(logdir=args.tensorboard_path)
......@@ -169,88 +140,85 @@ def main(args=None, config=None):
training.set_seed(args)
embeddings.set_seed(args)
if args.ensemble_size > 1:
ensemble_utils.set_logger(logger, tb_writer)
ensemble.set_seed(args)
logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
args.data_sampling_size))
dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset)
logger.info('Dataloaders built.')
for i, loader in enumerate(dataloaders):
path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i)
if not os.path.exists(path):
os.mkdir(path)
path = os.path.join(path, 'train.dataloader')
torch.save(loader, path)
logger.info('Dataloaders saved.')
train_slots = embeddings.get_slot_candidate_embeddings(
'train', args, tokenizer, encoder)
dev_slots = embeddings.get_slot_candidate_embeddings(
'dev', args, tokenizer, encoder)
test_slots = embeddings.get_slot_candidate_embeddings(
'test', args, tokenizer, encoder)
train_dataloader = Dataset.get_dataloader(
'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
torch.save(dev_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
dev_dataloader = Dataset.get_dataloader(
'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
torch.save(dev_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
test_dataloader = Dataset.get_dataloader(
'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
torch.save(test_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
# Do not perform standard training after ensemble setup is created
return 0
# if args.ensemble_size > 1:
# ensemble_utils.set_logger(logger, tb_writer)
# ensemble.set_seed(args)
# logger.info('Building %i resampled dataloaders each of size %i' % (args.ensemble_size,
# args.data_sampling_size))
# dataloaders = ensemble_utils.build_train_loaders(args, tokenizer, Dataset)
# logger.info('Dataloaders built.')
# for i, loader in enumerate(dataloaders):
# path = os.path.join(OUTPUT_DIR, 'ensemble-%i' % i)
# if not os.path.exists(path):
# os.mkdir(path)
# path = os.path.join(path, 'train.dataloader')
# torch.save(loader, path)
# logger.info('Dataloaders saved.')
#
# train_slots = embeddings.get_slot_candidate_embeddings(
# 'train', args, tokenizer, encoder)
# dev_slots = embeddings.get_slot_candidate_embeddings(
# 'dev', args, tokenizer, encoder)
# test_slots = embeddings.get_slot_candidate_embeddings(
# 'test', args, tokenizer, encoder)
#
# train_dataloader = Dataset.get_dataloader(
# 'train', args.train_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
# torch.save(dev_dataloader, os.path.join(
# OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
# dev_dataloader = Dataset.get_dataloader(
# 'dev', args.dev_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
# torch.save(dev_dataloader, os.path.join(
# OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
# test_dataloader = Dataset.get_dataloader(
# 'test', args.test_batch_size, tokenizer, args.max_dialogue_len, config.max_turn_len)
# torch.save(test_dataloader, os.path.join(
# OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
#
# # Do not perform standard training after ensemble setup is created
# return 0
# Perform tasks
# TRAINING
if args.do_train:
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
train_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'train.db'))
else:
train_slots = embeddings.get_slot_candidate_embeddings(
'train', args, tokenizer, encoder)
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
dev_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'dev.db'))
else:
dev_slots = embeddings.get_slot_candidate_embeddings(
'dev', args, tokenizer, encoder)
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader')):
train_dataloader = torch.load(os.path.join(
OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
train_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
if train_dataloader.batch_size == args.train_batch_size:
exists = True
if not exists:
if args.data_sampling_size <= 0:
args.data_sampling_size = None
train_dataloader = Dataset.get_dataloader('train', args.train_batch_size, tokenizer, args.max_dialogue_len,
train_dataloader = unified_format.get_dataloader(args.dataset, 'train',
args.train_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len, resampled_size=args.data_sampling_size)
torch.save(train_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
torch.save(train_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'train.dataloader'))
# Get training batch loaders and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'train.db')):
train_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'train.db'))
else:
train_slots = embeddings.get_slot_candidate_embeddings(train_dataloader.dataset.ontology,
'train', args, tokenizer, encoder)
# Get development set batch loaders= and ontology embeddings
if args.do_eval:
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(
OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if dev_dataloader.batch_size == args.dev_batch_size:
exists = True
if not exists:
dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
dev_dataloader = unified_format.get_dataloader(args.dataset, 'validation',
args.dev_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(dev_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
else:
dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
'dev', args, tokenizer, encoder)
else:
dev_dataloader = None
dev_slots = None
......@@ -275,37 +243,33 @@ def main(args=None, config=None):
os.path.join(OUTPUT_DIR, 'config.json'))
# Load best model for evaluation
model = SumbtModel.from_pretrained(OUTPUT_DIR)
model = SetSumbtModel.from_pretrained(OUTPUT_DIR)
model = model.to(device)
# Evaluation on the development set
if args.do_eval:
# Get development set batch loaders= and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
dev_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'dev.db'))
else:
dev_slots = embeddings.get_slot_candidate_embeddings(
'dev', args, tokenizer, encoder)
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader')):
dev_dataloader = torch.load(os.path.join(
OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
dev_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if dev_dataloader.batch_size == args.dev_batch_size:
exists = True
if not exists:
dev_dataloader = Dataset.get_dataloader('dev', args.dev_batch_size, tokenizer, args.max_dialogue_len,
dev_dataloader = unified_format.get_dataloader(args.dataset, 'validation',
args.dev_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(dev_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
torch.save(dev_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'dev.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'dev.db')):
dev_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'dev.db'))
else:
dev_slots = embeddings.get_slot_candidate_embeddings(dev_dataloader.dataset.ontology,
'dev', args, tokenizer, encoder)
# Load model ontology
training.set_ontology_embeddings(model, dev_slots)
# EVALUATION
jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
args, model, device, dev_dataloader)
jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(args, model, device, dev_dataloader)
if req_f1:
logger.info('Development loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
% (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
......@@ -315,32 +279,28 @@ def main(args=None, config=None):
# Evaluation on the test set
if args.do_test:
# Get test set batch loaders= and ontology embeddings
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
test_slots = torch.load(os.path.join(
OUTPUT_DIR, 'database', 'test.db'))
else:
test_slots = embeddings.get_slot_candidate_embeddings(
'test', args, tokenizer, encoder)
exists = False
if os.path.exists(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader')):
test_dataloader = torch.load(os.path.join(
OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
test_dataloader = torch.load(os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if test_dataloader.batch_size == args.test_batch_size:
exists = True
if not exists:
test_dataloader = Dataset.get_dataloader('test', args.test_batch_size, tokenizer, args.max_dialogue_len,
test_dataloader = unified_format.get_dataloader(args.dataset, 'test',
args.test_batch_size, tokenizer, args.max_dialogue_len,
config.max_turn_len)
torch.save(test_dataloader, os.path.join(
OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
torch.save(test_dataloader, os.path.join(OUTPUT_DIR, 'dataloaders', 'test.dataloader'))
if os.path.exists(os.path.join(OUTPUT_DIR, 'database', 'test.db')):
test_slots = torch.load(os.path.join(OUTPUT_DIR, 'database', 'test.db'))
else:
test_slots = embeddings.get_slot_candidate_embeddings(test_dataloader.dataset.ontology,
'test', args, tokenizer, encoder)
# Load model ontology
training.set_ontology_embeddings(model, test_slots)
# TESTING
jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(
args, model, device, test_dataloader)
jg_acc, sl_acc, req_f1, dom_f1, bye_f1, loss = training.evaluate(args, model, device, test_dataloader)
if req_f1:
logger.info('Test loss: %f, Joint Goal Accuracy: %f, Slot Accuracy: %f, Request F1 Score: %f, Domain F1 Score: %f, Goodbye F1 Score: %f'
% (loss, jg_acc, sl_acc, req_f1, dom_f1, bye_f1))
......
from convlab2.dst.setsumbt.multiwoz.dataset import multiwoz21, ontology
from convlab2.dst.setsumbt.multiwoz.Tracker import SetSUMBTTracker
\ No newline at end of file
it's it is
don't do not
doesn't does not
didn't did not
you'd you would
you're you are
you'll you will
i'm i am
they're they are
that's that is
what's what is
couldn't could not
i've i have
we've we have
can't cannot
i'd i would
i'd i would
aren't are not
isn't is not
wasn't was not
weren't were not
won't will not
there's there is
there're there are
. . .
restaurants restaurant -s
hotels hotel -s
laptops laptop -s
cheaper cheap -er
dinners dinner -s
lunches lunch -s
breakfasts breakfast -s
expensively expensive -ly
moderately moderate -ly
cheaply cheap -ly
prices price -s
places place -s
venues venue -s
ranges range -s
meals meal -s
locations location -s
areas area -s
policies policy -s
children child -s
kids kid -s
kidfriendly kid friendly
cards card -s
upmarket expensive
inpricey cheap
inches inch -s
uses use -s
dimensions dimension -s
driverange drive range
includes include -s
computers computer -s
machines machine -s
families family -s
ratings rating -s
constraints constraint -s
pricerange price range
batteryrating battery rating
requirements requirement -s
drives drive -s
specifications specification -s
weightrange weight range
harddrive hard drive
batterylife battery life
businesses business -s
hours hour -s
one 1
two 2
three 3
four 4
five 5
six 6
seven 7
eight 8
nine 9
ten 10
eleven 11
twelve 12
anywhere any where
good bye goodbye
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MultiWOZ 2.1/2.3 Dialogue Dataset"""
import os
import json
import requests
import zipfile
import io
from shutil import copy2 as copy
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from convlab2.dst.setsumbt.multiwoz.dataset.utils import (clean_text, ACTIVE_DOMAINS, get_domains, set_util_domains,
fix_delexicalisation, extract_dialogue, PRICERANGE,
BOOLEAN, DAYS, QUANTITIES, TIME, VALUE_MAP, map_values)
# Set up global data_directory
def set_datadir(dir):
global DATA_DIR
DATA_DIR = dir
def set_active_domains(domains):
global ACTIVE_DOMAINS
ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
set_util_domains(ACTIVE_DOMAINS)
# MultiWOZ2.1 download link
URL = 'https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip'
def set_url(url):
global URL
URL = url
# Create Dialogue examples from the dataset
def create_examples(max_utt_len, get_requestable_slots=False, force_processing=False):
# Load or download Raw Data
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'data_raw.json')):
# Download data archive and extract
archive = _download()
data = _extract(archive)
writer = open(os.path.join(DATA_DIR, 'data_raw.json'), 'w')
json.dump(data, writer, indent = 2)
del archive, writer
else:
reader = open(os.path.join(DATA_DIR, 'data_raw.json'), 'r')
data = json.load(reader)
if force_processing or not os.path.exists(os.path.join(DATA_DIR, 'data_train.json')):
# Preprocess all dialogues
data_processed = _process(data['data'], data['system_acts'])
# Format data and split train, test and devlopment sets
train, dev, test = _split_data(data_processed, data['testListFile'],
data['valListFile'], max_utt_len)
# Write data
writer = open(os.path.join(DATA_DIR, 'data_train.json'), 'w')
json.dump(train, writer, indent = 2)
writer = open(os.path.join(DATA_DIR, 'data_test.json'), 'w')
json.dump(test, writer, indent = 2)
writer = open(os.path.join(DATA_DIR, 'data_dev.json'), 'w')
json.dump(dev, writer, indent = 2)
writer.flush()
writer.close()
del writer
# Extract slots and slot value candidates from the dataset
for set_type in ['train', 'dev', 'test']:
_get_ontology(set_type, get_requestable_slots)
script_path = os.path.abspath(__file__).replace('/multiwoz21.py', '')
file_name = 'mwoz21_ont_request.json' if get_requestable_slots else 'mwoz21_ont.json'
copy(os.path.join(script_path, file_name), os.path.join(DATA_DIR, 'ontology_test.json'))
copy(os.path.join(script_path, 'mwoz21_slot_descriptions.json'), os.path.join(DATA_DIR, 'slot_descriptions.json'))
# Extract slots and slot value candidates from the dataset
def _get_ontology(set_type, get_requestable_slots=False):
datasets = ['train']
if set_type in ['test', 'dev']:
datasets.append('dev')
datasets.append('test')
# Load examples
data = []
for dataset in datasets:
reader = open(os.path.join(DATA_DIR, 'data_%s.json' % dataset), 'r')
data += json.load(reader)
ontology = dict()
for dial in data:
for turn in dial['dialogue']:
for state in turn['dialogue_state']:
slot, value = state
value = map_values(value)
if slot not in ontology:
ontology[slot] = [value]
else:
ontology[slot].append(value)
requestable_slots = []
if get_requestable_slots:
for dial in data:
for turn in dial['dialogue']:
for act, dom, slot, val in turn['user_acts']:
if act == 'request':
requestable_slots.append(f'{dom}-{slot}')
requestable_slots = list(set(requestable_slots))
for slot in ontology:
if 'price' in slot:
ontology[slot] = PRICERANGE
if 'parking' in slot or 'internet' in slot:
ontology[slot] = BOOLEAN
if 'day' in slot:
ontology[slot] = DAYS
if 'people' in slot or 'duration' in slot or 'stay' in slot:
ontology[slot] = QUANTITIES
if 'time' in slot or 'leave' in slot or 'arrive' in slot:
ontology[slot] = TIME
if 'stars' in slot:
ontology[slot] += [str(i) for i in range(5)]
# Sort slot values and add none and dontcare values
for slot in ontology:
ontology[slot] = list(set(ontology[slot]))
ontology[slot] = ['none', 'do not care'] + sorted([s for s in ontology[slot] if s not in ['none', 'do not care']])
for slot in requestable_slots:
if slot in ontology:
ontology[slot].append('request')
else:
ontology[slot] = ['request']
writer = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'w')
json.dump(ontology, writer, indent=2)
writer.close()
# Convert dialogue examples to model input features and labels
def convert_examples_to_features(set_type, tokenizer, max_turns=12, max_seq_len=64):
features = dict()
# Load examples
reader = open(os.path.join(DATA_DIR, 'data_%s.json' % set_type), 'r')
data = json.load(reader)
# Get encoder input for system, user utterance pairs
input_feats = []
for dial in data:
dial_feats = []
for turn in dial['dialogue']:
if len(turn['system_transcript']) == 0:
usr = turn['transcript']
dial_feats.append(tokenizer.encode_plus(usr, add_special_tokens = True,
max_length = max_seq_len, padding='max_length',
truncation = 'longest_first'))
else:
usr = turn['transcript']
sys = turn['system_transcript']
dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens = True,
max_length = max_seq_len, padding='max_length',
truncation = 'longest_first'))
if len(dial_feats) >= max_turns:
break
input_feats.append(dial_feats)
del dial_feats
# Perform turn level padding
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]:
token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
else:
token_type_ids = None
if 'attention_mask' in input_feats[0][0]:
attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats]
else:
attention_mask = None
del input_feats
# Create torch data tensors
features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
del input_ids, token_type_ids, attention_mask
# Load ontology
reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
ontology = json.load(reader)
reader.close()
informable_slots = [slot for slot, values in ontology.items() if values != ['request']]
requestable_slots = [slot for slot, values in ontology.items() if 'request' in values]
for slot in requestable_slots:
ontology[slot].remove('request')
domains = list(set(informable_slots + requestable_slots))
domains = list(set([slot.split('-', 1)[0] for slot in domains]))
# Create slot labels
for slot in informable_slots:
labels = []
for dial in data:
labs = []
for turn in dial['dialogue']:
slots_active = [s for s, v in turn['dialogue_state']]
if slot in slots_active:
value = [v for s, v in turn['dialogue_state'] if s == slot][0]
else:
value = 'none'
if value in ontology[slot]:
value = ontology[slot].index(value)
else:
value = map_values(value)
if value in ontology[slot]:
value = ontology[slot].index(value)
else:
value = -1 # If value is not in ontology then we do not penalise the model
labs.append(value)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['labels-' + slot] = labels
for slot in requestable_slots:
labels = []
for dial in data:
labs = []
for turn in dial['dialogue']:
slots_active = [[d, s] for i, d, s, v in turn['user_acts']]
if slot.split('-', 1) in slots_active:
act_ = [i for i, d, s, v in turn['user_acts'] if f"{d}-{s}" == slot][0]
if act_ == 'request':
labs.append(1)
else:
labs.append(0)
else:
labs.append(0)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['request-' + slot] = labels
# Greeting act labels (0-no greeting, 1-goodbye, 2-thank you)
labels = []
for dial in data:
labs = []
for turn in dial['dialogue']:
greeting_active = [i for i, d, s, v in turn['user_acts'] if i in ['bye', 'thank']]
if greeting_active:
if 'bye' in greeting_active:
labs.append(1)
else :
labs.append(2)
else:
labs.append(0)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['goodbye'] = labels
for domain in domains:
labels = []
for dial in data:
labs = []
for turn in dial['dialogue']:
if domain == turn['domain']:
labs.append(1)
else:
labs.append(0)
if len(labs) >= max_turns:
break
labs = labs + [-1] * (max_turns - len(labs))
labels.append(labs)
labels = torch.tensor(labels)
features['active-' + domain] = labels
del labels
return features
# MultiWOZ2.1 Dataset object
class MultiWoz21(Dataset):
def __init__(self, set_type, tokenizer, max_turns=12, max_seq_len=64):
self.features = convert_examples_to_features(set_type, tokenizer, max_turns, max_seq_len)
def __getitem__(self, index):
return {label: self.features[label][index] for label in self.features
if self.features[label] is not None}
def __len__(self):
return self.features['input_ids'].size(0)
def resample(self, size=None):
n_dialogues = self.__len__()
if not size:
size = n_dialogues
dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
self.features = {label: self.features[label][dialogues] for label in self.features
if self.features[label] is not None}
return self
def to(self, device):
self.device = device
self.features = {label: self.features[label].to(device) for label in self.features
if self.features[label] is not None}
# MultiWOZ2.1 Dataset object
class EnsembleMultiWoz21(Dataset):
def __init__(self, data):
self.features = data
def __getitem__(self, index):
return {label: self.features[label][index] for label in self.features
if self.features[label] is not None}
def __len__(self):
return self.features['input_ids'].size(0)
def resample(self, size=None):
n_dialogues = self.__len__()
if not size:
size = n_dialogues
dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
self.features = {label: self.features[label][dialogues] for label in self.features
if self.features[label] is not None}
def to(self, device):
self.device = device
self.features = {label: self.features[label].to(device) for label in self.features
if self.features[label] is not None}
# Module to create torch dataloaders
def get_dataloader(set_type, batch_size, tokenizer, max_turns=12, max_seq_len=64, device=None, resampled_size=None):
data = MultiWoz21(set_type, tokenizer, max_turns, max_seq_len)
data.to('cpu')
if resampled_size:
data.resample(resampled_size)
if set_type in ['test', 'dev']:
sampler = SequentialSampler(data)
else:
sampler = RandomSampler(data)
loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return loader
def _download(chunk_size=1048576):
"""Download data archive.
Parameters:
chunk_size (int): Download chunk size. (default=1048576)
Returns:
archive: ZipFile archive object.
"""
# Download the archive byte string
req = requests.get(URL, stream=True)
archive = b''
for n_chunks, chunk in tqdm(enumerate(req.iter_content(chunk_size=chunk_size)), desc='Download Chunk'):
if chunk:
archive += chunk
# Convert the bytestring into a zipfile object
archive = io.BytesIO(archive)
archive = zipfile.ZipFile(archive)
return archive
def _extract(archive):
"""Extract the json dictionaries from the archive.
Parameters:
archive: ZipFile archive object.
Returns:
data: Data dictionary.
"""
files = [file for file in archive.filelist if ('.json' in file.filename or '.txt' in file.filename)
and 'MACOSX' not in file.filename]
objects = []
for file in tqdm(files, desc='File'):
data = archive.open(file).read()
# Get data objects from the files
try:
data = json.loads(data)
except json.decoder.JSONDecodeError:
data = data.decode().split('\n')
objects.append(data)
files = [file.filename.split('/')[-1].split('.')[0] for file in files]
data = {file: data for file, data in zip(files, objects)}
return data
# Process files
def _process(dialogue_data, acts_data):
print('Processing Dialogues')
out = {}
for dial_name in tqdm(dialogue_data):
dialogue = dialogue_data[dial_name]
prev_dom = ''
for turn_id, turn in enumerate(dialogue['log']):
dialogue['log'][turn_id]['text'] = clean_text(turn['text'])
if len(turn['metadata']) != 0:
crnt_dom = get_domains(dialogue['log'], turn_id, prev_dom)
prev_dom = crnt_dom
dialogue['log'][turn_id - 1]['domain'] = crnt_dom
dialogue['log'][turn_id] = fix_delexicalisation(turn)
out[dial_name] = dialogue
return out
# Split data (train, dev, test)
def _split_data(dial_data, test, dev, max_utt_len):
train_dials, test_dials, dev_dials = [], [], []
print('Formatting and Splitting Data')
for name in tqdm(dial_data):
dialogue = dial_data[name]
domains = []
dial = extract_dialogue(dialogue, max_utt_len)
if dial:
dialogue = dict()
dialogue['dialogue_idx'] = name
dialogue['domains'] = []
dialogue['dialogue'] = []
for turn_id, turn in enumerate(dial):
turn_dialog = dict()
turn_dialog['system_transcript'] = dial[turn_id - 1]['sys'] if turn_id > 0 else ''
turn_dialog['turn_idx'] = turn_id
turn_dialog['dialogue_state'] = turn['ds']
turn_dialog['transcript'] = turn['usr']
# turn_dialog['system_acts'] = dial[turn_id - 1]['sys_a'] if turn_id > 0 else []
turn_dialog['user_acts'] = turn['usr_a']
turn_dialog['domain'] = turn['domain']
dialogue['domains'].append(turn['domain'])
dialogue['dialogue'].append(turn_dialog)
dialogue['domains'] = [d for d in list(set(dialogue['domains'])) if d != '']
if True in [dom not in ACTIVE_DOMAINS for dom in dialogue['domains']]:
dialogue['domains'] = []
dialogue['domains'] = [dom for dom in dialogue['domains'] if dom in ACTIVE_DOMAINS]
if dialogue['domains']:
if name in test:
test_dials.append(dialogue)
elif name in dev:
dev_dials.append(dialogue)
else:
train_dials.append(dialogue)
print('Number of Dialogues:\nTrain: %i\nDev: %i\nTest: %i' % (len(train_dials), len(dev_dials), len(test_dials)))
return train_dials, dev_dials, test_dials
This diff is collapsed.
This diff is collapsed.
{
"hotel-price range": "preferred cost or price of the hotel",
"hotel-type": "what is the type of the hotel",
"hotel-parking": "does the hotel have parking",
"hotel-book stay": "number of nights for the hotel reservation",
"hotel-book day": "starting day of the hotel booking",
"hotel-book people": "number of people for the hotel booking",
"hotel-area": "area or place of the hotel",
"hotel-stars": "star rating of the hotel",
"hotel-internet": "does the hotel have internet or wifi",
"hotel-name": "name of the hotel",
"hotel-phone": "phone number of the hotel",
"hotel-postcode": "postcode of the hotel",
"hotel-reference": "booking reference of the hotel booking",
"hotel-address": "street address of the hotel",
"train-destination": "train station you want to travel to",
"train-day": "day of the train booking",
"train-departure": "train station you want to leave from",
"train-arrive by": "arrival time of the train",
"train-book people": "number of people for the train booking",
"train-leave at": "departure time for the train",
"train-duration": "duration of the train journey",
"train-trainid": "train identifier or number",
"train-price": "how much does the train trip cost",
"train-reference": "booking reference of the train booking",
"attraction-type": "type of attraction or point of interest",
"attraction-area": "area or place of the attraction",
"attraction-name": "name of the attraction",
"attraction-phone": "phone number of the attraction",
"attraction-entrance fee": "entrace fee at the attraction",
"attraction-address": "street address of the attraction",
"attraction-postcode": "postcode of the attraction",
"restaurant-book people": "number of people for the restaurant booking",
"restaurant-book day": "weekday for the restaurant booking",
"restaurant-book time": "time of the restaurant booking",
"restaurant-food": "type of food served at the restaurant",
"restaurant-price range": "preferred cost or price of the restaurant",
"restaurant-name": "name of the restaurant",
"restaurant-area": "area or place of the restaurant",
"restaurant-postcode": "postcode of the restaurant",
"restaurant-phone": "phone number of the restaurant",
"restaurant-address": "street address of the restaurant",
"restaurant-reference": "booking reference of the hotel booking",
"taxi-leave at": "what time you want the taxi to leave by",
"taxi-destination": "where you want the taxi to drop you off",
"taxi-departure": "where you want the taxi to pick you up",
"taxi-arrive by": "what time you to arrive at your destination",
"taxi-taxi types": "vehicle type of the taxi",
"taxi-taxi phone": "phone number of the taxi",
"hospital-department": "name of hospital department",
"hospital-address": "street address of the hospital",
"hospital-phone": "phone number of the hospital",
"hospital-postcode": "postcode of the hospital",
"police-postcode": "postcode of the police station",
"police-address": "street address of the police station",
"police-phone": "phone number of the police station"
}
\ No newline at end of file
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create Ontology Embeddings"""
import json
import os
import random
import torch
import numpy as np
# Slot mapping table for description extractions
# SLOT_NAME_MAPPINGS = {
# 'arrive at': 'arriveAt',
# 'arrive by': 'arriveBy',
# 'leave at': 'leaveAt',
# 'leave by': 'leaveBy',
# 'arriveby': 'arriveBy',
# 'arriveat': 'arriveAt',
# 'leaveat': 'leaveAt',
# 'leaveby': 'leaveBy',
# 'price range': 'pricerange'
# }
# Set up global data directory
def set_datadir(dir):
global DATA_DIR
DATA_DIR = dir
# Set seeds
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
# Get embeddings for slots and candidates
def get_slot_candidate_embeddings(set_type, args, tokenizer, embedding_model, save_to_file=True):
# Get set alots and candidates
reader = open(os.path.join(DATA_DIR, 'ontology_%s.json' % set_type), 'r')
ontology = json.load(reader)
reader.close()
reader = open(os.path.join(DATA_DIR, 'slot_descriptions.json'), 'r')
slot_descriptions = json.load(reader)
reader.close()
embedding_model.eval()
slots = dict()
for slot in ontology:
if args.use_descriptions:
# d, s = slot.split('-', 1)
# s = SLOT_NAME_MAPPINGS[s] if s in SLOT_NAME_MAPPINGS else s
# s = d + '-' + s
# if slot in slot_descriptions:
desc = slot_descriptions[slot]
# elif slot.lower() in slot_descriptions:
# desc = slot_descriptions[s.lower()]
# else:
# desc = slot.replace('-', ' ')
else:
desc = slot
# Tokenize slot and get embeddings
feats = tokenizer.encode_plus(desc, add_special_tokens = True,
max_length = args.max_slot_len, padding='max_length',
truncation = 'longest_first')
with torch.no_grad():
input_ids = torch.tensor([feats['input_ids']]).to(embedding_model.device) # [1, max_slot_len]
if 'token_type_ids' in feats:
token_type_ids = torch.tensor([feats['token_type_ids']]).to(embedding_model.device) # [1, max_slot_len]
if 'attention_mask' in feats:
attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device) # [1, max_slot_len]
embedded_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask)
attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, embedded_feats.last_hidden_state.size(-1)))
feats = embedded_feats.last_hidden_state * attention_mask # [1, max_slot_len, hidden_dim]
else:
embedded_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids)
else:
if 'attention_mask' in feats:
attention_mask = torch.tensor([feats['attention_mask']]).to(embedding_model.device)
embedded_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, embedded_feats.last_hidden_state.size(-1)))
feats = embedded_feats.last_hidden_state * attention_mask # [1, max_slot_len, hidden_dim]
else:
embedded_feats = embedding_model(input_ids=input_ids) # [1, max_slot_len, hidden_dim]
if args.set_similarity:
slot_emb = feats[0, :, :].detach().cpu() # [seq_len, hidden_dim]
else:
if args.candidate_pooling == 'cls' and pooled_feats is not None:
slot_emb = embedded_feats.pooler_output[0, :].detach().cpu() # [hidden_dim]
elif args.candidate_pooling == 'mean':
feats = feats.sum(1)
feats = torch.nn.functional.layer_norm(feats, feats.size())
slot_emb = feats[0, :].detach().cpu() # [hidden_dim]
# Tokenize value candidates and get embeddings
values = ontology[slot]
is_requestable = False
if 'request' in values:
is_requestable = True
values.remove('request')
if values:
feats = [tokenizer.encode_plus(val, add_special_tokens = True,
max_length = args.max_candidate_len, padding='max_length',
truncation = 'longest_first')
for val in values]
with torch.no_grad():
input_ids = torch.tensor([f['input_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
if 'token_type_ids' in feats[0]:
token_type_ids = torch.tensor([f['token_type_ids'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
if 'attention_mask' in feats[0]:
attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device) # [num_candidates, max_candidate_len]
embedded_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask)
attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, embedded_feats.last_hidden_state.size(-1)))
feats = embedded_feats.last_hidden_state * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
else:
embedded_feats = embedding_model(input_ids=input_ids, token_type_ids=token_type_ids) # [num_candidates, max_candidate_len, hidden_dim]
else:
if 'attention_mask' in feats[0]:
attention_mask = torch.tensor([f['attention_mask'] for f in feats]).to(embedding_model.device)
embedded_feats = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
attention_mask = attention_mask.unsqueeze(-1).repeat((1, 1, embedded_feats.last_hidden_state.size(-1)))
feats = embedded_feats.last_hidden_state * attention_mask # [num_candidates, max_candidate_len, hidden_dim]
else:
embedded_feats = embedding_model(input_ids=input_ids) # [num_candidates, max_candidate_len, hidden_dim]
if args.set_similarity:
feats = feats.detach().cpu() # [num_candidates, max_candidate_len, hidden_dim]
else:
if args.candidate_pooling == 'cls' and pooled_feats is not None:
feats = embedded_feats.pooler_output.detach().cpu()
elif args.candidate_pooling == "mean":
feats = feats.sum(1)
feats = torch.nn.functional.layer_norm(feats, feats.size())
feats = feats.detach().cpu()
else:
feats = None
slots[slot] = (slot_emb, feats, is_requestable)
# Dump tensors for use in training
if save_to_file:
writer = os.path.join(args.output_dir, 'database', '%s.db' % set_type)
torch.save(slots, writer)
return slots
# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
# Code adapted from the TRADE preprocessing code (https://github.com/jasonwu0731/trade-dst)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MultiWOZ2.1/3 data processing utilities"""
import re
import os
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
from convlab2.dst.rule.multiwoz import normalize_value
# ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train']
ACTIVE_DOMAINS = ['attraction', 'hotel', 'restaurant', 'taxi', 'train', 'hospital', 'police']
def set_util_domains(domains):
global ACTIVE_DOMAINS
ACTIVE_DOMAINS = [d for d in domains if d in ACTIVE_DOMAINS]
MAPPING_PATH = os.path.abspath(__file__).replace('utils.py', 'mapping.pair')
# Read replacement pairs from the mapping.pair file
REPLACEMENTS = []
for line in open(MAPPING_PATH).readlines():
tok_from, tok_to = line.replace('\n', '').split('\t')
REPLACEMENTS.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
# Extract belief state from mturk annotations
def build_dialoguestate(metadata, get_domains=False):
domains_list = [dom for dom in ACTIVE_DOMAINS if dom in metadata]
dialogue_state, domains = [], []
for domain in domains_list:
active = False
# Extract booking information
booking = []
for slot in sorted(metadata[domain]['book'].keys()):
if slot != 'booked':
if metadata[domain]['book'][slot] == 'not mentioned':
continue
if metadata[domain]['book'][slot] != '':
val = ['%s-book %s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['book'][slot])]
dialogue_state.append(val)
active = True
for slot in metadata[domain]['semi']:
if metadata[domain]['semi'][slot] == 'not mentioned':
continue
elif metadata[domain]['semi'][slot] in ['dont care', 'dontcare', "don't care", 'don not care',
'do not care', 'does not care']:
dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), 'do not care'])
active = True
elif metadata[domain]['semi'][slot]:
dialogue_state.append(['%s-%s' % (domain, slot.strip().lower()), clean_text(metadata[domain]['semi'][slot])])
active = True
if active:
domains.append(domain)
if get_domains:
return domains
return clean_dialoguestate(dialogue_state)
PRICERANGE = ['do not care', 'cheap', 'moderate', 'expensive']
BOOLEAN = ['do not care', 'yes', 'no']
DAYS = ['do not care', 'monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saterday', 'sunday']
QUANTITIES = ['do not care', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
TIME = ['do not care'] + ['%02i:%02i' % t for l in TIME for t in l]
VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
def map_values(value):
for old, new in VALUE_MAP.items():
value = value.replace(old, new)
return value
def clean_dialoguestate(states, is_acts=False):
# path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
# path = os.path.join(path, 'data/multiwoz/value_dict.json')
# value_dict = json.load(open(path))
clean_state = []
for slot, value in states:
if 'pricerange' in slot:
d, s = slot.split('-', 1)
s = 'price range'
slot = f'{d}-{s}'
if value in PRICERANGE:
clean_state.append([slot, value])
elif True in [v in value for v in PRICERANGE]:
value = [v for v in PRICERANGE if v in value][0]
clean_state.append([slot, value])
elif value == '?' and is_acts:
clean_state.append([slot, value])
else:
continue
elif 'parking' in slot or 'internet' in slot:
if value in BOOLEAN:
clean_state.append([slot, value])
if value == 'free':
value = 'yes'
clean_state.append([slot, value])
elif True in [v in value for v in BOOLEAN]:
value = [v for v in BOOLEAN if v in value][0]
clean_state.append([slot, value])
elif value == '?' and is_acts:
clean_state.append([slot, value])
else:
continue
elif 'day' in slot:
if value in DAYS:
clean_state.append([slot, value])
elif True in [v in value for v in DAYS]:
value = [v for v in DAYS if v in value][0]
clean_state.append([slot, value])
else:
continue
elif 'people' in slot or 'duration' in slot or 'stay' in slot:
if value in QUANTITIES:
clean_state.append([slot, value])
elif True in [v in value for v in QUANTITIES]:
value = [v for v in QUANTITIES if v in value][0]
clean_state.append([slot, value])
elif value == '?' and is_acts:
clean_state.append([slot, value])
else:
try:
value = int(value)
if value >= 10:
value = '10 or more'
clean_state.append([slot, value])
else:
continue
except:
continue
elif 'time' in slot or 'leaveat' in slot or 'arriveby' in slot:
if 'leaveat' in slot:
d, s = slot.split('-', 1)
s = 'leave at'
slot = f'{d}-{s}'
if 'arriveby' in slot:
d, s = slot.split('-', 1)
s = 'arrive by'
slot = f'{d}-{s}'
if value in TIME:
if value == 'do not care':
clean_state.append([slot, value])
else:
h, m = value.split(':')
if int(m) % 5 == 0:
clean_state.append([slot, value])
else:
m = round(int(m) / 5) * 5
h = int(h)
if m == 60:
m = 0
h += 1
if h >= 24:
h -= 24
value = '%02i:%02i' % (h, m)
clean_state.append([slot, value])
elif True in [v in value for v in TIME]:
value = [v for v in TIME if v in value][0]
h, m = value.split(':')
if int(m) % 5 == 0:
clean_state.append([slot, value])
else:
m = round(int(m) / 5) * 5
h = int(h)
if m == 60:
m = 0
h += 1
if h >= 24:
h -= 24
value = '%02i:%02i' % (h, m)
clean_state.append([slot, value])
elif value == '?' and is_acts:
clean_state.append([slot, value])
else:
continue
elif 'stars' in slot:
if len(value) == 1 or value == 'do not care':
clean_state.append([slot, value])
elif value == '?' and is_acts:
clean_state.append([slot, value])
elif len(value) > 1:
try:
value = int(value[0])
value = str(value)
clean_state.append([slot, value])
except:
continue
elif 'area' in slot:
if '|' in value:
value = value.split('|', 1)[0]
clean_state.append([slot, value])
else:
if '|' in value:
value = value.split('|', 1)[0]
value = map_values(value)
# d, s = slot.split('-', 1)
# value = normalize_value(value_dict, d, s, value)
clean_state.append([slot, value])
return clean_state
# Module to process a dialogue and check its validity
def process_dialogue(dialogue, max_utt_len=128):
if len(dialogue['log']) % 2 != 0:
return None
# Extract user and system utterances
usr_utts, sys_utts = [], []
avg_len = sum(len(utt['text'].split(' ')) for utt in dialogue['log'])
avg_len = avg_len / len(dialogue['log'])
if avg_len > max_utt_len:
return None
# If the first term is a system turn then ignore dialogue
if dialogue['log'][0]['metadata']:
return None
usr, sys = None, None
for turn in dialogue['log']:
if not is_ascii(turn['text']):
return None
if not usr or not sys:
if len(turn['metadata']) == 0:
usr = turn
else:
sys = turn
if usr and sys:
states = build_dialoguestate(sys['metadata'], get_domains = False)
sys['dialogue_states'] = states
usr_utts.append(usr)
sys_utts.append(sys)
usr, sys = None, None
dial_clean = dict()
dial_clean['usr_log'] = usr_utts
dial_clean['sys_log'] = sys_utts
return dial_clean
# Get new domains
def get_act_domains(prev, crnt):
diff = {}
if not prev or not crnt:
return diff
for ((prev_dom, prev_val), (crnt_dom, crnt_val)) in zip(prev.items(), crnt.items()):
assert prev_dom == crnt_dom
if prev_val != crnt_val:
diff[crnt_dom] = crnt_val
return diff
# Get current domains
def get_domains(dial_log, turn_id, prev_domain):
if turn_id == 1:
active = build_dialoguestate(dial_log[turn_id]['metadata'], get_domains=True)
acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
active += acts
crnt = active[0] if active else ''
else:
active = get_act_domains(dial_log[turn_id - 2]['metadata'], dial_log[turn_id]['metadata'])
active = list(active.keys())
acts = format_acts(dial_log[turn_id].get('dialog_act', {})) if not active else []
acts = [domain for intent, domain, slot, value in acts if domain not in ['', 'general']]
active += acts
crnt = [prev_domain] if not active else active
crnt = crnt[0]
return crnt
# Function to extract dialogue info from data
def extract_dialogue(dialogue, max_utt_len=50):
dialogue = process_dialogue(dialogue, max_utt_len)
if not dialogue:
return None
usr_utts = [turn['text'] for turn in dialogue['usr_log']]
sys_utts = [turn['text'] for turn in dialogue['sys_log']]
# sys_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['sys_log']]
usr_acts = [format_acts(turn['dialog_act']) if 'dialog_act' in turn else [] for turn in dialogue['usr_log']]
dialogue_states = [turn['dialogue_states'] for turn in dialogue['sys_log']]
domains = [turn['domain'] for turn in dialogue['usr_log']]
# dial = [{'usr': u,'sys': s, 'usr_a': ua, 'sys_a': a, 'domain': d, 'ds': v}
# for u, s, ua, a, d, v in zip(usr_utts, sys_utts, usr_acts, sys_acts, domains, dialogue_states)]
dial = [{'usr': u,'sys': s, 'usr_a': ua, 'domain': d, 'ds': v}
for u, s, ua, d, v in zip(usr_utts, sys_utts, usr_acts, domains, dialogue_states)]
return dial
def format_acts(acts):
new_acts = []
for key, item in acts.items():
domain, intent = key.split('-', 1)
if domain.lower() in ACTIVE_DOMAINS + ['general']:
state = []
for slot, value in item:
slot = str(REF_SYS_DA[domain].get(slot, slot)).lower() if domain in REF_SYS_DA else slot
value = clean_text(value)
slot = slot.replace('_', ' ').replace('ref', 'reference')
state.append([f'{domain.lower()}-{slot}', value])
state = clean_dialoguestate(state, is_acts=True)
if domain == 'general':
if intent in ['thank', 'bye']:
state = [['general-none', 'none']]
else:
state = []
for slot, value in state:
if slot not in ['train-people']:
slot = slot.split('-', 1)[-1]
new_acts.append([intent.lower(), domain.lower(), slot, value])
return new_acts
# Fix act labels
def fix_delexicalisation(turn):
if 'dialog_act' in turn:
for dom, act in turn['dialog_act'].items():
if 'Attraction' in dom:
if 'restaurant_' in turn['text']:
turn['text'] = turn['text'].replace("restaurant", "attraction")
if 'hotel_' in turn['text']:
turn['text'] = turn['text'].replace("hotel", "attraction")
if 'Hotel' in dom:
if 'attraction_' in turn['text']:
turn['text'] = turn['text'].replace("attraction", "hotel")
if 'restaurant_' in turn['text']:
turn['text'] = turn['text'].replace("restaurant", "hotel")
if 'Restaurant' in dom:
if 'attraction_' in turn['text']:
turn['text'] = turn['text'].replace("attraction", "restaurant")
if 'hotel_' in turn['text']:
turn['text'] = turn['text'].replace("hotel", "restaurant")
return turn
# Check if a character is an ascii character
def is_ascii(s):
return all(ord(c) < 128 for c in s)
# Insert white space
def separate_token(token, text):
sidx = 0
while True:
# Find next instance of token
sidx = text.find(token, sidx)
if sidx == -1:
break
# If the token is already seperated continue to next
if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
re.match('[0-9]', text[sidx + 1]):
sidx += 1
continue
# Create white space separation around token
if text[sidx - 1] != ' ':
text = text[:sidx] + ' ' + text[sidx:]
sidx += 1
if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
text = text[:sidx + 1] + ' ' + text[sidx + 1:]
sidx += 1
return text
def clean_text(text):
# Replace white spaces in front and end
text = re.sub(r'^\s*|\s*$', '', text.strip().lower())
# Replace b&v or 'b and b' with 'bed and breakfast'
text = re.sub(r"b&b", "bed and breakfast", text)
text = re.sub(r"b and b", "bed and breakfast", text)
# Fix apostrophies
text = re.sub(u"(\u2018|\u2019)", "'", text)
# Correct punctuation
text = text.replace(';', ',')
text = re.sub('$\/', '', text)
text = text.replace('/', ' and ')
# Replace special characters
text = text.replace('-', ' ')
text = re.sub('[\"\<>@\(\)]', '', text)
# Insert white space around special tokens:
for token in ['?', '.', ',', '!']:
text = separate_token(token, text)
# insert white space for 's
text = separate_token('\'s', text)
# replace it's, does't, you'd ... etc
text = re.sub('^\'', '', text)
text = re.sub('\'$', '', text)
text = re.sub('\'\s', ' ', text)
text = re.sub('\s\'', ' ', text)
# Perform pair replacements listed in the mapping.pair file
for fromx, tox in REPLACEMENTS:
text = ' ' + text + ' '
text = text.replace(fromx, tox)[1:-1]
# Remove multiple spaces
text = re.sub(' +', ' ', text)
# Concatenate numbers eg '1 3' -> '13'
tokens = text.split()
i = 1
while i < len(tokens):
if re.match(u'^\d+$', tokens[i]) and \
re.match(u'\d+$', tokens[i - 1]):
tokens[i - 1] += tokens[i]
del tokens[i]
else:
i += 1
text = ' '.join(tokens)
return text
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from convlab2.dst.setsumbt.unified_format_data.dataset import unified_format, ontology
from convlab2.dst.setsumbt.unified_format_data.Tracker import SetSUMBTTracker
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment