They show that we can get the same performance (or even better) on a specific task by distilling the information from BERT into a much smaller BiLSTM neural network.
You can see their results in the table below.
The best performance was achieved using BiLSTM-Soft, which means “soft predictions”, i.
e, training on the raw logits and not the “hard” predictions.
The datasets are: SST-2 is Stanford Sentiment Treebank 2, QQP is Quora Question Pairs, MNLI is The Multi-genre Natural Language Inference.
In this post, I want to distill BERT into a much simpler Logistic Regression model.
Assuming you have a relatively small labeled dataset and a much bigger non-labeled dataset, the general framework for building the model is:Create some baseline on the labeled datasetBuild a big model by fine-tuning BERT on the labeled setIf you got good results (better than your baseline), calculate the raw logits for your unlabeled set using the big modelTrain a much smaller model (Logistic Regression) on the now pseudo-labeled setIf you got good results, deploy the small model anywhere!If you’re interested in a more basic tutorial on fine-tuning BERT, please checkout out my previous post:BERT to the rescue!A step-by-step tutorial on simple text classification using BERTtowardsdatascience.
comI want to solve the same task (IMDB Reviews Sentiment Classification) but with Logistic Regression.
You can find all the code in this notebook.
As before, I’ll use torchnlp to load the data and the excellent PyTorch-Pretrained-BERT to build the model.
There are 25,000 reviews in the train set, we’ll use only 1000 as a labeled set and another 5,000 as an unlabeled set (I also choose only 1000 reviews from the test set to speed things up):train_data_full, test_data_full = imdb_dataset(train=True, test=True)rn.
shuffle(test_data_full)train_data = train_data_full[:1000]test_data = test_data_full[:1000]The first thing we do is create a baseline using logistic regression:We get not so great results: precision recall f1-score supportneg 0.
80 522pos 0.
78 478accuracy 0.
79 1000Next step, is to fine-tune BERT, I will skip the code here, you can see it the notebook or a more detailed tutorial in my previous post.
The result is a trained model called BertBinaryClassifier which uses BERT and then a linear layer to provide the pos/neg classification.
The performance of this model is: precision recall f1-score supportneg 0.
89 522pos 0.
88 478accuracy 0.
89 1000Much much better! As I said — Magic :)Now to the interesting part, we use the unlabeled set and “label” it using our fine-tuned BERT model:We get: precision recall f1-score supportneg 0.
88 522pos 0.
86 478accuracy 0.
87 1000Not as great as the original fine-tuned BERT, but it’s much better than the baseline! Now we are ready to deploy this small model to production and enjoy both good quality and inference speed.
Here’s another reason to 5 Reasons “Logistic Regression” should be the first thing you learn when becoming a Data Scientist :).