diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index 71fea81e73969f74c5e962445d9143cf38e722d0..cc9e651291bcaf2284750901c8e3c5f386b0c43f 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -80,9 +80,9 @@ def create_nlg_data(dataset, data_dir, args): dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts']) if args.context_window_size>0: context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: ']) - context = f'{dialogue_acts_seq}\n\n{context}' + context = f'{dialogue_acts_seq}\n{context}' else: - context = f'{dialogue_acts_seq}\n\n{sample["speaker"]}: ' + context = f'{dialogue_acts_seq}\n{sample["speaker"]}: ' assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq)) data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n') diff --git a/convlab2/base_models/t5/dst/serialization.py b/convlab2/base_models/t5/dst/serialization.py index 6ccf25ae76048a11834566163591366ae5cdc61a..96cc0beb105b5c9ae22aeffb02881a28e0a443ae 100644 --- a/convlab2/base_models/t5/dst/serialization.py +++ b/convlab2/base_models/t5/dst/serialization.py @@ -1,32 +1,32 @@ def serialize_dialogue_state(state): - state_seqs = [] + state_dict = {} for domain in state: - for slot, value in state[domain].items(): + for slot, value in sorted(state[domain].items()): if len(value) > 0: - state_seqs.append(f'[{domain}][{slot}][{value}]') - - return ';'.join(state_seqs) + state_dict.setdefault(f'[{domain}]', []) + state_dict[f'[{domain}]'].append(f'[{slot}][{value}]') + return ';'.join([domain+'{'+','.join(slot_values)+'}' for domain, slot_values in state_dict.items()]) def deserialize_dialogue_state(state_seq): state = {} if len(state_seq) == 0: return state - state_seqs = state_seq.split('];[') + state_seqs = state_seq.split(']};[') # will consume "]}" and "[" for i, state_seq in enumerate(state_seqs): - if len(state_seq) == 0: + if len(state_seq) == 0 or len(state_seq.split(']{[')) != 2: continue if i == 0: if state_seq[0] == '[': state_seq = state_seq[1:] if i == len(state_seqs) - 1: - if state_seq[-1] == ']': - state_seq = state_seq[:-1] - s = state_seq.split('][') - if len(s) != 3: - continue - domain, slot, value = s - state.setdefault(domain, {}) - state[domain][slot] = value + if state_seq[-2:] == ']}': + state_seq = state_seq[:-2] + + domain, slot_values = state_seq.split(']{[') + for slot_value in slot_values.split('],['): + slot, value = slot_value.split('][') + state.setdefault(domain, {}) + state[domain][slot] = value return state def equal_state_seq(state, state_seq): diff --git a/convlab2/base_models/t5/nlu/serialization.py b/convlab2/base_models/t5/nlu/serialization.py index 5a620f4689519accaccdc1149a54ed6c8efb52d8..7c9a764fb9e4d04ec4036722448b4a1a00636a35 100644 --- a/convlab2/base_models/t5/nlu/serialization.py +++ b/convlab2/base_models/t5/nlu/serialization.py @@ -1,51 +1,39 @@ def serialize_dialogue_acts(dialogue_acts): - da_seqs = [] + da_dict = {} for da_type in dialogue_acts: for da in dialogue_acts[da_type]: - intent, domain, slot = da['intent'], da['domain'], da['slot'] - if da_type == 'binary': - da_seq = f'[{da_type}][{intent}][{domain}][{slot}]' - else: - value = da['value'] - da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]' - da_seqs.append(da_seq) - return ';'.join(da_seqs) + intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da.get('value', '') + intent_domain = f'[{intent}][{domain}]' + da_dict.setdefault(intent_domain, []) + da_dict[intent_domain].append(f'[{slot}][{value}]') + return ';'.join([intent_domain+'{'+','.join(slot_values)+'}' for intent_domain, slot_values in da_dict.items()]) def deserialize_dialogue_acts(das_seq): - dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} + dialogue_acts = [] if len(das_seq) == 0: return dialogue_acts - da_seqs = das_seq.split('];[') + da_seqs = das_seq.split(']};[') # will consume "]}" and "[" for i, da_seq in enumerate(da_seqs): - if len(da_seq) == 0: + if len(da_seq) == 0 or len(da_seq.split(']{[')) != 2: continue if i == 0: if da_seq[0] == '[': da_seq = da_seq[1:] if i == len(da_seqs) - 1: - if da_seq[-1] == ']': - da_seq = da_seq[:-1] - da = da_seq.split('][') - if len(da) == 0: - continue - da_type = da[0] - if len(da) == 5 and da_type in ['categorical', 'non-categorical']: - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) - elif len(da) == 4 and da_type == 'binary': - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) - else: - # invalid da format, skip - # print(das_seq) - # print(da_seq) - # print() - pass + if da_seq[-2:] == ']}': + da_seq = da_seq[:-2] + + intent_domain, slot_values = da_seq.split(']{[') + intent, domain = intent_domain.split('][') + for slot_value in slot_values.split('],['): + slot, value = slot_value.split('][') + dialogue_acts.append({'intent': intent, 'domain': domain, 'slot': slot, 'value': value}) return dialogue_acts def equal_da_seq(dialogue_acts, das_seq): predict_dialogue_acts = deserialize_dialogue_acts(das_seq) - for da_type in ['binary', 'categorical', 'non-categorical']: - das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]]) - predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]]) - if das != predict_das: - return False + das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da_type in ['binary', 'categorical', 'non-categorical'] for da in dialogue_acts[da_type]]) + predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts]) + if das != predict_das: + return False return True