diff --git a/metric_bert_dst.py b/metric_bert_dst.py index d2df63af7146e0feef28b90ee9bc11000c8fa337..e7818c1628bc96fcbd2e2ee9dac31375e5326144 100644 --- a/metric_bert_dst.py +++ b/metric_bert_dst.py @@ -171,9 +171,6 @@ def get_joint_slot_correctness(fp, class_types, label_maps, class_correctness[turn_gt_class].append(1.0) class_correctness[-1].append(1.0) c_tp[turn_gt_class] += 1 - for cc in range(len(class_types)): - if cc != turn_gt_class: - c_tn[cc] += 1 # Only where there is a span, we check its per turn correctness if turn_gt_class == class_types.index('copy_value'): if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos: @@ -198,6 +195,9 @@ def get_joint_slot_correctness(fp, class_types, label_maps, confusion_matrix[turn_gt_class][turn_pd_class].append(1.0) c_fn[turn_gt_class] += 1 c_fp[turn_pd_class] += 1 + for cc in range(len(class_types)): + if cc != turn_gt_class and cc != turn_pd_class: + c_tn[cc] += 1 # Check the joint slot correctness. # If the value label is not none, then we need to have a value prediction.