import argparse
import json
import os
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tensorboard.backend.event_processing import event_accumulator
from tqdm import tqdm


def get_args():
    parser = argparse.ArgumentParser(description='Export tensorboard data')
    parser.add_argument('--dir', type=str, default="results",
                        help='Root dir for tensorboard files')
    parser.add_argument('--tb-dir', type=str, default="TB_summary",
                        help='The last dir for tensorboard files')
    parser.add_argument("--map-file", type=str, default="results/map.json")
    parser.add_argument("--out-file", type=str, default="results/fig")
    parser.add_argument("--max-dialogues", type=int, default=0)
    parser.add_argument("--fill-between", type=float, default=0.3,
                        help="the transparency of the std err area")

    args = parser.parse_args()
    return args


def read_data(exp_dir, tb_dir, map_file):
    f_map = json.load(open(map_file))
    data = {}
    for m in f_map:
        data[m["legend"]] = read_dir(exp_dir, tb_dir, m["dir"])
    return data


def read_dir(exp_dir, tb_dir, method_dir):
    dfs = []
    for dir_name in tqdm(glob(os.path.join(exp_dir, method_dir, "*")), ascii=True, desc=method_dir):
        df = read_tb_data(os.path.join(dir_name, tb_dir))
        dfs.append(df)
    return dfs


def read_tb_data(in_path):
    # load log data
    event_data = event_accumulator.EventAccumulator(in_path)
    event_data.Reload()
    keys = event_data.scalars.Keys()
    df = pd.DataFrame(columns=keys[1:])
    for key in keys:
        w_times, step_nums, vals = zip(*event_data.Scalars(key))
        df[key] = vals
        df['steps'] = step_nums
    return df


def plot(data, out_file, plot_type="complete_rate", show_image=False, fill_between=0.3, max_dialogues=0, y_label=''):

    legends = [alg for alg in data]
    clrs = sns.color_palette("husl", len(legends))
    plt.figure(plot_type)

    with sns.axes_style("darkgrid"):
        for i, alg in enumerate(legends):

            max_step = min([len(d[plot_type]) for d in data[alg]])
            if max_dialogues > 0:
                max_length = min([len([s for s in d['steps'] if s <= max_dialogues]) for d in data[alg]])
                max_step = min([max_length, max_step])
            print("max_step: ", max_step)

            value = np.array([d[plot_type][:max_step] for d in data[alg]])
            step = np.array([d['steps'][:max_step] for d in data[alg]][0])
            mean, err = np.mean(value, axis=0), np.std(value, axis=0)
            plt.plot(
                step, mean, c=clrs[i], label=alg)

            plt.fill_between(
                step, mean - err,
                mean + err, alpha=fill_between, facecolor=clrs[i])
        # locs, labels = plt.xticks()
        # plt.xticks(locs, labels)
        #plt.yticks(np.arange(10) / 10)
        #plt.yticks([0.5, 0.6, 0.7])
        plt.xlabel('Training dialogues')
        if len(y_label) > 0:
            plt.ylabel(y_label)
        else:
            plt.ylabel(plot_type)
        plt.legend(fancybox=True, shadow=False, ncol=1, loc='lower left')
        plt.savefig(out_file, bbox_inches='tight')

        if show_image:
            plt.show()


if __name__ == "__main__":
    args = get_args()

    y_label_dict = {"complete_rate": 'Complete rate', "success_rate": 'Success rate', 'turns': 'Average turns',
                    'avg_return': 'Average Return'}

    for plot_type in ["complete_rate", "success_rate", 'turns', 'avg_return']:
        file_name, file_extension = os.path.splitext(args.out_file)
        fig_name = f"{file_name}_{plot_type}{file_extension}"

        data = read_data(exp_dir=args.dir, tb_dir=args.tb_dir,
                         map_file=args.map_file)
        plot(data=data,
             out_file=fig_name,
             plot_type=plot_type,
             fill_between=args.fill_between,
             max_dialogues=args.max_dialogues,
             y_label=y_label_dict[plot_type])