Skip to content
Snippets Groups Projects
Commit 413f971c authored by Carel van Niekerk's avatar Carel van Niekerk :desktop:
Browse files

Bug fix

parent 4d5ae7ab
No related branches found
No related tags found
No related merge requests found
...@@ -91,7 +91,6 @@ def main(args=None, config=None): ...@@ -91,7 +91,6 @@ def main(args=None, config=None):
args.fp16 = False args.fp16 = False
# Set up model training/evaluation # Set up model training/evaluation
evaluation_utils.set_logger(logger, None)
evaluation_utils.set_seed(args) evaluation_utils.set_seed(args)
# Perform tasks # Perform tasks
...@@ -173,7 +172,7 @@ def main(args=None, config=None): ...@@ -173,7 +172,7 @@ def main(args=None, config=None):
jg_acc /= n_turns jg_acc /= n_turns
logger.info('Joint Goal Accuracy: %f, Slot Accuracy %f' % (jg_acc, sl_acc)) logger.info(f'Joint Goal Accuracy: {jg_acc}')
l2 = l2_acc(belief_states, state_labels, remove_belief=False) l2 = l2_acc(belief_states, state_labels, remove_belief=False)
logger.info(f'Model L2 Norm Goal Accuracy: {l2}') logger.info(f'Model L2 Norm Goal Accuracy: {l2}')
...@@ -250,7 +249,7 @@ def main(args=None, config=None): ...@@ -250,7 +249,7 @@ def main(args=None, config=None):
p = p.unsqueeze(-1) p = p.unsqueeze(-1)
p = torch.cat((1 - p, p), -1) p = torch.cat((1 - p, p), -1)
active_domain_probs[dom] = p active_domain_probs[dom] = p
jg = jg_ece(active_domain_probs, domain_labels, 10) jg = jg_ece(active_domain_probs, active_domain_labels, 10)
logger.info('Domain Joint Goal ECE: %f' % jg) logger.info('Domain Joint Goal ECE: %f' % jg)
tp = ((general_act_probs.argmax(-1) > 0) * tp = ((general_act_probs.argmax(-1) > 0) *
...@@ -284,12 +283,12 @@ def main(args=None, config=None): ...@@ -284,12 +283,12 @@ def main(args=None, config=None):
l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True) l2 = l2_acc(active_domain_probs, active_domain_labels, remove_belief=True)
logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}') logger.info(f'Binary Model L2 Norm Domain Accuracy: {l2}')
active_domain_labels = {'general': active_domain_labels} general_act_labels = {'general': general_act_labels}
general_act_probs = {'general': general_act_probs} general_act_probs = {'general': general_act_probs}
l2 = l2_acc(general_act_probs, active_domain_labels, remove_belief=False) l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
logger.info(f'Model L2 Norm General Act Accuracy: {l2}') logger.info(f'Model L2 Norm General Act Accuracy: {l2}')
l2 = l2_acc(general_act_probs, active_domain_labels, remove_belief=False) l2 = l2_acc(general_act_probs, general_act_labels, remove_belief=False)
logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}') logger.info(f'Binary Model L2 Norm General Act Accuracy: {l2}')
......
...@@ -53,7 +53,7 @@ def get_predictions(args, model, device: torch.device, dataloader: torch.utils.d ...@@ -53,7 +53,7 @@ def get_predictions(args, model, device: torch.device, dataloader: torch.utils.d
active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids} active_domain_probs = {dom: [] for dom in model.setsumbt.domain_ids}
general_act_probs = [] general_act_probs = []
state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids} state_labels = {slot: [] for slot in model.setsumbt.informable_slot_ids}
request_labels = {slot: [] for slot in model.requestable_slot_ids} request_labels = {slot: [] for slot in model.setsumbt.requestable_slot_ids}
active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids} active_domain_labels = {dom: [] for dom in model.setsumbt.domain_ids}
general_act_labels = [] general_act_labels = []
epoch_iterator = tqdm(dataloader, desc="Iteration") epoch_iterator = tqdm(dataloader, desc="Iteration")
...@@ -83,13 +83,13 @@ def get_predictions(args, model, device: torch.device, dataloader: torch.utils.d ...@@ -83,13 +83,13 @@ def get_predictions(args, model, device: torch.device, dataloader: torch.utils.d
for domain in active_domain_probs: for domain in active_domain_probs:
p_ = p_dom[domain] p_ = p_dom[domain]
labs = batch['active-' + domain].to(device) labs = batch['active_domain_labels-' + domain].to(device)
active_domain_probs[domain].append(p_) active_domain_probs[domain].append(p_)
active_domain_labels[domain].append(labs) active_domain_labels[domain].append(labs)
general_act_probs.append(p_bye) general_act_probs.append(p_gen)
general_act_labels.append(batch['goodbye'].to(device)) general_act_labels.append(batch['general_act_labels'].to(device))
for slot in belief_states: for slot in belief_states:
belief_states[slot] = torch.cat(belief_states[slot], 0) belief_states[slot] = torch.cat(belief_states[slot], 0)
...@@ -107,6 +107,6 @@ def get_predictions(args, model, device: torch.device, dataloader: torch.utils.d ...@@ -107,6 +107,6 @@ def get_predictions(args, model, device: torch.device, dataloader: torch.utils.d
request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4 request_probs, request_labels, active_domain_probs, active_domain_labels = [None] * 4
general_act_probs, general_act_labels = [None] * 2 general_act_probs, general_act_labels = [None] * 2
out = (belief_states, state_labels, request_belief, request_labels) out = (belief_states, state_labels, request_probs, request_labels)
out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels) out += (active_domain_probs, active_domain_labels, general_act_probs, general_act_labels)
return out return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment