Python Two-Dimensional Plotting


In this lesson, we will focus on creating scatter plots, one of the most important visualization skills for a data scientist. First, we will introduce the scatter plot, and how to interpret different visual relationships. Second, we will use a scatter plot on the Iris data set to identify the relationships between the different features. Finally, we will explore multiple pairs of features simultaneously by using multiple sub-plots and the Seaborn pairplot function.

Before commencing with a discussion of scatter plots, however, we first apply our standard notebook opening to ensure plots are displayed inline and the necessary modules are imported correctly.


In [1]:
# Set up Notebook

%matplotlib inline

# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# We do this to ignore several specific Pandas warnings
import warnings
warnings.filterwarnings("ignore")

# Use default white plot style
sns.set(style="white")

Scatter Plots

Scatter plots are a useful tool to visually explore the relationship between two or more variables (or columns of data) that you may have read from a file. The number of variables used in the plot corresponds to the dimensionality of the plot. For practical purposes, most scatter plots are two-dimensional, thus we will focus on two-dimensional scatter plots in this module. In a two-dimensional scatter plot, the two variables are displayed graphically inside a two-dimensional box. The horizontal dimension is typically called the x-axis, while the vertical dimension is called the y-axis. Each point in the data file is placed within this two-dimensional box according to its particular x and y values.

While we use the names x and y for the dimensions in this plot, in practice the x and y values can be any two columns from your data file. Thus, you can use a scatter plot to identify if there is a dependence between any two variables in your data set.

To make a scatter plot in Python, we can either use the plot() method as discussed in the last module, or we can use the scatter() method. Since the scatter() method provides greater flexibility when making scatter plots, we will use it within this module. Throughout this module, we will use the random module within the NumPy library to generate artificial data for plotting purposes. Thus, every time you run this iPython notebook, or even just one cell within the notebook, you will get different data and a different plot.

Positive Correlation

In the first plot, we show a scatter plot where the vertical values, or y-axis, display an increase as the horizontal value, or x-axis, increases. This type of dependence is known as a positive correlation.

In [2]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# Now we generate something to plot. In this case, we will 
# need data that are randomly sampled from a particular function.

x = np.linspace(0,100)
y = x + np.random.uniform(-10, 10, 50)

ax.scatter(x, y) 

# Set our axis labels
ax.set_xlabel("X Axis", fontsize=14)
ax.set_ylabel("Y Axis", fontsize=14)

# Change the axis limits displayed in our plot
ax.set_xlim(-20, 120)
ax.set_ylim(-20, 120)

# Change the ticks on each axis and the corresponding 
# numerical values that are displayed
ax.set_xticks(np.arange(0, 120, 20))
ax.set_yticks(np.arange(0, 120, 20))
    
ax.set_title("A positive correlation scatter plot!", fontsize=14)

# Clean up final result
sns.despine(offset=2, trim=True)

Negative Correlation

In the second plot, we see the opposite effect, where the vertical values tend to decrease as the horizontal values increase. This type of dependence is known as a negative correlation.


In [3]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# Now we generate something to plot. In this case, we will 
# need data that are randomly sampled from a particular function.

x = np.linspace(0,100)
y = 100 - x + np.random.uniform(-10, 10, 50)

ax.scatter(x, y)

# Set our axis labels
ax.set_xlabel("X Axis", fontsize=14)
ax.set_ylabel("Y Axis", fontsize=14)

# Change the axis limits displayed in our plot
ax.set_xlim(-20, 120)
ax.set_ylim(-20, 120)

# Change the ticks on each axis and the corresponding 
# numerical values that are displayed
ax.set_xticks(np.arange(0, 120, 20))
ax.set_yticks(np.arange(0, 120, 20))
    
ax.set_title("A negative correlation scatter plot!", fontsize=14)

sns.despine(offset=2, trim=True)

Null Correlation

In many cases, a scatter plot shows no obvious trend between the two variables being plotted. In this case, we have a null (or no) correlation.


In [4]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# Now we generate something to plot. In this case, we will 
# need data that are randomly sampled from a particular function.

x = np.random.uniform(0, 100, 50)
y = np.random.uniform(0, 100, 50)

ax.scatter(x, y)

# Set our axis labels
ax.set_xlabel("X Axis", fontsize=14)
ax.set_ylabel("Y Axis", fontsize=14)

# Change the axis limits displayed in our plot
ax.set_xlim(-20, 120)
ax.set_ylim(-20, 120)

# Change the ticks on each axis and the corresponding 
# numerical values that are displayed
ax.set_xticks(np.arange(0, 120, 20))
ax.set_yticks(np.arange(0, 120, 20))
    
ax.set_title("A null correlation scatter plot!", fontsize=14)

sns.despine(offset=2, trim=True)

Outlier Detection

One final benefit of making a scatter plot is that it can be easy to identify points that are outliers, or significantly different than the typical trend shown by the majority of the points in the plot. For example, in the following plot, there are two points with low values of the x variable that have abnormally large y values, at least compared to the rest of the data points. In some cases, these points will indicate an error in data collection; while in other cases they may simply reflect a lack of knowledge about a certain part of a problem.


