"""
=========================
Exploratory Data Analysis
=========================
"""
import os
import pandas as pd

import matplotlib.pyplot as plt

from easy_mpl import imshow
from easy_mpl.utils import create_subplots

from mne.viz import circular_layout
from mne_connectivity.viz import plot_connectivity_circle

from utils import SAVE
from utils import LABEL_MAP
from utils import set_rcParams
from utils import distribution_plot
from utils import pie_from_series
from utils import merge_uniques
from utils import prepare_data
from utils import print_version_info

# %%

# Print the version info of the packages being used
print_version_info()

# %%

set_rcParams()

# %%
fpath = os.path.join("../data/data.xlsx")
df = pd.read_excel(fpath)

# %%

# Display the first 5 rows of the dataset
df.head()

# %%
# Display the last 5 rows of the dataset
df.tail()

# %%
# Display the shape of the dataset
df.shape

# %%
# Display the columns of the dataset
df.columns

# %%
# Display the info of the dataset
df.info()

# %%
# Display the summary statistics of the dataset
df.describe()

# %%
# Display the missing values of the dataset
df.isnull().sum()

# %%
# Display the duplicated rows of the dataset
df.duplicated().sum()

# %%
# numerical columns
num_columns = ['Surface area', 'Pore volume', 'BandGap (eV)', 'Au',
       'Bi', 'Fe', 'O', 'Catalyst loading (g/L)', 'Light intensity (W)',
       'time (min)', 'solution pH', 'Ci (mg/L)', 'Cf (mg/L)',
       'Efficiency (%)']

# %%
# Display the distribution of the numerical columns
for col in num_columns:
    print(f"{col}: {df[col].describe()}")


data_num = df[num_columns].copy()
# %%

fig, axes = create_subplots(data_num.shape[1]-2, figsize=(9, 8))
for ax, col in zip(axes.flat, data_num.columns):
    if col in ['Cf (mg/L)', 'Efficiency (%)']:
        continue

    distribution_plot(ax=ax, data=data_num[col],
                      box_facecolor='#dcae80',
                      scatter_fc = '#1b1b1c',
                      ridge_lc='#1b1b1c',
                      
                      )
    ax.set_xlabel(xlabel=LABEL_MAP.get(col, col), weight='bold', fontsize=14)
    ax.set_yticklabels('')
plt.tight_layout()

# plt.savefig("../manuscript/figures/fig2.png", dpi=600,
#             bbox_inches="tight")
plt.show()

# %%

fig, axes = create_subplots(data_num.shape[1], figsize=(9, 8))
for ax, col in zip(axes.flat, data_num.columns):

    distribution_plot(ax=ax, data=data_num[col],
                      box_facecolor='#dcae80',
                      scatter_fc = '#1b1b1c',
                      ridge_lc='#1b1b1c',
                      
                      )
    ax.set_xlabel(xlabel=LABEL_MAP.get(col, col), weight='bold', fontsize=14)
    ax.set_yticklabels('')
plt.tight_layout()

plt.show()

# %%
# Categorical Data
# ===================

# %%
# categorical columns
cat_columns = ['Catalyst type', 'Anions']

# %%
# Display the unique values of the categorical columns
for col in cat_columns:
    print(f"{col}: {df[col].unique()}")

# %%
# Display the value counts of the categorical columns
for col in cat_columns:
    print(f"{col}: {df[col].value_counts()}")


merged_series = merge_uniques(df['Catalyst type'], 7)
pie_from_series(merged_series, cmap="coolwarm",  show=False, leg_pos=(0.85, 0.7))

# %%

merged_series = merge_uniques(df['Anions'], 5)
pie_from_series(merged_series, cmap="coolwarm",  show=False, leg_pos=(0.85, 0.7))

# %%
# Correlation
# ===========

df, cat_enc, an_enc = prepare_data(encoding="ohe", exclude_cf=False)
df = df.rename(
    columns={f"Catalyst type_{idx}":category for idx, category in enumerate(cat_enc.categories_[0])})
