Skip to content
Snippets Groups Projects
Commit 62a7d11b authored by function2's avatar function2
Browse files

update eval

parent 02537cf8
No related branches found
No related tags found
No related merge requests found
import os
import json
import os
import zipfile
from copy import deepcopy
from convlab2 import DATA_ROOT
......@@ -23,40 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['text'] if i else None
user_utt = turns[i]['text']
state = {}
dialog_state = {}
for domain_name, domain in turns[i + 1]['metadata'].items():
if domain_name in ['警察机关', '医院', '公共汽车']:
continue
domain_state = {}
state = {}
for slots in domain.values():
for slot_name, value in slots.items():
domain_state[slot_name] = value
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
state[slot_name] = value
dialog_state[domain_name] = state
dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data
else:
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['messages']
selected_results = {k: [] for k in turns[1]['sys_state_init'].keys()}
selected_results = {k: [] for k in turns[1]['sys_state'].keys()}
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content']
state = {}
for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
new_selected_results = domain_state.pop('selectedResults')
# if state has changed compared to previous turn
state_change = i == 0 or domain_state != dialog_data[-1][2][domain_name]
# clear the invalid previous selected results if state has changed
dialog_state = {}
for domain_name, state in turns[i + 1]['sys_state_init'].items():
state.pop('selectedResults')
sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
# if state has changed compared to previous sys state
state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
# clear the outdated previous selected results if state has been updated
if state_change:
selected_results[domain_name].clear()
if not domain_state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
domain_state['name'] = selected_results[domain_name][0]
state[domain_name] = domain_state
if state_change:
selected_results[domain_name] = new_selected_results
if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
state['name'] = selected_results[domain_name][0]
dialog_state[domain_name] = state
if state_change and sys_selected_results:
selected_results[domain_name] = sys_selected_results
dialog_data.append((sys_utt, user_utt, state))
dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data
return data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment