Skip to content
Snippets Groups Projects
Commit 02537cf8 authored by function2's avatar function2
Browse files

update eval

parent 497a8d96
No related branches found
No related tags found
No related merge requests found
......@@ -38,15 +38,24 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['messages']
selected_results = {k: [] for k in turns[1]['sys_state_init'].keys()}
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content']
state = {}
for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
selected_results = domain_state.pop('selectedResults')
if selected_results and 'name' in domain_state and not domain_state['name']:
domain_state['name'] = selected_results
new_selected_results = domain_state.pop('selectedResults')
# if state has changed compared to previous turn
state_change = i == 0 or domain_state != dialog_data[-1][2][domain_name]
# clear the invalid previous selected results if state has changed
if state_change:
selected_results[domain_name].clear()
if not domain_state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
domain_state['name'] = selected_results[domain_name][0]
state[domain_name] = domain_state
if state_change:
selected_results[domain_name] = new_selected_results
dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data
......@@ -63,12 +72,6 @@ def extract_gt(test_data):
# for unifying values with the same meaning to the same expression
def unify_value(value, subtask):
if isinstance(value, list):
ret = deepcopy(value)
for i, v in enumerate(ret):
ret[i] = unify_value(v, subtask)
return ret
value = value.lower()
value = {
'multiwoz': {
......@@ -123,7 +126,7 @@ def eval_states(gt, pred, subtask):
pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1
if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
if gt_value == pred_value:
slot_acc += 1
if gt_value:
tp += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment