Skip to content
Snippets Groups Projects
Commit c761fc7b authored by function2's avatar function2
Browse files

Revert "update eval"

This reverts commit 02537cf8.
parent 1211183d
Branches
Tags 1.1.4
No related merge requests found
...@@ -37,26 +37,16 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): ...@@ -37,26 +37,16 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
for dialog_id, dialog in test_data.items(): for dialog_id, dialog in test_data.items():
dialog_data = [] dialog_data = []
turns = dialog['messages'] turns = dialog['messages']
selected_results = {k: [] for k in turns[1]['sys_state'].keys()}
for i in range(0, len(turns), 2): for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content'] user_utt = turns[i]['content']
dialog_state = {} state = {}
for domain_name, state in turns[i + 1]['sys_state_init'].items(): for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
state.pop('selectedResults') selected_results = domain_state.pop('selectedResults')
sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') if selected_results and 'name' in domain_state and not domain_state['name']:
# if state has changed compared to previous sys state domain_state['name'] = selected_results
state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] state[domain_name] = domain_state
# clear the outdated previous selected results if state has been updated dialog_data.append((sys_utt, user_utt, state))
if state_change:
selected_results[domain_name].clear()
if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
state['name'] = selected_results[domain_name][0]
dialog_state[domain_name] = state
if state_change and sys_selected_results:
selected_results[domain_name] = sys_selected_results
dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data data[dialog_id] = dialog_data
return data return data
...@@ -72,6 +62,12 @@ def extract_gt(test_data): ...@@ -72,6 +62,12 @@ def extract_gt(test_data):
# for unifying values with the same meaning to the same expression # for unifying values with the same meaning to the same expression
def unify_value(value, subtask): def unify_value(value, subtask):
if isinstance(value, list):
ret = deepcopy(value)
for i, v in enumerate(ret):
ret[i] = unify_value(v, subtask)
return ret
value = value.lower() value = value.lower()
value = { value = {
'multiwoz': { 'multiwoz': {
...@@ -126,7 +122,7 @@ def eval_states(gt, pred, subtask): ...@@ -126,7 +122,7 @@ def eval_states(gt, pred, subtask):
pred_value = unify_value(pred_domain[slot_name], subtask) pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1 slot_tot += 1
if gt_value == pred_value: if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
slot_acc += 1 slot_acc += 1
if gt_value: if gt_value:
tp += 1 tp += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment