From 2c8dc0d9a869cb7c38edab4c5d25906a6211b362 Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Sat, 17 Oct 2020 16:08:57 +0800
Subject: [PATCH] udpate dstc9 eval

---
 convlab2/dst/dstc9/eval_file.py  | 10 +++++-----
 convlab2/dst/dstc9/eval_model.py |  5 +++--
 convlab2/dst/dstc9/utils.py      | 18 +++++++++++++-----
 3 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py
index 50784fe..fe33b02 100644
--- a/convlab2/dst/dstc9/eval_file.py
+++ b/convlab2/dst/dstc9/eval_file.py
@@ -9,16 +9,16 @@ from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_
 
 
 def evaluate(model_dir, subtask, gt):
-    results = {}
     for i in range(1, 6):
         filepath = os.path.join(model_dir, f'submission{i}.json')
         if not os.path.exists(filepath):
             continue
         pred = json.load(open(filepath))
-        results[i] = eval_states(gt, pred, subtask)
-
-    print(json.dumps(results, indent=4, ensure_ascii=False))
-    dump_result(model_dir, 'file-results.json', results)
+        result, errors = eval_states(gt, pred, subtask)
+        print(json.dumps(result, indent=4, ensure_ascii=False))
+        dump_result(model_dir, 'file-result.json', result)
+        return
+    raise ValueError('submission file not found')
 
 
 # generate submission examples
diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py
index 207eac7..3a3e081 100644
--- a/convlab2/dst/dstc9/eval_model.py
+++ b/convlab2/dst/dstc9/eval_model.py
@@ -28,9 +28,10 @@ def evaluate(model_dir, subtask, test_data, gt):
             pred[dialog_id].append(model.update_turn(sys_utt, user_utt))
             bar.update()
     bar.close()
-    result = eval_states(gt, pred, subtask)
+
+    result, errors = eval_states(gt, pred, subtask)
     print(json.dumps(result, indent=4))
-    dump_result(model_dir, 'model-result.json', result)
+    dump_result(model_dir, 'model-result.json', result, errors, pred)
 
 
 if __name__ == '__main__':
diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index 4e7d967..8a37f4b 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -83,8 +83,6 @@ def unify_value(value, subtask):
 
     return ' '.join(value.strip().split())
 
-    return ' '.join(value.strip().split())
-
 
 def eval_states(gt, pred, subtask):
     def exception(description, **kargs):
@@ -94,7 +92,8 @@ def eval_states(gt, pred, subtask):
         }
         for k, v in kargs.items():
             ret[k] = v
-        return ret
+        return ret, None
+    errors = []
 
     joint_acc, joint_tot = 0, 0
     slot_acc, slot_tot = 0, 0
@@ -121,11 +120,13 @@ def eval_states(gt, pred, subtask):
                     gt_value = unify_value(gt_value, subtask)
                     pred_value = unify_value(pred_domain[slot_name], subtask)
                     slot_tot += 1
+
                     if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
                         slot_acc += 1
                         if gt_value:
                             tp += 1
                     else:
+                        errors.append([gt_value, pred_value])
                         turn_result = False
                         if gt_value:
                             fn += 1
@@ -145,10 +146,17 @@ def eval_states(gt, pred, subtask):
             'recall': recall,
             'f1': f1,
         }
-    }
+    }, errors
 
 
-def dump_result(model_dir, filename, result):
+def dump_result(model_dir, filename, result, errors=None, pred=None):
     output_dir = os.path.join('../results', model_dir)
     os.makedirs(output_dir, exist_ok=True)
     json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False)
+    if errors:
+        import csv
+        with open(os.path.join(output_dir, 'errors.csv'), 'w') as f:
+            writer = csv.writer(f)
+            writer.writerows(errors)
+    if pred:
+        json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False)
-- 
GitLab