Building Tree Models with TensorFlow Decision Forests

Building Tree Models with TensorFlow Decision Forests
Photo by Vincent Yuan @USA / Unsplash

1 Introduction

Both deep learning and traditional machine learning have excellent frameworks in their respective fields. For example, when building neural networks, TensorFlow and PyTorch are the common choices. The traditional tree model still performs very well in dealing with tabular data in real work. However, for a long time, deep learning frameworks did not have APIs to build tree models, until the emergence of TensorFlow Decision Forests.

TensorFlow Decision Forests provides a series of APIs to build tree-based models, such as Classification and Regression Trees (CART), Random Forest, Gradient Boosted Trees, etc. With TensorFlow Decision Forests, we can construct tree models using a paradigm similar to building neural networks. This article will explore it!

2 Obtain Data

As usual, import dependencies and download data. We use a tabular dataset to predict the species of penguins.

import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math

# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

# Display the first 3 examples.
dataset_df.head(3)
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex year
Adelie Torgersen 39.1 18.7 181.0 3750.0 male 2007
Adelie Torgersen 39.5 17.4 186.0 3800.0 female 2007
Adelie Torgersen 40.3 18.0 195.0 3250.0 female 2007

Specify the label field and convert the label category to integer data.

label = "species"

classes = dataset_df[label].unique().tolist()
print(f"Label classes: {classes}")

dataset_df[label] = dataset_df[label].map(classes.index)

3 Split and Process Data

Split the data into training set and test set:

def split_dataset(dataset, test_ratio=0.30):
  """Splits a panda dataframe in two."""
  test_indices = np.random.rand(len(dataset)) < test_ratio
  return dataset[~test_indices], dataset[test_indices]


train_ds_pd, test_ds_pd = split_dataset(dataset_df)
print("{} examples in training, {} examples for testing.".format(
    len(train_ds_pd), len(test_ds_pd)))

Then convert the Pandas DataFrame to the TensorFlow Dataset, which makes it easier to simplify the subsequent program and improve efficiency.

train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label)
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label)

Then data can be injected into the model for training. Unlike traditional machine learning frameworks, TensorFlow Decision Forests has the following advantages in implementing tree models:

  • Automatically handles numeric and categorical variables, eliminating the need to encode categorical variables and normalize numeric variables. The algorithm can also handle missing data well!
  • Hyperparameters are basically similar to those in other frameworks and default parameters give decent results in most cases.
  • Before training, there is no need to compile the model, and there is no need for a validation set during training. The validation set is only used to display performance metrics.

Note that this does not mean that using TensorFlow Decision Forests can eliminate all feature engineering, but it does save a lot of time.

4 Modeling

Construction of a random forest:

# Build a random forest
model = tfdf.keras.RandomForestModel()

# Train the model
model.fit(x=train_ds)

# Evaluate the model
model.compile(metrics=["accuracy"])
evaluation = model.evaluate(test_ds, return_dict=True)
print()

for name, value in evaluation.items():
  print(f"{name}: {value:.4f}")

Output:

1/1 [==============================] - 1s 706ms/step - loss: 0.0000e+00 - accuracy: 0.9608

loss: 0.0000
accuracy: 0.9608

5 Visualize Tree Model

TensorFlow Decision Forests provides a native API for visualizing trees. Here we select a tree from the forest for display.

with open("plot.html", "w") as f:
  f.write(tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=3))

from IPython.display import IFrame
IFrame(src='./plot.html', width=700, height=600)
Visualized Tree

There's also a lot of useful information in model.summary(), such as input features, feature importance, node information, etc. (limited space, not elaborating one by one). During the training process, the accuracy and loss can also be visualized:

import matplotlib.pyplot as plt

logs = model.make_inspector().training_logs()

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])
plt.xlabel("Number of trees")
plt.ylabel("Accuracy (out-of-bag)")

plt.subplot(1, 2, 2)
plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])
plt.xlabel("Number of trees")
plt.ylabel("Logloss (out-of-bag)")

plt.show()
Training Logs

6 Conclusion

TensorFlow Decision Forests strengthens the TensorFlow ecosystem and provides new ideas for data scientists to model tabular data. It is still in its early stages (v0.2.3), but it already has many available high-quality APIs. More functionalities can be found in the documentations. Hope this sharing will help you in your work. Welcome to leave a comment for discussion!