diff --git a/metric_bert_dst.py b/metric_bert_dst.py
index d2df63af7146e0feef28b90ee9bc11000c8fa337..e7818c1628bc96fcbd2e2ee9dac31375e5326144 100644
--- a/metric_bert_dst.py
+++ b/metric_bert_dst.py
@@ -171,9 +171,6 @@ def get_joint_slot_correctness(fp, class_types, label_maps,
                 class_correctness[turn_gt_class].append(1.0)
                 class_correctness[-1].append(1.0)
                 c_tp[turn_gt_class] += 1
-                for cc in range(len(class_types)):
-                    if cc != turn_gt_class:
-                        c_tn[cc] += 1
                 # Only where there is a span, we check its per turn correctness
                 if turn_gt_class == class_types.index('copy_value'):
                     if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos:
@@ -198,6 +195,9 @@ def get_joint_slot_correctness(fp, class_types, label_maps,
                 confusion_matrix[turn_gt_class][turn_pd_class].append(1.0)
                 c_fn[turn_gt_class] += 1
                 c_fp[turn_pd_class] += 1
+            for cc in range(len(class_types)):
+                if cc != turn_gt_class and cc != turn_pd_class:
+                    c_tn[cc] += 1
 
             # Check the joint slot correctness.
             # If the value label is not none, then we need to have a value prediction.