# -*- coding: utf-8 -*-
"""
DEPRECATED: Use seaborn instead
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import ubelt as ub
import six
import warnings
from six.moves import zip_longest
from . import mpl_core
__all__ = ['multi_plot']
# import xdev # NOQA
# @xdev.profile # NOQA
[docs]def multi_plot(xdata=None, ydata=None, xydata=None, **kwargs):
r"""
plots multiple lines, bars, etc...
One function call that concisely describes the all of the most commonly
used parameters needed when plotting a bar / line char. This is especially
useful when multiple plots are needed in the same domain.
Args:
xdata (List[ndarray] | Dict[str, ndarray] | ndarray):
x-coordinate data common to all y-coordinate values or xdata for
each line/bar in ydata. Mutually exclusive with xydata.
ydata (List[ndarray] | Dict[str, ndarray] | ndarray):
y-coordinate values for each line/bar to plot. Can also
be just a single ndarray of scalar values. Mutually exclusive with
xydata.
xydata (Dict[str, Tuple[ndarray, ndarray]]):
mapping from labels to a tuple of xdata and ydata for a each line.
**kwargs:
fnum (int):
figure number to draw on
pnum (Tuple[int, int, int]):
plot number to draw on within the figure, e.g. (1, 1, 1)
label (List|Dict):
if you specified ydata as a List[ndarray] this is the label for
each line in that list. Note this is unnecessary if you specify
input as a dictionary mapping labels to lines.
color (str|List|Dict):
either a special color code, a single color, or a color for
each item in ydata. In the later case, this should be specified
as either a list or a dict depending on how ydata was
specified.
marker (str|List|Dict):
type of matplotlib marker to use at every data point. Can be
specified for all lines jointly or for each line independently.
transpose (bool, default=False):
swaps x and y data.
kind (str, default='plot'):
The kind of plot. Can either be 'plot' or 'bar'.
We parse these other kwargs if:
if kind='plot':
spread
if kind='bar':
stacked, width
Misc:
use_legend (bool): ...
legend_loc (str):
one of 'best', 'upper right', 'upper left', 'lower left',
'lower right', 'right', 'center left', 'center right',
'lower center', or 'upper center'.
Layout:
xlabel (str): label for x-axis
ylabel (str): label for y-axis
title (str): title for the axes
figtitle (str): title for the figure
xscale (str): can be one of [linear, log, logit, symlog]
yscale (str): can be one of [linear, log, logit, symlog]
xlim (Tuple[float, float]): low and high x-limit of axes
ylim (Tuple[float, float]): low and high y-limit of axes
xmin (float): low x-limit of axes, mutex with xlim
xmax (float): high x-limit of axes, mutex with xlim
ymin (float): low y-limit of axes, mutex with ylim
ymax (float): high y-limit of axes, mutex with ylim
titlesize (float): ...
legendsize (float): ...
labelsize (float): ...
Grid:
gridlinewidth (float): ...
gridlinestyle (str): ...
Ticks:
num_xticks (int): number of x ticks
num_yticks (int): number of y ticks
tickwidth (float): ...
ticklength (float): ...
ticksize (float): ...
xticklabels (list): list of x-tick labels, overrides num_xticks
yticklabels (list): list of y-tick labels, overrides num_yticks
xtick_rotation (float): xtick rotation in degrees
ytick_rotation (float): ytick rotation in degrees
Data:
spread (List | Dict): Plots a spread around plot lines usually
indicating standard deviation
markersize (float|List|Dict): marker size for all or each plot
markeredgewidth (float|List|Dict): marker edge width for all or each plot
linewidth (float|List|Dict): line width for all or each plot
linestyle (str|List|Dict): line style for all or each plot
Note:
any plot_kw key can be a scalar (corresponding to all ydatas),
a list if ydata was specified as a list, or a dict if ydata was
specified as a dict.
plot_kw_keys = ['label', 'color', 'marker', 'markersize',
'markeredgewidth', 'linewidth', 'linestyle']
Note:
In general this should be deprecated in favor of using seaborn
Returns:
matplotlib.axes.Axes: ax : the axes that was drawn on
References:
matplotlib.org/examples/api/barchart_demo.html
Example:
>>> import kwplot
>>> kwplot.autompl()
>>> # The new way to use multi_plot is to pass ydata as a dict of lists
>>> ydata = {
>>> 'spamΣ': [1, 1, 2, 3, 5, 8, 13],
>>> 'eggs': [3, 3, 3, 3, 3, np.nan, np.nan],
>>> 'jamµ': [5, 3, np.nan, 1, 2, np.nan, np.nan],
>>> 'pram': [4, 2, np.nan, 0, 0, np.nan, 1],
>>> }
>>> ax = kwplot.multi_plot(ydata=ydata, title='ΣΣΣµµµ',
>>> xlabel='\nfdsΣΣΣµµµ', linestyle='--')
>>> kwplot.show_if_requested()
Example:
>>> # Old way to use multi_plot is a list of lists
>>> import kwplot
>>> kwplot.autompl()
>>> xdata = [1, 2, 3, 4, 5]
>>> ydata_list = [[1, 2, 3, 4, 5], [3, 3, 3, 3, 3], [5, 4, np.nan, 2, 1], [4, 3, np.nan, 1, 0]]
>>> kwargs = {'label': ['spamΣ', 'eggs', 'jamµ', 'pram'], 'linestyle': '-'}
>>> #ax = multi_plot(xdata, ydata_list, title='$\phi_1(\\vec{x})$', xlabel='\nfds', **kwargs)
>>> ax = multi_plot(xdata, ydata_list, title='ΣΣΣµµµ', xlabel='\nfdsΣΣΣµµµ', **kwargs)
>>> kwplot.show_if_requested()
Example:
>>> # Simple way to use multi_plot is to pass xdata and ydata exactly
>>> # like you would use plt.plot
>>> import kwplot
>>> kwplot.autompl()
>>> ax = multi_plot([1, 2, 3], [4, 5, 6], fnum=4, label='foo')
>>> kwplot.show_if_requested()
Example:
>>> import kwplot
>>> kwplot.autompl()
>>> xydata = {'a': ([0, 1, 2], [0, 1, 2]), 'b': ([0, 2, 4], [2, 1, 0])}
>>> ax = kwplot.multi_plot(xydata=xydata, fnum=4)
>>> kwplot.show_if_requested()
Example:
>>> import kwplot
>>> kwplot.autompl()
>>> ydata = {'a': [0, 1, 2], 'b': [1, 2, 1], 'c': [4, 4, 4, 3, 2]}
>>> kwargs = {
>>> 'spread': {'a': [.2, .3, .1], 'b': .2},
>>> 'xlim': (-1, 5),
>>> 'xticklabels': ['foo', 'bar'],
>>> 'xtick_rotation': 90,
>>> }
>>> ax = kwplot.multi_plot(ydata=ydata, fnum=4, **kwargs)
>>> kwplot.show_if_requested()
Ignore:
>>> import kwplot
>>> kwplot.autompl()
>>> ydata = {
>>> str(i): np.random.rand(100) + i for i in range(30)
>>> }
>>> ax = kwplot.multi_plot(ydata=ydata, fnum=1, doclf=True)
>>> kwplot.show_if_requested()
"""
import matplotlib as mpl
from matplotlib import pyplot as plt
# Initial integration with mpl rcParams standards
mplrc = mpl.rcParams
# mplrc.update({
# # 'legend.fontsize': custom_figure.LEGEND_SIZE,
# # 'legend.framealpha':
# # 'axes.titlesize': custom_figure.TITLE_SIZE,
# # 'axes.labelsize': custom_figure.LABEL_SIZE,
# # 'legend.facecolor': 'w',
# # 'font.family': 'sans-serif',
# # 'xtick.labelsize': custom_figure.TICK_SIZE,
# # 'ytick.labelsize': custom_figure.TICK_SIZE,
# })
if 'rcParams' in kwargs:
mplrc = mplrc.copy()
mplrc.update(kwargs['rcParams'])
if xydata is not None:
if xdata is not None or ydata is not None:
raise ValueError('Cannot specify xydata with xdata or ydata')
if isinstance(xydata, dict):
xdata = ub.odict((k, np.array(xy[0])) for k, xy in xydata.items())
ydata = ub.odict((k, np.array(xy[1])) for k, xy in xydata.items())
else:
raise ValueError('Only supports xydata as Dict at the moment')
if bool('label' in kwargs) and bool('label_list' in kwargs):
raise ValueError('Specify either label or label_list')
if isinstance(ydata, dict):
# Case where ydata is a dictionary
if isinstance(xdata, six.string_types):
# Special-er case where xdata is specified in ydata
xkey = xdata
ykeys = set(ydata.keys()) - {xkey}
xdata = ydata[xkey]
else:
ykeys = list(ydata.keys())
# Normalize input into ydata_list
ydata_list = list(ub.take(ydata, ykeys))
default_label_list = kwargs.pop('label', ykeys)
kwargs['label_list'] = kwargs.get('label_list', default_label_list)
else:
# ydata should be a List[ndarray] or an ndarray
ydata_list = ydata
ykeys = None
# allow ydata_list to be passed without a container
if is_list_of_scalars(ydata_list):
ydata_list = [np.array(ydata_list)]
if xdata is None:
xdata = list(range(max(map(len, ydata_list))))
num_lines = len(ydata_list)
# Transform xdata into xdata_list
if isinstance(xdata, dict):
xdata_list = [np.array(xdata[k], copy=True) for k in ykeys]
elif is_list_of_lists(xdata):
xdata_list = [np.array(xd, copy=True) for xd in xdata]
else:
xdata_list = [np.array(xdata, copy=True)] * num_lines
fnum = mpl_core.ensure_fnum(kwargs.get('fnum', None))
pnum = kwargs.get('pnum', None)
kind = kwargs.get('kind', 'plot')
transpose = kwargs.get('transpose', False)
def parsekw_list(key, kwargs, num_lines=num_lines, ykeys=ykeys,
default=ub.NoParam):
"""
Return properties that corresponds with ydata_list.
Searches kwargs for several keys based on the base key and finds either
a scalar, list, or dict and coerces this into a list of properties that
corresonds with the ydata_list.
"""
if key in kwargs:
val_list = kwargs[key]
elif key + '_list' in kwargs:
# warnings.warn('*_list is depricated, just use kwarg {}'.format(key))
val_list = kwargs[key + '_list']
elif key + 's' in kwargs:
# hack, multiple ways to do something
warnings.warn('*s depricated, just use kwarg {}'.format(key))
val_list = kwargs[key + 's']
else:
val_list = None
if val_list is not None:
if isinstance(val_list, dict):
# Extract propertly ordered dictionary values
if ykeys is None:
raise ValueError(
'Kwarg {!r} was a dict, but ydata was not'.format(key))
else:
if default is ub.NoParam:
val_list = [val_list[key] for key in ykeys]
else:
val_list = [val_list.get(key, default) for key in ykeys]
if not isinstance(val_list, list):
# Coerce a scalar value into a list
val_list = [val_list] * num_lines
return val_list
if kind == 'plot':
if 'marker' not in kwargs:
# kwargs['marker'] = mplrc['lines.marker']
kwargs['marker'] = 'distinct'
# kwargs['marker'] = 'cycle'
if isinstance(kwargs['marker'], six.string_types):
if kwargs['marker'] == 'distinct':
kwargs['marker'] = mpl_core.distinct_markers(num_lines)
elif kwargs['marker'] == 'cycle':
# Note the length of marker and linestyle cycles should be
# relatively prime.
# https://matplotlib.org/api/markers_api.html
marker_cycle = ['.', '*', 'x']
kwargs['marker'] = [marker_cycle[i % len(marker_cycle)] for i in range(num_lines)]
# else:
# raise KeyError(kwargs['marker'])
if 'linestyle' not in kwargs:
# kwargs['linestyle'] = 'distinct'
kwargs['linestyle'] = mplrc['lines.linestyle']
# kwargs['linestyle'] = 'cycle'
if isinstance(kwargs['linestyle'], six.string_types):
if kwargs['linestyle'] == 'cycle':
# https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html
linestyle_cycle = ['solid', 'dashed', 'dashdot', 'dotted']
kwargs['linestyle'] = [linestyle_cycle[i % len(linestyle_cycle)] for i in range(num_lines)]
if 'color' not in kwargs:
# kwargs['color'] = 'jet'
# kwargs['color'] = 'gist_rainbow'
kwargs['color'] = 'distinct'
if isinstance(kwargs['color'], six.string_types):
if kwargs['color'] == 'distinct':
kwargs['color'] = mpl_core.distinct_colors(num_lines, randomize=0)
else:
cm = plt.get_cmap(kwargs['color'])
kwargs['color'] = [cm(i / num_lines) for i in range(num_lines)]
# Parse out arguments to ax.plot
plot_kw_keys = ['label', 'color', 'marker', 'markersize',
'markeredgewidth', 'linewidth', 'linestyle', 'alpha']
# hackish / extra args that dont directly get passed to plt.plot
extra_plot_kw_keys = ['spread_alpha', 'autolabel', 'edgecolor', 'fill']
plot_kw_keys += extra_plot_kw_keys
plot_ks_vals = [parsekw_list(key, kwargs) for key in plot_kw_keys]
plot_list_kw = dict([
(key, vals)
for key, vals in zip(plot_kw_keys, plot_ks_vals) if vals is not None
])
if kind == 'plot':
if 'spread_alpha' not in plot_list_kw:
plot_list_kw['spread_alpha'] = [.2] * num_lines
if kind == 'bar':
# Remove non-bar kwargs
for key in ['markeredgewidth', 'linewidth', 'marker', 'markersize', 'linestyle']:
plot_list_kw.pop(key, None)
stacked = kwargs.get('stacked', False)
width_key = 'height' if transpose else 'width'
if 'width_list' in kwargs:
plot_list_kw[width_key] = kwargs['width_list']
else:
width = kwargs.get('width', .9)
# if width is None:
# # HACK: need variable width
# # width = np.mean(np.diff(xdata_list[0]))
# width = .9
if not stacked:
width /= num_lines
#plot_list_kw['orientation'] = ['horizontal'] * num_lines
plot_list_kw[width_key] = [width] * num_lines
spread_list = parsekw_list('spread', kwargs, default=None)
# nest into a list of dicts for each line in the multiplot
valid_keys = list(set(plot_list_kw.keys()) - set(extra_plot_kw_keys))
valid_vals = list(ub.take(plot_list_kw, valid_keys))
plot_kw_list = [dict(zip(valid_keys, vals)) for vals in zip(*valid_vals)]
extra_kw_keys = [key for key in extra_plot_kw_keys if key in plot_list_kw]
extra_kw_vals = list(ub.take(plot_list_kw, extra_kw_keys))
extra_kw_list = [dict(zip(extra_kw_keys, vals)) for vals in zip(*extra_kw_vals)]
# Get passed in axes or setup a new figure
ax = kwargs.get('ax', None)
if ax is None:
# NOTE: This is slow, can we speed it up?
doclf = kwargs.get('doclf', False)
fig = mpl_core.figure(fnum=fnum, pnum=pnum, docla=False, doclf=doclf)
ax = fig.gca()
else:
plt.sca(ax)
fig = ax.figure
# +---------------
# Draw plot lines
ydata_list = [np.array(ydata) for ydata in ydata_list]
if transpose:
if kind == 'bar':
plot_func = ax.barh
elif kind == 'plot':
def plot_func(_x, _y, **kw):
return ax.plot(_y, _x, **kw)
else:
plot_func = getattr(ax, kind) # usually ax.plot
if len(ydata_list) > 0:
# raise ValueError('no ydata')
_iter = enumerate(zip_longest(xdata_list, ydata_list, plot_kw_list, extra_kw_list))
for count, (_xdata, _ydata, plot_kw, extra_kw) in _iter:
_ydata = _ydata[0:len(_xdata)]
_xdata = _xdata[0:len(_ydata)]
ymask = np.isfinite(_ydata)
ydata_ = _ydata.compress(ymask)
xdata_ = _xdata.compress(ymask)
if kind == 'bar':
if stacked:
# Plot bars on top of each other
xdata_ = xdata_
else:
# Plot bars side by side
baseoffset = (width * num_lines) / 2
lineoffset = (width * count)
offset = baseoffset - lineoffset # Fixeme for more histogram bars
xdata_ = xdata_ - offset
# width_key = 'height' if transpose else 'width'
# plot_kw[width_key] = np.diff(xdata)
objs = plot_func(xdata_, ydata_, **plot_kw)
if kind == 'bar':
if extra_kw is not None and 'edgecolor' in extra_kw:
for rect in objs:
rect.set_edgecolor(extra_kw['edgecolor'])
if extra_kw is not None and extra_kw.get('autolabel', False):
# FIXME: probably a more cannonical way to include bar
# autolabeling with tranpose support, but this is a hack that
# works for now
for rect in objs:
if transpose:
numlbl = width = rect.get_width()
xpos = width + ((_xdata.max() - _xdata.min()) * .005)
ypos = rect.get_y() + rect.get_height() / 2.
ha, va = 'left', 'center'
else:
numlbl = height = rect.get_height()
xpos = rect.get_x() + rect.get_width() / 2.
ypos = 1.05 * height
ha, va = 'center', 'bottom'
barlbl = '%.3f' % (numlbl,)
ax.text(xpos, ypos, barlbl, ha=ha, va=va)
if kind == 'plot' and extra_kw.get('fill', False):
ax.fill_between(_xdata, ydata_, alpha=plot_kw.get('alpha', 1.0),
color=plot_kw.get('color', None)) # , zorder=0)
if spread_list is not None:
# Plots a spread around plot lines usually indicating standard
# deviation
_xdata = np.array(_xdata)
_spread = spread_list[count]
if _spread is not None:
if not ub.iterable(_spread):
_spread = [_spread] * len(ydata_)
ydata_ave = np.array(ydata_)
y_data_dev = np.array(_spread)
y_data_max = ydata_ave + y_data_dev
y_data_min = ydata_ave - y_data_dev
ax = plt.gca()
spread_alpha = extra_kw['spread_alpha']
ax.fill_between(_xdata, y_data_min, y_data_max, alpha=spread_alpha,
color=plot_kw.get('color', None)) # , zorder=0)
ydata = _ydata # HACK
xdata = _xdata # HACK
# L________________
#max_y = max(np.max(y_data), max_y)
#min_y = np.min(y_data) if min_y is None else min(np.min(y_data), min_y)
if transpose:
#xdata_list = ydata_list
ydata = xdata
# Hack / Fix any transpose issues
def transpose_key(key):
if key.startswith('x'):
return 'y' + key[1:]
elif key.startswith('y'):
return 'x' + key[1:]
elif key.startswith('num_x'):
# hackier, fixme to use regex or something
return 'num_y' + key[5:]
elif key.startswith('num_y'):
# hackier, fixme to use regex or something
return 'num_x' + key[5:]
else:
return key
kwargs = {transpose_key(key): val for key, val in kwargs.items()}
# Setup axes labeling
title = kwargs.get('title', None)
xlabel = kwargs.get('xlabel', '')
ylabel = kwargs.get('ylabel', '')
def none_or_unicode(text):
return None if text is None else ub.ensure_unicode(text)
xlabel = none_or_unicode(xlabel)
ylabel = none_or_unicode(ylabel)
title = none_or_unicode(title)
titlesize = kwargs.get('titlesize', mplrc['axes.titlesize'])
labelsize = kwargs.get('labelsize', mplrc['axes.labelsize'])
legendsize = kwargs.get('legendsize', mplrc['legend.fontsize'])
xticksize = kwargs.get('ticksize', mplrc['xtick.labelsize'])
yticksize = kwargs.get('ticksize', mplrc['ytick.labelsize'])
family = kwargs.get('fontfamily', mplrc['font.family'])
tickformat = kwargs.get('tickformat', None)
ytickformat = kwargs.get('ytickformat', tickformat)
xtickformat = kwargs.get('xtickformat', tickformat)
# 'DejaVu Sans','Verdana', 'Arial'
weight = kwargs.get('fontweight', None)
if weight is None:
weight = 'normal'
labelkw = {
'fontproperties': mpl.font_manager.FontProperties(
weight=weight,
family=family, size=labelsize)
}
ax.set_xlabel(xlabel, **labelkw)
ax.set_ylabel(ylabel, **labelkw)
tick_fontprop = mpl.font_manager.FontProperties(family=family,
weight=weight)
if tick_fontprop is not None:
# NOTE: This is slow, can we speed it up?
for ticklabel in ax.get_xticklabels():
ticklabel.set_fontproperties(tick_fontprop)
for ticklabel in ax.get_yticklabels():
ticklabel.set_fontproperties(tick_fontprop)
if xticksize is not None:
for ticklabel in ax.get_xticklabels():
ticklabel.set_fontsize(xticksize)
if yticksize is not None:
for ticklabel in ax.get_yticklabels():
ticklabel.set_fontsize(yticksize)
if xtickformat is not None:
# mpl.ticker.StrMethodFormatter # new style
# mpl.ticker.FormatStrFormatter # old style
ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter(xtickformat))
if ytickformat is not None:
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter(ytickformat))
xtick_kw = ytick_kw = {
'width': kwargs.get('tickwidth', None),
'length': kwargs.get('ticklength', None),
}
xtick_kw = {k: v for k, v in xtick_kw.items() if v is not None}
ytick_kw = {k: v for k, v in ytick_kw.items() if v is not None}
ax.xaxis.set_tick_params(**xtick_kw)
ax.yaxis.set_tick_params(**ytick_kw)
#ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%d'))
# Setup axes limits
if 'xlim' in kwargs:
xlim = kwargs['xlim']
if xlim is not None:
if 'xmin' not in kwargs and 'xmax' not in kwargs:
kwargs['xmin'] = xlim[0]
kwargs['xmax'] = xlim[1]
else:
raise ValueError('use xmax, xmin instead of xlim')
if 'ylim' in kwargs:
ylim = kwargs['ylim']
if ylim is not None:
if 'ymin' not in kwargs and 'ymax' not in kwargs:
kwargs['ymin'] = ylim[0]
kwargs['ymax'] = ylim[1]
else:
raise ValueError('use ymax, ymin instead of ylim')
xmin = kwargs.get('xmin', ax.get_xlim()[0])
xmax = kwargs.get('xmax', ax.get_xlim()[1])
ymin = kwargs.get('ymin', ax.get_ylim()[0])
ymax = kwargs.get('ymax', ax.get_ylim()[1])
text_type = six.text_type
if text_type(xmax) == 'data':
xmax = max([xd.max() for xd in xdata_list])
if text_type(xmin) == 'data':
xmin = min([xd.min() for xd in xdata_list])
# Setup axes ticks
num_xticks = kwargs.get('num_xticks', None)
num_yticks = kwargs.get('num_yticks', None)
if num_xticks is not None:
if xdata.dtype.kind == 'i':
xticks = np.linspace(np.ceil(xmin), np.floor(xmax),
num_xticks).astype(np.int32)
else:
xticks = np.linspace((xmin), (xmax), num_xticks)
ax.set_xticks(xticks)
if num_yticks is not None:
if ydata.dtype.kind == 'i':
yticks = np.linspace(np.ceil(ymin), np.floor(ymax),
num_yticks).astype(np.int32)
else:
yticks = np.linspace((ymin), (ymax), num_yticks)
ax.set_yticks(yticks)
force_xticks = kwargs.get('force_xticks', None)
if force_xticks is not None:
xticks = np.array(sorted(ax.get_xticks().tolist() + force_xticks))
ax.set_xticks(xticks)
yticklabels = kwargs.get('yticklabels', None)
if yticklabels is not None:
# Hack ONLY WORKS WHEN TRANSPOSE = True
# Overrides num_yticks
missing_labels = max(len(ydata) - len(yticklabels), 0)
yticklabels_ = yticklabels + [''] * missing_labels
ax.set_yticks(ydata)
ax.set_yticklabels(yticklabels_)
xticklabels = kwargs.get('xticklabels', None)
if xticklabels is not None:
# Overrides num_xticks
missing_labels = max(len(xdata) - len(xticklabels), 0)
xticklabels_ = xticklabels + [''] * missing_labels
ax.set_xticks(xdata)
ax.set_xticklabels(xticklabels_)
xticks = kwargs.get('xticks', None)
if xticks is not None:
ax.set_xticks(xticks)
yticks = kwargs.get('yticks', None)
if yticks is not None:
ax.set_yticks(yticks)
xtick_rotation = kwargs.get('xtick_rotation', None)
if xtick_rotation is not None:
[lbl.set_rotation(xtick_rotation)
for lbl in ax.get_xticklabels()]
ytick_rotation = kwargs.get('ytick_rotation', None)
if ytick_rotation is not None:
[lbl.set_rotation(ytick_rotation)
for lbl in ax.get_yticklabels()]
# Axis padding
xpad = kwargs.get('xpad', None)
ypad = kwargs.get('ypad', None)
xpad_factor = kwargs.get('xpad_factor', None)
ypad_factor = kwargs.get('ypad_factor', None)
if xpad is None and xpad_factor is not None:
xpad = (xmax - xmin) * xpad_factor
if ypad is None and ypad_factor is not None:
ypad = (ymax - ymin) * ypad_factor
xpad = 0 if xpad is None else xpad
ypad = 0 if ypad is None else ypad
ypad_high = kwargs.get('ypad_high', ypad)
ypad_low = kwargs.get('ypad_low', ypad)
xpad_high = kwargs.get('xpad_high', xpad)
xpad_low = kwargs.get('xpad_low', xpad)
xmin, xmax = (xmin - xpad_low), (xmax + xpad_high)
ymin, ymax = (ymin - ypad_low), (ymax + ypad_high)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
xscale = kwargs.get('xscale', None)
yscale = kwargs.get('yscale', None)
if yscale is not None:
ax.set_yscale(yscale)
if xscale is not None:
ax.set_xscale(xscale)
gridlinestyle = kwargs.get('gridlinestyle', None)
gridlinewidth = kwargs.get('gridlinewidth', None)
gridlines = ax.get_xgridlines() + ax.get_ygridlines()
if gridlinestyle:
for line in gridlines:
line.set_linestyle(gridlinestyle)
if gridlinewidth:
for line in gridlines:
line.set_linewidth(gridlinewidth)
# Setup title
if title is not None:
titlekw = {
'fontproperties': mpl.font_manager.FontProperties(
family=family,
weight=weight,
size=titlesize)
}
ax.set_title(title, **titlekw)
use_legend = kwargs.get('use_legend', 'label' in valid_keys)
legend_loc = kwargs.get('legend_loc', mplrc['legend.loc'])
legend_alpha = kwargs.get('legend_alpha', mplrc['legend.framealpha'])
if use_legend:
legendkw = {
'alpha': legend_alpha,
'fontproperties': mpl.font_manager.FontProperties(
family=family,
weight=weight,
size=legendsize)
}
mpl_core.legend(loc=legend_loc, ax=ax, **legendkw)
figtitle = kwargs.get('figtitle', None)
if figtitle is not None:
# mplrc['figure.titlesize'] TODO?
mpl_core.set_figtitle(figtitle, fontfamily=family, fontweight=weight,
size=kwargs.get('figtitlesize'))
# TODO: return better info
return ax
def is_listlike(data):
try:
import pandas as pd
flag = isinstance(data, (list, np.ndarray, tuple, pd.Series))
except Exception:
flag = isinstance(data, (list, np.ndarray, tuple))
flag &= hasattr(data, '__getitem__') and hasattr(data, '__len__')
return flag
def is_list_of_scalars(data):
if is_listlike(data):
if len(data) > 0 and not is_listlike(data[0]):
return True
return False
def is_list_of_lists(data):
if is_listlike(data):
if len(data) > 0 and is_listlike(data[0]):
return True
return False