diff --git a/convlab/dst/evaluate_unified_datasets.py b/convlab/dst/evaluate_unified_datasets.py index d4e0720dc02bf90efcfebb1780630211f0722f7f..d4b3ad2464be6c8de0d6f1f51ecdf1bf6bfdbd0e 100644 --- a/convlab/dst/evaluate_unified_datasets.py +++ b/convlab/dst/evaluate_unified_datasets.py @@ -7,7 +7,6 @@ def evaluate(predict_result): metrics = {'TP':0, 'FP':0, 'FN':0} acc = [] - for sample in predict_result: pred_state = sample['predictions']['state'] gold_state = sample['state'] @@ -37,7 +36,7 @@ def evaluate(predict_result): flag = False acc.append(flag) - + TP = metrics.pop('TP') FP = metrics.pop('FP') FN = metrics.pop('FN') diff --git a/convlab/dst/setsumbt/configs/setsumbt_multitask.json b/convlab/dst/setsumbt/configs/setsumbt_multitask.json new file mode 100644 index 0000000000000000000000000000000000000000..c076a557cb3e1d567784c70559fb1922fe05c545 --- /dev/null +++ b/convlab/dst/setsumbt/configs/setsumbt_multitask.json @@ -0,0 +1,11 @@ +{ + "model_type": "SetSUMBT", + "dataset": "multiwoz21+sgd+tm1+tm2+tm3", + "no_action_prediction": true, + "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base", + "transformers_local_files_only": true, + "train_batch_size": 3, + "dev_batch_size": 8, + "test_batch_size": 8, + "run_nbt": true +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json new file mode 100644 index 0000000000000000000000000000000000000000..0bff751c16f0bdcdf61f04ce33d616370c0d32d8 --- /dev/null +++ b/convlab/dst/setsumbt/configs/setsumbt_multiwoz21.json @@ -0,0 +1,11 @@ +{ + "model_type": "SetSUMBT", + "dataset": "multiwoz21", + "no_action_prediction": true, + "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base", + "transformers_local_files_only": true, + "train_batch_size": 3, + "dev_batch_size": 16, + "test_batch_size": 16, + "run_nbt": true +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/configs/setsumbt_pretrain.json b/convlab/dst/setsumbt/configs/setsumbt_pretrain.json new file mode 100644 index 0000000000000000000000000000000000000000..fdc22d157840e7494b0266d0bd99f8a99d242969 --- /dev/null +++ b/convlab/dst/setsumbt/configs/setsumbt_pretrain.json @@ -0,0 +1,11 @@ +{ + "model_type": "SetSUMBT", + "dataset": "sgd+tm1+tm2+tm3", + "no_action_prediction": true, + "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base", + "transformers_local_files_only": true, + "train_batch_size": 3, + "dev_batch_size": 12, + "test_batch_size": 12, + "run_nbt": true +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/configs/setsumbt_sgd.json b/convlab/dst/setsumbt/configs/setsumbt_sgd.json new file mode 100644 index 0000000000000000000000000000000000000000..97f5818334af4c7984ec24448861b627315820e3 --- /dev/null +++ b/convlab/dst/setsumbt/configs/setsumbt_sgd.json @@ -0,0 +1,11 @@ +{ + "model_type": "SetSUMBT", + "dataset": "sgd", + "no_action_prediction": true, + "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base", + "transformers_local_files_only": true, + "train_batch_size": 3, + "dev_batch_size": 6, + "test_batch_size": 3, + "run_nbt": true +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/configs/setsumbt_tm.json b/convlab/dst/setsumbt/configs/setsumbt_tm.json new file mode 100644 index 0000000000000000000000000000000000000000..138f84c358067389d5f7b478ae94c3eb2aa90ea3 --- /dev/null +++ b/convlab/dst/setsumbt/configs/setsumbt_tm.json @@ -0,0 +1,11 @@ +{ + "model_type": "SetSUMBT", + "dataset": "tm1+tm2+tm3", + "no_action_prediction": true, + "model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base", + "transformers_local_files_only": true, + "train_batch_size": 3, + "dev_batch_size": 8, + "test_batch_size": 8, + "run_nbt": true +} \ No newline at end of file diff --git a/convlab/dst/setsumbt/dataset/unified_format.py b/convlab/dst/setsumbt/dataset/unified_format.py index ca5793f6a42d01a9eef9769f363033aa798e16b4..1c3a68c3b2e627ac60f555a642dfa837734249b6 100644 --- a/convlab/dst/setsumbt/dataset/unified_format.py +++ b/convlab/dst/setsumbt/dataset/unified_format.py @@ -27,7 +27,7 @@ from convlab.util import load_dataset from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, get_values_from_data, ontology_add_requestable_slots, get_requestable_slots, load_dst_data, extract_dialogues, - combine_value_sets) + combine_value_sets, IdTensor) transformers.logging.set_verbosity_error() @@ -77,6 +77,11 @@ def convert_examples_to_features(data: list, del dial_feats # Perform turn level padding + dial_ids = list() + for dial in data: + _ids = [turn['dialogue_id'] for turn in dial][:max_turns] + _ids += [''] * (max_turns - len(_ids)) + dial_ids.append(_ids) input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] if 'token_type_ids' in input_feats[0][0]: @@ -92,6 +97,7 @@ def convert_examples_to_features(data: list, del input_feats # Create torch data tensors + features['dialogue_ids'] = IdTensor(dial_ids) features['input_ids'] = torch.tensor(input_ids) features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None diff --git a/convlab/dst/setsumbt/dataset/utils.py b/convlab/dst/setsumbt/dataset/utils.py index 1b601f027b8d8b02df7423e8b3d5fc351deca724..96773d6b9b181925b3004e4971e440d9c7720bfb 100644 --- a/convlab/dst/setsumbt/dataset/utils.py +++ b/convlab/dst/setsumbt/dataset/utils.py @@ -15,6 +15,9 @@ # limitations under the License. """Convlab3 Unified dataset data processing utilities""" +import numpy +import pdb + from convlab.util import load_ontology, load_dst_data, load_nlu_data from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME @@ -66,7 +69,9 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict: data = load_dst_data(dataset, data_split='all', speaker='user') # Remove test data from the data when building training/validation ontology - if data_split in ['train', 'validation']: + if data_split == 'train': + data = {key: itm for key, itm in data.items() if key == 'train'} + elif data_split == 'validation': data = {key: itm for key, itm in data.items() if key in ['train', 'validation']} value_sets = {} @@ -74,13 +79,14 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict: for turn in dataset: for domain, substate in turn['state'].items(): domain_name = DOMAINS_MAP.get(domain, domain.lower()) - if domain not in value_sets: + if domain_name not in value_sets: value_sets[domain_name] = {} for slot, value in substate.items(): if slot not in value_sets[domain_name]: value_sets[domain_name][slot] = [] if value and value not in value_sets[domain_name][slot]: value_sets[domain_name][slot].append(value) + # pdb.set_trace() return clean_values(value_sets) @@ -163,6 +169,9 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict, data_split: str if data_split in ['train', 'validation']: if domain not in value_sets: continue + possible_values = [v for slot, vals in value_sets[domain].items() for v in vals] + if len(possible_values) == 0: + continue ontology[domain] = {} for slot in sorted(ontology_slots[domain]): if not ontology_slots[domain][slot]['possible_values']: @@ -228,12 +237,13 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict return ontology_slots -def extract_turns(dialogue: list, dataset_name: str) -> list: +def extract_turns(dialogue: list, dataset_name: str, dialogue_id: str) -> list: """ Extract the required information from the data provided by unified loader Args: dialogue (list): List of turns within a dialogue dataset_name (str): Name of the dataset to which the dialogue belongs + dialogue_str (str): ID of the dialogue Returns: turns (list): List of turns within a dialogue @@ -261,6 +271,7 @@ def extract_turns(dialogue: list, dataset_name: str) -> list: turn_info['state'] = turn['state'] turn_info['dataset_name'] = dataset_name + turn_info['dialogue_id'] = dialogue_id if 'system_utterance' in turn_info and 'user_utterance' in turn_info: turns.append(turn_info) @@ -399,6 +410,17 @@ def get_active_domains(turns: list) -> list: return turns +class IdTensor: + def __init__(self, values): + self.values = numpy.array(values) + + def __getitem__(self, index: int): + return self.values[index].tolist() + + def to(self, device): + return self + + def extract_dialogues(data: list, dataset_name: str) -> list: """ Extract all dialogues from dataset @@ -411,7 +433,8 @@ def extract_dialogues(data: list, dataset_name: str) -> list: """ dialogues = [] for dial in data: - turns = extract_turns(dial['turns'], dataset_name) + dial_id = dial['dialogue_id'] + turns = extract_turns(dial['turns'], dataset_name, dial_id) turns = clean_states(turns) turns = get_active_domains(turns) dialogues.append(turns) diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index 276d13f2eda393ed7c0af45d546a9e27e26bce2f..ea099442ddd18d0cd36a79db13b1f47788eb4fd4 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -65,13 +65,15 @@ def main(args=None, config=None): paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] if 'pytorch_model.bin' in paths and 'config.json' in paths: args.model_name_or_path = args.output_dir - config = ConfigClass.from_pretrained(args.model_name_or_path) + config = ConfigClass.from_pretrained(args.model_name_or_path, + local_files_only=args.transformers_local_files_only) else: paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] if paths: paths = paths[0] args.model_name_or_path = paths - config = ConfigClass.from_pretrained(args.model_name_or_path) + config = ConfigClass.from_pretrained(args.model_name_or_path, + local_files_only=args.transformers_local_files_only) args = update_args(args, config) @@ -102,12 +104,15 @@ def main(args=None, config=None): # Initialise Model transformers.utils.logging.set_verbosity_info() - model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) + model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config, + local_files_only=args.transformers_local_files_only) model = model.to(device) # Create Tokenizer and embedding model for Data Loaders and ontology - encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) - tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config) + encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name, + local_files_only=args.transformers_local_files_only) + tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config, + local_files_only=args.transformers_local_files_only) # Set up model training/evaluation training.set_logger(logger, tb_writer) diff --git a/convlab/dst/setsumbt/get_golden_labels.py b/convlab/dst/setsumbt/get_golden_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb2841d0d503181119c791a7046fd7e0025d236 --- /dev/null +++ b/convlab/dst/setsumbt/get_golden_labels.py @@ -0,0 +1,138 @@ +import json +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import os + +from tqdm import tqdm + +from convlab.util import load_dataset +from convlab.util import load_dst_data +from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME + + +def extract_data(dataset_names: str) -> list: + dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')] + data = [] + for dataset_dict in dataset_dicts: + dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False) + for dial in dataset['test']: + data.append(dial) + + return data + +def clean_state(state): + clean_state = dict() + for domain, subset in state.items(): + clean_state[domain] = {} + for slot, value in subset.items(): + # Remove pipe separated values + value = value.split('|') + + # Map values using value_map + for old, new in VALUE_MAP.items(): + value = [val.replace(old, new) for val in value] + value = '|'.join(value) + + # Map dontcare to "do not care" and empty to 'none' + value = value.replace('dontcare', 'do not care') + value = value if value else 'none' + + # Map quantity values to the integer quantity value + if 'people' in slot or 'duration' in slot or 'stay' in slot: + try: + if value not in ['do not care', 'none']: + value = int(value) + value = str(value) if value < 10 else QUANTITIES[-1] + except: + value = value + # Map time values to the most appropriate value in the standard time set + elif 'time' in slot or 'leave' in slot or 'arrive' in slot: + try: + if value not in ['do not care', 'none']: + # Strip after/before from time value + value = value.replace('after ', '').replace('before ', '') + # Extract hours and minutes from different possible formats + if ':' not in value and len(value) == 4: + h, m = value[:2], value[2:] + elif len(value) == 1: + h = int(value) + m = 0 + elif 'pm' in value: + h = int(value.replace('pm', '')) + 12 + m = 0 + elif 'am' in value: + h = int(value.replace('pm', '')) + m = 0 + elif ':' in value: + h, m = value.split(':') + elif ';' in value: + h, m = value.split(';') + # Map to closest 5 minutes + if int(m) % 5 != 0: + m = round(int(m) / 5) * 5 + h = int(h) + if m == 60: + m = 0 + h += 1 + if h >= 24: + h -= 24 + # Set in standard 24 hour format + h, m = int(h), int(m) + value = '%02i:%02i' % (h, m) + except: + value = value + # Map boolean slots to yes/no value + elif 'parking' in slot or 'internet' in slot: + if value not in ['do not care', 'none']: + if value == 'free': + value = 'yes' + elif True in [v in value.lower() for v in ['yes', 'no']]: + value = [v for v in ['yes', 'no'] if v in value][0] + + value = value if value != 'none' else '' + + clean_state[domain][slot] = value + + return clean_state + +def extract_states(data): + states_data = {} + for dial in data: + states = [] + for turn in dial['turns']: + if 'state' in turn: + state = clean_state(turn['state']) + states.append(state) + states_data[dial['dialogue_id']] = states + + return states_data + + +def get_golden_state(prediction, data): + state = data[prediction['dial_idx']][prediction['utt_idx']] + pred = prediction['predictions']['state'] + pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '') + for slot in state[domain]} for domain in state} + prediction['state'] = state + prediction['predictions']['state'] = pred + + return prediction + + +if __name__ == "__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21") + parser.add_argument('--model_path', type=str, help='Path to model dir') + args = parser.parse_args() + + data = extract_data(args.dataset_name) + data = extract_states(data) + + reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r') + predictions = json.load(reader) + reader.close() + + predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)] + + writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w') + json.dump(predictions, writer) + writer.close() diff --git a/convlab/dst/setsumbt/modeling/training.py b/convlab/dst/setsumbt/modeling/training.py index 77f41dc3d0104768d531e1cd77f61496faad4146..590b2ac7372b26262625d08691a8528ffddd82d2 100644 --- a/convlab/dst/setsumbt/modeling/training.py +++ b/convlab/dst/setsumbt/modeling/training.py @@ -590,17 +590,16 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train for sample in eval_output_batch: dom, slt = slot.split('-', 1) lab = state_labels[sample['dial_idx']][sample['utt_idx']].item() - if lab != -1: - lab = ontology[dom][slt]['possible_values'][lab] - pred = prediction[sample['dial_idx']][sample['utt_idx']].item() - pred = ontology[dom][slt]['possible_values'][pred] + lab = ontology[dom][slt]['possible_values'][lab] if lab != -1 else 'NOT_IN_ONTOLOGY' + pred = prediction[sample['dial_idx']][sample['utt_idx']].item() + pred = ontology[dom][slt]['possible_values'][pred] - if dom not in sample['state']: - sample['state'][dom] = dict() - sample['predictions']['state'][dom] = dict() + if dom not in sample['state']: + sample['state'][dom] = dict() + sample['predictions']['state'][dom] = dict() - sample['state'][dom][slt] = lab if lab != 'none' else '' - sample['predictions']['state'][dom][slt] = pred if pred != 'none' else '' + sample['state'][dom][slt] = lab if lab != 'none' else '' + sample['predictions']['state'][dom][slt] = pred if pred != 'none' else '' if args.temp_scaling > 0.0: p_ = torch.log(p_ + 1e-10) / args.temp_scaling @@ -615,7 +614,9 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train num_inform_slots += (state_labels != -1).float().reshape(-1) if return_eval_output: - evaluation_output += deepcopy(eval_output_batch) + for sample in eval_output_batch: + sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']] + evaluation_output.append(deepcopy(sample)) eval_output_batch = [] if model.config.predict_actions: @@ -708,16 +709,6 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train req_f1, dom_f1, gen_f1 = None, None, None if return_eval_output: - dial_idx = 0 - for sample in evaluation_output: - if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0: - dial_idx = 0 - elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0: - dial_idx += 1 - elif sample['utt_idx'] == 0: - dial_idx += 1 - sample['dial_idx'] = dial_idx - return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output if is_train: return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats diff --git a/convlab/dst/setsumbt/utils.py b/convlab/dst/setsumbt/utils.py index 5183955254140c41dc164befb302f25ad6985a6a..ff374116a3f8e88e6219fdc8b134d40b0bee7caf 100644 --- a/convlab/dst/setsumbt/utils.py +++ b/convlab/dst/setsumbt/utils.py @@ -16,6 +16,7 @@ """SetSUMBT utils""" import os +import json import shutil from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from datetime import datetime @@ -27,6 +28,9 @@ def get_args(base_models: dict): # Get arguments parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + # Config file usage + parser.add_argument('--starting_config_name', default=None, type=str) + # Optional parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='') parser.add_argument('--logging_path', help='Path for log file', default='') @@ -54,6 +58,8 @@ def get_args(base_models: dict): parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None) parser.add_argument('--candidate_embedding_model_name', default=None, help='Name of the pretrained candidate embedding model.') + parser.add_argument('--transformers_local_files_only', help='Use local files only for huggingface transformers', + action='store_true') # Architecture parser.add_argument('--freeze_encoder', help='No training performed on the turn encoder Bert Model', @@ -143,6 +149,12 @@ def get_args(base_models: dict): parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true') args = parser.parse_args() + if args.starting_config_name: + args = get_starting_config(args) + + if args.do_train: + args.do_eval = True + # Simplify args args.set_similarity = not args.no_set_similarity args.use_descriptions = not args.no_descriptions @@ -217,6 +229,31 @@ def get_args(base_models: dict): return args, config +def get_starting_config(args): + path = os.path.dirname(os.path.realpath(__file__)) + path = os.path.join(path, 'configs', f"{args.starting_config_name}.json") + reader = open(path, 'r') + config = json.load(reader) + reader.close() + + if "model_type" in config: + if config["model_type"].lower() == 'setsumbt': + config["model_type"] = 'roberta' + config["no_set_similarity"] = False + config["no_descriptions"] = False + elif config["model_type"].lower() == 'sumbt': + config["model_type"] = 'bert' + config["no_set_similarity"] = True + config["no_descriptions"] = False + + variables = vars(args).keys() + for key, value in config.items(): + if key in variables: + setattr(args, key, value) + + return args + + def get_git_info(): repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) branch_name = repo.active_branch.name