"""
==================
utility functions
==================
"""

import os
import sys
import time
import warnings
from typing import Union, Any
from multiprocessing import cpu_count

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

from easy_mpl import pie
from easy_mpl import hist
from easy_mpl import plot
from easy_mpl import regplot
from easy_mpl import bar_chart
from easy_mpl.utils import to_1d_array
from easy_mpl.utils import despine_axes
from easy_mpl.utils import AddMarginalPlots
from easy_mpl.utils import make_cols_from_cmap

import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import seaborn as sns

from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelEncoder

from SeqMetrics import RegressionMetrics

from ai4water.utils.utils import get_version_info

# %%
SAVE = False

LABEL_MAP = {
    'time (min)': "Time (min)",
    'qe': "Removal Efficiency (%)",
    'O': "Oxygen (%)",
    'Au': "Gold (%)",
    'Fe': "Iron (%)",
    'Bi': "Bismuth (%)",
    'solution pH': 'Solution pH',
    'Surface area': 'Surf. Area (m2/g)',
}

# %%

if not os.path.exists('figures'):
    os.makedirs('figures')

# %%

def hardware_info()->dict:
    mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')  # e.g. 4015976448
    mem_gib = mem_bytes / (1024. ** 3)  # e.g. 3.74
    return dict(
        tot_cpus=cpu_count(),
        avail_cpus = len(os.sched_getaffinity(0)),
        mem_gib=mem_gib,
    )

def print_version_info(
        include_run_time:bool=True,
        include_hardware_info:bool = True
):
    info = version_info()
    if include_run_time:

        info['Script Executed on: '] = time.asctime()

    if include_hardware_info:
        info.update(hardware_info())

    for k,v in info.items():
        print(k, v)
    return

# %%

def version_info()->dict:
    """
    returns the version info of packages being used
    """

    info = get_version_info()

    try:
        import bnlearn
        info['bnlearn'] = bnlearn.__version__
    except (ImportError, ModuleNotFoundError):
        pass

    return info

# %%

def distribution_plot(ax, 
                      data, 
                      scatter_fc='#045568',
                      box_facecolor='#e6e6e6',
                      ridge_lc = '#045568',
                      width=0.8,
                      add_hist=True,
                      add_ridge=True):

    sns.boxplot(orient='h', data=data, saturation=1, showfliers=False,
                width=width, boxprops={'zorder': 3, 'facecolor': box_facecolor}, ax=ax,
                )
    old_len_collections = len(ax.collections)

    for dots in ax.collections[old_len_collections:]:
        dots.set_offsets(dots.get_offsets() + np.array([0, 0.12]))

    ax = sns.stripplot(orient='h', x=data,
                       edgecolor="gray",
                       #linewidth=0.1,
                       alpha=0.5,
                       c=scatter_fc,
                       size=1.5,
                        ax=ax,
                       jitter=0.2)

    despine_axes(ax, keep=['bottom', 'left', 'right'])

    if add_hist or add_ridge:
        aa = AddMarginalPlots(ax=ax,
                              hist=add_hist,
                              ridge=add_ridge,
                              hist_kws=dict(bins=20, color=box_facecolor),
                              fill_kws=dict(color=box_facecolor),
                              ridge_line_kws=dict(color=ridge_lc))
        aa.divider = make_axes_locatable(ax)
        axHistx = aa.add_ax_marg_x(data.values.reshape(-1,), hist_kws=aa.HIST_KWS[0], ax=None)
        plt.setp(axHistx.get_xticklabels(),visible=False)

    return ax

# %%

def merge_uniques(
        series:pd.Series,
        n_to_keep:int=5,
        replace_with="Rest"
):
    counts = series.value_counts()

    values = []
    for idx, (value, count) in enumerate(counts.items()):
        if idx >= n_to_keep:
            values.append(value)

    series = series.replace(values, replace_with)
    return series

# %%

def pie_from_series(
        data:pd.Series,
        cmap="tab20",
        label_percent:bool = True,
        n_to_merge:int = None,
        leg_pos=None,
        show=True,
        fontsize=14
):

    d:pd.Series = data.value_counts()
    labels = d.index.tolist()
    vals = d.values
    colors = make_cols_from_cmap(cm=cmap, num_cols=len(vals))
    percent = 100. * vals / vals.sum()

    outs = pie(fractions=percent, autopct=None,
               colors=colors, show=False)
    patches, texts = outs

    if label_percent:
        labels = ['{0}: {1:1.2f} %'.format(i, j) for i, j in zip(labels, percent)]
    else:
        labels = ['{0} (n={1:4})'.format(i, j) for i, j in zip(labels, vals)]

    patches, labels, dummy = zip(*sorted(zip(patches, labels, vals),
                                         key=lambda x: x[2],
                                         reverse=True))

    plt.legend(patches, labels, bbox_to_anchor=leg_pos or (1.1, 1.),
               fontsize=fontsize)

    if show:
        plt.tight_layout()
        plt.show()
    return


def _ohe_column(df:pd.DataFrame, col_name:str) -> tuple:

    assert isinstance(col_name, str)

    encoder = OneHotEncoder(sparse_output=False)
    ohe_cat = encoder.fit_transform(df[col_name].values.reshape(-1,1))
    cols_added = [f"{col_name}_{i}" for i in range(ohe_cat.shape[-1])]

    df[cols_added] = ohe_cat

    df.pop(col_name)

    return df, cols_added, encoder


def prepare_data(
        encoding:str='le',
        exclude_cf:bool = True,
        )->tuple:

    fpath = os.path.join("../data/data.xlsx")
    df = pd.read_excel(fpath)

    df['Catalyst type'] = df['Catalyst type'].replace('commercial TiO2', 'TiO2')
    df['Catalyst type'] = df['Catalyst type'].replace('0.25 wt% Au-BFO', '0.25 wt%')
    df['Catalyst type'] = df['Catalyst type'].replace('0.5 wt% Au-BFO', '0.5 wt%')
    df['Catalyst type'] = df['Catalyst type'].replace('1 wt% Au-BFO', '1 wt%')
    df['Catalyst type'] = df['Catalyst type'].replace('2 wt% Au-BFO', '2 wt%')
    df['Catalyst type'] = df['Catalyst type'].replace('no catalyst', 'None')
    df['Catalyst type'] = df['Catalyst type'].replace('pure BFO', 'Pure BFO')

    if encoding == "ohe":
        df['Anions'] = df['Anions'].replace('Without Anions', 'No Anions')
    else:
        df['Anions'] = df['Anions'].replace('Without Anions', 'None')

    if encoding == "ohe":
        data, _, ct_encoder = _ohe_column(df, 'Catalyst type')
        data, _, anion_encoder = _ohe_column(data, 'Anions')
    else:
        data, ct_encoder = le_column(df, 'Catalyst type')
        data, anion_encoder = le_column(data, 'Anions')
    
    if exclude_cf:
        # Cf must not be used as input feature
        data.pop('Cf (mg/L)')

    #moving target to last
    target = data.pop('Efficiency (%)')
    data['Efficiency (%)'] = target

    return data, ct_encoder, anion_encoder


def le_column(df:pd.DataFrame, col_name)->tuple:
    """label encode a column in dataframe"""
    encoder = LabelEncoder()
    df[col_name] = encoder.fit_transform(df[col_name])
    return df, encoder


# %%

def evaluate_model(true, predicted):
    metrics = RegressionMetrics(true, predicted)
    for i in ['mse', 'rmse', 'r2', 'r2_score', 'mape', 'mae']:
        print(i, getattr(metrics, i)())
    return

# %%

def regression_plot(
        train_true,
        train_pred,
        test_true,
        test_pred,
        label = 'qe',
        max_xtick_val = None,
        max_ytick_val = None,
        min_xtick_val=None,
        min_ytick_val=None,
        max_ticks = 5,
        train_color = '#9acad4',
        test_color="#fe8977",
        show=False,
        ax=None,
)->plt.Axes:
    TRAIN_RIDGE_LINE_KWS = [{'color': train_color, 'lw': 1.0},
                            {'color': train_color, 'lw': 1.0}]
    TRAIN_HIST_KWS = [{'color': train_color, 'bins': 50},
                      {'color': train_color, 'bins': 50}]

    ax = regplot(train_true, train_pred,
                 marker_size=35,
                 marker_color=train_color,
                 line_color='k',
                 fill_color='k',
                 scatter_kws={'edgecolors': 'black',
                              'linewidth': 0.7,
                              'alpha': 0.9,
                              },
                 label="Training",
                 show=False,
                 ax=ax,
                 )

    axHistx, axHisty = AddMarginalPlots(
        ax,
        ridge=False,
        pad=0.25,
        size=0.7,
        ridge_line_kws=TRAIN_RIDGE_LINE_KWS,
        hist_kws=TRAIN_HIST_KWS
    )(train_true, train_pred)

    train_r2 = RegressionMetrics(train_true, train_pred).r2()
    test_r2 = RegressionMetrics(test_true, test_pred).r2()
    ax.annotate(f'Training $R^2$= {round(train_r2, 2)}',
                xy=(0.95, 0.30),
                xycoords='axes fraction',
                horizontalalignment='right',
                verticalalignment='top',
                fontsize=12, weight="bold")
    ax.annotate(f'Test $R^2$= {round(test_r2, 2)}',
                xy=(0.95, 0.20),
                xycoords='axes fraction',
                horizontalalignment='right', verticalalignment='top',
                fontsize=12, weight="bold")

    ax_ = regplot(test_true, test_pred,
                  marker_size=35,
                  marker_color=test_color,
                  line_style=None,
                  scatter_kws={'edgecolors': 'black',
                               'linewidth': 0.7,
                               'alpha': 0.9,
                               },
                  show=False,
                  label="Test",
                  ax=ax
                  )

    ax_.legend(fontsize=12, prop=dict(weight="bold"))
    TEST_RIDGE_LINE_KWS = [{'color': test_color, 'lw': 1.0},
                           {'color': test_color, 'lw': 1.0}]
    TEST_HIST_KWS = {'color': test_color, 'bins': 50}
    AddMarginalPlots(
        ax,
        ridge=False,
        pad=0.25,
        size=0.7,
        ridge_line_kws=TEST_RIDGE_LINE_KWS,
        hist_kws=TEST_HIST_KWS
    )(test_true, test_pred, axHistx, axHisty)

    set_xticklabels(
        ax_,
        max_xtick_val=max_xtick_val,
        min_xtick_val=min_xtick_val,
        max_ticks=max_ticks,
    )
    set_yticklabels(
        ax_,
        max_ytick_val=max_ytick_val,
        min_ytick_val=min_ytick_val,
        max_ticks=max_ticks
    )
    ax.set_xlabel(f"Experimental {label}")
    ax.set_ylabel(f"Predicted {label}")

    if show:
        plt.show()
    return ax

# %%


def set_ticks(axes:plt.Axes, which="x", size=12):
    ticks = getattr(axes, f"get_{which}ticks")()
    ticks = np.array(ticks)

    if 'float' in ticks.dtype.name:
        ticks = np.round(ticks, 2)
    else:
        ticks = ticks.astype(int)

    getattr(axes, f"set_{which}ticklabels")(ticks, size=size, weight="bold")
    return

def set_xticklabels(
        ax:plt.Axes,
        max_ticks:Union[int, Any] = 5,
        dtype = int,
        weight = "bold",
        fontsize:Union[int, float]=12,
        max_xtick_val=None,
        min_xtick_val=None,
):
    """

    :param ax:
    :param max_ticks:
        maximum number of ticks, if not set, all the default ticks will be used
    :param dtype:
    :param weight:
    :param fontsize:
    :param max_xtick_val:
        maxikum value of tick
    :param min_xtick_val:
    :return:
    """
    return set_ticklabels(ax, "x", max_ticks, dtype, weight, fontsize,
                          max_tick_val=max_xtick_val,
                          min_tick_val=min_xtick_val)


