"""Utilities for matplotlib."""
from typing import Any, Callable, Iterable, Mapping, Tuple, List, Optional, Union
from functools import wraps
import os
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("Matplotlib not installed. Try installing with pip:"
"\n\t$ pip install matplotlib\n"
"\nor reinstall carabiner with matplotlib:\n"
"\n\t$ pip install carabiner[mpl]\n")
else:
from matplotlib import axes, cycler, figure, rcParams, legend
import matplotlib
import numpy as np
from pandas import DataFrame
from tqdm.auto import tqdm
from ..cast import cast
from ..utils import TOL_PALETTES, colorblind_palette as utils_colorblind_palette, print_err
TFigAx = Tuple[figure.Figure, axes.Axes]
colorblind_palette = utils_colorblind_palette
# Set default color cycle on import
rcParams["axes.prop_cycle"] = cycler(color=colorblind_palette())
rcParams["font.sans-serif"] = [
"Helvetica", # available on MacOS
"Arial", # available on Windows
"Nimbus Sans", # widely available, free
"DejaVu Sans", # available on Linux, mpl default
"FreeSans", # widely available, free
"sans-serif", # system default
]
rcParams["axes.titlesize"] = 14.
rcParams["axes.labelsize"] = 14.
rcParams["xtick.labelsize"] = 12.
rcParams["ytick.labelsize"] = 12.
[docs]
def set_plot_palette(palette: Union[str, Iterable[str]]) -> Tuple[str]:
if isinstance(palette, str):
if palette in TOL_PALETTES:
palette = colorblind_palette(name=palette)
else:
raise ValueError(f"Palette named '{palette}' not built-in. Try one of {', '.join(TOL_PALETTES)}.")
rcParams["axes.prop_cycle"] = cycler(color=palette)
return palette
[docs]
def set_plot_font(font: str, category: str = "sans-serif") -> List[str]:
valid_categories = ("sans-serif", "serif")
if category not in valid_categories:
raise ValueError(f"Invalid font category: {category}. Try one of {', '.join(valid_categories)}.")
rcParams[f"font.{category}"] = [font] + [f for f in rcParams[f"font.{category}"] if f != font]
return rcParams[f"font.{category}"]
[docs]
def grid(
nrow: int = 1,
ncol: int = 1,
panel_size: float = 3.,
aspect_ratio: float = 1.,
layout: str = 'constrained',
sharex: Union[str, bool] = False,
sharey: Union[str, bool] = False,
hide_shared_ticks: bool = False,
square: bool = False,
*args, **kwargs
) -> TFigAx:
"""Create a figure and a set of subplots with sensible defaults.
Additional arguments are passed to `matplotlib.pyplot.subplots()`.
Parameters
----------
nrow : int, optional
Number of rows. Default: 1.
ncol : int, optional
Number of columns. Default: 1.
panel_size : float, optional
Size of panels. Default: 3.
aspect_ratio : float
Ratio of width over height. Default: 1 (square).
layout : str
Matplotlib `fig` layout. Default: "constrained".
hide_shared_ticks : bool, optional
Whether to hide the ticks when axes have shared scales
from setting `sharex` or `sharey`. Default: `False`.
square: bool
Whether to force panels to be square. Default: `True`.
Returns
-------
tuple
Pair of `figure.Figure` and `axes.Axes` objects.
"""
figsize=(
ncol * panel_size * aspect_ratio,
nrow * panel_size,
)
fig, axes = plt.subplots(
nrow,
ncol,
figsize=figsize,
layout=layout,
sharex=sharex,
sharey=sharey,
*args, **kwargs
)
all_axes = fig.axes
if square:
for ax in all_axes:
ax.set(aspect="equal")
if sharex:
for ax in all_axes:
ax.xaxis.set_tick_params(labelbottom=True)
if sharey:
for ax in all_axes:
ax.yaxis.set_tick_params(labelleft=True)
return fig, axes
[docs]
def add_legend(
ax: axes.Axes,
**kwargs
) -> legend.Legend:
"""Add a legend to the right of a Matplotlib plotting axis.
Uses a sensible default for putting the legend out of the way. Keyword arguments
override `loc` and `bbox_to_anchor`, and additional arguments are passed to
`matplotlib.axes.Axes.legend()`.
Parameters
----------
ax : matplotlib.axes.Axes
Axes to add a legend to.
Returns
-------
matplotlib.legend.Legend
"""
default_opts = {
"loc": 'center left',
"bbox_to_anchor": (1, .5)
}
default_opts.update(kwargs)
return ax.legend(**default_opts)
[docs]
def scattergrid(
df: DataFrame,
grid_columns: Union[str, Iterable[str]],
grid_rows: Optional[Union[str, Iterable[str]]] = None,
grouping: Optional[Union[str, Iterable[str]]] = None,
log: Optional[Union[str, Iterable[str]]] = None,
n_bins: int = 40,
scatter_opts: Optional[Mapping[str, Any]] = None,
hist_opts: Optional[Mapping[str, Any]] = None,
legend_opts: Optional[Mapping[str, Any]] = None,
*args, **kwargs
) -> TFigAx:
"""Create a scatter plot to compare sets of variables in a Pandas DataFrame.
Similar to `pandas.plotting.scatter_matrix`, but with larger panels and control
over which variables are log-scaled.
Additional arguments are passed to `grid()`.
Parameters
----------
df : pandas.DataFrame
Data to plot.
grid_columns : str | Iterable[str]
Data columns to plot along the rows of the scatter grid.
Becomes the x-axes.
grid_rows : str | Iterable[str], optional
Data columns to plot down the columns of the scatter grid.
Becomes the y-axes. If not provided, uses `grid_columns` for
all pair-wise comparison.
grouping : str | Iterable[str], optional
If provided, use these columns of `df` to make groups and plot each
data group as a differnt color.
log : str | Iterable[str], optional
If provided, plot these columns of `df` on a log scale.
n_bins : int, optional
Number of bin for histograms plotted on identity diagonal of the
scatter grid. Default: 40.
scatter_opts : dict, optional
Extra keyword arguments to pass to the Matplotlib scatter plots.
hist_opts : dict, optional
Extra keyword arguments to pass to the Matplotlib histogram plots.
legend_opts : dict, optional
Extra keyword arguments to pass to the Matplotlib legend.
Returns
-------
tuple
Pair of `figure.Figure` and `axes.Axes` objects.
Raises
------
KeyError
If no named columns are in `df`.
"""
grid_columns = [name for name in cast(grid_columns, to=list) if name in df]
grid_rows = grid_rows or grid_columns
grid_rows = [name for name in cast(grid_rows, to=list) if name in df]
all_names = sorted(set(grid_columns + grid_rows))
log = log or []
log = [name for name in cast(log, to=list) if name in all_names]
if grouping is None:
grouping = "__group__"
df = df.assign(__group__=grouping).groupby(grouping)
dummy_group = True
else:
grouping = sorted(set([
g for g in cast(grouping, to=list) if g in df
]))
if len(grouping) == 0:
raise KeyError(f"No columns in grouping ({', '.join(grouping)}) were in the DataFrame ({', '.join(df)})!")
df = df.groupby(grouping)
dummy_group = False
_scatter_opts = {"s": 15., "facecolor": "none", "linewidth": .5}
_scatter_opts.update(scatter_opts or {})
_hist_opts = {
"alpha": .7,
"linewidth": 1.,
"histtype": "stepfilled",
"edgecolor": "lightgrey",
}
if dummy_group:
_hist_opts.update({
"alpha": 1.,
})
_hist_opts.update(hist_opts or {})
_legend_opts = {}
_legend_opts.update(legend_opts or {})
fig, axes = grid(
nrow=len(grid_rows),
ncol=len(grid_columns),
squeeze=False,
*args, **kwargs
)
for axrow, grid_row_name in zip(tqdm(axes), grid_rows):
for ax, grid_col_name in zip(axrow, grid_columns):
make_histogram = grid_row_name == grid_col_name
xscale = "log" if grid_col_name in log else "linear"
yscale = "log" if (grid_row_name in log and not make_histogram) else "linear"
if not make_histogram:
ylabel = grid_row_name
elif _hist_opts.get("density", False):
ylabel = "Density"
else:
ylabel = "Frequency"
for i, (group_name, group_df) in enumerate(df):
color = f"C{i}"
labels = {"label": ":".join(map(str, group_name))} if not dummy_group else {}
if make_histogram:
if xscale == "log":
values = group_df.query(f"`{grid_col_name}` > 0")[grid_col_name].values
values = values[np.isfinite(values)]
if values.size > 0:
values_min, values_max = values.min(), values.max()
if values_min == values_max:
hist_max = values_min + 1.
else:
hist_max = values_max
bins = np.geomspace(
values_min,
hist_max,
num=n_bins,
)
else:
continue
else:
bins = n_bins
try:
ax.hist(
group_df[grid_col_name],
bins=bins,
fill=color,
**_hist_opts,
**labels,
)
except ValueError as e: # Usually some problem with value ranges, but shouldn't prevent plotting
print_err(e)
else:
these_scatter_opts = {"edgecolor": color}
these_scatter_opts.update(_scatter_opts)
ax.scatter(
x=group_df[grid_col_name],
y=group_df[grid_row_name],
**these_scatter_opts,
**labels,
)
ax.set(
xlabel=grid_col_name,
ylabel=ylabel,
xscale=xscale,
yscale=yscale,
)
if not dummy_group:
add_legend(ax, **_legend_opts)
return fig, axes
[docs]
def figsaver(
output_dir: str = ".",
prefix: Optional[str] = None,
dpi: int = 300,
format: Union[str, Iterable[str]] = "png",
) -> Callable[[figure.Figure, str, int, str, Optional[DataFrame]], None]:
"""Create a function to save figures in a predefined location.
Parameters
----------
output_dir : str, optional
Directory to save figures. Default: ".".
prefix : str, optional
Prefix for filenames. Default: no prefix.
dpi : int, optional
Resolution of saved figures. Default: 300.
format : str or Iterable, optional
File format(s) of figures. Default: "png".
Returns
-------
Callable
A function taking Figure, name, and optionally a Pandas
DataFrame as arguments. Saves as {dir}/{prefix}{name}.{format}.
If a DataFrame is provided, it as saved as {dir}/{prefix}{name}.csv.
"""
prefix = prefix or ""
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if isinstance(format, str):
format = [format]
if isinstance(format, tuple):
format = list(format)
def _figsave(
fig: figure.Figure,
name: str,
df: Optional[DataFrame] = None
) -> None:
"""
"""
for _format in format:
figname = os.path.join(output_dir, f"{prefix}{name}.{_format}")
print_err(f"Saving plot at {figname}")
fig.savefig(
figname,
dpi=dpi,
bbox_inches="tight",
)
if df is not None and isinstance(df, DataFrame):
dataname = os.path.join(output_dir, f"{prefix}{name}.csv")
print_err(f"Saving data at {dataname}")
df.to_csv(dataname, index=False)
return None
return _figsave