From e0ee065c04a5d496be3402d13209cc5c2da0c59e Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Sun, 1 Nov 2020 17:13:53 +0800
Subject: [PATCH] add correct name labem argument

---
 convlab2/dst/dstc9/eval_file.py  |  5 +++--
 convlab2/dst/dstc9/eval_model.py |  9 +++++----
 convlab2/dst/dstc9/utils.py      | 26 +++++++++++++++++++++-----
 3 files changed, 29 insertions(+), 11 deletions(-)

diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py
index fe33b02..38b7577 100644
--- a/convlab2/dst/dstc9/eval_file.py
+++ b/convlab2/dst/dstc9/eval_file.py
@@ -51,10 +51,11 @@ if __name__ == '__main__':
     from argparse import ArgumentParser
     parser = ArgumentParser()
     parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
-    parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
+    parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val', 'dstc9-250'])
+    parser.add_argument('correct_name_label', action='store_true')
     args = parser.parse_args()
     subtask = args.subtask
     split = args.split
     dump_example(subtask, split)
-    test_data = prepare_data(subtask, split)
+    test_data = prepare_data(subtask, split, correct_name_label=args.correct_name_label)
     gt = extract_gt(test_data)
diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py
index 0876187..4544128 100644
--- a/convlab2/dst/dstc9/eval_model.py
+++ b/convlab2/dst/dstc9/eval_model.py
@@ -34,9 +34,9 @@ def evaluate(model_dir, subtask, test_data, gt):
     dump_result(model_dir, 'model-result.json', result, errors, pred)
 
 
-def eval_team(team):
+def eval_team(team, correct_name_label):
     for subtask in ['multiwoz', 'crosswoz']:
-        test_data = prepare_data(subtask, 'dstc9')
+        test_data = prepare_data(subtask, 'dstc9', correct_name_label=correct_name_label)
         gt = extract_gt(test_data)
         for i in range(1, 6):
             model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}')
@@ -50,12 +50,13 @@ if __name__ == '__main__':
     from argparse import ArgumentParser
     parser = ArgumentParser()
     parser.add_argument('--teams', type=str, nargs='*')
+    parser.add_argument('correct_name_label', action='store_true')
     args = parser.parse_args()
     if not args.teams:
         for team in os.listdir('.'):
             if not os.path.isdir(team):
                 continue
-            eval_team(team)
+            eval_team(team, args.correct_name_label)
     else:
         for team in args.teams:
-            eval_team(team)
+            eval_team(team, args.correct_name_label)
diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index 588e90c..29a951a 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -11,7 +11,7 @@ def get_subdir(subtask):
     return subdir
 
 
-def prepare_data(subtask, split, data_root=DATA_ROOT):
+def prepare_data(subtask, split, data_root=DATA_ROOT, correct_name_label=False):
     data_dir = os.path.join(data_root, get_subdir(subtask))
     zip_filename = os.path.join(data_dir, f'{split}.json.zip')
     test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
@@ -38,15 +38,31 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
         for dialog_id, dialog in test_data.items():
             dialog_data = []
             turns = dialog['messages']
+            if correct_name_label:
+                selected_results = {domain_name: [] for domain_name in turns[1]['sys_state_init']}
             for i in range(0, len(turns), 2):
                 sys_utt = turns[i - 1]['content'] if i else None
                 user_utt = turns[i]['content']
                 dialog_state = {}
                 for domain_name, state in turns[i + 1]['sys_state_init'].items():
-                    selected_results = state.pop('selectedResults')
-                    if selected_results and 'name' in state and not state['name']:
-                        state['name'] = selected_results
-                    dialog_state[domain_name] = state
+                    if correct_name_label:
+                        state.pop('selectedResults')
+                        sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
+                        # if state has changed compared to previous sys state
+                        state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
+                        # clear the outdated previous selected results if state has been updated
+                        if state_change:
+                            selected_results[domain_name].clear()
+                        if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
+                            state['name'] = selected_results[domain_name][0]
+                        dialog_state[domain_name] = state
+                        if state_change and sys_selected_results:
+                            selected_results[domain_name] = sys_selected_results
+                    else:
+                        selected_results = state.pop('selectedResults')
+                        if selected_results and 'name' in state and not state['name']:
+                            state['name'] = selected_results
+                        dialog_state[domain_name] = state
                 dialog_data.append((sys_utt, user_utt, dialog_state))
             data[dialog_id] = dialog_data
 
-- 
GitLab