def set_yticklabels(
        ax:plt.Axes,
        max_ticks:Union[int, Any] = 5,
        dtype=int,
        weight="bold",
        fontsize:int=12,
        max_ytick_val = None,
        min_ytick_val = None
):
    return set_ticklabels(
        ax, "y", max_ticks, dtype, weight,
        fontsize=fontsize,
        max_tick_val=max_ytick_val,
        min_tick_val=min_ytick_val,
    )


def set_ticklabels(
        ax:plt.Axes,
        which:str = "x",
        max_ticks:int = 5,
        dtype=int,
        weight="bold",
        fontsize:int=12,
        max_tick_val = None,
        min_tick_val = None,
):
    ticks_ = getattr(ax, f"get_{which}ticks")()
    ticks = np.array(ticks_)
    if len(ticks)<1:
        warnings.warn(f"can not get {which}ticks {ticks_}")
        return

    if max_ticks:
        ticks = np.linspace(min_tick_val or min(ticks), max_tick_val or max(ticks), max_ticks)

    ticks = ticks.astype(dtype)

    getattr(ax, f"set_{which}ticks")(ticks)

    getattr(ax, f"set_{which}ticklabels")(ticks, weight=weight, fontsize=fontsize)
    return ax

# %%

def residual_plot(
        train_true,
        train_prediction,
        test_true,
        test_prediction,
        label='',
        train_color = "#fe8977",
        test_color = "#9acad4",
        show:bool = False,
        axis = None,
)->np.ndarray:

    train_true = to_1d_array(train_true)
    train_prediction = to_1d_array(train_prediction)
    test_true = to_1d_array(test_true)
    test_prediction = to_1d_array(test_prediction)

    if axis is None:
        fig, axis = plt.subplots(1, 2, sharey="all"
                                , gridspec_kw={'width_ratios': [2, 1]})
        
    test_y = test_true.reshape(-1, ) - test_prediction.reshape(-1, )
    train_y = train_true.reshape(-1, ) - train_prediction.reshape(-1, )
    train_hist_kws = dict(bins=20, linewidth=0.7,
                          edgecolor="k", grid=False, color=train_color,
                          orientation='horizontal')
    hist(train_y, show=False, ax=axis[1],
         label="Training", **train_hist_kws)
    plot(train_prediction, train_y, 'o', show=False,
         ax=axis[0],
         color=train_color,
         markerfacecolor=train_color,
         markeredgecolor="black", markeredgewidth=0.7,
         alpha=0.9, label="Training"
         )

    _hist_kws = dict(bins=40, linewidth=0.7,
                     edgecolor="k", grid=False,
                     color=test_color,
                     orientation='horizontal')
    hist(test_y, show=False, ax=axis[1],
         **_hist_kws)

    set_xticklabels(axis[1], 3)

    plot(test_prediction, test_y, 'o', show=False,
         ax=axis[0],
         color=test_color,
         markerfacecolor=test_color,
         markeredgecolor="black", markeredgewidth=0.7,
         ax_kws=dict(
             #xlabel=f"Predicted {label}",
             #ylabel="Residual",
             legend_kws=dict(loc="upper left"),
         ),
         alpha=0.9, label="Test",
         )

    axis[0].set_xlabel(f'Predicted {label}', fontsize=14)
    axis[0].set_ylabel('Residual', fontsize=14)
    set_yticklabels(axis[0], 5)
    axis[0].axhline(0.0, color="black", ls="--")
    plt.subplots_adjust(wspace=0.15)

    if show:
       plt.show()
    return axis

# %%

