diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 78bd9fef0fa4596521c20dd5b944f9ca8ca514e6..61d0a5f9b9362e8d8cdbdf6d501d6aa8ab45a89b 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -38,15 +38,24 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for dialog_id, dialog in test_data.items(): dialog_data = [] turns = dialog['messages'] + selected_results = {k: [] for k in turns[1]['sys_state_init'].keys()} for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] state = {} for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): - selected_results = domain_state.pop('selectedResults') - if selected_results and 'name' in domain_state and not domain_state['name']: - domain_state['name'] = selected_results + new_selected_results = domain_state.pop('selectedResults') + # if state has changed compared to previous turn + state_change = i == 0 or domain_state != dialog_data[-1][2][domain_name] + # clear the invalid previous selected results if state has changed + if state_change: + selected_results[domain_name].clear() + if not domain_state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: + domain_state['name'] = selected_results[domain_name][0] state[domain_name] = domain_state + if state_change: + selected_results[domain_name] = new_selected_results + dialog_data.append((sys_utt, user_utt, state)) data[dialog_id] = dialog_data @@ -63,12 +72,6 @@ def extract_gt(test_data): # for unifying values with the same meaning to the same expression def unify_value(value, subtask): - if isinstance(value, list): - ret = deepcopy(value) - for i, v in enumerate(ret): - ret[i] = unify_value(v, subtask) - return ret - value = value.lower() value = { 'multiwoz': { @@ -123,7 +126,7 @@ def eval_states(gt, pred, subtask): pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: + if gt_value == pred_value: slot_acc += 1 if gt_value: tp += 1