diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 7c888882db3a60ca47d77221506db96c202036ae..78bd9fef0fa4596521c20dd5b944f9ca8ca514e6 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,6 +1,7 @@ -import json import os +import json import zipfile +from copy import deepcopy from convlab2 import DATA_ROOT @@ -22,16 +23,16 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['text'] if i else None user_utt = turns[i]['text'] - dialog_state = {} + state = {} for domain_name, domain in turns[i + 1]['metadata'].items(): if domain_name in ['警察机关', '医院', '公共汽车']: continue - state = {} + domain_state = {} for slots in domain.values(): for slot_name, value in slots.items(): - state[slot_name] = value - dialog_state[domain_name] = state - dialog_data.append((sys_utt, user_utt, dialog_state)) + domain_state[slot_name] = value + state[domain_name] = domain_state + dialog_data.append((sys_utt, user_utt, state)) data[dialog_id] = dialog_data else: for dialog_id, dialog in test_data.items():