In [5]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# First we create 50 linearly spaced x values
x = np.linspace(0,100)

# Second, we create 50 y values that are linearly related to the x values
y = x + np.random.uniform(-10, 10, 50)
 
# Now we change two points to be outliers
y[2] = 60
y[6] = 75

ax.scatter(x, y)

# Set our axis labels
ax.set_xlabel("X Axis", fontsize=14)
ax.set_ylabel("Y Axis", fontsize=14)

# Change the axis limits displayed in our plot
ax.set_xlim(-20, 120)
ax.set_ylim(-20, 120)

# Change the ticks on each axis and the corresponding 
# numerical values that are displayed
ax.set_xticks(np.arange(0, 120, 20))
ax.set_yticks(np.arange(0, 120, 20))
    
ax.set_title("An outlier detection scatter plot!", fontsize=14)

sns.despine(offset=2, trim=True)

The ability to visually see a trend or spot outlier points in a scatter plot make them an important tool for a data scientist. In a subsequent notebook, we will explore the use of a scatter plot to compare a model to the actual data. But for now we will focus on making a scatter plot with real data. For this, we will first load the Iris data, and extract two features for comparison.


In [6]:
# Load the Iris Data set, and define our two dimensions
idf = sns.load_dataset('iris')
sl = idf['sepal_length'].as_matrix()
pl = idf['petal_length'].as_matrix()
In [7]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# Plot data
ax.scatter(sl, pl)

# Set our axis labels and title
ax.set_xlabel("Sepal Length (cm)", fontsize=14)
ax.set_ylabel("Petal Length (cm)", fontsize=14)
ax.set_title("Iris Sepal-Petal Comparison!", fontsize=18)

sns.despine(offset=2, trim=True)

These two features, Sepal Length and Petal Length display a positive correlation, in that as one becomes larger, the other does as well.


Student Exercise

In the empty Code cell below, write and execute code to make a new scatter plot for the Sepal Length and the Sepal Width features. What type of correlation do these two features display?


In [ ]:
 

Comparing Multiple Data Sets

The scatter plots shown previously have all been rather plain, but, we can compare multiple items within the same matplotlib figure. For example, we can compare different sets of data by calling scatter for each pair of features; or, alternatively, if we have a function that we wish to compare to our data, we can use the plot() method to place the function over our data plot.

In either case, the visual comparison can be improved by coloring points (or functions) differently within a scatter plot, based on a specific feature or relationship. For example, we might read three columns from a file containing age, height, and gender. By using a scatter plot of age versus height and coloring the points differently based on the gender, we can explore trends in more than two dimensions.

We employ both of these techniques in the following plot, where we display two different pairs of features (lengths and widths, respectively) for the Iris data.


In [8]:
# Extract out the two width features
sw = idf['sepal_width'].as_matrix()
pw = idf['petal_width'].as_matrix()

x = np.arange(1,8)
In [9]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# Plot the length features and the width features
ax.scatter(sl, pl, color=sns.xkcd_rgb["pale red"], marker='s')
ax.scatter(sw, pw, color=sns.xkcd_rgb["denim blue"], marker='d')
ax.plot(x, x, color='green', linestyle='dashed')


# Set our axis labels and title
ax.set_xlabel("Sepal (cm)", fontsize=14)
ax.set_ylabel("Petal (cm)", fontsize=14)
ax.set_title("Iris Sepal-Petal Comparison!", fontsize=18)

sns.despine(offset=2, trim=True)

In the previous plot, we not only displayed two different types of data that were differentiated by their color, we also changed the type of marker used for each point. The different color options that can be selected for use in matplotlib figure is quite large. When choosing colors, be sure to keep an eye on the overall design of your plot. Certain colors go better together than others, and you want to ensure viewers focus on the information content of your visualizations. Likewise, there are a number of different marker types you can use within your plots.

For a full description of the options available to the scatter() method, see the appropriate matplotlib documentation.

Finally, when we over-plotted the function, y = x, we used the plot() method and specified both the color, as well as the linestyle. The matplotlib documentation provides more information on the linestyle and other options for the plot function. Notice that the abbreviations exist for many of these option combinations. For example, you can specify a green dashed line, by using plot(x, x, 'g--').

Labeling Data

When over-plotting multiple data or functions, it is generally a good idea to label them so the viewer can quickly understand the differences and to create a legend that provides a mapping between the labels and related points or functions. This can be easily done via matplotlib by simply adding the label ='' attribute to each plotting command. For example, we can modify our previous three plotting commands to have descriptive labels:

ax.scatter(sl, pl, color=sns.xkcd_rgb["pale red"], marker='s', label='Length')
ax.scatter(sw, pw, color=sns.xkcd_rgb["denim blue"], marker='d', label='Width')
ax.plot(x, x, color='green', linestyle='dashed', label='Function')

The legend() method can be used to label these data within our plot, as shown below.


In [10]:
# Now we create our figure and axes for the plot we will make.
fig, ax = plt.subplots()

# Plot the length features and the width features
ax.scatter(sl, pl, color=sns.xkcd_rgb["pale red"], marker='s', label='Length')
ax.scatter(sw, pw, color=sns.xkcd_rgb["denim blue"], marker='d', label='Width')
ax.plot(x, x, color='green', linestyle='dashed', label='Function')

# Add legend
ax.legend(loc='upper left', fontsize=14)

# Set our axis labels and title
ax.set_xlabel("Sepal (cm)", fontsize=14)
ax.set_ylabel("Petal (cm)", fontsize=14)
ax.set_title("Iris Sepal-Petal Comparison!", fontsize=18)

sns.despine(offset=2, trim=True)

In this plot, we now have multiple data over-plotted, with a legend that allows the viewer to understand the differences between the different plot components. Notice how we first increased the range on the y-axis to provide room for the legend, and second how we specified the location of the legend by using the loc attribute, which in this case we set to 'upper center'. By default, the legend() method will display all plot components that have a distinct label assigned. You can however, control the behaviors of this method in a number of different manners, as detailed in the matplotlib documentation.


Student Exercise

In the empty Code cell below, write and execute code to make a new plot that compares the two Sepal features to each other and over-plots a comparison of the two Petal features. Be sure to update the axes, title, and labels appropriately. What type of correlation do these two features display?


In [ ]:
 

Multiple Scatter Plots

We can compare multiple variables (or data columns) by using the subplot functionality within matplotlib. This allows us to make a scatterplot spreadsheet, where different variables are compared in different plots. When creating subplots in matplotlib, we use the add_subplot() method as described in a previous module. We demonstrate this within the following code, which generates separate scatter plots for the Sepal and _Petal features.

While powerful, this technique can lead to confusion if not done properly. Care should be taken to make sure the axes labels are not overlapping the plot ticks. You can use the subplots_adjust() method to provide extra space for either the width, height, or both via the wspace and hspace attributes.


In [11]:
# Define the plot layout
fig, axs = plt.subplots(figsize=(10, 4.0), nrows=2, ncols=1, sharex=True)
fig.subplots_adjust(hspace = 1.0)

# Define our labels
lbl = ['Sepal', 'Petal']

# Plot the length features and the width features
axs[0].scatter(sl, sw, color=sns.xkcd_rgb["pale red"], marker='s'),
axs[1].scatter(pl, pw, color=sns.xkcd_rgb["denim blue"], marker='d')

for idx in range(len(axs)):
    # Set our axis labels
    axs[idx].set_xlabel("Length (cm)", fontsize=14)
    axs[idx].set_ylabel("Width (cm)", fontsize=14)
    axs[idx].set_title("Iris {} Comparison".format(lbl[idx]), fontsize=18)

    sns.despine(ax = axs[idx], offset=2, trim=True)

Student Exercise

In the empty Code cell below, write and execute code to makes four scatter plots in two rows with two columns: Sepal Length versus Sepal Width, Petal Length versus Sepal Width, Sepal Length versus Petal Width, and Petal Length versus Petal Width. Be sure to update the axes, title, and labels appropriately, and also to adjust the hspace and vspace in the subplots_adjust method appropriately. What type of correlation do these two features display?


In [ ]:
 

Pair Plots

As the last code example demonstrated, quickly creating and displaying scatter plots of paired features from a data set can be instructive. Yet, doing this can be taxing, especially for larger data sets. Fortunately, the Seaborn module provides a convenience function, called pairplot, to quickly and easily make a visually comparative spreadsheet for a data set. The rows and columns in this visual spreadsheet, which forms a square, correspond to different features. The diagonal elements are generally plotted in a different manner, such as a histogram, since these elements would otherwise have the same feature on both axes.

To use the pairplot function, we simply pass in the Pandas DataFrame of interest, and by default the different numerical features are compared. We demonstrate this in the code example below, which also uses the hue parameter to this function to indicate a categorical feature within the DataFrame that should be used to distinguish (via a color) the different types of data. Other parameters can be specified to control the overall appearance of the visualization, see the pairplot documentation for more details.


In [12]:
# Generate the visual spreadsheet
sns.pairplot(idf, hue='species') ;

Ancillary Information

The following links are to additional documentation that you might find helpful in learning this material. Reading these web-accessible documents is completely optional.

  1. Matplotlib tutorial on working with multiple subplots
  2. Matplotlib tutorial on legends
  3. Seaborn tutorial on plotting pairwise relationships
  4. The online documentation for the scatter method in Matplotlib.

© 2017: Robert J. Brunner at the University of Illinois.

This notebook is released under the Creative Commons license CC BY-NC-SA 4.0. Any reproduction, adaptation, distribution, dissemination or making available of this notebook for commercial use is not allowed unless authorized in writing by the copyright holder.