diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 78bd9fef0fa4596521c20dd5b944f9ca8ca514e6..588e90cff00c9b672d0c2bc23c248a503dd91914 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -23,16 +23,16 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['text'] if i else None user_utt = turns[i]['text'] - state = {} + dialog_state = {} for domain_name, domain in turns[i + 1]['metadata'].items(): if domain_name in ['警察机关', '医院', '公共汽车']: continue - domain_state = {} + state = {} for slots in domain.values(): for slot_name, value in slots.items(): - domain_state[slot_name] = value - state[domain_name] = domain_state - dialog_data.append((sys_utt, user_utt, state)) + state[slot_name] = value + dialog_state[domain_name] = state + dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data else: for dialog_id, dialog in test_data.items(): @@ -41,13 +41,13 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] - state = {} - for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): - selected_results = domain_state.pop('selectedResults') - if selected_results and 'name' in domain_state and not domain_state['name']: - domain_state['name'] = selected_results - state[domain_name] = domain_state - dialog_data.append((sys_utt, user_utt, state)) + dialog_state = {} + for domain_name, state in turns[i + 1]['sys_state_init'].items(): + selected_results = state.pop('selectedResults') + if selected_results and 'name' in state and not state['name']: + state['name'] = selected_results + dialog_state[domain_name] = state + dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data return data