Skip to content

Decision Tree Classification in Python

I am going to implement algorithms for decision tree classification in this tutorial. I am going to train a simple decision tree and two decision tree ensembles (RandomForest and XGBoost), these models will be compared with 10-fold cross-validation. I am using the Titanic data set from kaggle, this data set will be preprocessed and visualized before it is used for training.

Decision tree algorithms was among the first solutions to aid in decision support system (expert systems). A decision tree is constructed as a number of if-then rules that builds an hierarchical tree that looks more like a pyramid. A decision tree is created with recursive binary splitting from the root node and down to the final predictions. We want to have the most important features at the top of the tree as this makes it faster to reach a satisfactory result.

Decision trees is easy to understand and explain, they can be used for binary classification problems and for multiclass problems. Decision trees can be biased if the data set not is balanced and they can be unstable as different trees might be generated after small variations in the input data.

Decision tree ensemble methods combines multiple descision trees to improve prediction performance. Decision tree ensemble methods can implement bagging or boosting. Bagging means that multiple trees is created on subsets of the input data, the result of such a model is the average prediction for all trees. Boosting is a technique where trees are created sequential, the next tree will try to minimize the loss/error from the previous tree. Random Forest is an example of an ensemble method that uses bagging och XGBoost is an example of an ensemble method that uses boosting.

Data set and libraries

I am going to use the Titanic dataset (download it) from kaggle.com, you need to register to able to download the data set. The data set consists of a training set and a test set, the test set is used if you want to make a submission. The data set includes data about passengers on Titanic and a boolean target value that indicates if the passenger survived or not. I am using the following libraries: pandas, joblib, numpy, matplotlib, csv, xgboost, graphviz and scikit-learn.

Data preparation

You can open the train.csv file with Excel, OpenOffice Calc or investigate it on kaggle. Some columns in the data set includes a lot of unique values like PassengerId, Name, Age, Ticket, Fare and Cabin. Columns with a lot of unique values might be removed or reconstructed. I decided to remove PassengerId, Name and Ticket, Cabin is reconstructed to indicate if the passenger has a cabin or not. You might be able to improve the accuracy by reconstructing Age and Fare. Some of the columns includes null (NaN) values and string values needs to be converted to numbers. The following method in a module called common (common.py) is used to prepare the data set.

  1. # Preprocess data
  2. def preprocess_data(ds):
  3. # Get passenger ids (should not be part of the dataset)
  4. ids = ds['PassengerId']
  5. # Set cabin to a boolean value (no, yes)
  6. cabins = ds['Cabin'].copy()
  7. for i in range(len(cabins)):
  8. if type(cabins.loc[i]) == float:
  9. cabins.loc[i] = 0
  10. else:
  11. cabins.loc[i] = 1
  12. # Update the cabin column in the data set
  13. ds['Cabin'] = cabins
  14. # Remove null (NaN) values from the data set
  15. median_fare = ds['Fare'].median()
  16. mean_age = ds['Age'].mean()
  17. ds['Fare'] = ds['Fare'].fillna(median_fare)
  18. ds['Age'] = ds['Age'].fillna(mean_age)
  19. ds['Embarked'] = ds['Embarked'].fillna('S')
  20. # Map string values to numbers (to be able to train and test models)
  21. ds['Sex'] = ds['Sex'].map({'female': 0, 'male': 1})
  22. ds['Embarked'] = ds['Embarked'].map({'Q': 0, 'C': 1, 'S': 2})
  23. # Drop columns
  24. ds = ds.drop(columns=['PassengerId', 'Name', 'Ticket'])
  25. # Return ids and data set
  26. return ids, ds

Visualize data set

The following module is used to visualize the data set. The output from the visualization process is shown below the code.

  1. # Import libraries
  2. import pandas
  3. import joblib
  4. import math
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import annytab.decision_trees.common as common
  8. # Visualize data set
  9. def visualize_dataset(ds):
  10. # Print first 10 rows in data set
  11. print('--- First 10 rows ---\n')
  12. #pandas.set_option('display.max_columns', 12)
  13. print(ds[0:10])
  14. # Print the shape
  15. print('\n--- Shape of data set ---\n')
  16. print(ds.shape)
  17. # Print class distribution
  18. print('\n--- Class distribution ---\n')
  19. print(ds.groupby('Survived').size())
  20. # Group data set
  21. survivors = ds[ds.Survived == True]
  22. non_survivors = ds[ds.Survived == False]
  23. # Create a figure
  24. figure = plt.figure(figsize = (12, 8))
  25. figure.suptitle('Surviviors and Non-surviviors on Titanic', fontsize=16)
  26. # Create a default grid
  27. plt.rc('axes', facecolor='#ececec', edgecolor='none', axisbelow=True, grid=True)
  28. plt.rc('grid', color='w', linestyle='solid')
  29. # Add spacing between subplots
  30. plt.subplots_adjust(top = 0.9, bottom=0.1, hspace=0.3, wspace=0.4)
  31. # Plot by Pclass (1)
  32. plt.subplot(2, 4, 1) # 2 rows and 4 columns
  33. survivors_data = survivors.groupby('Pclass').size().values
  34. non_survivors_data = non_survivors.groupby('Pclass').size().values
  35. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  36. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  37. plt.xticks([0,1,2], [1, 2, 3])
  38. plt.ylabel('Count')
  39. plt.title('Pclass')
  40. plt.legend(loc='upper left')
  41. # Plot by Gender (2)
  42. plt.subplot(2, 4, 2) # 2 rows and 4 columns
  43. survivors_data = survivors.groupby('Sex').size().values
  44. non_survivors_data = non_survivors.groupby('Sex').size().values
  45. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  46. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  47. plt.xticks([0,1], ['Female', 'Male'])
  48. plt.ylabel('Count')
  49. plt.title('Gender')
  50. plt.legend(loc='upper left')
  51. # Plot by Age (3)
  52. plt.subplot(2, 4, 3) # 2 rows and 4 columns
  53. survivors_data = survivors.groupby(['AgeGroup']).size().values
  54. non_survivors_data = non_survivors.groupby(['AgeGroup']).size().values
  55. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  56. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  57. plt.xticks([0,1,2,3,4,5,6,7], ['0-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70-79'], rotation=40, horizontalalignment='right')
  58. plt.ylabel('Count')
  59. plt.title('Age')
  60. plt.legend(loc='upper left')
  61. # Plot by SibSp (4)
  62. plt.subplot(2, 4, 4) # 2 rows and 4 columns
  63. survivors_data = np.append(survivors.groupby('SibSp').size().values, np.array([0,0])) # Make sure that arrays have same length
  64. non_survivors_data = non_survivors.groupby('SibSp').size().values
  65. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  66. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  67. plt.ylabel('Count')
  68. plt.title('Number of siblings/spouses')
  69. plt.legend(loc='upper left')
  70. # Plot by Parch (5)
  71. plt.subplot(2, 4, 5) # 2 rows and 4 columns
  72. survivors_data = np.append(survivors.groupby('Parch').size().values, np.array([0,0])) # Make sure that arrays have same length
  73. non_survivors_data = non_survivors.groupby('Parch').size().values
  74. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  75. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  76. plt.ylabel('Count')
  77. plt.title('Number of parents/children')
  78. plt.legend(loc='upper left')
  79. # Plot by Fare (6)
  80. plt.subplot(2, 4, 6) # 2 rows and 4 columns
  81. survivors_data = survivors.groupby(['FareGroup']).size().values
  82. non_survivors_data = non_survivors.groupby(['FareGroup']).size().values
  83. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  84. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  85. plt.xticks([0,1,2,3,4,5], ['0-99', '100-199', '200-299', '300-399', '400-499', '500-599'], rotation=40, horizontalalignment='right')
  86. plt.ylabel('Count')
  87. plt.title('Fare')
  88. plt.legend(loc='upper left')
  89. # Plot by Cabin (7)
  90. plt.subplot(2, 4, 7) # 2 rows and 4 columns
  91. survivors_data = survivors.groupby('Cabin').size().values
  92. non_survivors_data = non_survivors.groupby('Cabin').size().values
  93. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  94. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  95. plt.xticks([0,1], ['No', 'Yes'])
  96. plt.ylabel('Count')
  97. plt.title('Cabin')
  98. plt.legend(loc='upper left')
  99. # Plot by Embarked (8)
  100. plt.subplot(2, 4, 8) # 2 rows and 4 columns
  101. survivors_data = survivors.groupby('Embarked').size().values
  102. non_survivors_data = non_survivors.groupby('Embarked').size().values
  103. plt.bar(range(len(survivors_data)), survivors_data, label='Survivors', alpha=0.5, color='g')
  104. plt.bar(range(len(non_survivors_data)), non_survivors_data, bottom=survivors_data, label='Non-Survivors', alpha=0.5, color='r')
  105. plt.xticks([0,1,2], ['Q', 'C', 'S'])
  106. plt.ylabel('Count')
  107. plt.title('Embarked')
  108. plt.legend(loc='upper left')
  109. # Show or save the figure
  110. #plt.show()
  111. plt.savefig('C:\\DATA\\Python-data\\titanic\\plots\\bar-charts.png')
  112. # The main entry point for this module
  113. def main():
  114. # Load data set (includes header values)
  115. ds = pandas.read_csv('C:\\DATA\\Python-data\\titanic\\train.csv')
  116. # Preprocess data
  117. ids, ds = common.preprocess_data(ds)
  118. # Create age groups
  119. ds['AgeGroup'] = pandas.cut(ds.Age, range(0, 81, 10), right=False, labels=['0-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70-79'])
  120. # Create fare groups
  121. ds['FareGroup'] = pandas.cut(ds.Fare, range(0, 601, 100), right=False, labels=['0-99', '100-199', '200-299', '300-399', '400-499', '500-599'])
  122. # Visualize data set
  123. visualize_dataset(ds)
  124. # Tell python to run main method
  125. if __name__ == "__main__": main()
  1. --- First 10 rows ---
  2. Survived Pclass Sex Age ... Cabin Embarked AgeGroup FareGroup
  3. 0 0 3 1 22.000000 ... 0 2 20-29 0-99
  4. 1 1 1 0 38.000000 ... 1 1 30-39 0-99
  5. 2 1 3 0 26.000000 ... 0 2 20-29 0-99
  6. 3 1 1 0 35.000000 ... 1 2 30-39 0-99
  7. 4 0 3 1 35.000000 ... 0 2 30-39 0-99
  8. 5 0 3 1 29.699118 ... 0 0 20-29 0-99
  9. 6 0 1 1 54.000000 ... 1 2 50-59 0-99
  10. 7 0 3 1 2.000000 ... 0 2 0-9 0-99
  11. 8 1 3 0 27.000000 ... 0 2 20-29 0-99
  12. 9 1 2 0 14.000000 ... 0 1 10-19 0-99
  13. [10 rows x 11 columns]
  14. --- Shape of data set ---
  15. (891, 11)
  16. --- Class distribution ---
  17. Survived
  18. 0 549
  19. 1 342
  20. dtype: int64
