Pivot Tables
Pivot tables are very much similar to what we experienced in spreadsheets. The difference between pivot tables and
GroupBy
function: “Pivot table is essentially a multi-dimensional version of GroupBy aggregation." — that is, you split-apply-combine, but both the split and the combine happen across not a one-dimensional index, but across a two-dimensional grid.import numpy as np
import pandas as pd
import seaborn as sns
🛳 Titanic dataset for demonstration
# importing dataset for demonstration
titanic = pd.read_csv('data/titanic.csv')
titanic.head()
Text | survived | pclass | sex | age | fare | embarked | who | embark_town | alive | alone |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 3 | male | 22.0 | 7.2500 | S | man | Southampton | no | False |
1 | 1 | 1 | female | 38.0 | 71.2833 | C | woman | Cherbourg | yes | False |
2 | 1 | 3 | female | 26.0 | 7.9250 | S | woman | Southampton | yes | True |
3 | 1 | 1 | female | 35.0 | 53.1000 | S | woman | Southampton | yes | False |
4 | 0 | 3 | male | 35.0 | 8.0500 | S | man | Southampton | no | True |
Essentially:
- group(split) by
sex
, - select
survived
, and, - apply
mean
titanic.groupby('sex')['survived'].mean()
sex
female 0.742038
male 0.188908
Name: survived, dtype: float64
Essentially;
- group(split) by
sex
&pclass
, - select
survived
column, and, - apply
mean
aggregate
titanic.groupby(['sex','pclass'])['survived'].mean()
sex pclass
female 1 0.968085
2 0.921053
3 0.500000
male 1 0.368852
2 0.157407
3 0.135447
Name: survived, dtype: float64
# unstack the result for better presentation
titanic.groupby(['sex','pclass'])['survived'].mean().unstack()
pclass 1 2 3
sex
female 0.968085 0.921053 0.500000
male 0.368852 0.157407 0.135447
**Conclusion: ** Though we can apply two-dimensional Groupby but the code will start to look long-to-read and understand. Pandas have better tool,
pivot_table
, to deal with this.The above two-dimensional GroupBy result can be easily derived from following
pivot_table
code. We will use .pivot_table()
constructor, whose default aggfunc
is np.mean
titanic.pivot_table('survived', index='sex', columns='pclass')
pclass 1 2 3
sex
female 0.968085 0.921053 0.500000
male 0.368852 0.157407 0.135447
We can also get same result without mentioning the
index
and column
kwargstitanic.pivot_table('survived', 'sex', 'pclass')
pclass 1 2 3
sex
female 0.968085 0.921053 0.500000
male 0.368852 0.157407 0.135447
Let suppose, we want to group by
age
, sex
and get the survived
mean
value by each pclass
. But instead of a using each age value as separate group, we will make age_groups
. To do this, we will first use pd.cut
function to make the segment for age
column. To make age segments, first let see min
and max
age
in our dataset:print(f"Min Age: {titanic['age'].min()}")
print(f"Max Age: {titanic['age'].max()}")
Min Age: 0.42
Max Age: 80.0
Lets make two age group:
0-18
and 18-80
age_group = pd.cut(titanic['age'], [0,18,80])
age_group.head()
0 (18, 80]
1 (18, 80]
2 (18, 80]
3 (18, 80]
4 (18, 80]
Name: age, dtype: category
Categories (2, interval[int64]): [(0, 18] < (18, 80]]
Now, we will apply
pivot_table
on sex
and age
(through newly created age_group
) Other variables will stay the same — finding survived
mean
value for each pclass
titanic.pivot_table('survived', index=['sex',age_group], columns='pclass')
pclass 1 2 3
sex age
female (0, 18] 0.909091 1.000000 0.511628
(18, 80] 0.972973 0.900000 0.423729
male (0, 18] 0.800000 0.600000 0.215686
(18, 80] 0.375000 0.071429 0.133663
Paramter | Default |
---|---|
values= | None |
index= | None |
aggfunc= | ‘mean’ |
margins= | False |
dropna= | True |
margins_name= | ‘all’ |
Let suppose, we want to know the
sum
of survived
and mean
of fare
columns, in each pclass
titanic.pivot_table(index='sex',columns='pclass', aggfunc={'survived': sum, 'fare': 'mean'})
# omitted the values keyword;
# when you’re specifying a mapping for aggfunc, this is determined automatically.
fare survived
pclass 1 2 3 1 2 3
sex
female 106.125798 21.970121 16.118810 91 70 72
male 67.226127 19.741782 12.661633 45 17 47
This simple property
margins=True
computes sum along each column and rowtitanic.pivot_table('survived', index='sex', columns='pclass', margins=True)
pclass 1 2 3 All
sex
female 0.968085 0.921053 0.500000 0.742038
male 0.368852 0.157407 0.135447 0.188908
All 0.629630 0.472826 0.242363 0.383838
Overall, approx. 38% people on board survived
- First, load the dataset using Pandas
read_csv
function - Then we view the head of the dataset,
.head()
to get initial sense of dataset - To find total rows and columns in the dataset, we will use
.shape
method
births = pd.read_csv('data/births.csv')
print(births.head())
print(births.shape)
year month day gender births
0 1969 1 1.0 F 4046
1 1969 1 1.0 M 4440
2 1969 1 2.0 F 4454
3 1969 1 2.0 M 4548
4 1969 1 3.0 F 4548
(15547, 5)
1️⃣ Finding
sum
of births
in each month
, across each gender
births.pivot_table('births', index='month', columns='gender', aggfunc='sum', margins=True)
gender F M All
month
1 6035447 6328750 12364197
2 5634064 5907114 11541178
3 6181613 6497231 12678844
4 5889345 6196546 12085891
5 6145186 6479786 12624972
6 6093026 6428044 12521070
7 6512299 6855257 13367556
8 6600723 6927284 13528007
9 6473029 6779802 13252831
10 6330549 6624401 12954950
11 5956388 6241579 12197967
12 6184154 6472761 12656915
All 74035823 77738555 151774378
Plotting the results
# using matplotlib to draw figure of
# sum of births in each month, across each gender
# magic function (%matplotlib) to make the plot appear and store in notebook
%matplotlib inline
import matplotlib.pyplot as plt
sns.set() # set seaborn styles
births.pivot_table('births', index='month', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births in each month');
2️⃣ Finding
sum
of births
in each decade
, across each gender
# adding a decade column
births['decade'] = 10 * (births['year'] // 10 ) # //10 will remove the last digit in year
# creating pivot table for total births, in each decade, along each gender type
print(births.pivot_table('births', index='decade', columns='gender', aggfunc='sum', margins=True))
gender F M All
decade
1960 1753634 1846572 3600206
1970 16263075 17121550 33384625
1980 18310351 19243452 37553803
1990 19479454 20420553 39900007
2000 18229309 19106428 37335737
All 74035823 77738555 151774378
Let’s put this table into figure
# using matplotlib to draw figure of
# sum of births in each decade, across each gender
# magic function (%matplotlib) to make the plot appear and store in notebook
%matplotlib inline
sns.set() # set seaborn styles
births.pivot_table('births', index='year', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births per year');
Last modified 4mo ago