import json
import os
import zipfile

from convlab2 import DATA_ROOT


def get_subdir(subtask):
    subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
    return subdir


def prepare_data(subtask, split, data_root=DATA_ROOT, correct_name_label=False):
    data_dir = os.path.join(data_root, get_subdir(subtask))
    zip_filename = os.path.join(data_dir, f'{split}.json.zip')
    test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
    data = {}
    if subtask == 'multiwoz':
        for dialog_id, dialog in test_data.items():
            dialog_data = []
            turns = dialog['log']
            for i in range(0, len(turns), 2):
                sys_utt = turns[i - 1]['text'] if i else None
                user_utt = turns[i]['text']
                dialog_state = {}
                for domain_name, domain in turns[i + 1]['metadata'].items():
                    if domain_name in ['警察机关', '医院', '公共汽车']:
                        continue
                    state = {}
                    for slots in domain.values():
                        for slot_name, value in slots.items():
                            state[slot_name] = value
                    dialog_state[domain_name] = state
                dialog_data.append((sys_utt, user_utt, dialog_state))
            data[dialog_id] = dialog_data
    else:
        for dialog_id, dialog in test_data.items():
            dialog_data = []
            turns = dialog['messages']
            if correct_name_label:
                selected_results = {domain_name: [] for domain_name in turns[1]['sys_state_init']}
            for i in range(0, len(turns), 2):
                sys_utt = turns[i - 1]['content'] if i else None
                user_utt = turns[i]['content']
                dialog_state = {}
                for domain_name, state in turns[i + 1]['sys_state_init'].items():
                    if correct_name_label:
                        state.pop('selectedResults')
                        sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
                        # if state has changed compared to previous sys state
                        state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
                        # clear the outdated previous selected results if state has been updated
                        if state_change:
                            selected_results[domain_name].clear()
                        if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
                            state['name'] = selected_results[domain_name][0]
                        dialog_state[domain_name] = state
                        if state_change and sys_selected_results:
                            selected_results[domain_name] = sys_selected_results
                    else:
                        selected_results = state.pop('selectedResults')
                        if selected_results and 'name' in state and not state['name']:
                            state['name'] = selected_results
                        dialog_state[domain_name] = state
                dialog_data.append((sys_utt, user_utt, dialog_state))
            data[dialog_id] = dialog_data

    return data


def extract_gt(test_data):
    gt = {
        dialog_id: [state for _, _, state in turns]
        for dialog_id, turns in test_data.items()
    }
    return gt


# for unifying values with the same meaning to the same expression
def unify_value(value, subtask):
    if isinstance(value, list):
        for i, v in enumerate(value):
            value[i] = unify_value(v, subtask)
        return value

    value = value.lower()
    value = {
        'multiwoz': {
            '未提及': '',
            'none': '',
            '是的': '有',
            '不是': '没有',
        },
        'crosswoz': {
            'none': '',
            'free admission': 'free',
        }
    }[subtask].get(value, value)

    return ''.join(value.strip().split())


def eval_states(gt, pred, subtask):
    def exception(description, **kargs):
        ret = {
            'status': 'exception',
            'description': description,
        }
        for k, v in kargs.items():
            ret[k] = v
        return ret, None
    errors = [['dialog id', 'turn id', 'domain name', 'slot name', 'ground truth', 'predict']]

    joint_acc, joint_tot = 0, 0
    slot_acc, slot_tot = 0, 0
    tp, fp, fn = 0, 0, 0
    for dialog_id, gt_states in gt.items():
        if dialog_id not in pred:
            return exception('some dialog not found', dialog_id=dialog_id)

        pred_states = pred[dialog_id]
        if len(gt_states) != len(pred_states):
            return exception(f'turns number incorrect, {len(gt_states)} expected, {len(pred_states)} found', dialog_id=dialog_id)

        for turn_id, (gt_state, pred_state) in enumerate(zip(gt_states, pred_states)):
            joint_tot += 1
            turn_result = True
            for domain_name, gt_domain in gt_state.items():
                if domain_name not in pred_state:
                    return exception('domain missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name)

                pred_domain = pred_state[domain_name]
                for slot_name, gt_value in gt_domain.items():
                    if slot_name not in pred_domain:
                        return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
                    gt_value = unify_value(gt_value, subtask)
                    pred_value = unify_value(pred_domain[slot_name], subtask)
                    slot_tot += 1

                    if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
                        slot_acc += 1
                        if gt_value:
                            tp += 1
                    else:
                        errors.append([dialog_id, turn_id, domain_name, slot_name, gt_value, pred_value])
                        turn_result = False
                        if gt_value:
                            fn += 1
                        if pred_value:
                            fp += 1
            joint_acc += turn_result

    precision = tp / (tp + fp) if tp + fp else 1
    recall = tp / (tp + fn) if tp + fn else 1
    f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 1
    return {
        'status': 'ok',
        'joint accuracy': joint_acc / joint_tot,
        'slot': {
            'accuracy': slot_acc / slot_tot,
            'precision': precision,
            'recall': recall,
            'f1': f1,
        }
    }, errors


def dump_result(model_dir, filename, result, errors=None, pred=None):
    output_dir = os.path.join('../results', model_dir)
    os.makedirs(output_dir, exist_ok=True)
    json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False)
    if errors:
        import csv
        with open(os.path.join(output_dir, 'errors.csv'), 'w') as f:
            writer = csv.writer(f)
            writer.writerows(errors)
    if pred:
        json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False)