From c761fc7b4e126bb83cc91edd4025027a6da724f4 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Thu, 29 Oct 2020 10:56:15 +0800 Subject: [PATCH] Revert "update eval" This reverts commit 02537cf8f6474a33bb2d35e640e7f9d9b5b86f52. --- convlab2/dst/dstc9/utils.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 06615fb..7c88888 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -37,26 +37,16 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for dialog_id, dialog in test_data.items(): dialog_data = [] turns = dialog['messages'] - selected_results = {k: [] for k in turns[1]['sys_state'].keys()} for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] - dialog_state = {} - for domain_name, state in turns[i + 1]['sys_state_init'].items(): - state.pop('selectedResults') - sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') - # if state has changed compared to previous sys state - state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] - # clear the outdated previous selected results if state has been updated - if state_change: - selected_results[domain_name].clear() - if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: - state['name'] = selected_results[domain_name][0] - dialog_state[domain_name] = state - if state_change and sys_selected_results: - selected_results[domain_name] = sys_selected_results - - dialog_data.append((sys_utt, user_utt, dialog_state)) + state = {} + for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): + selected_results = domain_state.pop('selectedResults') + if selected_results and 'name' in domain_state and not domain_state['name']: + domain_state['name'] = selected_results + state[domain_name] = domain_state + dialog_data.append((sys_utt, user_utt, state)) data[dialog_id] = dialog_data return data @@ -72,6 +62,12 @@ def extract_gt(test_data): # for unifying values with the same meaning to the same expression def unify_value(value, subtask): + if isinstance(value, list): + ret = deepcopy(value) + for i, v in enumerate(ret): + ret[i] = unify_value(v, subtask) + return ret + value = value.lower() value = { 'multiwoz': { @@ -126,7 +122,7 @@ def eval_states(gt, pred, subtask): pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value: + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1 -- GitLab