Skip to content
Snippets Groups Projects
Commit 1cff01b4 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Merge New SetSUMBT Code

parent bfb58117
Branches
No related tags found
No related merge requests found
Showing with 286 additions and 32 deletions
......@@ -7,7 +7,6 @@ def evaluate(predict_result):
metrics = {'TP':0, 'FP':0, 'FN':0}
acc = []
for sample in predict_result:
pred_state = sample['predictions']['state']
gold_state = sample['state']
......
{
"model_type": "SetSUMBT",
"dataset": "multiwoz21+sgd+tm1+tm2+tm3",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 8,
"test_batch_size": 8,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "multiwoz21",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 16,
"test_batch_size": 16,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "sgd+tm1+tm2+tm3",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 12,
"test_batch_size": 12,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "sgd",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 6,
"test_batch_size": 3,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "tm1+tm2+tm3",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 8,
"test_batch_size": 8,
"run_nbt": true
}
\ No newline at end of file
......@@ -27,7 +27,7 @@ from convlab.util import load_dataset
from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
get_values_from_data, ontology_add_requestable_slots,
get_requestable_slots, load_dst_data, extract_dialogues,
combine_value_sets)
combine_value_sets, IdTensor)
transformers.logging.set_verbosity_error()
......@@ -77,6 +77,11 @@ def convert_examples_to_features(data: list,
del dial_feats
# Perform turn level padding
dial_ids = list()
for dial in data:
_ids = [turn['dialogue_id'] for turn in dial][:max_turns]
_ids += [''] * (max_turns - len(_ids))
dial_ids.append(_ids)
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]:
......@@ -92,6 +97,7 @@ def convert_examples_to_features(data: list,
del input_feats
# Create torch data tensors
features['dialogue_ids'] = IdTensor(dial_ids)
features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
......
......@@ -15,6 +15,9 @@
# limitations under the License.
"""Convlab3 Unified dataset data processing utilities"""
import numpy
import pdb
from convlab.util import load_ontology, load_dst_data, load_nlu_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
......@@ -66,7 +69,9 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
data = load_dst_data(dataset, data_split='all', speaker='user')
# Remove test data from the data when building training/validation ontology
if data_split in ['train', 'validation']:
if data_split == 'train':
data = {key: itm for key, itm in data.items() if key == 'train'}
elif data_split == 'validation':
data = {key: itm for key, itm in data.items() if key in ['train', 'validation']}
value_sets = {}
......@@ -74,13 +79,14 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
for turn in dataset:
for domain, substate in turn['state'].items():
domain_name = DOMAINS_MAP.get(domain, domain.lower())
if domain not in value_sets:
if domain_name not in value_sets:
value_sets[domain_name] = {}
for slot, value in substate.items():
if slot not in value_sets[domain_name]:
value_sets[domain_name][slot] = []
if value and value not in value_sets[domain_name][slot]:
value_sets[domain_name][slot].append(value)
# pdb.set_trace()
return clean_values(value_sets)
......@@ -163,6 +169,9 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict, data_split: str
if data_split in ['train', 'validation']:
if domain not in value_sets:
continue
possible_values = [v for slot, vals in value_sets[domain].items() for v in vals]
if len(possible_values) == 0:
continue
ontology[domain] = {}
for slot in sorted(ontology_slots[domain]):
if not ontology_slots[domain][slot]['possible_values']:
......@@ -228,12 +237,13 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict
return ontology_slots
def extract_turns(dialogue: list, dataset_name: str) -> list:
def extract_turns(dialogue: list, dataset_name: str, dialogue_id: str) -> list:
"""
Extract the required information from the data provided by unified loader
Args:
dialogue (list): List of turns within a dialogue
dataset_name (str): Name of the dataset to which the dialogue belongs
dialogue_str (str): ID of the dialogue
Returns:
turns (list): List of turns within a dialogue
......@@ -261,6 +271,7 @@ def extract_turns(dialogue: list, dataset_name: str) -> list:
turn_info['state'] = turn['state']
turn_info['dataset_name'] = dataset_name
turn_info['dialogue_id'] = dialogue_id
if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
turns.append(turn_info)
......@@ -399,6 +410,17 @@ def get_active_domains(turns: list) -> list:
return turns
class IdTensor:
def __init__(self, values):
self.values = numpy.array(values)
def __getitem__(self, index: int):
return self.values[index].tolist()
def to(self, device):
return self
def extract_dialogues(data: list, dataset_name: str) -> list:
"""
Extract all dialogues from dataset
......@@ -411,7 +433,8 @@ def extract_dialogues(data: list, dataset_name: str) -> list:
"""
dialogues = []
for dial in data:
turns = extract_turns(dial['turns'], dataset_name)
dial_id = dial['dialogue_id']
turns = extract_turns(dial['turns'], dataset_name, dial_id)
turns = clean_states(turns)
turns = get_active_domains(turns)
dialogues.append(turns)
......
......@@ -65,13 +65,15 @@ def main(args=None, config=None):
paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path)
config = ConfigClass.from_pretrained(args.model_name_or_path,
local_files_only=args.transformers_local_files_only)
else:
paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
if paths:
paths = paths[0]
args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path)
config = ConfigClass.from_pretrained(args.model_name_or_path,
local_files_only=args.transformers_local_files_only)
args = update_args(args, config)
......@@ -102,12 +104,15 @@ def main(args=None, config=None):
# Initialise Model
transformers.utils.logging.set_verbosity_info()
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config)
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config,
local_files_only=args.transformers_local_files_only)
model = model.to(device)
# Create Tokenizer and embedding model for Data Loaders and ontology
encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name)
tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config)
encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name,
local_files_only=args.transformers_local_files_only)
tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config,
local_files_only=args.transformers_local_files_only)
# Set up model training/evaluation
training.set_logger(logger, tb_writer)
......
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os
from tqdm import tqdm
from convlab.util import load_dataset
from convlab.util import load_dst_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
def extract_data(dataset_names: str) -> list:
dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')]
data = []
for dataset_dict in dataset_dicts:
dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False)
for dial in dataset['test']:
data.append(dial)
return data
def clean_state(state):
clean_state = dict()
for domain, subset in state.items():
clean_state[domain] = {}
for slot, value in subset.items():
# Remove pipe separated values
value = value.split('|')
# Map values using value_map
for old, new in VALUE_MAP.items():
value = [val.replace(old, new) for val in value]
value = '|'.join(value)
# Map dontcare to "do not care" and empty to 'none'
value = value.replace('dontcare', 'do not care')
value = value if value else 'none'
# Map quantity values to the integer quantity value
if 'people' in slot or 'duration' in slot or 'stay' in slot:
try:
if value not in ['do not care', 'none']:
value = int(value)
value = str(value) if value < 10 else QUANTITIES[-1]
except:
value = value
# Map time values to the most appropriate value in the standard time set
elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
try:
if value not in ['do not care', 'none']:
# Strip after/before from time value
value = value.replace('after ', '').replace('before ', '')
# Extract hours and minutes from different possible formats
if ':' not in value and len(value) == 4:
h, m = value[:2], value[2:]
elif len(value) == 1:
h = int(value)
m = 0
elif 'pm' in value:
h = int(value.replace('pm', '')) + 12
m = 0
elif 'am' in value:
h = int(value.replace('pm', ''))
m = 0
elif ':' in value:
h, m = value.split(':')
elif ';' in value:
h, m = value.split(';')
# Map to closest 5 minutes
if int(m) % 5 != 0:
m = round(int(m) / 5) * 5
h = int(h)
if m == 60:
m = 0
h += 1
if h >= 24:
h -= 24
# Set in standard 24 hour format
h, m = int(h), int(m)
value = '%02i:%02i' % (h, m)
except:
value = value
# Map boolean slots to yes/no value
elif 'parking' in slot or 'internet' in slot:
if value not in ['do not care', 'none']:
if value == 'free':
value = 'yes'
elif True in [v in value.lower() for v in ['yes', 'no']]:
value = [v for v in ['yes', 'no'] if v in value][0]
value = value if value != 'none' else ''
clean_state[domain][slot] = value
return clean_state
def extract_states(data):
states_data = {}
for dial in data:
states = []
for turn in dial['turns']:
if 'state' in turn:
state = clean_state(turn['state'])
states.append(state)
states_data[dial['dialogue_id']] = states
return states_data
def get_golden_state(prediction, data):
state = data[prediction['dial_idx']][prediction['utt_idx']]
pred = prediction['predictions']['state']
pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '')
for slot in state[domain]} for domain in state}
prediction['state'] = state
prediction['predictions']['state'] = pred
return prediction
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
parser.add_argument('--model_path', type=str, help='Path to model dir')
args = parser.parse_args()
data = extract_data(args.dataset_name)
data = extract_states(data)
reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r')
predictions = json.load(reader)
reader.close()
predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)]
writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w')
json.dump(predictions, writer)
writer.close()
......@@ -590,8 +590,7 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
for sample in eval_output_batch:
dom, slt = slot.split('-', 1)
lab = state_labels[sample['dial_idx']][sample['utt_idx']].item()
if lab != -1:
lab = ontology[dom][slt]['possible_values'][lab]
lab = ontology[dom][slt]['possible_values'][lab] if lab != -1 else 'NOT_IN_ONTOLOGY'
pred = prediction[sample['dial_idx']][sample['utt_idx']].item()
pred = ontology[dom][slt]['possible_values'][pred]
......@@ -615,7 +614,9 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
num_inform_slots += (state_labels != -1).float().reshape(-1)
if return_eval_output:
evaluation_output += deepcopy(eval_output_batch)
for sample in eval_output_batch:
sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']]
evaluation_output.append(deepcopy(sample))
eval_output_batch = []
if model.config.predict_actions:
......@@ -708,16 +709,6 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
req_f1, dom_f1, gen_f1 = None, None, None
if return_eval_output:
dial_idx = 0
for sample in evaluation_output:
if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0:
dial_idx = 0
elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0:
dial_idx += 1
elif sample['utt_idx'] == 0:
dial_idx += 1
sample['dial_idx'] = dial_idx
return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output
if is_train:
return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats
......
......@@ -16,6 +16,7 @@
"""SetSUMBT utils"""
import os
import json
import shutil
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from datetime import datetime
......@@ -27,6 +28,9 @@ def get_args(base_models: dict):
# Get arguments
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
# Config file usage
parser.add_argument('--starting_config_name', default=None, type=str)
# Optional
parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='')
parser.add_argument('--logging_path', help='Path for log file', default='')
......@@ -54,6 +58,8 @@ def get_args(base_models: dict):
parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None)
parser.add_argument('--candidate_embedding_model_name', default=None,
help='Name of the pretrained candidate embedding model.')
parser.add_argument('--transformers_local_files_only', help='Use local files only for huggingface transformers',
action='store_true')
# Architecture
parser.add_argument('--freeze_encoder', help='No training performed on the turn encoder Bert Model',
......@@ -143,6 +149,12 @@ def get_args(base_models: dict):
parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
args = parser.parse_args()
if args.starting_config_name:
args = get_starting_config(args)
if args.do_train:
args.do_eval = True
# Simplify args
args.set_similarity = not args.no_set_similarity
args.use_descriptions = not args.no_descriptions
......@@ -217,6 +229,31 @@ def get_args(base_models: dict):
return args, config
def get_starting_config(args):
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'configs', f"{args.starting_config_name}.json")
reader = open(path, 'r')
config = json.load(reader)
reader.close()
if "model_type" in config:
if config["model_type"].lower() == 'setsumbt':
config["model_type"] = 'roberta'
config["no_set_similarity"] = False
config["no_descriptions"] = False
elif config["model_type"].lower() == 'sumbt':
config["model_type"] = 'bert'
config["no_set_similarity"] = True
config["no_descriptions"] = False
variables = vars(args).keys()
for key, value in config.items():
if key in variables:
setattr(args, key, value)
return args
def get_git_info():
repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True)
branch_name = repo.active_branch.name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment