Transfer Learning with PyTorchGilbert AdjeiBlockedUnblockFollowFollowingFeb 20When we learn something in our daily lives, similar things become very easy to learn because—we use our existing knowledge on the new task.
Example: When I learned how to ride a bicycle, it became very easy to learn how to ride a motorcycle because in riding the bicycle, I knew I had to sit and maintain balance, hold the handles firmly, and peddle to accelerate.
In using my prior knowledge, I could easily adapt to a motorcycle’s design and how it could be driven.
And that is the general idea behind transfer learning.
The objectives for this blog post are to:Understand the meaning of transfer learningImportance of transfer learningHands on implementation of transfer learning using PyTorchLet us begin by defining what transfer learning is all about.
What Is Transfer Learning?Transfer learning is a machine learning technique where knowledge gained during training in one type of problem is used to train in other, similar types of problem.
Thus, instead of building your own deep neural networks, which can be a cumbersome task to say the least, you can find an existing neural network that accomplishes the same task you’re trying to solve and reuse the layers that are essential for pattern detection, while also making changes to the fully connected layer to suit your problem.
In practice, it’s rare to have a sufficiently big dataset for a convolutional network; instead it is very common to pre-train a ConvNet on a large dataset (e.
ImageNet, which contains 1.
2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task at hand.
Getting StartedIn this tutorial we’ll be using a pre-trained network to build an image classifier for malaria detection.
The data has two classes we’re going to classify.
Either the image is Parasitized or Uninfected.
The image dataset we are going to use can be downloaded here.
The pre-trained network was trained on ImageNet, which contains 1.
2 million images with 1000 categories), which is available on torch vision torchvision.
models, which has 6 different architectures we can use.
models has a breakdown of the performance of the model as well as the number of layers that can be used (indicated by the numbers attached to the models).
The larger the number, the better the performance; however, this comes with a computational cost and slows the training process.
All these networks use convolutional layers, which exploit patterns and regularities in images.
Training Our ModelIf you don’t have GPUs like myself, you’re still in luck.
You can use Google’s free GPUs offered through Google Colab to train your model like I did.
There’s an excellent tutorial on setting up Colab here.
Now assuming you have set up your GPU machine or Google Colab, let’s get our hands dirty.
We import all necessary packages and libraries we are going to need for this malaria detection application.
Visualizing some of the data we have, we specify the path to the directory containing our image datasets.
Note that this may be different from yours so check your path and specify accordingly.
Let’s first view how a Parasitized image would look like.
Defining Transformations and Loading in DataTransformation is a process by which one figure, expression, or function is converted into another.
Now let’s define a few transformations for the training, testing, & validation data.
We should keep in mind that in some categories, there could be a limited number of images.
Thus in order to increase the number of images recognized by the network, we perform what is called data augmentation.
During training, we randomly crop, resize, and rotate the images so that for each epoch (one pass through the dataset), the network sees different variations of the same image.
This will eventually lead to better accuracy on your validation tests.
Note that with validation data, we don’t perform data augmentation but just do a resize & centre crop.
This is because we want our validation data to be similar or look like your eventual input data (out of sample data/test data).
With the transformations defined, we have to load in the dataset and easiest way to load image data is by using the dataset.
ImageFolder from torchvision which accepts as input the path to the images and transforms.
With the imageFolder loaded, let’s split the data into a 20% validation set and 10% test set; then pass it to DataLoader, which takes a dataset like you’d get from ImageFolder and returns batches of images and their corresponding labels (shuffling can be set to true to introduce variation during the epochs).
Steps for Training the modelThe steps we are going to use for our pre-trained model is:Loading in the pre-trained modelFreezing the convolutional layersReplacing the fully connected layers with a custom classifierTraining the custom classifier for the specific taskWe can now load in one of the pre-trained models, here I’m going to use the densenet121, which has high accuracy on the ImageNet dataset.
This is telling us there are 121 different layers.
Loading the Pre-trained ModelWith our model built, we need to train the classifier.
However, now we’re using a really deep neural network.
If you try to train this on a CPU like normal, it will take a long, long time.
Instead, we’re going to use the GPU to do the calculations.
The linear algebra computations are done in parallel on the GPU, leading to 100x increased training speeds.
It’s also possible to train on multiple GPUs, further decreasing training time.
PyTorch, along with pretty much every other deep learning framework, uses CUDA to efficiently compute the forward and backwards passes on the GPU.
In PyTorch, you move your model parameters and other tensors to the GPU memory using model.
You can move them back from the GPU with model.
cpu(), which you'll commonly do when you need to operate on the network output outside of PyTorch.
Freezing the convolutional layers & replacing the fully connected layers with a custom classifierFreezing the model parameters essentially allows us to keep the pre-trained model’s weights for early convolutional layers — whose purpose is for feature extraction.
We then define our fully-connected network, which will have as input neurons, 1024 (this depends on the pre-trained model’s input neurons) and a custom hidden layer.
We also define the activation function to be used and a dropout that will aid in avoiding overfitting by randomly switching off neurons in a layer to force information to be shared among the remaining nodes.
After we have defined our custom fully-connected network, we attach it to the pre-trained model’s fully-connected network to suit the problem we want to solve.
We finally define the loss function, the optimizer, and prepare the model for training by moving it to the GPUs.
Training the custom classifier for the specific taskDuring the training, we iterate through the DataLoader for each epoch.
For each batch, the loss is calculated using the criterion function.
The gradients of the loss with respect to the model parameters is calculated using the loss.
zero_grad() is responsible for clearing any accumulated gradients since we would be calculating gradients over and over again.
step() updates the model parameters using Stochastic Gradient Descent with momentum (Adam).
To prevent overfitting we use a powerful technique called early stopping.
The idea behind is simple—to stop training when the performance on the validation dataset begin to degrade.
After patiently waiting for the training process to finish and saving checkpoints of best model parameters, let’s load the checkpoint and test the performance of the model on the unseen data (test data).
Loading the saved model from diskmodel.
pt'))Testing loaded model on unseen data.
We have an accuracy of 90% on unseen data which is very impressive on a first attempt.
Now that we have confidence in our model, it’s time to make some predictions and visualize the results.
Predictions made by our modelIf you’ve made it this far, clap for yourself.
You’ve been able to build a malaria classifier application that could (with some more work, of course) not only save lives but help speed up the process of laboratory technicians and health professionals.
Go ahead and play around with the code, and see how you can improve on the test accuracy.
Try making changes to the optimizer, the pre-trained model and the loss function.
You can add more transformations or even add more layers to the fully connected.
I believe you can surpass the 90% benchmark.
Cheers!Editor’s Note: Ready to dive into some code?.Check out Fritz on GitHub.
You’ll find open source, mobile-friendly implementations of the popular machine and deep learning models along with training scripts, project templates, and tools for building your own ML-powered iOS and Android apps.
Join us on Slack for help with technical problems, to share what you’re working on, or just chat with us about mobile development and machine learning.
And follow us on Twitter and LinkedIn for the all the latest content, news, and more from the mobile machine learning world.