Plotly continuous error band lineplot

For some time I have been using plotly_express.line in order to generate dynamic plots. I personally think that this is a great tool that has an API very similar to the one from seaborn. Unfortunately, I missed the possibility to plot continuous error bands with plotly in the same manner that seaborn does when multiple samples are present for each category. For this purpose I created a modules that does this computation in the same manner that seaborn does. This module actually takes parts from the core of seaborn to compute the standard deviation and the confidence interval. Let’s get into the code to see how it’s performance.

Code implementation in python

In order to use this module we need to import ploty_express, numpy, seaborn, itertools, plotly and pandas

import plotly_express as px
import seaborn as sns
import pandas as pd
import numpy as np
from seaborn.relational import bootstrap
from seaborn.utils import ci as ci_func
import matplotlib.pyplot as plt
import itertools
import plotly.graph_objs as go


def line(data_frame: pd.DataFrame,
         x="timepoint",
         y="signal",
         ci=95,
         estimator="mean",
         n_boot=1000,
         sort=True,
         color="region",
         error_y=None,
         error_y_minus=None,
         line_dash="event",
         **kwargs):
    cat_vars = {"color": color, "line_dash": line_dash}
    cat_vars = {key: cat_vars[key] for key in cat_vars if cat_vars[key] is not None}
    variables = dict(x=x, y=y)
    variables.update(cat_vars)
    data = pd.DataFrame()
    data['x'] = data_frame.get(variables.get('x'))
    data['y'] = data_frame.get(variables.get('y'))
    for key in cat_vars:
        data[key] = data_frame[cat_vars[key]]
    if error_y is not None:
        data["error_y"] = data_frame[error_y]
    if error_y_minus is not None:
        data["error_y_minus"] = data_frame[error_y_minus]
    grouping_vars = list(cat_vars.keys())
    x_l = []
    y_l = []
    low_l = []
    high_l = []
    color_l = []
    dash_l = []
    for sub_vars, sub_data in iter_data(data=data, variables=variables, grouping_vars=grouping_vars):
        if sort:
            sort_vars = ["units", "x", "y"]
            sort_cols = [var for var in sort_vars if var in variables]
            sub_data = sub_data.sort_values(sort_cols)
        sub_data = sub_data.dropna()
        null = pd.Series(index=data.index, dtype=float)
        x_v = sub_data.get("x", null)
        y_v = sub_data.get("y", null)
        if error_y is None:
            x_v, y_v, y_ci = aggregate(vals=y_v, grouper=x_v, estimator=estimator, ci=ci, n_boot=n_boot, sort=sort)
            low_l += list(y_ci["low"].values)
            high_l += list(y_ci["high"].values)
        elif error_y_minus is None:
            low_l += list(y_v + sub_data.get("error_y").values)
            high_l += list(y_v - sub_data.get("error_y").values)
        else:
            low_l += list(y_v + sub_data.get("error_y_minus").values)
            high_l += list(y_v - sub_data.get("error_y").values)
        x_l += list(x_v)
        y_l += list(y_v)
        if sub_vars.get("color"):
            color_l += [sub_vars.get("color") for _ in x_v]
        if sub_vars.get("line_dash"):
            dash_l += [sub_vars.get("line_dash") for _ in x_v]

    agg = {variables.get('x'): x_l, variables.get('y'): y_l, "low": low_l, "high": high_l}
    if color_l:
        agg[variables.get("color")] = color_l
    if dash_l:
        agg[variables.get("line_dash")] = dash_l
    agg = pd.DataFrame(agg)
    fig = px.line(data_frame=agg, x=x, y=y, color=color, line_dash=line_dash, **kwargs)
    traces = fig.data
    for tr in traces:
        name = tr.name
        categories = name.replace(' ', '').split(',')
        tmp_df = get_subdata(agg, categories)
        y_high = list(tmp_df.high.values)
        y_low = list(tmp_df.low.values)
        band = go.Scatter(
            name=name,
            legendgroup=tr.legendgroup,
            x=list(tr.x) + list(tr.x[::-1]),  # x, then x reversed
            y=y_high + y_low[::-1],  # upper, then lower reversed
            fill='toself',
            fillcolor=hex_to_rgba(tr.line.color),
            line=dict(color=hex_to_rgba(tr.line.color)),
            hoverinfo="skip",
            showlegend=False
        )
        fig.add_trace(band)
    return fig

Now we need to define a method that iterates over the categorical classes defined by color and line_dash.

def iter_data(data, variables, grouping_vars=None, reverse=False):
    if grouping_vars is None:
        grouping_vars = []
    elif isinstance(grouping_vars, str):
        grouping_vars = [grouping_vars]
    elif isinstance(grouping_vars, tuple):
        grouping_vars = list(grouping_vars)

    # Always insert faceting variables
    facet_vars = {"col", "row"}
    grouping_vars.extend(
        facet_vars & set(variables) - set(grouping_vars)
    )

    # Reduce to the semantics used in this plot
    grouping_vars = [
        var for var in grouping_vars if var in variables
    ]

    if grouping_vars:

        grouped_data = data.groupby(
            grouping_vars, sort=False, as_index=False
        )

        grouping_keys = []
        for var in grouping_vars:
            grouping_keys.append(var_levels(data, var).get(var, []))

        iter_keys = itertools.product(*grouping_keys)
        if reverse:
            iter_keys = reversed(list(iter_keys))

        for key in iter_keys:

            # Pandas fails with singleton tuple inputs
            pd_key = key[0] if len(key) == 1 else key

            try:
                data_subset = grouped_data.get_group(pd_key)
            except KeyError:
                continue

            sub_vars = dict(zip(grouping_vars, key))

            yield sub_vars, data_subset

    else:

        yield {}, data

def var_levels(data, variables):
    _var_levels = dict()
    if isinstance(variables, str):
        variables = [variables]
    for var in variables:
        try:
            _var_levels[var] = list(data.get(var).unique())
            _var_levels[var].sort()
        except AttributeError:
            pass
    return _var_levels

Once we have created the data groups for each category, we can aggregate them to compute the standard deviation or the confidence interval with bootstrap. This will define the error band to be plotted.

def aggregate(vals, grouper, estimator="mean", ci="sd", n_boot=1000, seed=None, sort=True):
    """Compute an estimate and confidence interval using grouper."""
    func = estimator

    # Define a "null" CI for when we only have one value
    null_ci = pd.Series(index=["low", "high"], dtype=float)

    # Function to bootstrap in the context of a pandas group by
    def bootstrapped_cis(vals):

        if len(vals) <= 1:
            return null_ci

        boots = bootstrap(vals, func=func, n_boot=n_boot, seed=seed)
        cis = ci_func(boots, ci)
        return pd.Series(cis, ["low", "high"])

    # Group and get the aggregation estimate
    grouped = vals.groupby(grouper, sort=sort)
    est = grouped.agg(func)

    # Exit early if we don't want a confidence interval
    if ci is None:
        return est.index, est, None

    # Compute the error bar extents
    if ci == "sd":
        sd = grouped.std()
        cis = pd.DataFrame(np.c_[est - sd, est + sd],
                           index=est.index,
                           columns=["low", "high"]).stack()
    else:
        cis = grouped.apply(bootstrapped_cis)

    # Unpack the CIs into "wide" format for plotting
    if cis.notnull().any():
        cis = cis.unstack().reindex(est.index)
    else:
        cis = None

    return est.index, est, cis

At the end of our line method we iterate over the data in the generated plotly figure in order to get the trace names, and the line color, such that the error bans have the same color but with a different transparency. To transform the colors we convert the hex color-code defined by plotly to RGBA. First we get the data corresponding to each trace, and then we compute the color.

def get_subdata(df: pd.DataFrame, categories: list):
    if isinstance(categories, str):
        categories = [categories]
    for cat in categories:
        col = df.columns[df.isin([cat]).any()]
        if len(col) != 1:
            raise ValueError(f"Can not infer column name from categorical value {cat}, Two columns should not have "
                             f"the same element in both columns for column name inference to succeed")
        else:
            col = col[0]
        tmp_df = df[df[col] == cat]
        categories.remove(cat)
        if not categories:
            return tmp_df
        else:
            return get_subdata(tmp_df, categories)
def hex_to_rgba(color: str, transparency=0.2):
    hex = color.lstrip('#')
    hlen = len(hex)
    rgba = tuple(int(hex[i:i+hlen//3], 16) for i in range(0, hlen, hlen//3)) + (transparency,)
    return f'rgba{str(rgba)}'

This addition of new traces to the created figure is not really elegant, but it is the only solution I could find without having to modify the core of plotly-express.

Let’s see now how the plot might look with an example.

fmri = sns.load_dataset('fmri')
figure = line(data_frame=fmri, x="timepoint", y="signal", color="region", line_dash="event")
figure.write_html("plot.html")
figure = sns.lineplot(data=fmri, x="timepoint", y="signal", hue="region", style="event")
plt.savefig("plot.png")

We can see now that the results are very similar.

In addition, if you have already computed the errors and are included in the dataframe as columns, you can pass this parameter in order to override the aggregation functionality. Let’s see this in the same example as before, but now we aggregate the results manually before passing it to the plotting function. You can also pass error_y_minus to the plot function in case the error band should be asymmetric.

fmri = sns.load_dataset('fmri')
fmri = fmri.groupby(["region", "event", "timepoint"], as_index=False).agg(["mean", "std"]).reset_index()
fmri.columns = ["region", "event", "timepoint", "signal", "std"]

figure = line(data_frame=fmri, x="timepoint", y="signal", color="region", line_dash="event", error_y="std")
figure.write_html("plot_agg.html")
figure = sns.lineplot(data=fmri, x="timepoint", y="signal", hue="region", style="event", ci="std")
plt.savefig("plot_agg.png")

Finally, we can profile the performance of each method to see how much time it takes to plot the results using each implementation. You can see that our custom implementation (with error bands) takes nearly 4 times longer than plotly (without error bands) and around 3 times longer than seaborn. This is due to the fact that we are iterating over the already created figure and updating the traces of the figure in order to generate the error bands. The plotting time is still very small and not really noticeable at this scale. The plotting time would be definitely reduced if such implementation would be implemented at the core level of plotly-express instead of having to update the figure.

from timeit import timeit
print(timeit("custom lineplot (sec):", """line(data_frame=fmri, x="timepoint", y="signal", color="region", line_dash="event", ci="sd")""", globals=globals(), number=1))
print(timeit("plotly lineplot (sec):", """px.line(data_frame=fmri, x="timepoint", y="signal", color="region", line_dash="event")""", globals=globals(), number=1))
print(timeit("seaborn lineplot (sec):", """sns.lineplot(data=fmri, x="timepoint", y="signal", hue="region", style="event", ci="sd")""", globals=globals(), number=1))
custom lineplot (sec): 0.18508239000038884
plotly lineplot (sec): 0.03941967600076168
seaborn lineplot (sec): 0.06663737300004868
Written by

Leonardo Ayala

I am physicist, turned chemist, turned teacher, turned data scientist, ... and I like to write, mainly about science, but also about many random facts that I find interesting.