Note
Go to the end to download the full example code.
utility functions
import os
import sys
import time
import warnings
from typing import Union, Any
from multiprocessing import cpu_count
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from easy_mpl import pie
from easy_mpl import hist
from easy_mpl import plot
from easy_mpl import regplot
from easy_mpl import bar_chart
from easy_mpl.utils import to_1d_array
from easy_mpl.utils import despine_axes
from easy_mpl.utils import AddMarginalPlots
from easy_mpl.utils import make_cols_from_cmap
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import seaborn as sns
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelEncoder
from SeqMetrics import RegressionMetrics
from ai4water.utils.utils import get_version_info
SAVE = False
LABEL_MAP = {
'time (min)': "Time (min)",
'qe': "Removal Efficiency (%)",
'O': "Oxygen (%)",
'Au': "Gold (%)",
'Fe': "Iron (%)",
'Bi': "Bismuth (%)",
'solution pH': 'Solution pH',
'Surface area': 'Surf. Area (m2/g)',
}
if not os.path.exists('figures'):
os.makedirs('figures')
def hardware_info()->dict:
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') # e.g. 4015976448
mem_gib = mem_bytes / (1024. ** 3) # e.g. 3.74
return dict(
tot_cpus=cpu_count(),
avail_cpus = len(os.sched_getaffinity(0)),
mem_gib=mem_gib,
)
def print_version_info(
include_run_time:bool=True,
include_hardware_info:bool = True
):
info = version_info()
if include_run_time:
info['Script Executed on: '] = time.asctime()
if include_hardware_info:
info.update(hardware_info())
for k,v in info.items():
print(k, v)
return
def version_info()->dict:
"""
returns the version info of packages being used
"""
info = get_version_info()
try:
import bnlearn
info['bnlearn'] = bnlearn.__version__
except (ImportError, ModuleNotFoundError):
pass
return info
def distribution_plot(ax,
data,
scatter_fc='#045568',
box_facecolor='#e6e6e6',
ridge_lc = '#045568',
width=0.8,
add_hist=True,
add_ridge=True):
sns.boxplot(orient='h', data=data, saturation=1, showfliers=False,
width=width, boxprops={'zorder': 3, 'facecolor': box_facecolor}, ax=ax,
)
old_len_collections = len(ax.collections)
for dots in ax.collections[old_len_collections:]:
dots.set_offsets(dots.get_offsets() + np.array([0, 0.12]))
ax = sns.stripplot(orient='h', x=data,
edgecolor="gray",
#linewidth=0.1,
alpha=0.5,
c=scatter_fc,
size=1.5,
ax=ax,
jitter=0.2)
despine_axes(ax, keep=['bottom', 'left', 'right'])
if add_hist or add_ridge:
aa = AddMarginalPlots(ax=ax,
hist=add_hist,
ridge=add_ridge,
hist_kws=dict(bins=20, color=box_facecolor),
fill_kws=dict(color=box_facecolor),
ridge_line_kws=dict(color=ridge_lc))
aa.divider = make_axes_locatable(ax)
axHistx = aa.add_ax_marg_x(data.values.reshape(-1,), hist_kws=aa.HIST_KWS[0], ax=None)
plt.setp(axHistx.get_xticklabels(),visible=False)
return ax
def merge_uniques(
series:pd.Series,
n_to_keep:int=5,
replace_with="Rest"
):
counts = series.value_counts()
values = []
for idx, (value, count) in enumerate(counts.items()):
if idx >= n_to_keep:
values.append(value)
series = series.replace(values, replace_with)
return series
def pie_from_series(
data:pd.Series,
cmap="tab20",
label_percent:bool = True,
n_to_merge:int = None,
leg_pos=None,
show=True,
fontsize=14
):
d:pd.Series = data.value_counts()
labels = d.index.tolist()
vals = d.values
colors = make_cols_from_cmap(cm=cmap, num_cols=len(vals))
percent = 100. * vals / vals.sum()
outs = pie(fractions=percent, autopct=None,
colors=colors, show=False)
patches, texts = outs
if label_percent:
labels = ['{0}: {1:1.2f} %'.format(i, j) for i, j in zip(labels, percent)]
else:
labels = ['{0} (n={1:4})'.format(i, j) for i, j in zip(labels, vals)]
patches, labels, dummy = zip(*sorted(zip(patches, labels, vals),
key=lambda x: x[2],
reverse=True))
plt.legend(patches, labels, bbox_to_anchor=leg_pos or (1.1, 1.),
fontsize=fontsize)
if show:
plt.tight_layout()
plt.show()
return
def _ohe_column(df:pd.DataFrame, col_name:str) -> tuple:
assert isinstance(col_name, str)
encoder = OneHotEncoder(sparse_output=False)
ohe_cat = encoder.fit_transform(df[col_name].values.reshape(-1,1))
cols_added = [f"{col_name}_{i}" for i in range(ohe_cat.shape[-1])]
df[cols_added] = ohe_cat
df.pop(col_name)
return df, cols_added, encoder
def prepare_data(
encoding:str='le',
exclude_cf:bool = True,
)->tuple:
fpath = os.path.join("../data/data.xlsx")
df = pd.read_excel(fpath)
df['Catalyst type'] = df['Catalyst type'].replace('commercial TiO2', 'TiO2')
df['Catalyst type'] = df['Catalyst type'].replace('0.25 wt% Au-BFO', '0.25 wt%')
df['Catalyst type'] = df['Catalyst type'].replace('0.5 wt% Au-BFO', '0.5 wt%')
df['Catalyst type'] = df['Catalyst type'].replace('1 wt% Au-BFO', '1 wt%')
df['Catalyst type'] = df['Catalyst type'].replace('2 wt% Au-BFO', '2 wt%')
df['Catalyst type'] = df['Catalyst type'].replace('no catalyst', 'None')
df['Catalyst type'] = df['Catalyst type'].replace('pure BFO', 'Pure BFO')
if encoding == "ohe":
df['Anions'] = df['Anions'].replace('Without Anions', 'No Anions')
else:
df['Anions'] = df['Anions'].replace('Without Anions', 'None')
if encoding == "ohe":
data, _, ct_encoder = _ohe_column(df, 'Catalyst type')
data, _, anion_encoder = _ohe_column(data, 'Anions')
else:
data, ct_encoder = le_column(df, 'Catalyst type')
data, anion_encoder = le_column(data, 'Anions')
if exclude_cf:
# Cf must not be used as input feature
data.pop('Cf (mg/L)')
#moving target to last
target = data.pop('Efficiency (%)')
data['Efficiency (%)'] = target
return data, ct_encoder, anion_encoder
def le_column(df:pd.DataFrame, col_name)->tuple:
"""label encode a column in dataframe"""
encoder = LabelEncoder()
df[col_name] = encoder.fit_transform(df[col_name])
return df, encoder
def evaluate_model(true, predicted):
metrics = RegressionMetrics(true, predicted)
for i in ['mse', 'rmse', 'r2', 'r2_score', 'mape', 'mae']:
print(i, getattr(metrics, i)())
return
def regression_plot(
train_true,
train_pred,
test_true,
test_pred,
label = 'qe',
max_xtick_val = None,
max_ytick_val = None,
min_xtick_val=None,
min_ytick_val=None,
max_ticks = 5,
train_color = '#9acad4',
test_color="#fe8977",
show=False,
ax=None,
)->plt.Axes:
TRAIN_RIDGE_LINE_KWS = [{'color': train_color, 'lw': 1.0},
{'color': train_color, 'lw': 1.0}]
TRAIN_HIST_KWS = [{'color': train_color, 'bins': 50},
{'color': train_color, 'bins': 50}]
ax = regplot(train_true, train_pred,
marker_size=35,
marker_color=train_color,
line_color='k',
fill_color='k',
scatter_kws={'edgecolors': 'black',
'linewidth': 0.7,
'alpha': 0.9,
},
label="Training",
show=False,
ax=ax,
)
axHistx, axHisty = AddMarginalPlots(
ax,
ridge=False,
pad=0.25,
size=0.7,
ridge_line_kws=TRAIN_RIDGE_LINE_KWS,
hist_kws=TRAIN_HIST_KWS
)(train_true, train_pred)
train_r2 = RegressionMetrics(train_true, train_pred).r2()
test_r2 = RegressionMetrics(test_true, test_pred).r2()
ax.annotate(f'Training $R^2$= {round(train_r2, 2)}',
xy=(0.95, 0.30),
xycoords='axes fraction',
horizontalalignment='right',
verticalalignment='top',
fontsize=12, weight="bold")
ax.annotate(f'Test $R^2$= {round(test_r2, 2)}',
xy=(0.95, 0.20),
xycoords='axes fraction',
horizontalalignment='right', verticalalignment='top',
fontsize=12, weight="bold")
ax_ = regplot(test_true, test_pred,
marker_size=35,
marker_color=test_color,
line_style=None,
scatter_kws={'edgecolors': 'black',
'linewidth': 0.7,
'alpha': 0.9,
},
show=False,
label="Test",
ax=ax
)
ax_.legend(fontsize=12, prop=dict(weight="bold"))
TEST_RIDGE_LINE_KWS = [{'color': test_color, 'lw': 1.0},
{'color': test_color, 'lw': 1.0}]
TEST_HIST_KWS = {'color': test_color, 'bins': 50}
AddMarginalPlots(
ax,
ridge=False,
pad=0.25,
size=0.7,
ridge_line_kws=TEST_RIDGE_LINE_KWS,
hist_kws=TEST_HIST_KWS
)(test_true, test_pred, axHistx, axHisty)
set_xticklabels(
ax_,
max_xtick_val=max_xtick_val,
min_xtick_val=min_xtick_val,
max_ticks=max_ticks,
)
set_yticklabels(
ax_,
max_ytick_val=max_ytick_val,
min_ytick_val=min_ytick_val,
max_ticks=max_ticks
)
ax.set_xlabel(f"Experimental {label}")
ax.set_ylabel(f"Predicted {label}")
if show:
plt.show()
return ax
def set_ticks(axes:plt.Axes, which="x", size=12):
ticks = getattr(axes, f"get_{which}ticks")()
ticks = np.array(ticks)
if 'float' in ticks.dtype.name:
ticks = np.round(ticks, 2)
else:
ticks = ticks.astype(int)
getattr(axes, f"set_{which}ticklabels")(ticks, size=size, weight="bold")
return
def set_xticklabels(
ax:plt.Axes,
max_ticks:Union[int, Any] = 5,
dtype = int,
weight = "bold",
fontsize:Union[int, float]=12,
max_xtick_val=None,
min_xtick_val=None,
):
"""
:param ax:
:param max_ticks:
maximum number of ticks, if not set, all the default ticks will be used
:param dtype:
:param weight:
:param fontsize:
:param max_xtick_val:
maxikum value of tick
:param min_xtick_val:
:return:
"""
return set_ticklabels(ax, "x", max_ticks, dtype, weight, fontsize,
max_tick_val=max_xtick_val,
min_tick_val=min_xtick_val)
def set_yticklabels(
ax:plt.Axes,
max_ticks:Union[int, Any] = 5,
dtype=int,
weight="bold",
fontsize:int=12,
max_ytick_val = None,
min_ytick_val = None
):
return set_ticklabels(
ax, "y", max_ticks, dtype, weight,
fontsize=fontsize,
max_tick_val=max_ytick_val,
min_tick_val=min_ytick_val,
)
def set_ticklabels(
ax:plt.Axes,
which:str = "x",
max_ticks:int = 5,
dtype=int,
weight="bold",
fontsize:int=12,
max_tick_val = None,
min_tick_val = None,
):
ticks_ = getattr(ax, f"get_{which}ticks")()
ticks = np.array(ticks_)
if len(ticks)<1:
warnings.warn(f"can not get {which}ticks {ticks_}")
return
if max_ticks:
ticks = np.linspace(min_tick_val or min(ticks), max_tick_val or max(ticks), max_ticks)
ticks = ticks.astype(dtype)
getattr(ax, f"set_{which}ticks")(ticks)
getattr(ax, f"set_{which}ticklabels")(ticks, weight=weight, fontsize=fontsize)
return ax
def residual_plot(
train_true,
train_prediction,
test_true,
test_prediction,
label='',
train_color = "#fe8977",
test_color = "#9acad4",
show:bool = False,
axis = None,
)->np.ndarray:
train_true = to_1d_array(train_true)
train_prediction = to_1d_array(train_prediction)
test_true = to_1d_array(test_true)
test_prediction = to_1d_array(test_prediction)
if axis is None:
fig, axis = plt.subplots(1, 2, sharey="all"
, gridspec_kw={'width_ratios': [2, 1]})
test_y = test_true.reshape(-1, ) - test_prediction.reshape(-1, )
train_y = train_true.reshape(-1, ) - train_prediction.reshape(-1, )
train_hist_kws = dict(bins=20, linewidth=0.7,
edgecolor="k", grid=False, color=train_color,
orientation='horizontal')
hist(train_y, show=False, ax=axis[1],
label="Training", **train_hist_kws)
plot(train_prediction, train_y, 'o', show=False,
ax=axis[0],
color=train_color,
markerfacecolor=train_color,
markeredgecolor="black", markeredgewidth=0.7,
alpha=0.9, label="Training"
)
_hist_kws = dict(bins=40, linewidth=0.7,
edgecolor="k", grid=False,
color=test_color,
orientation='horizontal')
hist(test_y, show=False, ax=axis[1],
**_hist_kws)
set_xticklabels(axis[1], 3)
plot(test_prediction, test_y, 'o', show=False,
ax=axis[0],
color=test_color,
markerfacecolor=test_color,
markeredgecolor="black", markeredgewidth=0.7,
ax_kws=dict(
#xlabel=f"Predicted {label}",
#ylabel="Residual",
legend_kws=dict(loc="upper left"),
),
alpha=0.9, label="Test",
)
axis[0].set_xlabel(f'Predicted {label}', fontsize=14)
axis[0].set_ylabel('Residual', fontsize=14)
set_yticklabels(axis[0], 5)
axis[0].axhline(0.0, color="black", ls="--")
plt.subplots_adjust(wspace=0.15)
if show:
plt.show()
return axis
def set_rcParams(kwargs:dict = None):
_kwargs = {'axes.labelsize': '20',
'axes.labelweight': 'bold',
'xtick.labelsize': '18',
'ytick.labelsize': '18',
'font.weight': 'bold',
'legend.title_fontsize': '12',
'axes.titleweight': 'bold',
'axes.titlesize': '22',
#"font.family" : "Times New Roman"
}
if sys.platform == "linux":
_kwargs['font.family'] = 'serif'
_kwargs['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
else:
_kwargs['font.family'] = "Times New Roman"
if kwargs:
_kwargs.update(kwargs)
for k,v in _kwargs.items():
plt.rcParams[k] = v
return
def bar_pie(
data:pd.DataFrame,
ax:plt.Axes=None,
save:bool = True,
name:str = '',
show:bool = True,
pie_pos = 5,
):
if ax is None:
f, ax = plt.subplots(figsize=(9, 9))
sv_bar = data['mean_shap'].copy()
colors = data['colors'].unique()
feature_names = data['features'].tolist()
ax.spines[['top', 'right']].set_visible(False)
ax_ = bar_chart(
sv_bar,
[LABEL_MAP[n] if n in LABEL_MAP else n for n in feature_names],
bar_labels=np.round(sv_bar, 2),
bar_label_kws={'label_type': 'edge',
'fontsize': 10,
'weight': 'bold',
"fmt": '%.4f',
'padding': 1.5
},
show=False,
sort=True,
color=data['colors'].to_list(),
ax=ax
)
ax_.spines[['top', 'right']].set_visible(True)
ax_.set_xlabel(xlabel='mean(|SHAP value|)')
# ax.set_xticklabels(ax.get_xticks().astype(float))
ax_.set_yticklabels(ax_.get_yticklabels())
labels = data['classes'].unique()
handles = [plt.Rectangle((0, 0), 1, 1,
color=colors[idx]) for idx, l in enumerate(labels)]
ax_.legend(handles, labels, loc='lower right', facecolor="white")
ax_.xaxis.set_major_locator(plt.MaxNLocator(4))
ax_.set_facecolor('white')
seg_colors = tuple(colors)
# Change the saturation of seg_colors to 70% for the interior segments
rgb = mcolors.to_rgba_array(seg_colors)[:, :-1]
hsv = mcolors.rgb_to_hsv(rgb)
hsv[:, 1] = 0.95 * hsv[:, 1]
interior_colors = mcolors.hsv_to_rgb(hsv)
labels = data['classes'].unique().tolist()
fractions = []
for label in labels:
fractions.append(data.loc[data['classes'] == label]['mean_shap'].sum())
fractions = np.array(fractions)
fractions /= fractions.sum()
for label, fraction in zip(labels, fractions):
print(label, fraction)
ax2 = inset_axes(ax, width='45%', height='45%',
loc=pie_pos)
# outer circle/ring
pie1_out = pie(fractions=fractions,
colors=seg_colors,
#labels=labels,
wedgeprops=dict(edgecolor="w", width=0.03),
radius=1,
autopct=None,
textprops=dict(fontsize=12),
startangle=90,
counterclock=False,
show=False,
ax=ax2)
# inner pie
pie2_out = pie(fractions=fractions,
colors=interior_colors,
autopct='%1.0f%%',
textprops=dict(fontsize=24),
wedgeprops=dict(edgecolor="w"),
pctdistance = 1.3,
radius=1 - 2 * 0.03,
startangle=90,
counterclock=False,
ax=ax2,
show=False
)
if save or SAVE:
plt.savefig(f"figures/shap_bar_{name}.png",
dpi=600,
bbox_inches="tight")
if show:
print('showing')
plt.tight_layout()
plt.show()
return ax, pie1_out, pie2_out
Total running time of the script: (0 minutes 0.009 seconds)