Titanic plots

Baseline performance

The data set is not perfectly balanced as there is 549 non-survivors and 342 surviviors, a possible measure to get better results is to create a better balance in the data set. The probability to make a correct prediction of a non-survivor is 66.67 % (549/891) and our model must perform better than this.

Python module

The following module is used for training, evaluation and submission. I am using tree models which each has a lot of hyperparameters that can be adjusted. All of the project files is stored in annytab/decision_trees and the namespace for our common module is therefore annytab.decision_trees.

  1. # Import libraries
  2. import pandas
  3. import joblib
  4. import csv
  5. import numpy as np
  6. import sklearn.model_selection
  7. import sklearn.tree
  8. import sklearn.ensemble
  9. import sklearn.metrics
  10. import xgboost
  11. import graphviz
  12. import matplotlib.pyplot as plt
  13. import annytab.decision_trees.common as common
  14. # Train and evaluate
  15. def train_and_evaluate():
  16. # Load train data set (includes header values)
  17. ds = pandas.read_csv('C:\\DATA\\Python-data\\titanic\\train.csv')
  18. # Preprocess data
  19. ids, ds = common.preprocess_data(ds)
  20. # Slice data set in values and target (2D-array)
  21. X = ds.values[:,1:9] # Data
  22. Y = ds.values[:,0] # Survived
  23. # Create models
  24. models = []
  25. models.append(('DecisionTree', sklearn.tree.DecisionTreeClassifier(criterion='gini', splitter='best', max_depth=None, min_samples_split=5, min_samples_leaf=1,
  26. min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None,
  27. min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)))
  28. models.append(('RandomForest', sklearn.ensemble.RandomForestClassifier(n_estimators=100, criterion='gini', max_depth=None, min_samples_split=5,
  29. min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features='auto',
  30. max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None,
  31. bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0,
  32. warm_start=False, class_weight=None)))
  33. models.append(('XGBoost', xgboost.XGBClassifier(booster='gbtree', max_depth=6, min_child_weight=1, learning_rate=0.1, n_estimators=500, verbosity=0, objective='binary:logistic',
  34. gamma=0, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, reg_alpha=0, reg_lambda=0,
  35. scale_pos_weight=1, seed=0, missing=None)))
  36. # Loop models
  37. for name, model in models:
  38. # Train the model on the whole data set
  39. model.fit(X, Y)
  40. # Save the model (Make sure that the folder exists)
  41. joblib.dump(model, 'C:\\DATA\\Python-data\\titanic\\models\\' + name + '.jbl')
  42. # Evaluate on training data
  43. print('\n--- ' + name + ' ---')
  44. print('\nTraining data')
  45. predictions = model.predict(X)
  46. accuracy = sklearn.metrics.accuracy_score(Y, predictions)
  47. print('Accuracy: {0:.2f}'.format(accuracy * 100.0))
  48. print('Classification Report:')
  49. print(sklearn.metrics.classification_report(Y, predictions))
  50. print('Confusion Matrix:')
  51. print(sklearn.metrics.confusion_matrix(Y, predictions))
  52. # Evaluate with 10-fold CV
  53. print('\n10-fold CV')
  54. predictions = sklearn.model_selection.cross_val_predict(model, X, Y, cv=10)
  55. accuracy = sklearn.metrics.accuracy_score(Y, predictions)
  56. print('Accuracy: {0:.2f}'.format(accuracy * 100.0))
  57. print('Classification Report:')
  58. print(sklearn.metrics.classification_report(Y, predictions))
  59. print('Confusion Matrix:')
  60. print(sklearn.metrics.confusion_matrix(Y, predictions))
  61. # Predict and submit
  62. def predict_and_submit():
  63. # Load test data set (includes header values)
  64. ds = pandas.read_csv('C:\\DATA\\Python-data\\titanic\\test.csv')
  65. # Preprocess data
  66. ids, ds = common.preprocess_data(ds)
  67. # Slice data set in values (2D-array), test set does not have target values
  68. X = ds.values[:,0:8] # Data
  69. # Load the best models
  70. model = joblib.load('C:\\DATA\\Python-data\\titanic\\models\\RandomForest.jbl')
  71. # Make predictions
  72. predictions = model.predict(X)
  73. # Save predictions to a csv file
  74. file = open('C:\\DATA\\Python-data\\titanic\\submission.csv', 'w', newline='')
  75. writer = csv.writer(file, delimiter=',')
  76. writer.writerow(('PassengerId', 'Survived'))
  77. for i in range(len(predictions)):
  78. writer.writerow((ids[i], predictions[i].astype(int)))
  79. file.close()
  80. # Print success
  81. print('Successfully created submission.csv!')
  82. # Plot models
  83. def plot_models():
  84. # Load models
  85. decision_tree_model = joblib.load('C:\\DATA\\Python-data\\titanic\\models\\DecisionTree.jbl')
  86. random_forest_model = joblib.load('C:\\DATA\\Python-data\\titanic\\models\\RandomForest.jbl')
  87. xgboost_model = joblib.load('C:\\DATA\\Python-data\\titanic\\models\\XGBoost.jbl')
  88. # Names
  89. feature_names = ['Pclass', 'Gender', 'Age', 'SibSp', 'Parch', 'Fare', 'Cabin', 'Embarked']
  90. class_names = ['Died', 'Survived']
  91. # Save decision tree model to an image
  92. source = graphviz.Source(sklearn.tree.export_graphviz(decision_tree_model, out_file=None, feature_names=feature_names, class_names=class_names, filled=True))
  93. source.render('C:\\DATA\\Python-data\\titanic\\plots\\decision-tree',format='png', view=False)
  94. # Save random forest model to an image
  95. source = graphviz.Source(sklearn.tree.export_graphviz(random_forest_model.estimators_[8], out_file=None, filled=True))
  96. source.render('C:\\DATA\\Python-data\\titanic\\plots\\random-forest',format='png', view=False)
  97. # Save xgboost model to an image
  98. xgboost_model.get_booster().feature_names = feature_names
  99. xgboost.plot_tree(xgboost_model, num_trees=0)
  100. figure = plt.gcf()
  101. figure.set_size_inches(100, 50)
  102. plt.savefig('C:\\DATA\\Python-data\\titanic\\plots\\xgboost.png')
  103. # The main entry point for this module
  104. def main():
  105. # Train and evaluate
  106. #train_and_evaluate()
  107. # Predict and submit
  108. #predict_and_submit()
  109. # Plot a model
  110. plot_models()
  111. # Tell python to run main method
  112. if __name__ == "__main__": main()

