A different way to deploy a Python model over Spark

A different way to deploy a Python model over SparkSeparate the prediction method from the rest of the Python class and then implement in ScalaSchaun WheelerBlockedUnblockFollowFollowingMay 6Instead of using the whole thing, just take the pieces you need.

A while ago, I wrote a post about how to deploy a Python model over Spark.

The approach was roughly as follows:Train the model in Python on a sample of the total data.

Collect the test data into groups of arbitrary size — something around 500,000 records seemed to work well for me.

Broadcast the trained model and then use a User Defined Function to call the predict method of the model on each group of records as a whole, rather than on each individual record (which is what Spark will do if you call the UDF on an ungrouped DataFrame).

The method takes advantage of the numpy-enabled optimization that backs scikit-learn and reduces the number of times you have to go through the expensive process of serializing and de-serializing the model object.

I’ve recently adopted a different way of deploying a Python model over Spark that doesn’t require the grouping and then exploding of lots of rows of data.

I still train the model in Python on a sample of the total data, but then I store everything I need to call the predict method in a JSON file, and then that file can be called into a Scala function that can implement the predict method.

For example, take a scikit-learn RandomForestRegressor.

The predict method is the average of the results of the predict methods of a bunch of individual decision trees, but the implementation of the method itself doesn’t use a tree structure at all.

It encodes everything into a series of lists.

The following code creates a pure Python function that will exactly reproduce the predictions of a trained RandomForestRegressor:Look at the tree_template string.

We start out with five lists — each as long as there are nodes in the tree.

The features list has as many unique values as there are features upon which the model was trained.

We start at the first value of that list — index zero.

If the value at index zero is, say, 3, then we pull out the value of the third feature upon which the model was trained.

We then pull out the index zero value from the thresholds list.

If the value of the selected feature is less than or equal to the value of the corresponding threshold, then we look at the zero-index value of the children_left list.

Otherwise, we look at the zero-index value of the children_right list.

Either way, that value is our new index, and then we start the process over again.

We keep doing this until the value from the children list is the placeholder for “you’ve reached the end of the tree”.

In scikit-learn, the default value for this placeholder is -2.

At that point, whatever index you’re currently on, you look up the value at that index from the values list.

That’s your prediction.

So, yes, it’s a lot of data — it common for a decision tree regressor with a couple dozen features to have around 100,000 nodes.

But the logic of navigating the tree to get the prediction is incredibly simple.

So all you have to do is create a function that contains the lists for each tree, along with the logic for jumping from index to index.

Then take a set of features, get the predictions from each individual tree, and average.

That’s your random forest.

It’s easy to dump all of this information to JSON.

The following function does exactly that.

All it requires is a trained RandomForestRegressor object, the list of features in the order they were used in the training, and a filepath to which to dump the JSON file.

The output for a single tree will look something like this:This next snippet of code comes from my colleague Sam Hendley, who has forgotten more about Scala than I will ever know.

It reads in the tree information from the JSON file and implements the prediction logic for each tree and then the averaging for the forest as a whole.

Implementing the predictions in Scala avoids the process of moving serializing and de-serializing the Python representations of the function — everything can be done directly on the JVM.

And a Random Forest is one of the most complicated use cases here.

In the case of any sort of model that produces coefficients, the transforming the predict function to JSON and Scala is even easier.


. More details

Leave a Reply