diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index a89c2e3e8795608b1dfec26652b4d8d2c9ca5800..4b3e1b26106c318bd139a401d2fbeeba13bbe481 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -2,9 +2,8 @@ evaluate output file """ -import os import json -from copy import deepcopy +import os from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index b378fd2e79486c0c9936f887a307ca1746e4e1aa..c926103684d23dc11adef2cb266025b64849ab7f 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -41,12 +41,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] state = {} - for domain_name, domain in turns[i + 1]['sys_state_init'].items(): - domain_state = {} - for slot_name, value in domain.items(): - if slot_name == 'selectedResults': - continue - domain_state[slot_name] = value + for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): + selected_results = domain_state.pop('selectedResults') + if selected_results and 'name' in domain_state and not domain_state['name']: + domain_state['name'] = selected_results state[domain_name] = domain_state dialog_data.append((sys_utt, user_utt, state)) data[dialog_id] = dialog_data @@ -66,14 +64,16 @@ def eval_states(gt, pred, subtask): # for unifying values with the same meaning to the same expression value_unifier = { 'multiwoz': { - + '未提及': '', }, 'crosswoz': { - '未提及': '', + } }[subtask] def unify_value(value): + if isinstance(value, list): + return list(map(unify_value, value)) return value_unifier.get(value, value) def exception(description, **kargs): @@ -110,7 +110,7 @@ def eval_states(gt, pred, subtask): gt_value = unify_value(gt_value) pred_value = unify_value(pred_domain[slot_name]) slot_tot += 1 - if gt_value == pred_value: + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1