diff --git a/convlab2/dst/evaluate.py b/convlab2/dst/evaluate.py index 842dde702c2402470eee14ad2bbe8d672b10127f..04263a168dbf9c08c0ca3bf1c0e2b247730256b9 100755 --- a/convlab2/dst/evaluate.py +++ b/convlab2/dst/evaluate.py @@ -56,13 +56,24 @@ def reformat_state(state): state = state['belief_state'] new_state = [] for domain in state.keys(): - domain_data = state[domain] - if 'semi' in domain_data: - domain_data = domain_data['semi'] + domain_data_all = state[domain] + if 'semi' in domain_data_all: + domain_data = domain_data_all['semi'] for slot in domain_data.keys(): val = domain_data[slot] if val is not None and val not in ['', 'not mentioned', '未提及', '未提到', '没有提到']: new_state.append(domain + '-' + slot + '-' + val) + if 'book' in domain_data_all: + domain_data = domain_data_all['book'] + for slot in domain_data.keys(): + if slot == 'booked': + continue + elif domain == 'bus' and slot == 'people': + continue + else: + val = domain_data[slot] + if val is not None and val not in ['', 'not mentioned', '未提及', '未提到', '没有提到']: + new_state.append(domain+'_book' + '-' + slot + '-' + val) # lower new_state = [item.lower() for item in new_state] return new_state