Skip to content
Snippets Groups Projects
Commit d7ba8e0f authored by Carel van Niekerk's avatar Carel van Niekerk :desktop:
Browse files

Save dialogue ids in prediction file

parent 396b9598
No related branches found
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ from convlab.util import load_dataset ...@@ -27,7 +27,7 @@ from convlab.util import load_dataset
from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
get_values_from_data, ontology_add_requestable_slots, get_values_from_data, ontology_add_requestable_slots,
get_requestable_slots, load_dst_data, extract_dialogues, get_requestable_slots, load_dst_data, extract_dialogues,
combine_value_sets) combine_value_sets, IdTensor)
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
...@@ -77,6 +77,11 @@ def convert_examples_to_features(data: list, ...@@ -77,6 +77,11 @@ def convert_examples_to_features(data: list,
del dial_feats del dial_feats
# Perform turn level padding # Perform turn level padding
dial_ids = list()
for dial in data:
_ids = [turn['dialogue_id'] for turn in dial][:max_turns]
_ids += [''] * (max_turns - len(_ids))
dial_ids.append(_ids)
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats] for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]: if 'token_type_ids' in input_feats[0][0]:
...@@ -92,6 +97,7 @@ def convert_examples_to_features(data: list, ...@@ -92,6 +97,7 @@ def convert_examples_to_features(data: list,
del input_feats del input_feats
# Create torch data tensors # Create torch data tensors
features['dialogue_ids'] = IdTensor(dial_ids)
features['input_ids'] = torch.tensor(input_ids) features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
"""Convlab3 Unified dataset data processing utilities""" """Convlab3 Unified dataset data processing utilities"""
import numpy
from convlab.util import load_ontology, load_dst_data, load_nlu_data from convlab.util import load_ontology, load_dst_data, load_nlu_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
...@@ -228,12 +230,13 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict ...@@ -228,12 +230,13 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict
return ontology_slots return ontology_slots
def extract_turns(dialogue: list, dataset_name: str) -> list: def extract_turns(dialogue: list, dataset_name: str, dialogue_id: str) -> list:
""" """
Extract the required information from the data provided by unified loader Extract the required information from the data provided by unified loader
Args: Args:
dialogue (list): List of turns within a dialogue dialogue (list): List of turns within a dialogue
dataset_name (str): Name of the dataset to which the dialogue belongs dataset_name (str): Name of the dataset to which the dialogue belongs
dialogue_str (str): ID of the dialogue
Returns: Returns:
turns (list): List of turns within a dialogue turns (list): List of turns within a dialogue
...@@ -261,6 +264,7 @@ def extract_turns(dialogue: list, dataset_name: str) -> list: ...@@ -261,6 +264,7 @@ def extract_turns(dialogue: list, dataset_name: str) -> list:
turn_info['state'] = turn['state'] turn_info['state'] = turn['state']
turn_info['dataset_name'] = dataset_name turn_info['dataset_name'] = dataset_name
turn_info['dialogue_id'] = dialogue_id
if 'system_utterance' in turn_info and 'user_utterance' in turn_info: if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
turns.append(turn_info) turns.append(turn_info)
...@@ -399,6 +403,17 @@ def get_active_domains(turns: list) -> list: ...@@ -399,6 +403,17 @@ def get_active_domains(turns: list) -> list:
return turns return turns
class IdTensor:
def __init__(self, values):
self.values = numpy.array(values)
def __getitem__(self, index: int):
return self.values[index].tolist()
def to(self, device):
return self
def extract_dialogues(data: list, dataset_name: str) -> list: def extract_dialogues(data: list, dataset_name: str) -> list:
""" """
Extract all dialogues from dataset Extract all dialogues from dataset
...@@ -411,7 +426,8 @@ def extract_dialogues(data: list, dataset_name: str) -> list: ...@@ -411,7 +426,8 @@ def extract_dialogues(data: list, dataset_name: str) -> list:
""" """
dialogues = [] dialogues = []
for dial in data: for dial in data:
turns = extract_turns(dial['turns'], dataset_name) dial_id = dial['dialogue_id']
turns = extract_turns(dial['turns'], dataset_name, dial_id)
turns = clean_states(turns) turns = clean_states(turns)
turns = get_active_domains(turns) turns = get_active_domains(turns)
dialogues.append(turns) dialogues.append(turns)
......
...@@ -65,13 +65,13 @@ def main(args=None, config=None): ...@@ -65,13 +65,13 @@ def main(args=None, config=None):
paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths: if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path) config = ConfigClass.from_pretrained(args.model_name_or_path, local_files_only=True)
else: else:
paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
if paths: if paths:
paths = paths[0] paths = paths[0]
args.model_name_or_path = paths args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path) config = ConfigClass.from_pretrained(args.model_name_or_path, local_files_only=True)
args = update_args(args, config) args = update_args(args, config)
...@@ -102,12 +102,12 @@ def main(args=None, config=None): ...@@ -102,12 +102,12 @@ def main(args=None, config=None):
# Initialise Model # Initialise Model
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config, local_files_only=True)
model = model.to(device) model = model.to(device)
# Create Tokenizer and embedding model for Data Loaders and ontology # Create Tokenizer and embedding model for Data Loaders and ontology
encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name, local_files_only=True)
tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config) tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config, local_files_only=True)
# Set up model training/evaluation # Set up model training/evaluation
training.set_logger(logger, tb_writer) training.set_logger(logger, tb_writer)
......
...@@ -615,6 +615,8 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train ...@@ -615,6 +615,8 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
num_inform_slots += (state_labels != -1).float().reshape(-1) num_inform_slots += (state_labels != -1).float().reshape(-1)
if return_eval_output: if return_eval_output:
for sample in eval_output_batch:
sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']]
evaluation_output += deepcopy(eval_output_batch) evaluation_output += deepcopy(eval_output_batch)
eval_output_batch = [] eval_output_batch = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment