How to do a stack plot with Matplotlib

2 min readMatplotlibPandasDataFrame
7-Day Challenge

Land Your First Data Science Job

A proven roadmap to prepare for $75K+ entry-level data roles. Perfect for Data Scientist ready to level up their career.

Build portfolios that hiring managers love
Master the Python and SQL essentials to be industry-ready
Practice with real interview questions from tech companies
Access to the $100k/y Data Scientist Cheatsheet

Join thousands of developers who transformed their careers through our challenge. Unsubscribe anytime.

Stack plots are extremely useful when you want to compare multi-series data on the same axis.

Especially when you want to check whether there is some kind of correlation between those variables.

Here is a simple example of a stacked plot, using the matplotlib library.

import matplotlib.pyplot as plt
import pandas as pd

# We set our x labels
months = ['Jan', 'Feb', 'Mar', 
          'Apr', 'May', 'Jun',
          'Jul', 'Aug', 'Sep',
          'Oct', 'Nov', 'Dec']

# We add some random data
sales_per_month = {
    '2018': [2221, 2315, 2455,
             2304, 2670, 2181,
             2768, 1897, 2488,
             2456, 1915, 2759],
    '2019': [3969, 3009, 3949,
             4077, 3228, 3339,
             3565, 3278, 3389,
             2422, 3451, 4095],
    '2020': [5222, 3875, 5132,
             3872, 4592, 5685,
             4289, 3517, 5243,
             4794, 4693, 4324],
    '2021': [10161, 8268, 3540,
             10256, 10409, 9525,
             10560, 10390, 10432,
             11617, 10323, 15200],
}

# We set our canvas
fig, axes = plt.subplots(1,1, figsize=(8,6))

# We do a line plot on the axes
axes.stackplot(months, 
               sales_per_month.values(),
               labels=sales_per_month.keys())

# We set a title
axes.set_title("Sales per month")

# Change the labels
axes.set_xlabel("Month")
axes.set_ylabel("In USD")

# Add the legend
axes.legend(loc='upper left')

# Fixing the layout to fit the size
fig.tight_layout()

# Showing the plot
plt.show()
Plotting a stacked plot

As we can see we are using the axes.stackplot() method that will plot a stackplot given a list of x values and multiple y as pd.Series.

(e.g. axes.stackplot(x, y1, y2, y3, ...))

In this example, we are plotting the sales per month in dollars and comparing it per year.

Here is the result.

Sales per month per year

Here you are ! You now know how to make stackplots.

More on plots

If you want to know more about how to add labels, plot different types of plots, etc... check out the other articles I wrote on the topic, just here :

Matplotlib - The Python You Need
We gathered the only Python essentials that you will probably ever need.

7-Day Challenge

Land Your First Data Science Job

A proven roadmap to prepare for $75K+ entry-level data roles. Perfect for Data Scientist ready to level up their career.

Build portfolios that hiring managers love
Master the Python and SQL essentials to be industry-ready
Practice with real interview questions from tech companies
Access to the $100k/y Data Scientist Cheatsheet

Join thousands of developers who transformed their careers through our challenge. Unsubscribe anytime.

Free Newsletter

Master Data Science in Days, Not Months 🚀

Skip the theoretical rabbit holes. Get practical data science skills delivered in bite-sized lessons – Approach used by real data scientist. Not bookworms. 📚

Weekly simple and practical lessons
Access to ready to use code examples
Skip the math, focus on results
Learn while drinking your coffee

By subscribing, you agree to receive our newsletter. You can unsubscribe at any time.