How to make good graphs with Matplotlib using Python

3 min

The main job of a Data scientist is to transform chaotic data into actionable decisions.

In other words, try explain complex data in a way that everyone understands.

Your reader, will have no time and no energy to try to understand what's going on your graph.

This is why, you will have to do your best to make self-explanatory graphs.

Here are 5 rules which are essential if you want to make outstanding graphs:

  1. Keep it simple
  2. Keep the variables simple to understand.
  3. Don't over-engineer it.
  4. Would Grandma understand?

Let me show you what I mean by that.

The example

You are an e-commerce owner.

On your e-shop you can retrieve orders data.

That data contains, the quantity, the product title, the product category, the product vendor, the order total amount and the client email.

As an e-commerce owner questions that could come up are the following:

  1. What is the most sold product ?
  2. What is the average client LTV ?
  3. etc...

Let's take the first question as an example.

What is the most sold product ?

  1. What you shouldn't do.
# To work with dataframes
import pandas as pd

# To plot stuff
import matplotlib.pyplot as plt

# We read a sample DataFrame
df = pd.read_csv("https://thepythonyouneed.nyc3.cdn."\
                 "digitaloceanspaces.com/orders-example.csv",
                 parse_dates=True)

# We group by product title and sum up the quantity
# We then sort by values in descending order
# We select the top 10
# We plot
df.groupby("TITLE")["QUANTITY"].sum().sort_values(ascending=False).plot(kind="bar")
A type of graph that gives a great overview but is way to crowded to have any impact

Here, your graph is tiring. First, it's too small, second it doesn't contain a title, third, do we really care about what is the most sold product after the top 10?

What you should do instead.

## Checkout thepythonyouneed.com for more snippets!

# To work with dataframes
import pandas as pd

# To plot stuff
import matplotlib.pyplot as plt

# We read a sample DataFrame
df = pd.read_csv("https://thepythonyouneed.nyc3.cdn."\
                 "digitaloceanspaces.com/orders-example.csv",
                 parse_dates=True)

# We group by product title and sum up the quantity
# We then sort by values in descending order
# Then we select the top 10
df_top_10_sold = df.groupby("TITLE")["QUANTITY"].sum()\
                               .sort_values(ascending=False)\
                               .iloc[:10]


# We create our canvas on which we are going to plot
fig, axes = plt.subplots(1,1, figsize=(8,6))

# We plot the product name vs the quantity sold
axes.bar(x=df_top_10_sold.index, height=df_top_10_sold)

# We change the labels orientation
axes.set_xticklabels(df_top_10_sold.index, rotation=45, ha='right')

# We change the labels & title
axes.set_ylabel("Total Quantity Sold")
axes.set_xlabel("Product Title")
axes.set_title(f"Top 10 product sold")

# We plot for every product the amount sold at the top of the bar
for i, v in enumerate(df_top_10_sold):
    axes.text(i - .2, v + 1, str(v), color='black')

# We show the grid
axes.grid()

# We tidy the layout
fig.tight_layout()

# We show the plot
plt.show()
A type of graph that is not crowded and simple to understand

Here you are! Here is a glance at what you should do to make clearer graphs.

More on Matplotlib

If you like what you've just read and want to know more about the Matplotlib library (e.g. how to add labels, plot different types of plots, etc...) check out the other articles I wrote on the topic, just here :

Matplotlib - The Python You Need
We gathered the only Python essentials that you will probably ever need.