From b7896c8ed0a6506378353accfbd95c67a66e20ff Mon Sep 17 00:00:00 2001
From: Michael Heck <michael.heck@hhu.de>
Date: Wed, 16 Sep 2020 15:46:27 +0200
Subject: [PATCH] Fix stat bug

---
 metric_bert_dst.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/metric_bert_dst.py b/metric_bert_dst.py
index d2df63a..e7818c1 100644
--- a/metric_bert_dst.py
+++ b/metric_bert_dst.py
@@ -171,9 +171,6 @@ def get_joint_slot_correctness(fp, class_types, label_maps,
                 class_correctness[turn_gt_class].append(1.0)
                 class_correctness[-1].append(1.0)
                 c_tp[turn_gt_class] += 1
-                for cc in range(len(class_types)):
-                    if cc != turn_gt_class:
-                        c_tn[cc] += 1
                 # Only where there is a span, we check its per turn correctness
                 if turn_gt_class == class_types.index('copy_value'):
                     if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos:
@@ -198,6 +195,9 @@ def get_joint_slot_correctness(fp, class_types, label_maps,
                 confusion_matrix[turn_gt_class][turn_pd_class].append(1.0)
                 c_fn[turn_gt_class] += 1
                 c_fp[turn_pd_class] += 1
+            for cc in range(len(class_types)):
+                if cc != turn_gt_class and cc != turn_pd_class:
+                    c_tn[cc] += 1
 
             # Check the joint slot correctness.
             # If the value label is not none, then we need to have a value prediction.
-- 
GitLab