def set_rcParams(kwargs:dict = None):

    _kwargs = {'axes.labelsize': '20',
               'axes.labelweight': 'bold',
               'xtick.labelsize': '18',
               'ytick.labelsize': '18',
               'font.weight': 'bold',
               'legend.title_fontsize': '12',
               'axes.titleweight': 'bold',
               'axes.titlesize': '22',
               #"font.family" : "Times New Roman"
               }

    if sys.platform == "linux":

        _kwargs['font.family'] = 'serif'
        _kwargs['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
    else:
        _kwargs['font.family'] = "Times New Roman"

    if kwargs:
        _kwargs.update(kwargs)

    for k,v in _kwargs.items():
        plt.rcParams[k] = v
    
    return

# %%


def bar_pie(
        data:pd.DataFrame,
        ax:plt.Axes=None,
        save:bool = True,
        name:str = '',
        show:bool = True,
        pie_pos = 5,
):
    if ax is None:
        f, ax = plt.subplots(figsize=(9, 9))

    sv_bar = data['mean_shap'].copy()
    colors = data['colors'].unique()
    feature_names = data['features'].tolist()
    ax.spines[['top', 'right']].set_visible(False)

    ax_ = bar_chart(
        sv_bar,
        [LABEL_MAP[n] if n in LABEL_MAP else n for n in feature_names],
        bar_labels=np.round(sv_bar, 2),
        bar_label_kws={'label_type': 'edge',
                       'fontsize': 10,
                       'weight': 'bold',
                       "fmt": '%.4f',
                       'padding': 1.5
                       },
        show=False,
        sort=True,
        color=data['colors'].to_list(),
        ax=ax
    )
    ax_.spines[['top', 'right']].set_visible(True)
    ax_.set_xlabel(xlabel='mean(|SHAP value|)')
    # ax.set_xticklabels(ax.get_xticks().astype(float))
    ax_.set_yticklabels(ax_.get_yticklabels())

    labels = data['classes'].unique()
    handles = [plt.Rectangle((0, 0), 1, 1,
                             color=colors[idx]) for idx, l in enumerate(labels)]
    ax_.legend(handles, labels, loc='lower right', facecolor="white")
    ax_.xaxis.set_major_locator(plt.MaxNLocator(4))
    ax_.set_facecolor('white')


    seg_colors = tuple(colors)
    # Change the saturation of seg_colors to 70% for the interior segments
    rgb = mcolors.to_rgba_array(seg_colors)[:, :-1]
    hsv = mcolors.rgb_to_hsv(rgb)
    hsv[:, 1] = 0.95 * hsv[:, 1]
    interior_colors = mcolors.hsv_to_rgb(hsv)

    labels = data['classes'].unique().tolist()

    fractions = []
    for label in labels:
        fractions.append(data.loc[data['classes'] == label]['mean_shap'].sum())

    fractions = np.array(fractions)

    fractions /= fractions.sum()

    for label, fraction in zip(labels, fractions):
        print(label, fraction)

    ax2 = inset_axes(ax, width='45%', height='45%',
                     loc=pie_pos)

    # outer circle/ring
    pie1_out = pie(fractions=fractions,
                   colors=seg_colors,
                   #labels=labels,
                   wedgeprops=dict(edgecolor="w", width=0.03),
                   radius=1,
                   autopct=None,
                   textprops=dict(fontsize=12),
                   startangle=90, 
                   counterclock=False, 
                   show=False,
                   ax=ax2)

    # inner pie
    pie2_out = pie(fractions=fractions,
                colors=interior_colors,
                autopct='%1.0f%%',
                textprops=dict(fontsize=24),
                wedgeprops=dict(edgecolor="w"),
                pctdistance = 1.3,
                radius=1 - 2 * 0.03,
                startangle=90,
                counterclock=False, 
                ax=ax2, 
                show=False
                )
    
    if save or SAVE:
        plt.savefig(f"figures/shap_bar_{name}.png",
                        dpi=600,
                        bbox_inches="tight")

    if show:
        print('showing')
        plt.tight_layout()
        plt.show()

    return ax, pie1_out, pie2_out
