Customer Churn Prediction – Understanding Models with Feature Permutation Importance using Python

One of the primary goals of many service companies is to build solid and long-lasting relationships with their customers. Customers who churn cost a lot of money. Modern service companies who understand which customers tend to end the business relationship can take appropriate countermeasures early on to retain them. The prerequisite for this is that the provider can identify churn candidates among their customer base. Once we have trained a model, we can use a technique called Permutation Feature Importance to understand the relationships between model predictions and features. Analyzing trained machine learnings models can provide exciting insights into market events and business logic. This tutorial will show how Feature Importance works and how to implement it.

machine learning. It is particularly effective when combined with feature permutation importance
Customer churn prediction is a compelling use case for machine learning. It is particularly effective when combined with feature permutation importance.

The rest of this article is structured as follows. First, we look at the business problem; We will also speak about permutation feature importance – an excellent technique to identify the essential features of our machine learning model. Then we turn to the coding part and implement a churn prediction model in Python. We will train a classification model and select the most promising parameters using hyperparameter tuning. We then use this model to predict the churn probabilities of a set of test customers. Finally, we create a ranking of features according to their impact on model performance. In this way, this article demonstrates how we can use permutation feature importance to gain insight into the relationship between input variables and model predictions.

Churn Prediction – What’s the Business Case?

A company’s effort to persuade a new customer to sign a contract is many times higher than the costs incurred in retaining existing customers. According to industry experts, it is four times more expensive to win a new customer than keep an existing one. Providers that can identify churn candidates in advance and manage to retain them can significantly reduce costs.

A crucial point is whether the provider succeeds in getting the churn candidates to stay. Sometimes it may be enough to contact the churn candidate and inquire about customer satisfaction. In other cases, this may not be enough, and the provider needs to increase the service value, for example, by offering free services or granting a discount. However, actions should be well thought out, as they can also negatively affect. For instance, if a customer hardly ever uses his contract, a call from the provider may even increase the desire to cancel the contract. Machine learning can help assess cases individually and identify the optimal anti-churn action.

About Permutation Feature Importance

Feature importance is a helpful technique to understand the contribution of input variables (features) to a predictive model. The results from this technique can be as valuable as the predictions themselves, as they can help us understand the business context better. For example, let’s say we have trained a model that predicts which of our customers will likely churn. Wouldn’t it be interesting to know why specific customers are more likely to churn than others? Permutation feature importance can help us answer this question by providing us with a ranking of the input variables in our model by their usefulness. The ranking can validate assumptions about the business context and uncover causal relations in the data.

One of the most significant advantages of traditional prediction models, such as a decision tree, compared to neural networks is their interpretability. Neural networks are black boxes because it is tough to understand the relationship between input and model predictions. In traditional models, on the other hand, we can calculate the meaning of the features and use it to interpret the model and optimize its performance, for example, by removing features from the model that are not important. We, therefore, start with a simple model first and move on to more complex models once we understand the data.

Implementing a Customer Churn Prediction Model in Python

In the following, we will implement a customer churn prediction model. We will train a decision forest model on a data set from Kaggle and optimize it using grid search. The data contains customer-level information for a telecom provider and a binary prediction label of which customers canceled their contracts and did not. Finally, we will calculate the feature importance to understand how the model works.

The code is available on the GitHub repository.


Before starting the coding part, make sure that you have set up your Python 3 environment and required packages.

Python Environment

If you don’t have an environment set up yet, you can follow this tutorial to set up the Anaconda environment.

Python Packages

Make sure you install all required packages. In this tutorial, we will be working with the following packages: 

  • Pandas
  • NumPy
  • Matplotlib
  • Seaborn

In addition, we will be using Keras (2.0 or higher) with Tensorflow backend, the machine learning library Scikit-learn.

You can install packages using console commands:

  • pip install <package name>
  • conda install <package name> (if you are using the anaconda packet manager)

Step #1 Loading the Customer Churn Data

We begin by loading a customer churn dataset from Kaggle. After completing the download, put the dataset under the file path of your choice, but don’t forget to adjust the file path variable in the code. If you work with the Kaggle Python environment, you can also directly save the dataset into your Kaggle project.

The dataset contains 3333 records and the following attributes.

  • Churn: The prediction label: 1 if the customer canceled service, 0 if not.
  • AccountWeeks: number of weeks the customer has had an active account
  • ContractRenewal: 1 if customer recently renewed contract, 0 if not
  • DataPlan: 1 if the customer has a data plan, 0 if not
  • DataUsage: gigabytes of monthly data usage
  • CustServCalls: number of calls into customer service
  • DayMins: average daytime minutes per month
  • DayCalls: average number of daytime calls
  • MonthlyCharge: average monthly bill
  • OverageFee: The largest overage fee in the last 12 months

The following code will load the data from your local folder into your anaconda Python project:

import numpy as np 
import pandas as pd 
import math
from pandas.plotting import register_matplotlib_converters
import matplotlib.pyplot as plt 
import matplotlib.colors as mcolors
import matplotlib.dates as mdates 

from sklearn.metrics import confusion_matrix, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.inspection import permutation_importance
import seaborn as sns

# set file path
filepath = "data/Churn-prediction/"

# Load train and test datasets
train_df = pd.read_csv(filepath + 'telecom_churn.csv')

Step #2 Exploring the Data

Before we begin with the preprocessing, we will quickly explore the data. For this purpose, we will create histograms for the different attributes in our data.

# # Create histograms for feature columns separated by prediction label value
df_plot = train_df.copy()

# class_columnname = 'Churn'
sns.pairplot(df_plot, hue="Churn", height=2.5, palette='muted')
Histograms of the churn prediction dataset separated by prediction label (red=churn, blue= no churn)

We can see that the data distribution for several attributes looks quite good and resembles a normal distribution, for example, for OverageFeed, DayMins, and DayCalls. However, the distribution for the prediction label is unbalanced. Naturally, this is because more customers remain with their contract (prediction label class = 0) than those that cancel their contract (prediction label class = 1).

Step #3 Data Preprocessing

The next step is to preprocess the data. I have reduced this part to a minimum to keep this tutorial simple. For example, I do not treat the unbalanced label classes. However, this would be appropriate to improve the model performance in a real business context. The imbalanced data is also why I chose a decision forest as a model type. Compared to traditional models such as logistic regression, decision forests can handle unbalanced data relatively well.

The following code splits the data into the train (x_train) and test data (x_test) and creates the respective datasets, which only contain the label class (y_train, y_test). The ratio is 0.7, resulting in 2333 records in the training dataset and 1000 in the test dataset.

# Create Training Dataset
x_df = train_df[train_df.columns[train_df.columns.isin(['AccountWeeks', 'ContractRenewal', 'DataPlan','DataUsage', 'CustServCalls', 'DayCalls', 'MonthlyCharge', 'OverageFee', 'RoamMins'])]].copy()
y_df = train_df['Churn'].copy()

# Split the data into x_train and y_train data sets
x_train, x_test, y_train, y_test = train_test_split(x_df, y_df, train_size=0.7, random_state=0)

Now comes the exciting part. We will train a series of 36 decision forests and then choose the best-performing model. The technique used in this process is called hyperparameter tuning (more specifically, grid search), and I have recently published a separate article on this topic.

The following code defines the parameters that the grid search will test (max_depth, n_estimators, and min_samples_split). Then the code runs the grid search and trains the decision forests. Finally, we print out the model ranking along with model parameters.

# Define parameters
max_depth=[2, 4, 8, 16]
n_estimators = [64, 128, 256]
min_samples_split = [5, 20, 30]

param_grid = dict(max_depth=max_depth, n_estimators=n_estimators, min_samples_split=min_samples_split)

# Build the gridsearch
dfrst = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, min_samples_split=min_samples_split, class_weight='balanced')
grid = GridSearchCV(estimator=dfrst, param_grid=param_grid, cv = 5)
grid_results =, y_train)

# Summarize the results in a readable format
results_df = pd.DataFrame(grid_results.cv_results_)
results_df.sort_values(by=['rank_test_score'], ascending=True, inplace=True)

# Reduce the results to selected columns
results_filtered = results_df[results_df.columns[results_df.columns.isin(['param_max_depth', 'param_min_samples_split', 'param_n_estimators','std_fit_time', 'rank_test_score', 'std_test_score', 'mean_test_score'])]].copy()
Ranking of the models created with grid search
Model ranking created with Grid Search

The best performing model is model number 29, which scores 92,7 %. Its hyperparameters are as follows:

  • max_depth = 16
  • min_samples_split = 5
  • n_estimators 256

We will proceed with this model. So what does this model tell us?

We can gain an overview of the distributions of our customers according to their churn probability. Just use the following code:

# Predicting Probabilities
y_pred_prob = best_clf.predict_proba(x_test) 
churnproba = y_pred_prob[:,1]

# Create histograms for feature columns separated by prediction label value
Customer Base According to their Churn Rate
Customer Base According to their Churn Rate

Customers who tend to churn have a churn probability greater than 0.5. They are further to the right in the diagram. So, we don’t have to worry about the customers on the far left (<0.5).

Step #5 Best Model Performance Insights

Let’s take a more detailed look at the performance of the best model. We do this by calculating the confusion matrix.

If you want to learn more about measuring the performance of classification models, check out this tutorial.

# Extract the best decision forest 
best_clf = grid_results.best_estimator_
y_pred = best_clf.predict(x_test)

# Create a confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)

# Create heatmap from the confusion matrix
class_names=[False, True] 
tick_marks = [0.5, 1.5]
fig, ax = plt.subplots(figsize=(7, 6))
sns.heatmap(pd.DataFrame(cnf_matrix), annot=True, cmap="Blues", fmt='g')
plt.title('Confusion matrix')
plt.ylabel('Actual label'); plt.xlabel('Predicted label')
plt.yticks(tick_marks, class_names); plt.xticks(tick_marks, class_names)
Confusion matrix on churn probabilities calculated with feature permutation importance

From 1000 customers in the test dataset, our model correctly classified 100 customers as churn candidates. For 832 customers, the model accurately predicted that these customers are unlikely to churn. In 30 cases, the model falsely classified customers as churn candidates, and 38 were missed and falsely classified as non-churn candidates. The result is a model accuracy of 93,2 % (based on a 0.5 threshold).

Step #6 Permutation Feature Importance

Now that we have trained a model that gives good results, we want to understand the importance of the model’s features. With the following code, we calculate the Feature Importance. Then we visualize the results in a barplot.

# Load the data
r = permutation_importance(best_clf, x_test, y_test, n_repeats=30, random_state=0)

# Set the color range
clist = [(0, "purple"), (1, "blue")]
rvb = mcolors.LinearSegmentedColormap.from_list("", clist)
colors = rvb(data_im['feature_permuation_score']/len(x_test.columns))

# Plot the barchart
data_im = pd.DataFrame(r.importances_mean, columns=['feature_permuation_score'])
data_im['feature_names'] = x_test.columns
data_im = data_im.sort_values('feature_permuation_score', ascending=False)

fig, ax = plt.subplots(figsize=(16, 5))
sns.barplot(y=data_im['feature_names'], x="feature_permuation_score", data=data_im, palette='nipy_spectral')
ax.set_title("Random Forest Feature Importances")

The feature ranking can provide the starting point for deeper analysis. As we can see, the most important features are the monthly fee, data usage, and customer service calls (CustServCalls). Of particular interest is the importance of customer service calls, as this could indicate that customers who encounter customer service have negative experiences.


This article has shown how to implement a churn prediction model using Python and scikit-learn Machine Learning. We have calculated the permutation feature importance to analyze which features contribute to the performance of our model. You have learned that permutation feature importance can provide data scientists new insights into the context of a prediction model. Therefore, the technique is often a good starting point for forthleading investigations.

If you liked this article, show your appreciation by leaving a comment. Cheers

And if you want to learn more about text mining and customer satisfaction, you might want to take a look at my recent blog about sentiment analysis:


  • Hi, I am Florian, a Zurich-based consultant for AI and Data. Since the completion of my Ph.D. in 2017, I have been working on the design and implementation of ML use cases in the Swiss financial sector. I started this blog in 2020 with the goal in mind to share my experiences and create a place where you can find key concepts of machine learning and materials that will allow you to kick-start your own Python projects.

Leave a Comment