From c24d5ecb2182b96240f12bd6189a47838d59a8c2 Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Tue, 13 Oct 2020 10:44:37 +0800
Subject: [PATCH] update eval

---
 convlab2/dst/dstc9/eval_file.py  | 15 +++++++--------
 convlab2/dst/dstc9/eval_model.py |  9 ++++-----
 convlab2/dst/dstc9/utils.py      |  2 +-
 3 files changed, 12 insertions(+), 14 deletions(-)

diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py
index 4b3e1b2..d592b27 100644
--- a/convlab2/dst/dstc9/eval_file.py
+++ b/convlab2/dst/dstc9/eval_file.py
@@ -9,24 +9,24 @@ from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_
 
 
 def evaluate(model_dir, subtask, gt):
-    subdir = get_subdir(subtask)
     results = {}
     for i in range(1, 6):
-        filepath = os.path.join(model_dir, subdir, f'submission{i}.json')
+        filepath = os.path.join(model_dir, f'submission{i}.json')
         if not os.path.exists(filepath):
             continue
         pred = json.load(open(filepath))
         results[i] = eval_states(gt, pred, subtask)
 
-    print(json.dumps(results, indent=4))
-    json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
+    print(json.dumps(results, indent=4, ensure_ascii=False))
+    json.dump(results, open(os.path.join(model_dir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
 
 
 # generate submission examples
 def dump_example(subtask, split):
     test_data = prepare_data(subtask, split)
     pred = extract_gt(test_data)
-    json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
+    subdir = get_subdir(subtask)
+    json.dump(pred, open(os.path.join('example', subdir, 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
     import random
     for dialog_id, states in pred.items():
         for state in states:
@@ -38,13 +38,13 @@ def dump_example(subtask, split):
                     else:
                         if random.randint(0, 4) == 0:
                             domain[slot] = "2333"
-    json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
+    json.dump(pred, open(os.path.join('example', subdir, 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
     for dialog_id, states in pred.items():
         for state in states:
             for domain in state.values():
                 for slot in domain:
                     domain[slot] = ""
-    json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4)
+    json.dump(pred, open(os.path.join('example', subdir, 'submission3.json'), 'w'), ensure_ascii=False, indent=4)
 
 
 if __name__ == '__main__':
@@ -58,4 +58,3 @@ if __name__ == '__main__':
     dump_example(subtask, split)
     test_data = prepare_data(subtask, split)
     gt = extract_gt(test_data)
-    evaluate('example', subtask, gt)
diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py
index 377ecdd..eaba8e6 100644
--- a/convlab2/dst/dstc9/eval_model.py
+++ b/convlab2/dst/dstc9/eval_model.py
@@ -7,14 +7,13 @@ import json
 import importlib
 
 from convlab2.dst import DST
-from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir
+from convlab2.dst.dstc9.utils import prepare_data, eval_states
 
 
 def evaluate(model_dir, subtask, test_data, gt):
-    subdir = get_subdir(subtask)
-    module = importlib.import_module(f'{model_dir}.{subdir}')
+    module = importlib.import_module(model_dir.replace('/', '.'))
     assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
-    model_cls = module.__getattribute__('Model')
+    model_cls = getattr(module, 'Model')
     assert issubclass(model_cls, DST), 'the model must implement DST interface'
     # load weights, set eval() on default
     model = model_cls()
@@ -24,7 +23,7 @@ def evaluate(model_dir, subtask, test_data, gt):
         pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
     result = eval_states(gt, pred, subtask)
     print(json.dumps(result, indent=4))
-    json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
+    json.dump(result, open(os.path.join(model_dir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
 
 
 if __name__ == '__main__':
diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index b685623..0987cbe 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -25,7 +25,7 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
                 user_utt = turns[i]['text']
                 state = {}
                 for domain_name, domain in turns[i + 1]['metadata'].items():
-                    if domain_name in ['警察机关', '医院']:
+                    if domain_name in ['警察机关', '医院', '公共汽车']:
                         continue
                     domain_state = {}
                     for slots in domain.values():
-- 
GitLab