Adjustments are only made to nodes on an as-needed basis (when there’s non-zero residuals).
When adjustments are needed, shortcut connections apply the identity function to pass information to subsequent layers.
This shortens the neural network when possible and allows resnets to have deep architectures and behave more like shallow neural networks.
The 34 in resnet34 just refers to the number of layers.
Anand Saha gives a great more in-depth explanation here.
Finding a learning rateI’m going to find a learning rate for gradient descent to make sure that my neural network converges reasonably quickly without missing the optimal error.
For a refresher on the learning rate, check out Jeremy Jordan’s post on choosing a learning rate.
The learning rate finder suggests a learning rate of 5.
With this, we can train the model.
TrainingResults of training on validation setI ran my model for 20 epochs.
What’s cool about this fitting method is that the learning rate decreases with each epoch, allowing us to get closer and closer to the optimum.
6%, the validation error looks super good… let’s see how it performs on the test data though.
First, we can take a look at which images were most incorrectly classified.
Visualizing most incorrect imagesThe images here that the recycler performed poorly on were actually degraded.
It looks the photos received too much exposure or something so this actually isn’t a fault with the model!This model often confused plastic for glass and confused metal for glass.
The list of most confused images is below.
Make new predictions on test dataTo see how this mode really performs, we need to make predictions on test data.
First, I’ll make predictions on the test data using the learner.
predict() only predicts on a single image, while learner.
get_preds() predicts on a set of images.
I highly recommend reading the documentation to learn more about predict() and get_preds().
The ds_type argument in get_preds(ds_type) takes a DataSet argument.
Example values are DataSet.
Valid, and DataSet.
I mention this because I made the mistake of passing in actual data (learn.
test_ds) which gave me the wrong output and took embarrassingly long to debug.
Don’t make this mistake!.Don’t pass in data — pass in the dataset type!These are the predicted probabilities for each image.
This tensor has 365 rows — one for each image — and 6 columns — one for each material category.
Now I’m going to convert the probabilities in the tensor above to a vector of predicted class names.
These are the predicted labels of all the images!.Let’s check if the first image is actually glass.
It is!Next, I’ll get the actual labels from the test dataset.
It looks the first five predictions match up!How does this model perform overall?.We can use a confusion matrix to find out.
Test confusion matrixConfusion matrix arrayI’m going to make this matrix a little bit prettier:Again, the model seems to have confused metal for glass and plastic for glass.
With more time, I’m sure further investigation could help reduce these mistakes.
I ended up achieving an accuracy of 92.
1% on the test data which is pretty great — the original creators of the TrashNet dataset achieved a test accuracy of 63% with a support vector machine on a 70–30 test-train split (they trained a neural network as well for a test accuracy of 27%).
Next stepsIf I had more time, I’d go back and reduce classification error for glass in particular.
I’d also delete photos from the dataset that are overexposed since those images are just bad data.
This was just a quick and dirty mini-project to show that it’s pretty quick to train an image classification model, but it pretty amazing how quickly you can create a state-of-the-art model by using the fastai library.
If you have an application you’re interested in but don’t think you have the machine learning chops, this should be encouraging for you.
Here’s the Github repo for this project.
Thanks to James Dellinger for this blog post about classifying bluejays.
For more information about recycling, check out this FiveThirtyEight post.