'''
High-level functions for creating and saving plots
'''
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import spellbook as sb
# At some point all the plotting stuff can be organised into proper classes.
# The advantage would be that accessors to elements can be made available in a
# more organised way. Something like
#
# fig = sb.plot.plot_grid_2D(nrows=2, ncols=5,
# data=data, xs=features, ys=target, relative='true')
# sb.plot2D.heatmap_set_annotations_fontsize(ax=fig.get_axes()[7],
# fontsize='x-small')
#
# would not be necessary anymore and could be replaced with the likes of
#
# corrs = sb.plot.grid_2D(nrows=2, ncols=5,
# data=data, xs=features, ys=target, relative='true')
# corrs.plots[iRow][iCol].set_annotation_fontsize('x-small')
#
# class PlotBase:
# # - fig
# # - ax
# # - save()
# pass
#
# class Plot(PlotBase):
# # - grid(1,1)
# # - plot: instance of type Barchart, Histogram, Heatmap, ...
# pass
#
# class PlotGrid(PlotBase):
# # - grid(n,m)
# # - axs: [[...], [...]] nested list so that each of them can be accessed
# # - plots: [[...], [...]] nested list of types Barchart, Histogram, ...
# pass
[docs]def save(fig: mpl.figure.Figure,
filename: str,
dpi: int = 200):
'''
Save a plot to a file
Args:
fig: The figure to plot
filename: The filename under which to save the plot
dpi: Optional resolution
'''
fig.savefig(filename, dpi=dpi)
print("Saved plot to file '{}'".format(filename))
[docs]def plot_1D(data: pd.DataFrame,
x: str,
xlabel: str = None,
fontsize: float = 12.0,
figure_args: dict = {},
barchart_args: dict = {},
histogram_args: dict = {},
histplot_args: dict = {},
statsbox_args: dict = {}
) -> mpl.figure.Figure:
'''
Create a single univariate plot
The type of the variable (*categorical* or *continuous*) is determined
automatically and either :func:`spellbook.plot1D.barchart` or
:func:`spellbook.plot1D.histogram` is called.
Args:
data (:class:`pandas.DataFrame`): The dataset to plot
x: Name of the variable to plot
xlabel: *Optional*. Title of the x-axis. If unspecified or set to
``None``, the name of the variable, as specified by **x**, will
be used.
fontsize: *Optional*. Baseline fontsize for all elements. This is
probably the fontsize that ``medium`` corresponds to?
figure_args: *Optional*. Arguments for the creation of the
:class:`matplotlib.figure.Figure` with
:func:`matplotlib.pyplot.figure`
barchart_args: *Optional*. Arguments passed on to
:func:`spellbook.plot1D.barchart` for categorical data
histogram_args: *Optional*. Arguments passed on to
:func:`spellbook.plot1D.histogram` for continuous data
histplot_args: *Optional*. Arguments for :func:`seaborn.histplot`,
which is used to draw the plot
statsbox_args: *Optional*. Arguments passed on by
`spellbook.plot1D.histogram` to :func:`spellbook.plotutils.statsbox`
Returns:
The figure containing the plot
'''
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig = plt.figure(tight_layout=True, **figure_args)
grid = mpl.gridspec.GridSpec(nrows=1, ncols=1, wspace=0.0, hspace=0.0)
kind = sb.plotutils.get_data_kind(data[x])
if kind == 'cat':
sb.plot1D.barchart(data, x=x, fig=fig, grid=grid, gridindex=0,
xlabel=xlabel,
histplot_args=histplot_args,
**barchart_args)
elif kind == 'ord':
print('plot1D/ord: TODO / IMPLEMENTATION MISSING')
elif kind == 'cont':
sb.plot1D.histogram(data=data, x=x, fig=fig, grid=grid, gridindex=0,
xlabel=xlabel,
histplot_args=histplot_args,
statsbox_args=statsbox_args,
**histogram_args)
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
return(fig)
[docs]def plot_grid_1D(nrows: int, ncols: int,
data: pd.DataFrame,
target: str = None,
features: List[str] = None,
xlabels: Union[str, List[str]] = None,
fontsize: float = 12.0,
figure_args: dict = {},
stats: Union[bool, List[bool]] = True,
stats_align: Union[str, List[str]] = None,
binwidths: Union[float, List[float]] = None,
histogram_args: dict = {}
) -> mpl.figure.Figure:
'''
Create a grid of univariate plots
The type / visual representation of each variable is determined automatically
via :func:`spellbook.plotutils.get_data_kind`. Categorical variables are shown as
barcharts and continuous variables are shown as univariate / 1D histograms.
Summary statistics boxes can be shown for the histograms.
.. image:: ../images/plot_grid_1D.png
:width: 800px
:align: center
Args:
nrows: Number of rows
ncols: Number of columns
data (:class:`pandas.DataFrame`): The dataset to plot
target: *Optional*. The name of the target variable. If specified,
the target variable will be plotted first and highlighted by
plotting it in orange. Either **target** or **features** has to
be specified.
features: *Optional*. List with the names of the feature
variables. If specified, the feature variables will be plotted after
the target variable. Either **target** or **features** has to
be specified.
xlabels: *Optional*. The titles of the x-axes. If unspecified or set to
``None``, the names of the variables, as specified by **target**
and **features** will be used.
fontsize: *Optional*. Baseline fontsize for all elements. This is
probably the fontsize that ``medium`` corresponds to?
figure_args: *Optional*. Arguments for the creation of the returned
:class:`matplotlib.figure.Figure` with
:func:`matplotlib.pyplot.figure`
stats: *Optional*. Bool or list of bools that
indicate if statistics boxes are shown in each plot
stats_align: *Optional*. List of alignment strings, one for
each plot
binwidths: *Optional*. Float or list of floats
that indicate the binwidth in each plot
histogram_args: *Optional*. Dictionary of parameters and values
that are passed to :func:`spellbook.plot1D.histogram`
Returns:
Figure containing the grid of plots
Example:
.. code:: python
import pandas as pd
import spellbook as sb
data = pd.read_csv('dataset.csv')
plot_vars = sb.plot.plot_grid_1D(2, 4, data,
target='z', features=['x', 'y'],
stats=True, stats_align=['tl', 'br', 'tr'])
'''
vars = []
if target: vars = [target]
if features: vars += features
assert len(vars) > 0
if xlabels is None or isinstance(xlabels, str):
xlabels = [xlabels] * len(vars)
if isinstance(stats, bool):
stats = [stats] * len(vars)
if not type(binwidths) is list:
binwidths = [binwidths] * len(vars)
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig = plt.figure(figsize=(3*ncols, 3*nrows), **figure_args)
grid = mpl.gridspec.GridSpec(nrows=nrows, ncols=ncols)
for irows in range(nrows):
for icols in range(ncols):
i = irows*ncols + icols # vars index
if i >= len(vars):
# ax = plt.Subplot(fig, grid[i])
# ax.axis("off") # blank grid cell, without axes
continue
binwidth = binwidths[i] if len(binwidths)>i else None
# statistics box
stat = stats[i] if len(stats)>i else True
if stat:
stats_alignment = stats_align[i] \
if (stats_align is not None and len(stats_align)>i) \
else "tr"
histplot_args = {'binwidth': binwidth}
if target and i==0: histplot_args['color'] = 'C1'
# categorical variable
if sb.plotutils.get_data_kind(data[vars[i]]) == 'cat':
sb.plot1D.barchart(data=data, x=vars[i],
fig=fig, grid=grid, gridindex=i,
xlabel=xlabels[i],
histplot_args=histplot_args)
# ordinal variable
elif sb.plotutils.get_data_kind(data[vars[i]]) == 'ord':
sb.plotutils.not_yet_implemented(fig, grid, i,
'plot.plot_grid_1D() / ord')
# plot = sns.histplot(data=data, x=vars[i], ax=axes[irows][icols], binwidth=binwidth)
# # statistics box
# if stat:
# sb.plotutils.statsbox(ax=axes[irows][icols],
# text=data[vars[i]].describe(percentiles=[]).round(2).to_string(),
# alignment=stats_alignment)
# continuous variable
else:
statsbox_args = {}
if stat: statsbox_args = {'alignment': stats_alignment}
sb.plot1D.histogram(data=data, x=vars[i],
fig=fig, grid=grid, gridindex=i,
xlabel=xlabels[i],
show_stats=stat,
histplot_args=histplot_args,
statsbox_args=statsbox_args,
**histogram_args)
fig.tight_layout()
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
return(fig)
[docs]def plot_2D(data: pd.DataFrame, x: str, y: str,
relative: bool = False,
fontsize: float = 12.0,
figure_args: dict = {},
heatmap_args: dict = {},
violinplot_args: dict = {},
cathist_args: dict = {},
scatterplot_args: dict = {}
) -> mpl.figure.Figure:
'''
Create a single bivariate/correlation plot
The types of the variables (*categorical* or *continuous*) are determined
automatically and the corresponding 2D plotting function is called:
- *x* is ``categorical`` and *y* is ``categorical``:
:func:`spellbook.plot2D.heatmap`
- *x* is ``categorical`` and *y* is ``continuous``:
:func:`spellbook.plot2D.violinplot`
- *x* is ``continuous`` and *y* is ``categorical``:
:func:`spellbook.plot2D.categorical_histogram`
- *x* is ``continuous`` and *y* is ``continuous``:
:func:`spellbook.plot2D.scatterplot`
Args:
data (:class:`pandas.DataFrame`): The dataset to plot
x: Name of the variable to plot on the x-axis
y: Name of the variable to plot on the y-axis
relative: Optional, whether or not the heatmaps drawn with
:func:`spellbook.plot2D.heatmap` should be normalised or not
- ``True``: heatmap will be column-normalised
(``normalisation = norm-col``)
- ``False``: heatmap will be show absolute numbers
(``normalisation = count``)
fontsize: *Optional*. Baseline fontsize for all elements.
This is probably the fontsize that ``medium`` corresponds to?
figure_args: *Optional*. Arguments for the creation of the
:class:`matplotlib.figure.Figure` with
:func:`matplotlib.pyplot.figure`
heatmap_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.heatmap` for correlations between
a *categorical* variable on the x-axis and
a *categorical* variable on the y-axis
violinplot_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.violinplot` for correlations between
a *categorical* variable on the x-axis and
a *continuous* variable on the y-axis
cathist_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.categorical_histogram` for correlations
between
a *continuous* variable on the x-axis and
a *categorical* variable on the y-axis
scatterplot_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.scatterplot` for correlations between
a *continuous* variable on the x-axis and
a *continuous* variable on the y-axis
Returns:
The figure containing the plot
Examples:
- simple example
.. code:: python
fig = sb.plot.plot_2D(data=data, x='age', y=target, fontsize=14.0)
- advanced example
The target variable has two categories and therefore, two histograms
will be stacked on top of each other. Via the *histogram_args* parameter,
a list of two dictionaries is passed on to
:func:`spellbook.plot2D.categorical_histogram` - one dictionary for each
of the two categories. Each one of the dictionaries is then passed on to
:func:`spellbook.plot1D.histogram`.
.. code:: python
fig = sb.plot.plot_2D(
data=data, x='age', y=target, fontsize=11.0,
cathist_args = {
'histogram_args': [
dict(
show_stats=True,
statsbox_args = {'alignment': 'bl'}
),
dict(
show_stats=True,
statsbox_args = {
'y': 0.96,
'text_args': {
# RGBA white with 50% alpha/opacity
'backgroundcolor': (1.0, 1.0, 1.0, 0.5)
}
}
)
]
})
'''
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig = plt.figure(tight_layout=True, **figure_args)
grid = mpl.gridspec.GridSpec(nrows=1, ncols=1, wspace=0.0, hspace=0.0)
x_kind = sb.plotutils.get_data_kind(data[x])
y_kind = sb.plotutils.get_data_kind(data[y])
if x_kind == 'cat' and y_kind == 'cat':
normalisation = 'norm-col' if relative else 'count'
sb.plot2D.heatmap(data=data, x=x, y=y,
fig=fig, grid=grid, gridindex=0,
normalisation=normalisation, **heatmap_args)
elif x_kind == 'cat' and y_kind == 'cont':
sb.plot2D.violinplot(data=data, x=x, y=y,
fig=fig, grid=grid, gridindex=0,
**violinplot_args)
elif x_kind == 'cont' and y_kind == 'cat':
sb.plot2D.categorical_histogram(data=data, x=x, y=y,
fig=fig, grid=grid, gridindex=0,
**cathist_args)
else:
ax = plt.Subplot(fig, grid[0])
sb.plot2D.scatterplot(data=data, x=x, y=y, ax=ax,
**scatterplot_args)
fig.add_subplot(ax)
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
return(fig)
[docs]def plot_grid_2D(nrows: int,
ncols: int,
data: pd.DataFrame,
xs: List[str],
ys: List[str],
relative: bool = False,
fontsize: float = 12.0,
figure_args: dict = {},
heatmap_args: dict = {},
violinplot_args: dict = {},
cathist_args: dict = {},
scatterplot_args: dict = {}
) -> mpl.figure.Figure:
'''
Create a grid of bivariate/correlation plots
Args:
nrows: Number of rows
ncols: Number of columns
data (:class:`pandas.DataFrame`): The dataset to plot
xs: Names of the variables to plot on the x-axis
ys: Names of the variables to plot on the y-axis
relative: *Optional*. Whether or not the heatmaps drawn with
:func:`spellbook.plot2D.heatmap` should be normalised or not
- ``True``: heatmap will be column-normalised
(``normalisation = norm-col``)
- ``False``: heatmap will be show absolute numbers
(``normalisation = count``)
fontsize: *Optional*. Baseline fontsize for all elements.
This is probably the fontsize that ``medium`` corresponds to?
figure_args: *Optional*. Arguments for the creation of the
:class:`matplotlib.figure.Figure` with
:func:`matplotlib.pyplot.figure`
heatmap_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.heatmap` for correlations between
a *categorical* variable on the x-axis and
a *categorical* variable on the y-axis
violinplot_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.violinplot` for correlations between
a *categorical* variable on the x-axis and
a *continuous* variable on the y-axis
cathist_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.categorical_histogram` for correlations
between
a *continuous* variable on the x-axis and
a *categorical* variable on the y-axis
scatterplot_args: *Optional*. Arguments passed on to
:func:`spellbook.plot2D.scatterplot` for correlations between
a *continuous* variable on the x-axis and
a *continuous* variable on the y-axis
Returns:
The figure containing the grid of plot
'''
if not (type(ys) is list): ys = [ys for i in range(len(xs))]
assert(len(xs) == len(ys))
assert(nrows * ncols >= len(xs))
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig = plt.figure(figsize=(3*ncols, 3*nrows), tight_layout=True,
**figure_args)
grid = mpl.gridspec.GridSpec(nrows=nrows, ncols=ncols)
for irows in range(nrows):
for icols in range(ncols):
i = irows*ncols + icols
if i >= len(xs):
ax = plt.Subplot(fig, grid[i])
ax.axis("off")
continue
x_kind = sb.plotutils.get_data_kind(data[xs[i]])
y_kind = sb.plotutils.get_data_kind(data[ys[i]])
if x_kind == 'cat' and y_kind == 'cat':
normalisation = 'norm-col' if relative else 'count'
sb.plot2D.heatmap(data=data, x=xs[i], y=ys[i],
fig=fig, grid=grid, gridindex=i,
normalisation=normalisation,
**heatmap_args)
elif x_kind == 'cat' and y_kind == 'cont':
sb.plot2D.violinplot(data=data, x=xs[i], y=ys[i],
fig=fig, grid=grid, gridindex=0,
**violinplot_args)
elif x_kind == 'cont' and y_kind == 'cat':
sb.plot2D.categorical_histogram(data=data, x=xs[i], y=ys[i],
fig=fig, grid=grid, gridindex=i,
**cathist_args)
else:
ax = plt.Subplot(fig, grid[0])
sb.plot2D.scatterplot(data=data, x=x, y=y, ax=ax,
**scatterplot_args)
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
return(fig)
[docs]def pairplot(data: pd.DataFrame,
xs: List[str],
ys: List[str] = None,
fontsize: float = 12.0,
histplot_args: dict = {},
) -> mpl.figure.Figure:
'''
Create a pairplot
.. image:: ../images/pairplot-3x5.png
:width: 700px
:align: center
The plot does not need to contain the same variables or number of variables
in x and y. It can be rectangular with any number of rows and any number of
columns. The subplots with the same variable in x and y are detected
automatically, no matter where they are located in the pairplot, and instead
of a 2D/bivariate/correlation plot, the appropriate 1D/univariate
distribution is shown. This behaviour allows to split a full and possibly
large pairplot for all variables into arbitrarily-sized separate smaller
pieces.
The visual representation of the distributions and correlations is chosen
automatically depending on the type of random variables (categorical,
ordinal, continuous).
Args:
data (:class:`pandas.DataFrame`): The dataset to plot
xs: Names of the variables to plot on the x-axis
ys: *Optional*. Names of the variables to plot on the y-axes.
If not specified, the same variables will be shown on the x-axes
and the y-axes.
fontsize: *Optional*. Baseline fontsize for all elements.
This is probably the fontsize that ``medium`` corresponds to?
histplot_args: *Optional*. Arguments for :func:`seaborn.histplot`,
which is used to draw the histograms
Returns:
The figure containing the grid of plots
'''
if ys is None: ys = xs
assert isinstance(data, pd.DataFrame)
assert isinstance(xs, list)
assert isinstance(ys, list)
assert all(isinstance(x, str) for x in xs)
assert all(isinstance(y, str) for y in ys)
ncols = len(xs)
nrows = len(ys)
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig = plt.figure(figsize=(3*ncols, 3*nrows))
grid = mpl.gridspec.GridSpec(nrows=nrows, ncols=ncols)
for iy in range(nrows):
for ix in range(ncols):
i = iy*ncols + ix
x_kind = sb.plotutils.get_data_kind(data[xs[ix]])
y_kind = sb.plotutils.get_data_kind(data[ys[iy]])
histplot_args = {}
if xs[ix] == ys[iy]: # on-diagonal: univariate plots
histplot_args['color'] = 'C1' # 'gray' # pink
if x_kind == 'cat':
sb.plot1D.barchart(data, x=xs[ix],
fig=fig, grid=grid, gridindex=i,
histplot_args=histplot_args)
elif x_kind == 'ord':
sb.plotutils.not_yet_implemented(fig, grid, i,
'plot.pairplot()\non-diagonal / ord')
elif x_kind == 'cont':
sb.plot1D.histogram(data=data, x=xs[ix],
fig=fig, grid=grid, gridindex=i,
show_stats=False,
histplot_args=histplot_args)
else:
assert False # ax.axis("off")
else: # off-diagonal: bivariate/correlation plots
if x_kind == 'cat' and y_kind == 'cat':
sb.plot2D.heatmap(data=data, x=xs[ix], y=ys[iy],
fig=fig, grid=grid, gridindex=i,
normalisation='norm-col')
elif x_kind == 'cat' and y_kind == 'cont':
sb.plot2D.violinplot(data=data, x=xs[ix], y=ys[iy],
fig=fig, grid=grid, gridindex=i)
elif x_kind == 'cont' and y_kind == 'cat':
sb.plot2D.categorical_histogram(data=data, x=xs[ix], y=ys[iy],
fig=fig, grid=grid, gridindex=i,
histplot_args=histplot_args)
elif x_kind == 'cont' and y_kind == 'cont':
ax = plt.Subplot(fig, grid[i])
sb.plot2D.scatterplot(data=data, x=xs[ix], y=ys[iy], ax=ax,
show_lineplot=False,
show_stats=False)
fig.add_subplot(ax)
fig.tight_layout()
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
return(fig)
[docs]def plot_confusion_matrix(confusion_matrix: tf.Tensor,
class_names: List[str],
class_ids: List[int] = None,
normalisation: str = 'count',
crop: bool = True,
figsize: Tuple[float, float] = (5.8, 5.3),
fontsize: float = None,
fontsize_annotations: Union[str, float] = None
) -> mpl.figure.Figure:
'''
Create a confusion matrix heatmap plot
.. image:: ../images/confusion-matrix-absolute.png
:width: 600px
:align: center
Both the absolute frequencies as well as the relative frequencies, either
normalised by the true labels, the predictedlabels or their combinations,
can be shown. The desired behaviour is specified with the parameter
``normalisation``.
Args:
confusion_matrix (:class:`tf.Tensor`): The confusion
matrix
class_names: List of the class names
class_ids: Optional, list of IDs for each target class.
These IDs are shown on the x-axis and, together with the class
names, on the y-axis.
normalisation: Optional, indicates if the absolute or relative
frequencies should be plotted
- ``count``: Numbers of datapoints
- ``norm-all``: Percentages normalised across all combinations of
the true and the predicted classes/labels
- ``norm-true``: Percentages normalised across the true labels
- ``norm-pred``: Percentages normalised across the predicted classes
figsize: Optional, size (width, height) of the figure in inches
crop: Plots with *normalisation* set to ``norm-true``/``norm-pred``
do not include the *SUM* row/column, respectively. When *crop* is
set to
- ``True``, the excluded *SUM* row/column is removed from the
heatmap matrix, thus making it occupy a larger portion of the plot
- ``False``, the excluded *SUM* row/column is kept empty but still
included in the heatmap matrix, so as to make each cell appear
in the same position as with normalisation set to ``count`` or
``norm-all``
fontsize: *Optional*. Baseline fontsize for all elements.
This is probably the fontsize that ``medium`` corresponds to?
fontsize_annotations: *Optional*. Fontsize for the annotations.
As specified in :meth:`matplotlib.text.Text.set_fontsize`.
Returns:
The figure containing the plot
See also:
:func:`tf.math.confusion_matrix`
'''
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig = plt.figure(figsize=figsize)
grid = mpl.gridspec.GridSpec(nrows=1, ncols=1)
sb.plot2D.heatmap(data = confusion_matrix.numpy(),
x = 'predicted labels',
y = 'true labels',
fig = fig,
grid = grid,
gridindex = 0,
normalisation = normalisation,
crop = crop,
xlabels = class_ids,
ylabels = ['{} - {}'.format(id, name)
for id, name in zip(class_ids, class_names)],
ylabels_horizontal = True,
heatmap_args = dict(square=True))
if fontsize_annotations:
for child in fig.get_children():
if isinstance(child, mpl.axes.Axes):
sb.plot2D.heatmap_set_annotations_fontsize(
child, fontsize_annotations)
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
fig.tight_layout()
return(fig)
[docs]def parallel_coordinates(
data: pd.DataFrame,
features: List[str],
target: str,
categories: Dict[str, Dict[int, str]],
fontsize: float = None,
shift: float = 0.3
) -> mpl.figure.Figure:
'''
Parallel coordinates plot
.. image:: /images/parallel-coordinates.png
:align: center
:height: 250px
Based on `Parallel Coordinates in Matplotlib
<https://benalexkeen.com/parallel-coordinates-in-matplotlib/>`_,
but extended to also support categorical variables.
For categorical variables, a random uniform shift is applied to spread
the lines in the vicinity of the respective classes. This way, there
is an indication for the composition of the datapoints in a particular
class/category in terms of the target labels/classes. Furthermore, the
shift interval is sized according to the number of datapoints in the
respective class/category in order to give an impression for how many
datapoints there are in that class.
Args:
data (:class:`pandas.DataFrame`): The dataset to plot
features: The names of the feature variables
target: The name of the target variable
categories: Dictionary holding the category codes/indices and names
as returned by :func:`spellbook.input.encode_categories`
fontsize: *Optional*. Baseline fontsize for all elements.
This is probably the fontsize that ``medium`` corresponds to?
shift: *Optional*. The half-size of the interval for uniformely
shifting categorical variables
.. todo:: Support more than the 10 colours included in *Matplotlib*'s
tableau colours
'''
df = data.copy() # get local copy to protect original data
target_name = target.replace('_codes', '').replace('_norm', '')
features.append(target) #
feature_names = [f.replace('_codes', '').replace('_norm', '')
for f in features]
x = [i for i, _ in enumerate(features)]
# https://matplotlib.org/stable/gallery/color/named_colors.html
colours = list(mpl.colors.TABLEAU_COLORS)
if fontsize:
tmp_fontsize = plt.rcParams['font.size']
plt.rcParams['font.size'] = fontsize
fig, axs = plt.subplots(1, len(features)-1, sharey=False,
figsize = (1.5*len(features), 5))
cat_mins = {}
cat_maxs = {}
for i, feature in enumerate(features):
features[i] = feature + '_y'
if feature_names[i] in categories:
ncats = len(categories[feature_names[i]])
# counts / absolute frequencies of each class
counts = df[feature].value_counts()
# empty shifting vector, one entry per row/datapoint
shifts = np.zeros(shape=len(df[feature]))
# calculate shift for each class
for index, value in enumerate(df[feature].values):
# larger shift if class is more frequent
# -> thickness of line bundle indicates how many datapoints
# fall in that particular class
shifts[index] = shift * (counts[value]-1)/max(5,len(df)-1)
# apply uniform shifts
df[feature+'_y'] = 1.0 / (ncats-1) \
* np.random.uniform(low = df[feature]-shifts,
high = df[feature]+shifts)
cat_mins[feature+'_y'] = -shift/(ncats-1)
cat_maxs[feature+'_y'] = (ncats-1+shift)/(ncats-1)
else:
df[feature+'_y'] = df[feature]
cat_min = np.amin(list(cat_mins.values()))
cat_max = np.amax(list(cat_maxs.values()))
# Get min, max and range for each column
# Normalize the data for each column
min_max_range = {}
for i, feature in enumerate(features):
min_val = df[feature].min()
max_val = df[feature].max()
val_range = np.ptp(df[feature])
min_max_range[feature] = [min_val, max_val, val_range]
if feature_names[i] in categories:
y = (df[feature]-cat_mins[feature]) / (cat_maxs[feature]-cat_mins[feature])
else:
if max_val > 0.0:
y = df[feature] / max_val
else:
y = df[feature]
df[feature] = cat_min + (cat_max-cat_min) * y
# Plot each row
for i, ax in enumerate(axs):
for idx, row in df.iterrows():
category = int(row[target])
ax.plot(x, row[features], colours[category])
ax.set_xlim([i, i+1])
# Set the tick positions and labels on y axis for each plot
# Tick positions based on normalised data
# Tick labels are based on original data
def set_ticks_for_axis(dim, ax, ticks):
min_val, max_val, val_range = min_max_range[features[dim]]
feature = feature_names[dim]
if feature in categories:
tick_labels = categories[feature].values()
ncats = len(categories[feature])
ticks = [(cat_index+shift) / (ncats-1+2*shift)
for cat_index in categories[feature].keys()]
else:
if max_val > 0.0:
step = max_val / float(ticks-1)
norm_step = 1.0 / float(ticks-1)
tick_labels = [round(step * i, 2) for i in range(ticks)]
ticks = [round(norm_step * i, 2) for i in range(ticks)]
else:
step = 1.0 / float(ticks-1)
norm_step = 1.0 / float(ticks-1)
tick_labels = [round(step * i, 2) for i in range(ticks)]
ticks = [round(norm_step * i, 2) for i in range(ticks)]
ticks = [cat_min + (cat_max-cat_min)*tick for tick in ticks]
ax.set_yticks(ticks)
ax.set_yticklabels(tick_labels, backgroundcolor=(1,1,1, 0.7)) # fontweight='bold')
ax.set_ylim(bottom=-shift-0.08, top=1+shift+0.08)
for dim, ax in enumerate(axs):
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator([dim]))
set_ticks_for_axis(dim, ax, ticks=6)
ax.set_xticklabels([feature_names[dim]])
# Move the final axis' ticks to the right-hand side
ax = axs[-1].twinx()
dim = len(axs)
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator([x[-2], x[-1]]))
set_ticks_for_axis(dim, ax, ticks=6)
ax.set_xticklabels([feature_names[-2], feature_names[-1]])
ncats = len(categories[target_name])
ax.set_ylim(bottom=-shift-0.05, top=1+shift+0.05)
fig.tight_layout()
fig.subplots_adjust(wspace=0) # remove space between subplots
if fontsize: plt.rcParams['font.size'] = tmp_fontsize
return fig