From 62a7d11bd0b3a86bc145099c5c7e4098a0b62a35 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Wed, 28 Oct 2020 22:15:18 +0800 Subject: [PATCH] update eval --- convlab2/dst/dstc9/utils.py | 40 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 61d0a5f..06615fb 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,7 +1,6 @@ -import os import json +import os import zipfile -from copy import deepcopy from convlab2 import DATA_ROOT @@ -23,40 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['text'] if i else None user_utt = turns[i]['text'] - state = {} + dialog_state = {} for domain_name, domain in turns[i + 1]['metadata'].items(): if domain_name in ['警察机关', '医院', '公共汽车']: continue - domain_state = {} + state = {} for slots in domain.values(): for slot_name, value in slots.items(): - domain_state[slot_name] = value - state[domain_name] = domain_state - dialog_data.append((sys_utt, user_utt, state)) + state[slot_name] = value + dialog_state[domain_name] = state + dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data else: for dialog_id, dialog in test_data.items(): dialog_data = [] turns = dialog['messages'] - selected_results = {k: [] for k in turns[1]['sys_state_init'].keys()} + selected_results = {k: [] for k in turns[1]['sys_state'].keys()} for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] - state = {} - for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): - new_selected_results = domain_state.pop('selectedResults') - # if state has changed compared to previous turn - state_change = i == 0 or domain_state != dialog_data[-1][2][domain_name] - # clear the invalid previous selected results if state has changed + dialog_state = {} + for domain_name, state in turns[i + 1]['sys_state_init'].items(): + state.pop('selectedResults') + sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') + # if state has changed compared to previous sys state + state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] + # clear the outdated previous selected results if state has been updated if state_change: selected_results[domain_name].clear() - if not domain_state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: - domain_state['name'] = selected_results[domain_name][0] - state[domain_name] = domain_state - if state_change: - selected_results[domain_name] = new_selected_results + if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: + state['name'] = selected_results[domain_name][0] + dialog_state[domain_name] = state + if state_change and sys_selected_results: + selected_results[domain_name] = sys_selected_results - dialog_data.append((sys_utt, user_utt, state)) + dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data return data -- GitLab