There are a couple of problems:Yellowbrick relies on some Scikit-Learn model semantics that KerasClassifier doesn’t provide.
KerasClassifier doesn’t provide alternative training methods, like fit_generator or fit_dataframe, which are hard to give up.
Luckily we can fix this by writing our own subclass, KerasBatchClassifier:KerasBatchClassifier fixes issue #1 by setting _estimator_type and classes_ properties and adding a diamond dependency on BaseEstimator, and it fixes issue #2 by making fit use fit_generator internally.
The inability to use fit_generator with KerasClassifier is a well-known pain point.
This code is an adaptation of (and refinement on) existing solutions from other users — particularly this one.
With this shim in place, we can move on to the fun part: applying yellowbrick visualizations to our neural network models!Evaluating classificationWe’ll start off by checking out yellowbrick classification evaluation plots.
For the purposes of this demo, I trained a very basic CNN trained on a subset of images of fruits from the Google Open Images dataset.
You can get that dataset here, and you can follow along with the code here.
The simplest of the classification evaluation plots is ClassPredictionError, which provides a stacked bar chart of per-class model predictions:With this chart in hand we can quickly assess which classes are popular classification targets and which ones are not, and what the most common misclassifications are within a single class.
However, I personally much prefer the ConfusionMatrix:This visualization lets us quickly zero in on important properties of the model:Which classes are most accurately predictedWhich classes are least accurately predictedWhich misclassifications are most commonFinally there is ClassificationReport.
This provides four essential classification model metrics — precision, recall, f1 score, and support — in an easily digestible visual format:Evaluating regressionYellowbrick also packs tools for evaluating regression models.
For this demo I trained a simple feedforward neural network that attempts to predict price-per-day for various homes from the Boston AirBnBs dataset on Kaggle.
You can see the code for yourself here.
The basic regression analysis plot is PredictionError, which charts predicted values from the model against ground truth values from the dataset:This chart is useful for identifying patterns in the data (and seeing how well the model adapts to them).
For example, by examining the y values in this plot, we can see that users have a strong preference for rentals at multiples of 100.
Then there is ResidualsPlot:A model residual is the distance between the actual and predicted value of a single record.
By putting all of our residuals on a single plot we can assess whether or not our model performs better on some sections of the data then on others.
In this case we see that our residuals are larger in magnitude when the predicted value is larger as well — a sign that the model is performing better on smaller values in the dataset than on larger ones.
ConclusionWhen working with neural networks, having a library of advanced visualizations you can use to dig into specific properties of your model is essential to your ability to iterate on your model builds quickly and effectively.
In this post we saw how we can leverage yellowbrick with keras to build some of these kinds of graphs.
Hopefully having read this post, you’re now ready to replace one or two hacky matplotlib code gists lying around with well-maintained, well-architected visualization recipes from this nifty new library.
Interested in learning more about the Python data visualization ecosystem?.I recommend watching Jake Vanderplas’s highly entertaining PyCon 2017 talk “The Python Visualization Landscape”.
Interested in learning more about Yellowbrick?.In additional to the model evaluation plots showcased here, the library also provides plots for evaluating unsupervised clustering, modeling text, and modeling data features.
Take a look.
Unfortunately there are many more advanced model evaluation plots in the library that didn’t appear here because they aren’t as easily emulated.
But, watch this space — I suspect using using yellowbrick visualizations with keras models is going to get easier in the future!.. More details