Training and evaluation

A for loop is used to train and evaluate models, each model is saved to a file. The output from the training and evaluation process is shown below.

  1. --- DecisionTree ---
  2. Training data
  3. Accuracy: 94.84
  4. Classification Report:
  5. precision recall f1-score support
  6. 0.0 0.94 0.98 0.96 549
  7. 1.0 0.96 0.90 0.93 342
  8. accuracy 0.95 891
  9. macro avg 0.95 0.94 0.94 891
  10. weighted avg 0.95 0.95 0.95 891
  11. Confusion Matrix:
  12. [[536 13]
  13. [ 33 309]]
  14. 10-fold CV
  15. Accuracy: 78.90
  16. Classification Report:
  17. precision recall f1-score support
  18. 0.0 0.82 0.85 0.83 549
  19. 1.0 0.74 0.70 0.72 342
  20. accuracy 0.79 891
  21. macro avg 0.78 0.77 0.77 891
  22. weighted avg 0.79 0.79 0.79 891
  23. Confusion Matrix:
  24. [[464 85]
  25. [103 239]]
  26. --- RandomForest ---
  27. Training data
  28. Accuracy: 94.84
  29. Classification Report:
  30. precision recall f1-score support
  31. 0.0 0.94 0.98 0.96 549
  32. 1.0 0.97 0.90 0.93 342
  33. accuracy 0.95 891
  34. macro avg 0.95 0.94 0.94 891
  35. weighted avg 0.95 0.95 0.95 891
  36. Confusion Matrix:
  37. [[538 11]
  38. [ 35 307]]
  39. 10-fold CV
  40. Accuracy: 82.27
  41. Classification Report:
  42. precision recall f1-score support
  43. 0.0 0.84 0.88 0.86 549
  44. 1.0 0.79 0.73 0.76 342
  45. accuracy 0.82 891
  46. macro avg 0.82 0.81 0.81 891
  47. weighted avg 0.82 0.82 0.82 891
  48. Confusion Matrix:
  49. [[483 66]
  50. [ 92 250]]
  51. --- XGBoost ---
  52. Training data
  53. Accuracy: 98.20
  54. Classification Report:
  55. precision recall f1-score support
  56. 0.0 0.98 0.99 0.99 549
  57. 1.0 0.99 0.97 0.98 342
  58. accuracy 0.98 891
  59. macro avg 0.98 0.98 0.98 891
  60. weighted avg 0.98 0.98 0.98 891
  61. Confusion Matrix:
  62. [[544 5]
  63. [ 11 331]]
  64. 10-fold CV
  65. Accuracy: 81.37
  66. Classification Report:
  67. precision recall f1-score support
  68. 0.0 0.84 0.87 0.85 549
  69. 1.0 0.77 0.73 0.75 342
  70. accuracy 0.81 891
  71. macro avg 0.80 0.80 0.80 891
  72. weighted avg 0.81 0.81 0.81 891
  73. Confusion Matrix:
  74. [[475 74]
  75. [ 92 250]]

Submission

I created a submission file by using the XGBoost model and uploaded the file to kaggle. My accuracy score was 0.73250, not much better than the baseline performance.

Plot trees

You will need to unpack or install Graphviz in order to plot models in Python. You also need to add a Path to the bin folder (C:\Program Files\Graphviz\bin) in environment variables. I load all the models and save plots as png:s, you can save them as pdf:s or other formats.

Plotted Decision Tree
Tags:

Leave a Reply

Your email address will not be published. Required fields are marked *