Skip to content
Snippets Groups Projects
Select Git revision
  • 2d5229ef81969488565b7c1a04c3d630488134ae
  • master default protected
  • update-goal-generator
  • add-overrides==4.1.2
  • eval-v1
  • eval-v2
  • dev
7 results

utils.py

Blame
  • user avatar
    function2 authored
    2d5229ef
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    utils.py 4.93 KiB
    import os
    import json
    import zipfile
    
    from convlab2 import DATA_ROOT
    
    
    def prepare_data(subtask, split, data_root=DATA_ROOT):
        data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en')
        zip_filename = os.path.join(data_dir, f'{split}.json.zip')
        test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
        data = {}
        if subtask == 'multiwoz':
            for dialog_id, dialog in test_data.items():
                dialog_data = []
                turns = dialog['log']
                for i in range(0, len(turns), 2):
                    sys_utt = turns[i - 1]['text'] if i else None
                    user_utt = turns[i]['text']
                    state = {}
                    for domain_name, domain in turns[i + 1]['metadata'].items():
                        if domain_name in ['警察机关', '医院']:
                            continue
                        domain_state = {}
                        for slots in domain.values():
                            for slot_name, value in slots.items():
                                domain_state[slot_name] = value
                        state[domain_name] = domain_state
                    dialog_data.append((sys_utt, user_utt, state))
                data[dialog_id] = dialog_data
        else:
            for dialog_id, dialog in test_data.items():
                dialog_data = []
                turns = dialog['messages']
                for i in range(0, len(turns), 2):
                    sys_utt = turns[i - 1]['content'] if i else None
                    user_utt = turns[i]['content']
                    state = {}
                    for domain_name, domain in turns[i + 1]['sys_state_init'].items():
                        domain_state = {}
                        for slot_name, value in domain.items():
                            if slot_name == 'selectedResults':
                                continue
                            domain_state[slot_name] = value
                        state[domain_name] = domain_state
                    dialog_data.append((sys_utt, user_utt, state))
                data[dialog_id] = dialog_data
    
        return data
    
    
    def extract_gt(test_data):
        gt = {
            dialog_id: [state for _, _, state in turns]
            for dialog_id, turns in test_data.items()
        }
        return gt
    
    
    def eval_states(gt, pred, subtask):
        # for unifying values with the same meaning to the same expression
        value_unifier = {
            'multiwoz': {
    
            },
            'crosswoz': {
                '未提及': '',
            }
        }[subtask]
    
        def unify_value(value):
            return value_unifier.get(value, value)
    
        def exception(description, **kargs):
            ret = {
                'status': 'exception',
                'description': description,
            }
            for k, v in kargs.items():
                ret[k] = v
            return ret
    
        joint_acc, joint_tot = 0, 0
        slot_acc, slot_tot = 0, 0
        tp, fp, fn = 0, 0, 0
        for dialog_id, gt_states in gt.items():
            if dialog_id not in pred:
                return exception('some dialog not found', dialog_id=dialog_id)
    
            pred_states = pred[dialog_id]
            if len(gt_states) != len(pred_states):
                return exception(f'turns number incorrect, {len(gt_states)} expected, {len(pred_states)} found', dialog_id=dialog_id)
    
            for turn_id, (gt_state, pred_state) in enumerate(zip(gt_states, pred_states)):
                joint_tot += 1
                turn_result = True
                for domain_name, gt_domain in gt_state.items():
                    if domain_name not in pred_state:
                        return exception('domain missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name)
    
                    pred_domain = pred_state[domain_name]
                    for slot_name, gt_value in gt_domain.items():
                        if slot_name not in pred_domain:
                            return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
                        gt_value = unify_value(gt_value)
                        pred_value = unify_value(pred_domain[slot_name])
                        slot_tot += 1
                        if gt_value == pred_value:
                            slot_acc += 1
                            if gt_value:
                                tp += 1
                        else:
                            turn_result = False
                            if gt_value:
                                fn += 1
                            if pred_value:
                                fp += 1
                joint_acc += turn_result
    
        precision = tp / (tp + fp) if tp + fp else 1
        recall = tp / (tp + fn) if tp + fn else 1
        f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 1
        return {
            'status': 'ok',
            'joint accuracy': joint_acc / joint_tot,
            # 'slot accuracy': slot_acc / slot_tot,
            'slot': {
                'accuracy': slot_acc / slot_tot,
                'precision': precision,
                'recall': recall,
                'f1': f1,
            }
        }
    
    
    def get_subdir(subtask):
        subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
        return subdir