"""
===============
Experiments
===============
"""

import matplotlib.pyplot as plt

from ai4water.experiments import MLRegressionExperiments

from utils import prepare_data, SAVE
from utils import set_rcParams
from utils import print_version_info

# %%
print_version_info()

# %%

set_rcParams()
# %%

data, _, _ = prepare_data()

# %%

comparisons = MLRegressionExperiments(
    input_features=data.columns.tolist()[0:-1],
    output_features=data.columns.tolist()[-1:],
    split_random=True,
    seed=1575,
    verbosity=0,
    show=False
)

# %%

comparisons.fit(
    data=data,
    run_type="dry_run",
    include=['XGBRegressor',
             'AdaBoostRegressor', 'LinearSVR',
             'BaggingRegressor', 'DecisionTreeRegressor',
             'HistGradientBoostingRegressor',
             'ExtraTreesRegressor', 'ExtraTreeRegressor',
             'LinearRegression', 'KNeighborsRegressor',
             'RandomForestRegressor',
             'SGDRegressor', 'SVR',
             'LassoCV', 'RidgeCV',
             ]
)

# %%

set_rcParams({'xtick.labelsize': '14',
               'ytick.labelsize': '14'})
_ = comparisons.compare_errors(
    'r2',
    data=data,
    colors=('#063970', '#e28743'),
    figsize=(8, 8),
    )
fig = plt.gcf()
fig.text(0.35, 0.95, 'a)', ha='center', fontsize=14)
fig.text(0.7, 0.95, 'b)', ha='center', fontsize=14)
plt.tight_layout()
# plt.savefig("../manuscript/figures/figS1.png", dpi=600, bbox_inches="tight")
plt.show()

# %%

r2_score = comparisons.compare_errors(
    'r2_score', 
    data=data, 
    figsize=(8, 10),
    cutoff_type='greater',
    cutoff_val=0.5,
    colors=('#063970', '#e28743')
)
plt.tight_layout()
plt.show()
# %%

rmse = comparisons.compare_errors(
    'rmse', 
    data=data, 
    figsize=(8, 10),
    cutoff_type='less',
    cutoff_val=1000,
    colors=('#063970', '#e28743')
)
plt.tight_layout()
plt.show()

# %%

figure = comparisons.taylor_plot(
    data=data,
    figsize=(8, 8),
    include=r2_score.index.tolist(),
    leg_kws={'facecolor': 'white',
             'edgecolor': 'black', 'bbox_to_anchor':(1.0, 0.9),
             'fontsize': 10, 'labelspacing': 1.0, 'ncol': 1
            },
)
figure.axes[0].axis['left'].label.set_text('')
figure.axes[0].set_title('Training')
if SAVE:
    plt.savefig("results/figures/exp_taylor.png", dpi=600, bbox_inches="tight")
plt.tight_layout()
plt.show()
