From d0ec9473d25c44f34de45f770437e4ec1bffcfb8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=BD=97=E5=B4=9A=E9=AA=81?= <function2@qq.com>
Date: Tue, 22 Sep 2020 19:08:34 +0800
Subject: [PATCH] revise xldst evaluation (#124)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example
---
 convlab2/dst/dstc9/.gitignore                 |  1 +
 convlab2/dst/dstc9/__init__.py                |  2 +
 convlab2/dst/dstc9/eval_file.py               | 54 ++++++++++++-------
 convlab2/dst/dstc9/eval_model.py              | 30 +++++------
 convlab2/dst/dstc9/example/.gitignore         |  1 -
 .../dst/dstc9/example/multiwoz_zh/model.py    |  4 +-
 convlab2/dst/dstc9/utils.py                   | 31 ++++++-----
 7 files changed, 74 insertions(+), 49 deletions(-)
 create mode 100644 convlab2/dst/dstc9/.gitignore
 delete mode 100644 convlab2/dst/dstc9/example/.gitignore

diff --git a/convlab2/dst/dstc9/.gitignore b/convlab2/dst/dstc9/.gitignore
new file mode 100644
index 0000000..5e8e7e0
--- /dev/null
+++ b/convlab2/dst/dstc9/.gitignore
@@ -0,0 +1 @@
+**/*.json
diff --git a/convlab2/dst/dstc9/__init__.py b/convlab2/dst/dstc9/__init__.py
index e69de29..1abf4a7 100644
--- a/convlab2/dst/dstc9/__init__.py
+++ b/convlab2/dst/dstc9/__init__.py
@@ -0,0 +1,2 @@
+from .eval_file import evaluate as eval_file
+from .eval_model import evaluate as eval_model
diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py
index fa3187d..b53b46a 100644
--- a/convlab2/dst/dstc9/eval_file.py
+++ b/convlab2/dst/dstc9/eval_file.py
@@ -2,28 +2,46 @@
     evaluate output file
 """
 
-from convlab2.dst.dstc9.utils import prepare_data, eval_states
+import os
+import json
 
-if __name__ == '__main__':
-    import os
-    import json
-    from argparse import ArgumentParser
-    parser = ArgumentParser()
-    parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
-    args = parser.parse_args()
+from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
 
-    gt = {
-        dialog_id: [state for _, _, state in turns]
-        for dialog_id, turns in prepare_data(args.subtask).items()
-    }
-    # json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4)
 
+def evaluate(model_dir, subtask, gt):
+    subdir = get_subdir(subtask)
     results = {}
     for i in range(1, 6):
-        filename = f'submission{i}.json'
-        if not os.path.exists(filename):
+        filepath = os.path.join(model_dir, subdir, f'submission{i}.json')
+        if not os.path.exists(filepath):
             continue
-        pred = json.load(open(filename))
-        results[filename] = eval_states(gt, pred)
+        pred = json.load(open(filepath))
+        results[i] = eval_states(gt, pred)
+
+    json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
+
 
-    json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False)
+def dump_example(subtask, split):
+    test_data = prepare_data(subtask, split)
+    gt = extract_gt(test_data)
+    json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
+    for dialog_id, states in gt.items():
+        for state in states:
+            for domain in state.values():
+                for slot in domain:
+                    domain[slot] = ""
+    json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
+    parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
+    args = parser.parse_args()
+    subtask = args.subtask
+    split = args.split
+    dump_example(subtask, split)
+    test_data = prepare_data(subtask, split)
+    gt = extract_gt(test_data)
+    evaluate('example', subtask, gt)
diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py
index 9b86d05..1b84ea0 100644
--- a/convlab2/dst/dstc9/eval_model.py
+++ b/convlab2/dst/dstc9/eval_model.py
@@ -7,36 +7,36 @@ import json
 import importlib
 
 from convlab2.dst import DST
-from convlab2.dst.dstc9.utils import prepare_data, eval_states
+from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir
 
 
-def evaluate(model_name, subtask):
-    subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
-    module = importlib.import_module(f'{model_name}.{subdir}')
+def evaluate(model_dir, subtask, test_data, gt):
+    subdir = get_subdir(subtask)
+    module = importlib.import_module(f'{model_dir}.{subdir}')
     assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
     model_cls = module.__getattribute__('Model')
     assert issubclass(model_cls, DST), 'the model must implement DST interface'
     # load weights, set eval() on default
     model = model_cls()
-    gt = {}
     pred = {}
-    for dialog_id, turns in prepare_data(subtask).items():
-        gt_dialog = []
-        pred_dialog = []
+    for dialog_id, turns in test_data.items():
         model.init_session()
-        for sys_utt, user_utt, gt_turn in turns:
-            gt_dialog.append(gt_turn)
-            pred_dialog.append(model.update_turn(sys_utt, user_utt))
-        gt[dialog_id] = gt_dialog
-        pred[dialog_id] = pred_dialog
+        pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
     result = eval_states(gt, pred)
     print(result)
-    json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False)
+    json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
 
 
 if __name__ == '__main__':
     from argparse import ArgumentParser
     parser = ArgumentParser()
     parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
+    parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
     args = parser.parse_args()
-    evaluate('example', args.subtask)
+    subtask = args.subtask
+    test_data = prepare_data(subtask, args.split)
+    gt = {
+        dialog_id: [state for _, _, state in turns]
+        for dialog_id, turns in test_data.items()
+    }
+    evaluate('example', subtask, test_data, gt)
diff --git a/convlab2/dst/dstc9/example/.gitignore b/convlab2/dst/dstc9/example/.gitignore
deleted file mode 100644
index 4b544e7..0000000
--- a/convlab2/dst/dstc9/example/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-*/result.json
diff --git a/convlab2/dst/dstc9/example/multiwoz_zh/model.py b/convlab2/dst/dstc9/example/multiwoz_zh/model.py
index 1a81cdb..cd6ef36 100644
--- a/convlab2/dst/dstc9/example/multiwoz_zh/model.py
+++ b/convlab2/dst/dstc9/example/multiwoz_zh/model.py
@@ -51,8 +51,8 @@ class ExampleModel(DST):
                 "出发时间": "",
                 "目的地": "",
                 "日期": "",
-                "到达时间": "未提及",
-                "出发地": "未提及",
+                "到达时间": "",
+                "出发地": "",
             },
         }
 
diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index 75b6077..80cd25d 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -2,21 +2,13 @@ import os
 import json
 import zipfile
 
+from convlab2 import DATA_ROOT
 
-def load_test_data(subtask):
-    from convlab2 import DATA_ROOT
-    data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
-    # test public data currently
-    # to check if this script works properly with your code when label information is
-    # not available, you may need to fill the missing fields yourself (with any value)
-    zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip')
-    test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json'))
-    assert len(test_data) == 250
-    return test_data
 
-
-def prepare_data(subtask):
-    test_data = load_test_data(subtask)
+def prepare_data(subtask, split, data_root=DATA_ROOT):
+    data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
+    zip_filename = os.path.join(data_dir, f'{split}.json.zip')
+    test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
     data = {}
     if subtask == 'multiwoz':
         for dialog_id, dialog in test_data.items():
@@ -57,6 +49,14 @@ def prepare_data(subtask):
     return data
 
 
+def extract_gt(test_data):
+    gt = {
+        dialog_id: [state for _, _, state in turns]
+        for dialog_id, turns in test_data.items()
+    }
+    return gt
+
+
 def eval_states(gt, pred):
     def exception(description, **kargs):
         ret = {
@@ -116,3 +116,8 @@ def eval_states(gt, pred):
         #     'f1': f1,
         # }
     }
+
+
+def get_subdir(subtask):
+    subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
+    return subdir
-- 
GitLab