[seaborn] 04/14: Refactor common setup for boxplot and violinplot

Andreas Tille tille at debian.org
Fri Jan 20 15:00:38 UTC 2017


This is an automated email from the git hooks/post-receive script.

tille pushed a commit to tag v0.2.1
in repository seaborn.

commit 864b5765bea812dcd02cb0103732ac8d2fa2936c
Author: mwaskom <mwaskom at stanford.edu>
Date:   Wed Dec 25 19:56:59 2013 -0800

    Refactor common setup for boxplot and violinplot
---
 seaborn/distributions.py | 275 +++++++++++++++++++++++------------------------
 1 file changed, 132 insertions(+), 143 deletions(-)

diff --git a/seaborn/distributions.py b/seaborn/distributions.py
index 0504af8..159d409 100644
--- a/seaborn/distributions.py
+++ b/seaborn/distributions.py
@@ -6,7 +6,6 @@ import pandas as pd
 import statsmodels.api as sm
 import matplotlib as mpl
 import matplotlib.pyplot as plt
-from six import string_types
 from six.moves import range
 import moss
 
@@ -14,6 +13,108 @@ from seaborn.utils import (color_palette, husl_palette, blend_palette,
                            desaturate, _kde_support)
 
 
+def _box_reshape(vals, groupby, names, order):
+    """Reshape the box/violinplot input options and find plot labels."""
+
+    # Set up default label outputs
+    xlabel, ylabel = None, None
+
+    # If order is provided, make sure it was used correctly
+    if order is not None:
+        # Assure that order is the same length as names, if provided
+        if names is not None:
+            if len(order) != len(names):
+                raise ValueError("`order` must have same length as `names`")
+        # Assure that order is only used with the right inputs
+        is_pd = isinstance(vals, pd.Series) or isinstance(vals, pd.DataFrame)
+        if not is_pd:
+            raise ValueError("`vals` must be a Pandas object to use `order`.")
+
+    # Handle case where data is a wide DataFrame
+    if isinstance(vals, pd.DataFrame):
+        if order is not None:
+            vals = vals[order]
+        if names is None:
+            names = vals.columns.tolist()
+        if vals.columns.name is not None:
+            xlabel = vals.columns.name
+        vals = vals.values.T
+
+    # Handle case where data is a long Series and there is a grouping object
+    elif isinstance(vals, pd.Series) and groupby is not None:
+        groups = pd.groupby(vals, groupby).groups
+        order = sorted(groups) if order is None else order
+        if hasattr(groupby, "name"):
+            if groupby.name is not None:
+                xlabel = groupby.name
+        if vals.name is not None:
+            ylabel = vals.name
+        vals = [vals.reindex(groups[name]) for name in order]
+        if names is None:
+            names = order
+
+    else:
+
+        # Handle case where the input data is an array or there was no groupby
+        if hasattr(vals, 'shape'):
+            if len(vals.shape) == 1:
+                if np.isscalar(vals[0]):
+                    vals = [vals]
+                else:
+                    vals = list(vals)
+            elif len(vals.shape) == 2:
+                nr, nc = vals.shape
+                if nr == 1:
+                    vals = [vals]
+                elif nc == 1:
+                    vals = [vals.ravel()]
+                else:
+                    vals = [vals[:, i] for i in range(nc)]
+            else:
+                error = "Input `vals` can have no more than 2 dimensions"
+                raise ValueError(error)
+
+        # This should catch things like flat lists
+        elif np.isscalar(vals[0]):
+            vals = [vals]
+
+        # By default, just use the plot positions as names
+        if names is None:
+            names = list(range(1, len(vals) + 1))
+        elif hasattr(names, "name"):
+            if names.name is not None:
+                xlabel = names.name
+
+    # Now convert vals to a common representation
+    # The plotting functions will work with a list of arrays
+    # The list allows each array to possibly be of a different length
+    vals = [np.asarray(a, np.float) for a in vals]
+
+    return vals, xlabel, ylabel, names
+
+
+def _box_colors(vals, color):
+    """Find colors to use for boxplots or violinplots."""
+    if color is None:
+        colors = husl_palette(len(vals), l=.7)
+    else:
+        try:
+            color = mpl.colors.colorConverter.to_rgb(color)
+            colors = [color for _ in vals]
+        except ValueError:
+                colors = color_palette(color, len(vals))
+
+    # Desaturate a bit because these are patches
+    colors = [mpl.colors.colorConverter.to_rgb(c) for c in colors]
+    colors = [desaturate(c, .7) for c in colors]
+
+    # Determine the gray color for the lines
+    light_vals = [colorsys.rgb_to_hls(*c)[1] for c in colors]
+    l = min(light_vals) * .6
+    gray = (l, l, l)
+
+    return colors, gray
+
 def boxplot(vals, groupby=None, names=None, join_rm=False, order=None,
             color=None, alpha=None, fliersize=3, linewidth=1.5, widths=.8,
             ax=None, **kwargs):
@@ -21,17 +122,18 @@ def boxplot(vals, groupby=None, names=None, join_rm=False, order=None,
 
     Parameters
     ----------
-    vals : DataFrame, sequence of vectors, or Series.
-        Data for plot. DataFrames are assuemd to be "wide" with each column
-        mapping to a box. Other two-dimensional data is assumed to be a
-        sequence where each item is the data that will go into a box. Can
-        also provide one long Series in conjunction with a grouping element
-        as the `groupy` parameter.
+    vals : DataFrame, Series, 2D array, list of vectors, or vector.
+        Data for plot. DataFrames and 2D arrays are assuemd to be "wide" with
+        each column mapping to a box. Lists of data are assumed to have one
+        element per box.  Can also provide one long Series in conjunction with
+        a grouping element as the `groupy` parameter to reshape the data into
+        several boxes. Otherwise 1D data will produce a single box.
     groupby : grouping object
         If `vals` is a Series, this is used to group into boxes by calling
         pd.groupby(vals, groupby).
     names : list of strings, optional
-        Names to plot on x axis, otherwise plots numbers.
+        Names to plot on x axis; otherwise plots numbers. This will override
+        names inferred from Pandas inputs.
     order : list of strings, optional
         If vals is a Pandas object with name information, you can control the
         order of the boxes by providing the box names in your preferred order.
@@ -59,60 +161,14 @@ def boxplot(vals, groupby=None, names=None, join_rm=False, order=None,
     if ax is None:
         ax = plt.gca()
 
-    # Handle case where data is a wide DataFrame
-    if isinstance(vals, pd.DataFrame):
-        if vals.columns.name is not None:
-            xlabel = vals.columns.name
-        else:
-            xlabel = None
-        if order is not None:
-            vals = vals[order]
-        if names is None:
-            names = vals.columns
-        vals = vals.values
-        ylabel = None
-
-    # Handle case where data is a long Series and there is a grouping object
-    elif isinstance(vals, pd.Series) and groupby is not None:
-        if names is None:
-            names = np.sort(pd.unique(groupby))
-        order = names if order is None else order
-        if hasattr(groupby, "name"):
-            xlabel = groupby.name
-        ylabel = vals.name
-        groups = pd.groupby(vals, groupby).groups
-        vals = [vals.reindex(groups[name]) for name in order]
-    else:
-        xlabel = None
-        ylabel = None
+    # Reshape and find labels for the plot
+    vals, xlabel, ylabel, names = _box_reshape(vals, groupby, names, order)
 
     # Draw the boxplot using matplotlib
     boxes = ax.boxplot(vals, patch_artist=True, widths=widths, **kwargs)
-    vals = np.atleast_2d(vals).T
-
-    # Sort out the inner box color
-    if color is None:
-        colors = husl_palette(len(vals), l=.7)
-    else:
-        color_is_color = (not isinstance(color, string_types)
-                          and not isinstance(color, tuple))
-        if color_is_color:
-            colors = color
-        else:
-            try:
-                color = mpl.colors.colorConverter.to_rgb(color)
-                colors = [color for _ in vals]
-            except ValueError:
-                colors = color_palette(color, len(vals))
 
-    # Desaturate a bit because these are patches
-    colors = [mpl.colors.colorConverter.to_rgb(c) for c in colors]
-    colors = [desaturate(c, .7) for c in colors]
-
-    # Determine the gray color for the lines
-    light_vals = [colorsys.rgb_to_hls(*c)[1] for c in colors]
-    l = min(light_vals) * .6
-    gray = (l, l, l)
+    # Find plot colors
+    colors, gray = _box_colors(vals, color)
 
     # Set the new aesthetics
     for i, box in enumerate(boxes["boxes"]):
@@ -142,19 +198,17 @@ def boxplot(vals, groupby=None, names=None, join_rm=False, order=None,
 
     # Draw the joined repeated measures
     if join_rm:
-        x, y = np.arange(1, len(vals.T) + 1), vals.T
+        x, y = np.arange(1, len(np.transpose(vals)) + 1), np.transpose(vals)
         if not vertical:
             x, y = y, x
         ax.plot(x, y, color=gray, alpha=2. / 3)
 
     # Label the axes and ticks
-    if not vertical:
+    if vertical:
+        ax.set_xticklabels(names)
+    else:
+        ax.set_yticklabels(names)
         xlabel, ylabel = ylabel, xlabel
-    if names is not None:
-        if vertical:
-            ax.set_xticklabels(names)
-        else:
-            ax.set_yticklabels(names)
     if xlabel is not None:
         ax.set_xlabel(xlabel)
     if ylabel is not None:
@@ -186,12 +240,12 @@ def violinplot(vals, groupby=None, inner="box", color=None, positions=None,
 
     Parameters
     ----------
-    vals : DataFrame, sequence of vectors, or Series.
-        Data for plot. DataFrames are assuemd to be "wide" with each column
-        mapping to a box. Other two-dimensional data is assumed to be a
-        sequence where each item is the data that will go into a box. Can
-        also provide one long Series in conjunction with a grouping element
-        as the `groupy` parameter.
+    vals : DataFrame, Series, 2D array, or list of vectors.
+        Data for plot. DataFrames and 2D arrays are assuemd to be "wide" with
+        each column mapping to a box. Lists of data are assumed to have one
+        element per box.  Can also provide one long Series in conjunction with
+        a grouping element as the `groupy` parameter to reshape the data into
+        several violins. Otherwise 1D data will produce a single violins.
     groupby : grouping object
         If `vals` is a Series, this is used to group into boxes by calling
         pd.groupby(vals, groupby).
@@ -202,10 +256,12 @@ def violinplot(vals, groupby=None, inner="box", color=None, positions=None,
     positions : number or sequence of numbers
         Position of first violin or positions of each violin.
     names : list of strings, optional
-        Names to plot on x axis; otherwise plots numbers.
+        Names to plot on x axis; otherwise plots numbers. This will override
+        names inferred from Pandas inputs.
     order : list of strings, optional
         If vals is a Pandas object with name information, you can control the
-        order of the plot by providing the violin names in your preferred order.
+        order of the plot by providing the violin names in your preferred
+        order.
     kernel : {'gau' | 'cos' | 'biw' | 'epa' | 'tri' | 'triw' }
         Code for shape of kernel to fit with.
     bw : {'scott' | 'silverman' | scalar}
@@ -237,78 +293,11 @@ def violinplot(vals, groupby=None, inner="box", color=None, positions=None,
     if ax is None:
         ax = plt.gca()
 
-    # Find existing names
-    if isinstance(vals, pd.DataFrame):
-        if vals.columns.name is not None:
-            xlabel = vals.columns.name
-        else:
-            xlabel = None
-        if order is not None:
-            vals = vals[order]
-        if names is None:
-            names = vals.columns
-        vals = vals.values
-        ylabel = None
-
-    # Possibly perform a group-by to get the batches
-    elif isinstance(vals, pd.Series) and groupby is not None:
-        if names is None:
-            names = np.sort(pd.unique(groupby))
-        order = names if order is None else order
-        if hasattr(groupby, "name"):
-            xlabel = groupby.name
-        ylabel = vals.name
-        groups = pd.groupby(vals, groupby).groups
-        vals = [vals.reindex(groups[name]) for name in order]
-    else:
-        xlabel = None
-        ylabel = None
-
-    # Handle the input data (from pyplot.boxplot)
-    if hasattr(vals, 'shape'):
-        if len(vals.shape) == 1:
-            if hasattr(vals[0], 'shape'):
-                vals = list(vals)
-            else:
-                vals = [vals]
-        elif len(vals.shape) == 2:
-            nr, nc = vals.shape
-            if nr == 1:
-                vals = [vals]
-            elif nc == 1:
-                vals = [vals.ravel()]
-            else:
-                vals = [vals[:, i] for i in range(nc)]
-        else:
-            raise ValueError("Input x can have no more than 2 dimensions")
-    if not hasattr(vals[0], '__len__'):
-        vals = [vals]
-
-    vals = [np.asarray(a, float) for a in vals]
+    # Reshape and find labels for the plot
+    vals, xlabel, ylabel, names = _box_reshape(vals, groupby, names, order)
 
     # Sort out the plot colors
-    if color is None:
-        colors = husl_palette(len(vals), l=.7)
-    else:
-        color_is_color = (not isinstance(color, string_types)
-                          and not isinstance(color, tuple))
-        if color_is_color:
-            colors = color
-        else:
-            try:
-                color = mpl.colors.colorConverter.to_rgb(color)
-                colors = [color for _ in vals]
-            except ValueError:
-                colors = color_palette(color, len(vals))
-
-    # Use somewhat desaturated colors because we're drawing patches
-    colors = [mpl.colors.colorConverter.to_rgb(c) for c in colors]
-    colors = [desaturate(c, .7) for c in colors]
-
-    # Find the shade of gray for lines
-    light_vals = [colorsys.rgb_to_hls(*c)[1] for c in colors]
-    l = min(light_vals) * .6
-    gray = (l, l, l)
+    colors, gray = _box_colors(vals, color)
 
     # Initialize the kwarg dict for the inner plot
     if inner_kws is None:

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/seaborn.git



More information about the debian-science-commits mailing list