Visualizing Decision Trees
Whenever we use decision trees to make decisions, we often look at how it performs using different metrics like accuracy, etc. But, if we want to visualize how the tree is created and how the decisions are being made, we need to plot the decision tree corresponding to the dataset fed to the model. This blog will look at how we can plot the decision tree and look at multiple ways available in python to do the same.
A Brief Introduction to Decision Trees
A decision tree is a common supervised learning method in machine learning that is often used for various classification and regression problems. It is convenient because it doesn’t require feature scaling and is easy to understand. To interpret them, we can visualize them by plotting the decision tree and understanding how our model is working, and making adjustments according to it. To know more about decision trees, you can refer here.
Various ways to visualize the Decision tree
There are multiple methods present in python libraries to visualize decision trees. Some of them are as follows:
- Visualizing Decision Trees using Sklearn plot tree method.
- Visualizing Decision Trees using Matplotlib
- Visualizing Decision Trees using Graphviz
- Visualizing Decision Trees using dtreeviz
Let’s look at each of these methods with the help of a code.
Code in Python
We will be using the inbuilt iris dataset of sklearn, which is present inside sklearn.datasets.
import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from dtreeviz.trees import dtreeviz from sklearn import tree import graphviz import random import math import warnings warnings.filterwarnings("ignore") ### this class basically helps visualise Decision Trees ### by passing a decision tree model and the dataset dataframe. ### we can train and fit the model and visualise the decision trees class DecisionTreeVisualiser: # this is an initialiser function that would store # store the dataframe of the dataset along with # parameters required by a Decision Tree Model def __init__(self, dataset_df, min_samples_leaf = 1,min_samples_split = 2, criterion="gini", splitter = "best", max_depth = None, max_features = None, min_weight_fraction_leaf = None, max_leaf_nodes = None): self.criterion = criterion self.max_depth = max_depth self.max_features = max_features self.splitter = splitter self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf self.min_weight_fraction_leaf = min_weight_fraction_leaf self.dataset_df = dataset_df self.X_train, self.Y_train = None, None self.X_test, self.Y_test = None, None # function with which you can visualise the dataframe # with the first num_entries which is set to default 5 # you can change as per your choice def visualiseDataset(self, num_entries = 5): self.dataset_df.head(num_entries) # function the splits the dataset into training # and testing dataset splitting 100 % into x percent training data # z% testing data such that x+y+z = 100 # by default it's set to x = 80, z = 20 # another default parameter is shuffle which allows to shuffle to data # you can set it to false if you don't want to shuffle the data def train_test_split(self, train_size = 80, test_size = 20, shuffle = True): if(shuffle): self.dataset_df = self.dataset_df.sample(frac = 1) dataset_val = self.dataset_df.values self.X_train, self.Y_train = dataset_val[:int((train_size/100)*len(dataset_val)),:-1], dataset_val[:int((train_size/100)*len(dataset_val)),-1] self.X_test, self.Y_test = dataset_val[int((train_size/100)*len(dataset_val)):,:-1], dataset_val[int((train_size/100)*len(dataset_val)):,-1] # function which fits and predicts the results and computes the score # it also displays the decision trees in different models. # the function takes input the plotting method of the decision tree as well # as the feature names and class names of the dataset def fit_and_predict_and_PlotDT(self, plot_method="export_text", feature_names=None, cls_names=None): # decision tree classifier object initialised dt = DecisionTreeClassifier(criterion = self.criterion, max_depth = self.max_depth, max_features = self.max_features, splitter = self.splitter, min_samples_split = self.min_samples_split, min_samples_leaf = self.min_samples_leaf) # fitting the training dataset dt.fit(self.X_train, self.Y_train) # print the scores of the dataset print("The score on the training set is: ", dt.score(self.X_train, self.Y_train)) print("The score on the testing set is: ", dt.score(self.X_test, self.Y_test)) # the following conditions would execute if the plot_method # given as input matches with anyone of the options if(plot_method == "export_text"): text_representation = tree.export_text(dt) print(text_representation) elif(plot_method == "plot_tree"): fig = plt.figure(figsize=(25,20)) pltree = tree.plot_tree(dt, feature_names=feature_names, class_names=cls_names, filled=True) fig.savefig("decistion_tree.png") elif(plot_method == "graphviz"): dot_data = tree.export_graphviz(dt, out_file=None, feature_names=feature_names, class_names=cls_names,filled=True) graph = graphviz.Source(dot_data, format="png") graph graph.render("decision_tree_graphivz") elif(plot_method == "dtreeviz"): viz = dtreeviz(dt, np.c_(self.X_train, self.X_test), np.c_(self.Y_train, self.Y_test), target_name="target", feature_names=feature_names, class_names=list(cls_names)) viz # loading the iris dataset iris_dataset = load_iris() # store the data in the form of a dataset df = pd.DataFrame(iris_dataset.data, columns = iris_dataset.feature_names) # add teh target column in the dataframe df['target'] = iris_dataset.target # create the Decision tree visualiser object dtv = DecisionTreeVisualiser(df) # call the train_test_split function dtv.train_test_split() # fit_predict and plot the decision tree dtv.fit_and_predict_and_PlotDT(plot_method = "graphviz", feature_names = iris_dataset.feature_names, cls_names = iris_dataset.target_names)
Output Using graphviz method
Output Using plot_tree method
Output Using export_text method
Output Using dtreeviz method
Frequently Asked Questions
Q1. What are the advantages of decision trees?
Ans. The decision tree is a commonly used model as it’s easy to understand and takes fewer training periods.
Q2. Why is dtreeviz used?
Ans. dtreeviz is a python library used to visualize a decision tree and gives an explanatory view, including plots.
Q3. How to tune the decision trees to get better results?
Ans. To tune the decision trees to get better results, one could manually tune the max_depth parameter and the number of leaf nodes. The other way is to change the decision criterion.
This article gave a brief explanation about how we can visualize decision trees using various methods present in the python library. We looked at the graphviz method, the export_text method, the plot_tree method, and the dtreeviz method. To dive deeper into machine learning, check out our industry-level courses on coding ninjas.