df = df.rename(
    columns={f"Anions_{idx}":category for idx, category in enumerate(an_enc.categories_[0])})
# # %%

corr = df.corr(method="pearson")

# %%

imshow(corr, colorbar=True, show=False)
plt.tight_layout()
plt.show()

# %%

df = df.fillna(0.0)

# %%

node_angles = circular_layout(corr.columns.tolist(), corr.columns.tolist(),
                              start_pos=90, group_boundaries=[0, len(corr.columns.tolist()) // 2])

print(node_angles.shape)

# %%

fig, ax = plt.subplots(figsize=(16, 16),
                       facecolor="#EFE9E6",
                       subplot_kw=dict(polar=True))
fig, axes = plot_connectivity_circle(
    corr.values,
    node_names = corr.columns.tolist(),
    node_angles=node_angles,
    fontsize_names =14,
    fontsize_colorbar =14,
    facecolor ="#EFE9E6",
    textcolor='black',
    #n_lines = 14,
    node_edgecolor="white",
    colormap="Spectral",
    colorbar_size=0.5,
    colorbar_pos=(-0.5, 0.5),
    ax=ax)

#fig.savefig(f"../manuscript/figures/fig3.png", dpi=600, bbox_inches="tight")
fig.tight_layout()

# %%

df, _, _ = prepare_data(encoding="le", exclude_cf=False)

# %%

corr = df.corr(method="pearson")

# %%

imshow(corr, colorbar=True, show=False)
plt.tight_layout()
plt.show()

# %%

df = df.fillna(0.0)

# %%

node_angles = circular_layout(corr.columns.tolist(), corr.columns.tolist(),
                              start_pos=90, group_boundaries=[0, len(corr.columns.tolist()) // 2])

print(node_angles.shape)

# %%

fig, ax = plt.subplots(figsize=(16, 16),
                       facecolor="#EFE9E6",
                       subplot_kw=dict(polar=True))
fig, axes = plot_connectivity_circle(
    corr.values,
    node_names = corr.columns.tolist(),
    node_angles=node_angles,
    fontsize_names =14,
    fontsize_colorbar =14,
    facecolor ="#EFE9E6",
    textcolor='black',
    #n_lines = 14,
    node_edgecolor="white",
    colormap="Spectral",
    colorbar_size=0.5,
    colorbar_pos=(-0.5, 0.5),
    ax=ax)

# fig.savefig(f"figures/chord_large_le", dpi=600, bbox_inches="tight")
fig.tight_layout()

# %%

df_org = pd.read_excel(fpath)
print(df_org.shape)

df_org = df_org.drop(columns=['Catalyst type', 'Anions'])

corr = df_org.corr(method="pearson")

imshow(corr, colorbar=True, show=False)
plt.tight_layout()
plt.show()

# %%

df_org = df_org.fillna(0.0)

# %%

node_angles = circular_layout(corr.columns.tolist(), corr.columns.tolist(),
                              start_pos=90, group_boundaries=[0, len(corr.columns.tolist()) // 2])

print(node_angles.shape)

# %%

fig, ax = plt.subplots(figsize=(16, 16),
                       facecolor="#EFE9E6",
                       subplot_kw=dict(polar=True))
fig, axes = plot_connectivity_circle(
    corr.values,
    node_names = corr.columns.tolist(),
    node_angles=node_angles,
    fontsize_names =14,
    fontsize_colorbar =14,
    facecolor ="#EFE9E6",
    textcolor='black',
    #n_lines = 14,conda 
    node_edgecolor="white",
    colormap="Spectral",
    colorbar_size=0.5,
    colorbar_pos=(-0.5, 0.5),
    ax=ax)

if SAVE:
    fig.savefig(f"figures/chord_large_org", dpi=600, bbox_inches="tight")
fig.tight_layout()
