Skip to content
Snippets Groups Projects
Commit e0f5210b authored by function2's avatar function2
Browse files

use selectedResults for missing name

parent 0ac85ac3
Branches
No related tags found
No related merge requests found
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
evaluate output file evaluate output file
""" """
import os
import json import json
from copy import deepcopy import os
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
......
...@@ -41,12 +41,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): ...@@ -41,12 +41,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
sys_utt = turns[i - 1]['content'] if i else None sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content'] user_utt = turns[i]['content']
state = {} state = {}
for domain_name, domain in turns[i + 1]['sys_state_init'].items(): for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
domain_state = {} selected_results = domain_state.pop('selectedResults')
for slot_name, value in domain.items(): if selected_results and 'name' in domain_state and not domain_state['name']:
if slot_name == 'selectedResults': domain_state['name'] = selected_results
continue
domain_state[slot_name] = value
state[domain_name] = domain_state state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state)) dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data data[dialog_id] = dialog_data
...@@ -66,14 +64,16 @@ def eval_states(gt, pred, subtask): ...@@ -66,14 +64,16 @@ def eval_states(gt, pred, subtask):
# for unifying values with the same meaning to the same expression # for unifying values with the same meaning to the same expression
value_unifier = { value_unifier = {
'multiwoz': { 'multiwoz': {
'未提及': '',
}, },
'crosswoz': { 'crosswoz': {
'未提及': '',
} }
}[subtask] }[subtask]
def unify_value(value): def unify_value(value):
if isinstance(value, list):
return list(map(unify_value, value))
return value_unifier.get(value, value) return value_unifier.get(value, value)
def exception(description, **kargs): def exception(description, **kargs):
...@@ -110,7 +110,7 @@ def eval_states(gt, pred, subtask): ...@@ -110,7 +110,7 @@ def eval_states(gt, pred, subtask):
gt_value = unify_value(gt_value) gt_value = unify_value(gt_value)
pred_value = unify_value(pred_domain[slot_name]) pred_value = unify_value(pred_domain[slot_name])
slot_tot += 1 slot_tot += 1
if gt_value == pred_value: if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
slot_acc += 1 slot_acc += 1
if gt_value: if gt_value:
tp += 1 tp += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment