# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re

from utils_dst import (DSTExample, convert_to_unicode)


LABEL_MAPS = {} # Loaded from file
LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'}


def delex_utt(utt, values):
    utt_norm = utt.copy()
    for s, v in values.items():
        if v != 'none':
            v_norm = tokenize(v)
            v_len = len(v_norm)
            for i in range(len(utt_norm) + 1 - v_len):
                if utt_norm[i:i + v_len] == v_norm:
                    utt_norm[i:i + v_len] = ['[UNK]'] * v_len
    return utt_norm


def get_token_pos(tok_list, label):
    find_pos = []
    found = False
    label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0]
    len_label = len(label_list)
    for i in range(len(tok_list) + 1 - len_label):
        if tok_list[i:i + len_label] == label_list:
            find_pos.append((i, i + len_label))  # start, exclusive_end
            found = True
    return found, find_pos


def check_label_existence(label, usr_utt_tok, sys_utt_tok):
    in_usr, usr_pos = get_token_pos(usr_utt_tok, label)
    if not in_usr and label in LABEL_MAPS:
        for tmp_label in LABEL_MAPS[label]:
            in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label)
            if in_usr:
                break
    in_sys, sys_pos = get_token_pos(sys_utt_tok, label)
    if not in_sys and label in LABEL_MAPS:
        for tmp_label in LABEL_MAPS[label]:
            in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label)
            if in_sys:
                break
    return in_usr, usr_pos, in_sys, sys_pos


def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence):
    usr_utt_tok_label = [0 for _ in usr_utt_tok]
    if label == 'none' or label == 'dontcare':
        class_type = label
    else:
        in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
        if in_usr:
            class_type = 'copy_value'
            if slot_last_occurrence:
                (s, e) = usr_pos[-1]
                for i in range(s, e):
                    usr_utt_tok_label[i] = 1
            else:
                for (s, e) in usr_pos:
                    for i in range(s, e):
                        usr_utt_tok_label[i] = 1
        elif in_sys:
            class_type = 'inform'
        else:
            class_type = 'unpointable'
    return usr_utt_tok_label, class_type


def tokenize(utt):
    utt_lower = convert_to_unicode(utt).lower()
    utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
    return utt_tok


def create_examples(input_file, set_type, slot_list,
                    label_maps={},
                    append_history=False,
                    use_history_labels=False,
                    swap_utterances=False,
                    label_value_repetitions=False,
                    delexicalize_sys_utts=False,
                    analyze=False):
    """Read a DST json file into a list of DSTExample."""
    with open(input_file, "r", encoding='utf-8') as reader:
        input_data = json.load(reader)

    global LABEL_MAPS
    LABEL_MAPS = label_maps

    examples = []
    for entry in input_data:
        diag_seen_slots_dict = {}
        diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
        diag_state = {slot: 'none' for slot in slot_list}
        sys_utt_tok = []
        sys_utt_tok_delex = []
        usr_utt_tok = []
        hst_utt_tok = []
        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
        for turn in entry['dialogue']:
            sys_utt_tok_label_dict = {}
            usr_utt_tok_label_dict = {}
            inform_dict = {slot: 'none' for slot in slot_list}
            inform_slot_dict = {slot: 0 for slot in slot_list}
            referral_dict = {}
            class_type_dict = {}

            # Collect turn data
            if append_history:
                if swap_utterances:
                    if delexicalize_sys_utts:
                        hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok
                    else:
                        hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
                else:
                    if delexicalize_sys_utts:
                        hst_utt_tok = sys_utt_tok_delex + usr_utt_tok + hst_utt_tok
                    else:
                        hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok

            sys_utt_tok = tokenize(turn['system_transcript'])
            usr_utt_tok = tokenize(turn['transcript'])
            turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']}

            guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx']))

            # Create delexicalized sys utterances.
            if delexicalize_sys_utts:
                delex_dict = {}
                for slot in slot_list:
                    delex_dict[slot] = 'none'
                    label = 'none'
                    if slot in turn_label:
                        label = turn_label[slot]
                    elif label_value_repetitions and slot in diag_seen_slots_dict:
                        label = diag_seen_slots_value_dict[slot]
                    if label != 'none' and label != 'dontcare':
                        _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
                        if in_sys:
                            delex_dict[slot] = label
                sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict)

            new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
            new_diag_state = diag_state.copy()
            for slot in slot_list:
                label = 'none'
                if slot in turn_label:
                    label = turn_label[slot]
                elif label_value_repetitions and slot in diag_seen_slots_dict:
                    label = diag_seen_slots_value_dict[slot]

                (usr_utt_tok_label,
                 class_type) = get_turn_label(label,
                                              sys_utt_tok,
                                              usr_utt_tok,
                                              slot_last_occurrence=True)

                if class_type == 'inform':
                    inform_dict[slot] = label
                    if label != 'none':
                        inform_slot_dict[slot] = 1

                referral_dict[slot] = 'none' # Referral is not present in woz2 data

                # Generally don't use span prediction on sys utterance (but inform prediction instead).
                if delexicalize_sys_utts:
                    sys_utt_tok_label = [0 for _ in sys_utt_tok_delex]
                else:
                    sys_utt_tok_label = [0 for _ in sys_utt_tok]

                # Determine what to do with value repetitions.
                # If value is unique in seen slots, then tag it, otherwise not,
                # since correct slot assignment can not be guaranteed anymore.
                if label_value_repetitions and slot in diag_seen_slots_dict:
                    if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1:
                        class_type = 'none'
                        usr_utt_tok_label = [0 for _ in usr_utt_tok_label]

                sys_utt_tok_label_dict[slot] = sys_utt_tok_label
                usr_utt_tok_label_dict[slot] = usr_utt_tok_label

                if append_history:
                    if use_history_labels:
                        if swap_utterances:
                            new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
                        else:
                            new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
                    else:
                        new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]

                # For now, we map all occurences of unpointable slot values
                # to none. However, since the labels will still suggest
                # a presence of unpointable slot values, the task of the
                # DST is still to find those values. It is just not
                # possible to do that via span prediction on the current input.
                if class_type == 'unpointable':
                    class_type_dict[slot] = 'none'
                elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
                    # If slot has seen before and its class type did not change, label this slot a not present,
                    # assuming that the slot has not actually been mentioned in this turn.
                    # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
                    # this must mean there is evidence in the original labels, therefore consider
                    # them as mentioned again.
                    class_type_dict[slot] = 'none'
                    referral_dict[slot] = 'none'
                else:
                    class_type_dict[slot] = class_type
                # Remember that this slot was mentioned during this dialog already.
                if class_type != 'none':
                    diag_seen_slots_dict[slot] = class_type
                    diag_seen_slots_value_dict[slot] = label
                    new_diag_state[slot] = class_type
                    # Unpointable is not a valid class, therefore replace with
                    # some valid class for now...
                    if class_type == 'unpointable':
                        new_diag_state[slot] = 'copy_value'

            if swap_utterances:
                txt_a = usr_utt_tok
                if delexicalize_sys_utts:
                    txt_b = sys_utt_tok_delex
                else:
                    txt_b = sys_utt_tok
                txt_a_lbl = usr_utt_tok_label_dict
                txt_b_lbl = sys_utt_tok_label_dict
            else:
                if delexicalize_sys_utts:
                    txt_a = sys_utt_tok_delex
                else:
                    txt_a = sys_utt_tok
                txt_b = usr_utt_tok
                txt_a_lbl = sys_utt_tok_label_dict
                txt_b_lbl = usr_utt_tok_label_dict
            examples.append(DSTExample(
                guid=guid,
                text_a=txt_a,
                text_b=txt_b,
                history=hst_utt_tok,
                text_a_label=txt_a_lbl,
                text_b_label=txt_b_lbl,
                history_label=hst_utt_tok_label_dict,
                values=diag_seen_slots_value_dict.copy(),
                inform_label=inform_dict,
                inform_slot_label=inform_slot_dict,
                refer_label=referral_dict,
                diag_state=diag_state,
                class_label=class_type_dict))

            # Update some variables.
            hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
            diag_state = new_diag_state.copy()
            
    return examples