Skip to content
Snippets Groups Projects
Commit bc767e5d authored by zqwerty's avatar zqwerty
Browse files

make nlu,dst serialization more compact

parent d4a98242
Branches
No related tags found
No related merge requests found
...@@ -80,9 +80,9 @@ def create_nlg_data(dataset, data_dir, args): ...@@ -80,9 +80,9 @@ def create_nlg_data(dataset, data_dir, args):
dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts']) dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
if args.context_window_size>0: if args.context_window_size>0:
context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: ']) context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: '])
context = f'{dialogue_acts_seq}\n\n{context}' context = f'{dialogue_acts_seq}\n{context}'
else: else:
context = f'{dialogue_acts_seq}\n\n{sample["speaker"]}: ' context = f'{dialogue_acts_seq}\n{sample["speaker"]}: '
assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq)) assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq))
data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n') data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n')
......
def serialize_dialogue_state(state): def serialize_dialogue_state(state):
state_seqs = [] state_dict = {}
for domain in state: for domain in state:
for slot, value in state[domain].items(): for slot, value in sorted(state[domain].items()):
if len(value) > 0: if len(value) > 0:
state_seqs.append(f'[{domain}][{slot}][{value}]') state_dict.setdefault(f'[{domain}]', [])
state_dict[f'[{domain}]'].append(f'[{slot}][{value}]')
return ';'.join(state_seqs) return ';'.join([domain+'{'+','.join(slot_values)+'}' for domain, slot_values in state_dict.items()])
def deserialize_dialogue_state(state_seq): def deserialize_dialogue_state(state_seq):
state = {} state = {}
if len(state_seq) == 0: if len(state_seq) == 0:
return state return state
state_seqs = state_seq.split('];[') state_seqs = state_seq.split(']};[') # will consume "]}" and "["
for i, state_seq in enumerate(state_seqs): for i, state_seq in enumerate(state_seqs):
if len(state_seq) == 0: if len(state_seq) == 0 or len(state_seq.split(']{[')) != 2:
continue continue
if i == 0: if i == 0:
if state_seq[0] == '[': if state_seq[0] == '[':
state_seq = state_seq[1:] state_seq = state_seq[1:]
if i == len(state_seqs) - 1: if i == len(state_seqs) - 1:
if state_seq[-1] == ']': if state_seq[-2:] == ']}':
state_seq = state_seq[:-1] state_seq = state_seq[:-2]
s = state_seq.split('][')
if len(s) != 3: domain, slot_values = state_seq.split(']{[')
continue for slot_value in slot_values.split('],['):
domain, slot, value = s slot, value = slot_value.split('][')
state.setdefault(domain, {}) state.setdefault(domain, {})
state[domain][slot] = value state[domain][slot] = value
return state return state
......
def serialize_dialogue_acts(dialogue_acts): def serialize_dialogue_acts(dialogue_acts):
da_seqs = [] da_dict = {}
for da_type in dialogue_acts: for da_type in dialogue_acts:
for da in dialogue_acts[da_type]: for da in dialogue_acts[da_type]:
intent, domain, slot = da['intent'], da['domain'], da['slot'] intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da.get('value', '')
if da_type == 'binary': intent_domain = f'[{intent}][{domain}]'
da_seq = f'[{da_type}][{intent}][{domain}][{slot}]' da_dict.setdefault(intent_domain, [])
else: da_dict[intent_domain].append(f'[{slot}][{value}]')
value = da['value'] return ';'.join([intent_domain+'{'+','.join(slot_values)+'}' for intent_domain, slot_values in da_dict.items()])
da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]'
da_seqs.append(da_seq)
return ';'.join(da_seqs)
def deserialize_dialogue_acts(das_seq): def deserialize_dialogue_acts(das_seq):
dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} dialogue_acts = []
if len(das_seq) == 0: if len(das_seq) == 0:
return dialogue_acts return dialogue_acts
da_seqs = das_seq.split('];[') da_seqs = das_seq.split(']};[') # will consume "]}" and "["
for i, da_seq in enumerate(da_seqs): for i, da_seq in enumerate(da_seqs):
if len(da_seq) == 0: if len(da_seq) == 0 or len(da_seq.split(']{[')) != 2:
continue continue
if i == 0: if i == 0:
if da_seq[0] == '[': if da_seq[0] == '[':
da_seq = da_seq[1:] da_seq = da_seq[1:]
if i == len(da_seqs) - 1: if i == len(da_seqs) - 1:
if da_seq[-1] == ']': if da_seq[-2:] == ']}':
da_seq = da_seq[:-1] da_seq = da_seq[:-2]
da = da_seq.split('][')
if len(da) == 0: intent_domain, slot_values = da_seq.split(']{[')
continue intent, domain = intent_domain.split('][')
da_type = da[0] for slot_value in slot_values.split('],['):
if len(da) == 5 and da_type in ['categorical', 'non-categorical']: slot, value = slot_value.split('][')
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) dialogue_acts.append({'intent': intent, 'domain': domain, 'slot': slot, 'value': value})
elif len(da) == 4 and da_type == 'binary':
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]})
else:
# invalid da format, skip
# print(das_seq)
# print(da_seq)
# print()
pass
return dialogue_acts return dialogue_acts
def equal_da_seq(dialogue_acts, das_seq): def equal_da_seq(dialogue_acts, das_seq):
predict_dialogue_acts = deserialize_dialogue_acts(das_seq) predict_dialogue_acts = deserialize_dialogue_acts(das_seq)
for da_type in ['binary', 'categorical', 'non-categorical']: das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da_type in ['binary', 'categorical', 'non-categorical'] for da in dialogue_acts[da_type]])
das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]]) predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts])
predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]])
if das != predict_das: if das != predict_das:
return False return False
return True return True
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment