Customer Churn Prediction

Anyone About to Leave? Customer Churn Prediction at a Telecommunications Provider in Python

This article shows how to implement a customer churn prediction model using Python machine learning. Furthermore, we will use permutation feature importance to gain insight into the relationship between input variables and model predictions.

Customer churn probabilities
Predicted churn probabilities

Telecommunications service providers face considerable pressure to expand and retain their subscriber base. One of the most significant cost factors is customers canceling their contracts. Therefore, innovative service providers have learned to use machine learning to predict which of their customers will tend to cancel their contracts. Those providers who understand which customers tend to churn can take appropriate countermeasures early on to retain them. The prerequisite for this is that the provider can identify churn candidates among their customer base. In the following, we will take a look at how this works.

The rest of this article is structured as follows. First, we take a look at the business problem; We will also speak about permutation feature importance – an excellent technique to identify the most important features to our machine learning model. Then we turn to the coding part and implement a churn prediction model in Python. This part includes a final look at the most critical features in our model.

What’s the Business Case?

A company’s effort to persuade a new customer to sign a contract is many times higher than the costs incurred in retaining existing customers. According to industry experts, it is at least four times more expensive to win a new customer than to keep an existing customer. Providers that can identify churn candidates in advance and manage to retain them can significantly reduce costs.

A crucial point is whether the provider succeeds in getting the churn candidates to stay. Sometimes it may be enough to contact the churn candidate and inquire about customer satisfaction. In other cases, this may not be enough, and the provider needs to increase the service value, for example, by offering free services or granting a discount. However, actions should be well thought out, as they can also have a negative effect. For example, if a customer hardly ever uses his contract, a call from the provider may even increase the desire to cancel the contract. Here again, machine learning can help to assess cases individually and identify the optimal anti-churn action.

About Permutation Feature Importance

Feature importance is a helpful technique to understand the contribution of input variables (features) to a predictive model. The results from this technique can be as valuable a the predictions themselves, as they can help us understand the business context better. For example, let’s say we have trained a model that predicts which of our customers will likely churn. Wouldn’t it be interesting to know why specific customers are more likely to churn than others? Permutation feature importance can help us answer this question by providing us with a ranking of the input variables in our model by their usefulness. The ranking can validate assumptions about the business context and uncover causal relations in the data.

One of the most significant advantages of traditional prediction models, such as a decision tree, compared to neural networks is their interpretability. Neural networks are black boxes because it is tough to understand the relationship between input and model predictions. In traditional models, on the other hand, we can calculate the meaning of the features and use it to interpret the model and optimize its performance, for example, by removing features from the model that are not important. This is one of the reasons why it is a good idea to start with a simple model first and move on to more complex models once you understand the data.

Implementing a Customer Churn Prediction Model in Python

In the following, we will implement a customer churn prediction model. We will train a decision forest model on a data set from Kaggle and optimize it using grid search. The data contains customer-level information for a telecom provider and a binary prediction label of which customers canceled their contracts and did not. Finally, we will calculate the feature importance to understand how the model works.


Before we start the coding part, make sure that you have set up your Python 3 environment and required packages. If you don’t have an environment set up yet, you can follow this tutorial to set up the Anaconda environment.

Also, make sure you install all required packages. In this tutorial, we will be working with the following standard packages: 

In addition, we will be using Keras (2.0 or higher) with Tensorflow backend, the machine learning library scikit-learn.

You can install packages using console commands:

  • pip install <package name>
  • conda install <package name> (if you are using the anaconda packet manager)

Step #1 Loading the Customer Churn Data

We begin by loading a customer churn dataset from Kaggle. After completing the download, put the dataset under the file path of your choice, but don’t forget to adjust the file path variable in the code. If you are working with the Kaggle Python environment, you can also directly save the dataset into your Kaggle project.

The dataset contains 3333 records and the following attributes.

  • Churn: 1 if customer cancelled service, 0 if not. This will be the prediction label.
  • AccountWeeks: number of weeks customer has had active account
  • ContractRenewal: 1 if customer recently renewed contract, 0 if not
  • DataPlan: 1 if customer has data plan, 0 if not
  • DataUsage: gigabytes of monthly data usage
  • CustServCalls: number of calls into customer service
  • DayMins: average daytime minutes per month
  • DayCalls: average number of daytime calls
  • MonthlyCharge: average monthly bill
  • OverageFee: largest overage fee in last 12 months

The following code will load the data from your local folder into your anaconda Python project:

import numpy as np 
import pandas as pd 
import math
from pandas.plotting import register_matplotlib_converters
import matplotlib.pyplot as plt 
import matplotlib.colors as mcolors
import matplotlib.dates as mdates 

from sklearn.metrics import confusion_matrix, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.inspection import permutation_importance
import seaborn as sns

# set file path
filepath = "data/Churn-prediction/"

# Load train and test datasets
train_df = pd.read_csv(filepath + 'telecom_churn.csv')

Step #2 Exploring the Data

Before we begin to with the preprocessing, we will quickly explore the data. For this purpose, we will create histograms for the different attributes in our data.

# Create histograms for feature columns separated by prediction label value
df_plot = train_df.copy()

class_columnname = 'Churn'

list_length = df_plot.shape[1]
ncols = 4
nrows = int(round(list_length / ncols, 0))
if ncols * nrows < list_length:
    nrows += 1

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=False, figsize=(15, 12))
fig.subplots_adjust(hspace=0.5, wspace=0.5)
for i in range(0, list_length):
        featurename = df_plot.columns[i]
        ax = plt.subplot(nrows, ncols, i+1)
        y0 = df_plot[df_plot[class_columnname]==0][featurename]
        ax.hist(y0, color='blue', label= featurename + f'-{class_columnname}', bins='auto', edgecolor='w')
        y1 = df_plot[df_plot[class_columnname]==1][featurename]
        ax.hist(y1, color='red', alpha=0.8, label=featurename + f'-No{class_columnname}', bins='auto', edgecolor='w')
        ax.tick_params(axis="x", rotation=30, labelsize=10, length=0)
density plots of the features used to train the churn prediction model
Histograms of the churn prediction dataset separated by prediction label (red=churn, blue= no churn)

We can see that the data distribution for several attributes looks quite good and resembles a normal distribution, for example, for OverageFeed, DayMins, DayCalls. However, the distribution for the prediction label is unbalanced. Naturally, this is because more customers remain with their contract (prediction label class = 0) than those that cancel their contract (prediction label class = 1).

Step #3 Data Preprocessing

The next step is to preprocess the data. For the sake of keeping this tutorial simple, I have reduced this part to a minimum. For example, I do not treat the unbalanced label classes. In a real business, context this, however, would be appropriate to improve the model performance. The imbalanced data is also a reason why I chose a decision forest as a model type. Compared to other traditional models such as logistic regression, decision forests can handle unbalanced data relatively well.

The following code splits the data into the train (x_train) and test data (x_test) and creates the respective datasets, which only contain the label class (y_train, y_test). The ratio is 0.7, resulting in 2333 records in the training dataset and 1000 records in the test dataset.

# Create Training Dataset
x_df = train_df[train_df.columns[train_df.columns.isin(['AccountWeeks', 'ContractRenewal', 'DataPlan','DataUsage', 'CustServCalls', 'DayCalls', 'MonthlyCharge', 'OverageFee', 'RoamMins'])]].copy()
y_df = train_df['Churn'].copy()

# Split the data into x_train and y_train data sets
x_train, x_test, y_train, y_test = train_test_split(x_df, y_df, train_size=0.7, random_state=0)

Now comes the exciting part. We will train a whole series of 36 decision forests and then choose the best-performing model. The technique used in this process is called hyperparameter tuning (more specifically, grid search), and I have recently published a separate article on this topic.

The following code defines the parameters that the grid search will test (max_depth, n_estimators, and min_samples_split). Then the code runs the grid search and trains the decision forests. Finally, we print out the model ranking along with model parameters.

# Define parameters
max_depth=[2, 4, 8, 16]
n_estimators = [64, 128, 256]
min_samples_split = [5, 20, 30]

param_grid = dict(max_depth=max_depth, n_estimators=n_estimators, min_samples_split=min_samples_split)

# Build the gridsearch
dfrst = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, min_samples_split=min_samples_split, class_weight='balanced')
grid = GridSearchCV(estimator=dfrst, param_grid=param_grid, cv = 5)
grid_results =, y_train)

# Summarize the results in a readable format
results_df = pd.DataFrame(grid_results.cv_results_)
results_df.sort_values(by=['rank_test_score'], ascending=True, inplace=True)

# Reduce the results to selected columns
results_filtered = results_df[results_df.columns[results_df.columns.isin(['param_max_depth', 'param_min_samples_split', 'param_n_estimators','std_fit_time', 'rank_test_score', 'std_test_score', 'mean_test_score'])]].copy()
Ranking of the models created with grid search
Model ranking created with Grid Search

The best performing model is model number 29, which scores 92,7 %. Its hyperparameters are as follows:

  • max_depth = 16
  • min_samples_split = 5
  • n_estimators 256

We will proceed with this model. So what does this model tell us?

We can gain an overview of the distributions of our customers according to their churn probability. Just use the following code:

# Predicting Probabilities
y_pred_prob = best_clf.predict_proba(x_test) 
churnproba = y_pred_prob[:,1]

# Create histograms for feature columns separated by prediction label value
plt.hist(churnproba, color='blue', label='Churn', bins='auto')

Customer Base According to their Churn Rate

Customers who tend to churn have a churn probability greater than 0.5. They are further to the right in the diagram. So, we don’t have to worry about the customers on the far left (<0.5).

Step #5 Best Model Performance Insights

Let’s take a more detailed look at the performance of the best model by calculating the confusion matrix:

# Extract the best decision forest 
best_clf = grid_results.best_estimator_
y_pred = best_clf.predict(x_test)

# Create a confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)

# Create heatmap from the confusion matrix
class_names=[False, True] 
tick_marks = [0.5, 1.5]
fig, ax = plt.subplots(figsize=(7, 6))
sns.heatmap(pd.DataFrame(cnf_matrix), annot=True, cmap="Blues", fmt='g')
plt.title('Confusion matrix')
plt.ylabel('Actual label'); plt.xlabel('Predicted label')
plt.yticks(tick_marks, class_names); plt.xticks(tick_marks, class_names)

From 1000 customers in the test dataset, our model correctly classified 100 customers as churn candidates. For 832 customers, the model accurately predicted that these customers are unlikely to churn. In 30 cases, the model falsely classified customers as churn candidates, and 38 churn candidates were missed and falsely classified as non-churn candidates. The result is a model accuracy of 93,2 % (based on a 0.5 threshold).

Step #6 Permutation Feature Importance

Now that we have trained a model that gives good results, we want to understand the importance of the features for the model. With the following code, we calculate the Feature Importance. Then we visualize the results in a barplot.

# Load the data
r = permutation_importance(best_clf, x_test, y_test, n_repeats=30, random_state=0)

# Set the color range
clist = [(0, "purple"), (1, "blue")]
rvb = mcolors.LinearSegmentedColormap.from_list("", clist)
colors = rvb(data_im['feature_permuation_score']/len(x_test.columns))

# Plot the barchart
data_im = pd.DataFrame(r.importances_mean, columns=['feature_permuation_score'])
data_im['feature_names'] = x_test.columns
data_im = data_im.sort_values('feature_permuation_score', ascending=False)

fig, ax = plt.subplots(figsize=(16, 5))
sns.barplot(y=data_im['feature_names'], x="feature_permuation_score", data=data_im, palette='nipy_spectral')
ax.set_title("Random Forest Feature Importances")

As we can see, the most important features are the monthly fee, data usage, and customer service calls (CustServCalls). Of particular interest is the importance of customer service calls, as this could indicate that customers who come into contact with customer service have negative experiences. This shows how Feature Importance can provide the starting point for deeper analysis.


In this article, we have implemented a churn prediction model using Python and scikit-learn Machine Learning. We have calculated the permutation feature importance to analyze which features contribute to the performance of our model. You have learned that permutation feature importance can provide data scientists new insights into the context of a prediction model. The technique is, therefore, often a good starting point for forthleading investigations.

If you liked this article, show your appreciation by leaving a comment. Cheers

And if you want to learn more about text mining and customer satisfaction, you might want to take a look at my recent blog about sentiment analysis:


  • Hi, I am Florian, a Zurich-based consultant for AI and Data. Since the completion of my Ph.D. in 2017, I have been working on the design and implementation of ML use cases in the Swiss financial sector. I started this blog in 2020 with the goal in mind to share my experiences and create a place where you can find key concepts of machine learning and materials that will allow you to kick-start your own Python projects.

Leave a Reply